Module hummingbird.ml.operator_converters.onnx_operator
Converters for ONNX operators.
Expand source code
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
"""
Converters for ONNX operators.
"""
import numpy as np
import torch
from onnxconverter_common.registration import register_converter
from . import constants
from ._base_operator import BaseOperator
class Cast(BaseOperator, torch.nn.Module):
def __init__(self, to_type):
super(Cast, self).__init__()
assert to_type is not None
self.to_type = to_type
def forward(self, x):
if self.to_type == 7: # Cast to long
return x.long()
def convert_onnx_cast(operator, device=None, extra_config={}):
"""
Converter for `ai.onnx.Cast`.
Args:
operator: An operator wrapping a `ai.onnx.Cast` model
device: String defining the type of device the converted operator should be run on
extra_config: Extra configuration used to select the best conversion strategy
Returns:
A PyTorch model
"""
assert operator is not None
to_type = None
for attr in operator.raw_operator.origin.attribute:
if attr.name == "to":
to_type = attr.i
# Generate the model.
return Cast(to_type)
register_converter("ONNXMLCast", convert_onnx_cast)
Functions
def convert_onnx_cast(operator, device=None, extra_config={})
-
Converter for
ai.onnx.Cast
.Args
operator
- An operator wrapping a
ai.onnx.Cast
model device
- String defining the type of device the converted operator should be run on
extra_config
- Extra configuration used to select the best conversion strategy
Returns
A PyTorch model
Expand source code
def convert_onnx_cast(operator, device=None, extra_config={}): """ Converter for `ai.onnx.Cast`. Args: operator: An operator wrapping a `ai.onnx.Cast` model device: String defining the type of device the converted operator should be run on extra_config: Extra configuration used to select the best conversion strategy Returns: A PyTorch model """ assert operator is not None to_type = None for attr in operator.raw_operator.origin.attribute: if attr.name == "to": to_type = attr.i # Generate the model. return Cast(to_type)
Classes
class Cast (to_type)
-
Abstract class defining the basic structure for operator implementations in Hummingbird.
Expand source code
class Cast(BaseOperator, torch.nn.Module): def __init__(self, to_type): super(Cast, self).__init__() assert to_type is not None self.to_type = to_type def forward(self, x): if self.to_type == 7: # Cast to long return x.long()
Ancestors
- hummingbird.ml.operator_converters._base_operator.BaseOperator
- abc.ABC
- torch.nn.modules.module.Module
Methods
def forward(self, x)
-
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the :class:
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.Expand source code
def forward(self, x): if self.to_type == 7: # Cast to long return x.long()