Module hummingbird.ml.operator_converters.sklearn.skl_one_hot_encoder
Converter for scikit-learn one hot encoder.
Expand source code Browse git
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
"""
Converter for scikit-learn one hot encoder.
"""
import numpy as np
from onnxconverter_common.registration import register_converter
from .._one_hot_encoder_implementations import OneHotEncoderString, OneHotEncoder
def convert_sklearn_one_hot_encoder(operator, device, extra_config):
"""
Converter for `sklearn.preprocessing.OneHotEncoder`
Args:
operator: An operator wrapping a `sklearn.preprocessing.OneHotEncoder` model
device: String defining the type of device the converted operator should be run on
extra_config: Extra configuration used to select the best conversion strategy
Returns:
A PyTorch model
"""
if all([np.array(c).dtype == object for c in operator.raw_operator.categories_]):
categories = [[str(x) for x in c.tolist()] for c in operator.raw_operator.categories_]
return OneHotEncoderString(categories, device)
else:
return OneHotEncoder(operator.raw_operator.categories_, device)
register_converter("SklearnOneHotEncoder", convert_sklearn_one_hot_encoder)
Functions
def convert_sklearn_one_hot_encoder(operator, device, extra_config)
-
Converter for
sklearn.preprocessing.OneHotEncoder
Args
operator
- An operator wrapping a
sklearn.preprocessing.OneHotEncoder
model device
- String defining the type of device the converted operator should be run on
extra_config
- Extra configuration used to select the best conversion strategy
Returns
A PyTorch model
Expand source code Browse git
def convert_sklearn_one_hot_encoder(operator, device, extra_config): """ Converter for `sklearn.preprocessing.OneHotEncoder` Args: operator: An operator wrapping a `sklearn.preprocessing.OneHotEncoder` model device: String defining the type of device the converted operator should be run on extra_config: Extra configuration used to select the best conversion strategy Returns: A PyTorch model """ if all([np.array(c).dtype == object for c in operator.raw_operator.categories_]): categories = [[str(x) for x in c.tolist()] for c in operator.raw_operator.categories_] return OneHotEncoderString(categories, device) else: return OneHotEncoder(operator.raw_operator.categories_, device)