Module hummingbird.ml.operator_converters.onnx.onnx_operator
Converters for ONNX operators.
Expand source code Browse git
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
"""
Converters for ONNX operators.
"""
import numpy as np
from onnxconverter_common.registration import register_converter
import torch
from .. import constants
from .._base_operator import BaseOperator
class Cast(BaseOperator, torch.nn.Module):
def __init__(self, to_type):
super(Cast, self).__init__()
assert to_type is not None
self.to_type = to_type
def forward(self, x):
if self.to_type == 1: # Cast to float
return x.float()
if self.to_type == 7: # Cast to long
return x.long()
class Concat(BaseOperator, torch.nn.Module):
def __init__(self):
super(Concat, self).__init__()
def forward(self, *x):
if len(x[0].shape) > 1:
return torch.cat(x, dim=1)
else:
return torch.stack(x, dim=1)
class Reshape(BaseOperator, torch.nn.Module):
def __init__(self, shape):
super(Reshape, self).__init__()
self.shape = shape
def forward(self, x):
return torch.reshape(x, self.shape)
def convert_onnx_cast(operator, device=None, extra_config={}):
"""
Converter for `ai.onnx.Cast`.
Args:
operator: An operator wrapping a `ai.onnx.Cast` model
device: String defining the type of device the converted operator should be run on
extra_config: Extra configuration used to select the best conversion strategy
Returns:
A PyTorch model
"""
assert operator is not None
to_type = None
for attr in operator.raw_operator.origin.attribute:
if attr.name == "to":
to_type = attr.i
# Generate the model.
return Cast(to_type)
def convert_onnx_concat(operator, device=None, extra_config={}):
"""
Converter for `ai.onnx.Concat`.
Args:
operator: An operator wrapping a `ai.onnx.Concat` model
device: String defining the type of device the converted operator should be run on
extra_config: Extra configuration used to select the best conversion strategy
Returns:
A PyTorch model
"""
assert operator is not None
# Generate the model.
return Concat()
def convert_onnx_reshape(operator, device=None, extra_config={}):
"""
Converter for `ai.onnx.Reshape`.
Args:
operator: An operator wrapping a `ai.onnx.Reshape` model
device: String defining the type of device the converted operator should be run on
extra_config: Extra configuration used to select the best conversion strategy
Returns:
A PyTorch model
"""
assert operator is not None
shape = []
initializers = extra_config[constants.ONNX_INITIALIZERS]
shape = list(initializers[operator.raw_operator.origin.input[1]].int64_data)
# Generate the model.
return Reshape(shape)
register_converter("ONNXMLCast", convert_onnx_cast)
register_converter("ONNXMLConcat", convert_onnx_concat)
register_converter("ONNXMLReshape", convert_onnx_reshape)
Functions
def convert_onnx_cast(operator, device=None, extra_config={})
-
Converter for
ai.onnx.Cast
.Args
operator
- An operator wrapping a
ai.onnx.Cast
model device
- String defining the type of device the converted operator should be run on
extra_config
- Extra configuration used to select the best conversion strategy
Returns
A PyTorch model
Expand source code Browse git
def convert_onnx_cast(operator, device=None, extra_config={}): """ Converter for `ai.onnx.Cast`. Args: operator: An operator wrapping a `ai.onnx.Cast` model device: String defining the type of device the converted operator should be run on extra_config: Extra configuration used to select the best conversion strategy Returns: A PyTorch model """ assert operator is not None to_type = None for attr in operator.raw_operator.origin.attribute: if attr.name == "to": to_type = attr.i # Generate the model. return Cast(to_type)
def convert_onnx_concat(operator, device=None, extra_config={})
-
Converter for
ai.onnx.Concat
.Args
operator
- An operator wrapping a
ai.onnx.Concat
model device
- String defining the type of device the converted operator should be run on
extra_config
- Extra configuration used to select the best conversion strategy
Returns
A PyTorch model
Expand source code Browse git
def convert_onnx_concat(operator, device=None, extra_config={}): """ Converter for `ai.onnx.Concat`. Args: operator: An operator wrapping a `ai.onnx.Concat` model device: String defining the type of device the converted operator should be run on extra_config: Extra configuration used to select the best conversion strategy Returns: A PyTorch model """ assert operator is not None # Generate the model. return Concat()
def convert_onnx_reshape(operator, device=None, extra_config={})
-
Converter for
ai.onnx.Reshape
.Args
operator
- An operator wrapping a
ai.onnx.Reshape
model device
- String defining the type of device the converted operator should be run on
extra_config
- Extra configuration used to select the best conversion strategy
Returns
A PyTorch model
Expand source code Browse git
def convert_onnx_reshape(operator, device=None, extra_config={}): """ Converter for `ai.onnx.Reshape`. Args: operator: An operator wrapping a `ai.onnx.Reshape` model device: String defining the type of device the converted operator should be run on extra_config: Extra configuration used to select the best conversion strategy Returns: A PyTorch model """ assert operator is not None shape = [] initializers = extra_config[constants.ONNX_INITIALIZERS] shape = list(initializers[operator.raw_operator.origin.input[1]].int64_data) # Generate the model. return Reshape(shape)
Classes
class Cast (to_type)
-
Abstract class defining the basic structure for operator implementations in Hummingbird.
Expand source code Browse git
class Cast(BaseOperator, torch.nn.Module): def __init__(self, to_type): super(Cast, self).__init__() assert to_type is not None self.to_type = to_type def forward(self, x): if self.to_type == 1: # Cast to float return x.float() if self.to_type == 7: # Cast to long return x.long()
Ancestors
- hummingbird.ml.operator_converters._base_operator.BaseOperator
- abc.ABC
- torch.nn.modules.module.Module
Methods
def forward(self, x) -> Callable[..., Any]
-
Expand source code Browse git
def forward(self, x): if self.to_type == 1: # Cast to float return x.float() if self.to_type == 7: # Cast to long return x.long()
class Concat
-
Abstract class defining the basic structure for operator implementations in Hummingbird.
Expand source code Browse git
class Concat(BaseOperator, torch.nn.Module): def __init__(self): super(Concat, self).__init__() def forward(self, *x): if len(x[0].shape) > 1: return torch.cat(x, dim=1) else: return torch.stack(x, dim=1)
Ancestors
- hummingbird.ml.operator_converters._base_operator.BaseOperator
- abc.ABC
- torch.nn.modules.module.Module
Methods
def forward(self, *x) -> Callable[..., Any]
-
Expand source code Browse git
def forward(self, *x): if len(x[0].shape) > 1: return torch.cat(x, dim=1) else: return torch.stack(x, dim=1)
class Reshape (shape)
-
Abstract class defining the basic structure for operator implementations in Hummingbird.
Expand source code Browse git
class Reshape(BaseOperator, torch.nn.Module): def __init__(self, shape): super(Reshape, self).__init__() self.shape = shape def forward(self, x): return torch.reshape(x, self.shape)
Ancestors
- hummingbird.ml.operator_converters._base_operator.BaseOperator
- abc.ABC
- torch.nn.modules.module.Module
Methods
def forward(self, x) -> Callable[..., Any]
-
Expand source code Browse git
def forward(self, x): return torch.reshape(x, self.shape)