Module hummingbird.ml.operator_converters.onnx.one_hot_encoder
Converter for ONNX-ML One Hot Encoder.
Expand source code Browse git
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
"""
Converter for ONNX-ML One Hot Encoder.
"""
import numpy as np
from onnxconverter_common.registration import register_converter
from .._one_hot_encoder_implementations import OneHotEncoderString, OneHotEncoder
def convert_onnx_one_hot_encoder(operator, device=None, extra_config={}):
"""
Converter for `ai.onnx.ml.OneHotEncoder`
Args:
operator: An operator wrapping a `ai.onnx.ml.OneHotEncoder` model
device: String defining the type of device the converted operator should be run on
extra_config: Extra configuration used to select the best conversion strategy
Returns:
A PyTorch model
"""
categories = []
# is_strings = False
operator = operator.raw_operator
for attr in operator.origin.attribute:
if attr.name == "cats_int64s":
categories.append(np.array(attr.ints))
elif attr.name == "cats_strings":
raise NotImplementedError("OneHotEncoder does not yet support Strings (Issue #209)")
# categories.append([x.decode("UTF-8") for x in attr.strings])
# is_strings = True
if categories == []:
raise RuntimeError("Error parsing OneHotEncoder, no categories")
# if is_strings:
# return OneHotEncoderString(categories, device)
return OneHotEncoder(categories, device)
register_converter("ONNXMLOneHotEncoder", convert_onnx_one_hot_encoder)
Functions
def convert_onnx_one_hot_encoder(operator, device=None, extra_config={})
-
Converter for
ai.onnx.ml.OneHotEncoder
Args
operator
- An operator wrapping a
ai.onnx.ml.OneHotEncoder
model device
- String defining the type of device the converted operator should be run on
extra_config
- Extra configuration used to select the best conversion strategy
Returns
A PyTorch model
Expand source code Browse git
def convert_onnx_one_hot_encoder(operator, device=None, extra_config={}): """ Converter for `ai.onnx.ml.OneHotEncoder` Args: operator: An operator wrapping a `ai.onnx.ml.OneHotEncoder` model device: String defining the type of device the converted operator should be run on extra_config: Extra configuration used to select the best conversion strategy Returns: A PyTorch model """ categories = [] # is_strings = False operator = operator.raw_operator for attr in operator.origin.attribute: if attr.name == "cats_int64s": categories.append(np.array(attr.ints)) elif attr.name == "cats_strings": raise NotImplementedError("OneHotEncoder does not yet support Strings (Issue #209)") # categories.append([x.decode("UTF-8") for x in attr.strings]) # is_strings = True if categories == []: raise RuntimeError("Error parsing OneHotEncoder, no categories") # if is_strings: # return OneHotEncoderString(categories, device) return OneHotEncoder(categories, device)