Module hummingbird.ml.operator_converters.onnx_operator

Converters for ONNX operators.

Expand source code
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------

"""
Converters for ONNX operators.
"""

import numpy as np
import torch

from onnxconverter_common.registration import register_converter

from . import constants
from ._base_operator import BaseOperator


class Cast(BaseOperator, torch.nn.Module):
    def __init__(self, to_type):
        super(Cast, self).__init__()

        assert to_type is not None

        self.to_type = to_type

    def forward(self, x):
        if self.to_type == 7:  # Cast to long
            return x.long()


def convert_onnx_cast(operator, device=None, extra_config={}):
    """
    Converter for `ai.onnx.Cast`.

    Args:
        operator: An operator wrapping a `ai.onnx.Cast` model
        device: String defining the type of device the converted operator should be run on
        extra_config: Extra configuration used to select the best conversion strategy

    Returns:
        A PyTorch model
    """
    assert operator is not None

    to_type = None

    for attr in operator.raw_operator.origin.attribute:
        if attr.name == "to":
            to_type = attr.i

    # Generate the model.
    return Cast(to_type)


register_converter("ONNXMLCast", convert_onnx_cast)

Functions

def convert_onnx_cast(operator, device=None, extra_config={})

Converter for ai.onnx.Cast.

Args

operator
An operator wrapping a ai.onnx.Cast model
device
String defining the type of device the converted operator should be run on
extra_config
Extra configuration used to select the best conversion strategy

Returns

A PyTorch model
 
Expand source code
def convert_onnx_cast(operator, device=None, extra_config={}):
    """
    Converter for `ai.onnx.Cast`.

    Args:
        operator: An operator wrapping a `ai.onnx.Cast` model
        device: String defining the type of device the converted operator should be run on
        extra_config: Extra configuration used to select the best conversion strategy

    Returns:
        A PyTorch model
    """
    assert operator is not None

    to_type = None

    for attr in operator.raw_operator.origin.attribute:
        if attr.name == "to":
            to_type = attr.i

    # Generate the model.
    return Cast(to_type)

Classes

class Cast (to_type)

Abstract class defining the basic structure for operator implementations in Hummingbird.

Expand source code
class Cast(BaseOperator, torch.nn.Module):
    def __init__(self, to_type):
        super(Cast, self).__init__()

        assert to_type is not None

        self.to_type = to_type

    def forward(self, x):
        if self.to_type == 7:  # Cast to long
            return x.long()

Ancestors

  • hummingbird.ml.operator_converters._base_operator.BaseOperator
  • abc.ABC
  • torch.nn.modules.module.Module

Methods

def forward(self, x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Expand source code
def forward(self, x):
    if self.to_type == 7:  # Cast to long
        return x.long()