Module hummingbird.ml.operator_converters.skl_normalizer
Expand source code
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import torch
from onnxconverter_common.registration import register_converter
from ._base_operator import BaseOperator
class Normalizer(BaseOperator, torch.nn.Module):
def __init__(self, norm, device):
super(Normalizer, self).__init__()
self.norm = norm
self.transformer = True
def forward(self, x):
if self.norm == "l1":
return x / torch.abs(x).sum(1, keepdim=True)
elif self.norm == "l2":
return x / torch.pow(torch.pow(x, 2).sum(1, keepdim=True), 0.5)
elif self.norm == "max":
return x / torch.max(torch.abs(x), dim=1, keepdim=True)[0]
else:
raise RuntimeError("Unsupported norm: {0}".format(self.norm))
def convert_sklearn_normalizer(operator, device, extra_config):
"""
Converter for `sklearn.preprocessing.Normalizer`
Args:
operator: An operator wrapping a `sklearn.preprocessing.Normalizer` model
device: String defining the type of device the converted operator should be run on
extra_config: Extra configuration used to select the best conversion strategy
Returns:
A PyTorch model
"""
return Normalizer(operator.raw_operator.norm, device)
register_converter("SklearnNormalizer", convert_sklearn_normalizer)
Functions
def convert_sklearn_normalizer(operator, device, extra_config)
-
Converter for
sklearn.preprocessing.Normalizer
Args
operator
- An operator wrapping a
sklearn.preprocessing.Normalizer
model device
- String defining the type of device the converted operator should be run on
extra_config
- Extra configuration used to select the best conversion strategy
Returns
A PyTorch model
Expand source code
def convert_sklearn_normalizer(operator, device, extra_config): """ Converter for `sklearn.preprocessing.Normalizer` Args: operator: An operator wrapping a `sklearn.preprocessing.Normalizer` model device: String defining the type of device the converted operator should be run on extra_config: Extra configuration used to select the best conversion strategy Returns: A PyTorch model """ return Normalizer(operator.raw_operator.norm, device)
Classes
class Normalizer (norm, device)
-
Abstract class defining the basic structure for operator implementations in Hummingbird.
Expand source code
class Normalizer(BaseOperator, torch.nn.Module): def __init__(self, norm, device): super(Normalizer, self).__init__() self.norm = norm self.transformer = True def forward(self, x): if self.norm == "l1": return x / torch.abs(x).sum(1, keepdim=True) elif self.norm == "l2": return x / torch.pow(torch.pow(x, 2).sum(1, keepdim=True), 0.5) elif self.norm == "max": return x / torch.max(torch.abs(x), dim=1, keepdim=True)[0] else: raise RuntimeError("Unsupported norm: {0}".format(self.norm))
Ancestors
- hummingbird.ml.operator_converters._base_operator.BaseOperator
- abc.ABC
- torch.nn.modules.module.Module
Methods
def forward(self, x)
-
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the :class:
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.Expand source code
def forward(self, x): if self.norm == "l1": return x / torch.abs(x).sum(1, keepdim=True) elif self.norm == "l2": return x / torch.pow(torch.pow(x, 2).sum(1, keepdim=True), 0.5) elif self.norm == "max": return x / torch.max(torch.abs(x), dim=1, keepdim=True)[0] else: raise RuntimeError("Unsupported norm: {0}".format(self.norm))