Module hummingbird.ml.operator_converters.sklearn.binarizer

Converter for scikit-learn Binarizer.

Expand source code Browse git
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------

"""
Converter for scikit-learn Binarizer.
"""
from .._base_operator import BaseOperator
from onnxconverter_common.registration import register_converter
import torch


class Binarizer(BaseOperator, torch.nn.Module):
    """
    Class implementing Binarizer operators in PyTorch.
    """

    def __init__(self, threshold, device):
        super(Binarizer, self).__init__()
        self.transformer = True
        self.threshold = torch.nn.Parameter(torch.FloatTensor([threshold]), requires_grad=False)

    def forward(self, x):
        return torch.gt(x, self.threshold).float()


def convert_sklearn_binarizer(operator, device, extra_config):
    """
    Converter for `sklearn.preprocessing.Binarizer`

    Args:
        operator: An operator wrapping a `sklearn.preprocessing.Binarizer` model
        device: String defining the type of device the converted operator should be run on
        extra_config: Extra configuration used to select the best conversion strategy

    Returns:
        A PyTorch model
    """
    return Binarizer(operator.raw_operator.threshold, device)


register_converter("SklearnBinarizer", convert_sklearn_binarizer)

Functions

def convert_sklearn_binarizer(operator, device, extra_config)

Converter for sklearn.preprocessing.Binarizer

Args

operator
An operator wrapping a sklearn.preprocessing.Binarizer model
device
String defining the type of device the converted operator should be run on
extra_config
Extra configuration used to select the best conversion strategy

Returns

A PyTorch model
 
Expand source code Browse git
def convert_sklearn_binarizer(operator, device, extra_config):
    """
    Converter for `sklearn.preprocessing.Binarizer`

    Args:
        operator: An operator wrapping a `sklearn.preprocessing.Binarizer` model
        device: String defining the type of device the converted operator should be run on
        extra_config: Extra configuration used to select the best conversion strategy

    Returns:
        A PyTorch model
    """
    return Binarizer(operator.raw_operator.threshold, device)

Classes

class Binarizer (threshold, device)

Class implementing Binarizer operators in PyTorch.

Expand source code Browse git
class Binarizer(BaseOperator, torch.nn.Module):
    """
    Class implementing Binarizer operators in PyTorch.
    """

    def __init__(self, threshold, device):
        super(Binarizer, self).__init__()
        self.transformer = True
        self.threshold = torch.nn.Parameter(torch.FloatTensor([threshold]), requires_grad=False)

    def forward(self, x):
        return torch.gt(x, self.threshold).float()

Ancestors

  • hummingbird.ml.operator_converters._base_operator.BaseOperator
  • abc.ABC
  • torch.nn.modules.module.Module

Methods

def forward(self, x) -> Callable[..., Any]
Expand source code Browse git
def forward(self, x):
    return torch.gt(x, self.threshold).float()