Module hummingbird.ml.operator_converters.skl_normalizer

Expand source code
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import torch
from onnxconverter_common.registration import register_converter

from ._base_operator import BaseOperator


class Normalizer(BaseOperator, torch.nn.Module):
    def __init__(self, norm, device):
        super(Normalizer, self).__init__()
        self.norm = norm
        self.transformer = True

    def forward(self, x):
        if self.norm == "l1":
            return x / torch.abs(x).sum(1, keepdim=True)
        elif self.norm == "l2":
            return x / torch.pow(torch.pow(x, 2).sum(1, keepdim=True), 0.5)
        elif self.norm == "max":
            return x / torch.max(torch.abs(x), dim=1, keepdim=True)[0]
        else:
            raise RuntimeError("Unsupported norm: {0}".format(self.norm))


def convert_sklearn_normalizer(operator, device, extra_config):
    """
    Converter for `sklearn.preprocessing.Normalizer`

    Args:
        operator: An operator wrapping a `sklearn.preprocessing.Normalizer` model
        device: String defining the type of device the converted operator should be run on
        extra_config: Extra configuration used to select the best conversion strategy

    Returns:
        A PyTorch model
    """

    return Normalizer(operator.raw_operator.norm, device)


register_converter("SklearnNormalizer", convert_sklearn_normalizer)

Functions

def convert_sklearn_normalizer(operator, device, extra_config)

Converter for sklearn.preprocessing.Normalizer

Args

operator
An operator wrapping a sklearn.preprocessing.Normalizer model
device
String defining the type of device the converted operator should be run on
extra_config
Extra configuration used to select the best conversion strategy

Returns

A PyTorch model
 
Expand source code
def convert_sklearn_normalizer(operator, device, extra_config):
    """
    Converter for `sklearn.preprocessing.Normalizer`

    Args:
        operator: An operator wrapping a `sklearn.preprocessing.Normalizer` model
        device: String defining the type of device the converted operator should be run on
        extra_config: Extra configuration used to select the best conversion strategy

    Returns:
        A PyTorch model
    """

    return Normalizer(operator.raw_operator.norm, device)

Classes

class Normalizer (norm, device)

Abstract class defining the basic structure for operator implementations in Hummingbird.

Expand source code
class Normalizer(BaseOperator, torch.nn.Module):
    def __init__(self, norm, device):
        super(Normalizer, self).__init__()
        self.norm = norm
        self.transformer = True

    def forward(self, x):
        if self.norm == "l1":
            return x / torch.abs(x).sum(1, keepdim=True)
        elif self.norm == "l2":
            return x / torch.pow(torch.pow(x, 2).sum(1, keepdim=True), 0.5)
        elif self.norm == "max":
            return x / torch.max(torch.abs(x), dim=1, keepdim=True)[0]
        else:
            raise RuntimeError("Unsupported norm: {0}".format(self.norm))

Ancestors

  • hummingbird.ml.operator_converters._base_operator.BaseOperator
  • abc.ABC
  • torch.nn.modules.module.Module

Methods

def forward(self, x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Expand source code
def forward(self, x):
    if self.norm == "l1":
        return x / torch.abs(x).sum(1, keepdim=True)
    elif self.norm == "l2":
        return x / torch.pow(torch.pow(x, 2).sum(1, keepdim=True), 0.5)
    elif self.norm == "max":
        return x / torch.max(torch.abs(x), dim=1, keepdim=True)[0]
    else:
        raise RuntimeError("Unsupported norm: {0}".format(self.norm))