Module hummingbird.ml.operator_converters.skl_one_hot_encoder

Expand source code
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------

import numpy as np
from onnxconverter_common.registration import register_converter
import torch

from ._base_operator import BaseOperator


class OneHotEncoderString(BaseOperator, torch.nn.Module):
    """
    Class implementing OneHotEncoder operators for strings in PyTorch.

    Because we are dealing with tensors, strings require additional length information for processing.
    """

    def __init__(self, categories, device):
        super(OneHotEncoderString, self).__init__(transformer=True)

        self.num_columns = len(categories)
        self.max_word_length = max([max([len(c) for c in cat]) for cat in categories])

        # Strings are casted to int32, therefore we need to properly size the tensor to me dividable by 4.
        while self.max_word_length % 4 != 0:
            self.max_word_length += 1

        # We build condition tensors as a 2d tensor of integers.
        # The first dimension is of size num words, the second dimension is fixed to the max word length (// 4).
        condition_tensors = []
        categories_idx = [0]
        for arr in categories:
            cats = (
                np.array(arr, dtype="|S" + str(self.max_word_length))  # Encode objects into 4 byte strings.
                .view("int32")
                .reshape(-1, self.max_word_length // 4)
                .tolist()
            )
            # We merge all categories for all columns into a single tensor
            condition_tensors.extend(cats)
            # Since all categories are merged together, we need to track of indexes to retrieve them at inference time.
            categories_idx.append(categories_idx[-1] + len(cats))
        self.condition_tensors = torch.nn.Parameter(torch.IntTensor(condition_tensors), requires_grad=False)
        self.categories_idx = categories_idx

    def forward(self, x):
        encoded_tensors = []
        for i in range(self.num_columns):
            # First we fetch the condition for the particular column.
            conditions = self.condition_tensors[self.categories_idx[i] : self.categories_idx[i + 1], :].view(
                1, -1, self.max_word_length // 4
            )
            # Differently than the numeric case where eq is enough, here we need to aggregate per object (dim = 2)
            # because objects can span multiple integers. We use product here since all ints must match to get encoding of 1.
            encoded_tensors.append(torch.prod(torch.eq(x[:, i : i + 1, :], conditions), dim=2))

        return torch.cat(encoded_tensors, dim=1).float()


class OneHotEncoder(BaseOperator, torch.nn.Module):
    """
    Class implementing OneHotEncoder operators for ints in PyTorch.
    """

    def __init__(self, categories, device):
        super(OneHotEncoder, self).__init__(transformer=True)

        self.num_columns = len(categories)

        condition_tensors = []
        for arr in categories:
            condition_tensors.append(torch.nn.Parameter(torch.LongTensor(arr), requires_grad=False))
        self.condition_tensors = torch.nn.ParameterList(condition_tensors)

    def forward(self, x):
        if x.dtype != torch.int64:
            x = x.long()

        encoded_tensors = []
        for i in range(self.num_columns):
            encoded_tensors.append(torch.eq(x[:, i : i + 1], self.condition_tensors[i]))
        return torch.cat(encoded_tensors, dim=1).float()


def convert_sklearn_one_hot_encoder(operator, device, extra_config):
    """
    Converter for `sklearn.preprocessing.OneHotEncoder`

    Args:
        operator: An operator wrapping a `sklearn.preprocessing.OneHotEncoder` model
        device: String defining the type of device the converted operator should be run on
        extra_config: Extra configuration used to select the best conversion strategy

    Returns:
        A PyTorch model
    """
    if all([np.array(c).dtype == object for c in operator.raw_operator.categories_]):
        categories = [[str(x) for x in c.tolist()] for c in operator.raw_operator.categories_]
        return OneHotEncoderString(categories, device)
    else:
        return OneHotEncoder(operator.raw_operator.categories_, device)


register_converter("SklearnOneHotEncoder", convert_sklearn_one_hot_encoder)

Functions

def convert_sklearn_one_hot_encoder(operator, device, extra_config)

Converter for sklearn.preprocessing.OneHotEncoder

Args

operator
An operator wrapping a sklearn.preprocessing.OneHotEncoder model
device
String defining the type of device the converted operator should be run on
extra_config
Extra configuration used to select the best conversion strategy

Returns

A PyTorch model
 
Expand source code
def convert_sklearn_one_hot_encoder(operator, device, extra_config):
    """
    Converter for `sklearn.preprocessing.OneHotEncoder`

    Args:
        operator: An operator wrapping a `sklearn.preprocessing.OneHotEncoder` model
        device: String defining the type of device the converted operator should be run on
        extra_config: Extra configuration used to select the best conversion strategy

    Returns:
        A PyTorch model
    """
    if all([np.array(c).dtype == object for c in operator.raw_operator.categories_]):
        categories = [[str(x) for x in c.tolist()] for c in operator.raw_operator.categories_]
        return OneHotEncoderString(categories, device)
    else:
        return OneHotEncoder(operator.raw_operator.categories_, device)

Classes

class OneHotEncoder (categories, device)

Class implementing OneHotEncoder operators for ints in PyTorch.

Expand source code
class OneHotEncoder(BaseOperator, torch.nn.Module):
    """
    Class implementing OneHotEncoder operators for ints in PyTorch.
    """

    def __init__(self, categories, device):
        super(OneHotEncoder, self).__init__(transformer=True)

        self.num_columns = len(categories)

        condition_tensors = []
        for arr in categories:
            condition_tensors.append(torch.nn.Parameter(torch.LongTensor(arr), requires_grad=False))
        self.condition_tensors = torch.nn.ParameterList(condition_tensors)

    def forward(self, x):
        if x.dtype != torch.int64:
            x = x.long()

        encoded_tensors = []
        for i in range(self.num_columns):
            encoded_tensors.append(torch.eq(x[:, i : i + 1], self.condition_tensors[i]))
        return torch.cat(encoded_tensors, dim=1).float()

Ancestors

  • hummingbird.ml.operator_converters._base_operator.BaseOperator
  • abc.ABC
  • torch.nn.modules.module.Module

Methods

def forward(self, x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Expand source code
def forward(self, x):
    if x.dtype != torch.int64:
        x = x.long()

    encoded_tensors = []
    for i in range(self.num_columns):
        encoded_tensors.append(torch.eq(x[:, i : i + 1], self.condition_tensors[i]))
    return torch.cat(encoded_tensors, dim=1).float()
class OneHotEncoderString (categories, device)

Class implementing OneHotEncoder operators for strings in PyTorch.

Because we are dealing with tensors, strings require additional length information for processing.

Expand source code
class OneHotEncoderString(BaseOperator, torch.nn.Module):
    """
    Class implementing OneHotEncoder operators for strings in PyTorch.

    Because we are dealing with tensors, strings require additional length information for processing.
    """

    def __init__(self, categories, device):
        super(OneHotEncoderString, self).__init__(transformer=True)

        self.num_columns = len(categories)
        self.max_word_length = max([max([len(c) for c in cat]) for cat in categories])

        # Strings are casted to int32, therefore we need to properly size the tensor to me dividable by 4.
        while self.max_word_length % 4 != 0:
            self.max_word_length += 1

        # We build condition tensors as a 2d tensor of integers.
        # The first dimension is of size num words, the second dimension is fixed to the max word length (// 4).
        condition_tensors = []
        categories_idx = [0]
        for arr in categories:
            cats = (
                np.array(arr, dtype="|S" + str(self.max_word_length))  # Encode objects into 4 byte strings.
                .view("int32")
                .reshape(-1, self.max_word_length // 4)
                .tolist()
            )
            # We merge all categories for all columns into a single tensor
            condition_tensors.extend(cats)
            # Since all categories are merged together, we need to track of indexes to retrieve them at inference time.
            categories_idx.append(categories_idx[-1] + len(cats))
        self.condition_tensors = torch.nn.Parameter(torch.IntTensor(condition_tensors), requires_grad=False)
        self.categories_idx = categories_idx

    def forward(self, x):
        encoded_tensors = []
        for i in range(self.num_columns):
            # First we fetch the condition for the particular column.
            conditions = self.condition_tensors[self.categories_idx[i] : self.categories_idx[i + 1], :].view(
                1, -1, self.max_word_length // 4
            )
            # Differently than the numeric case where eq is enough, here we need to aggregate per object (dim = 2)
            # because objects can span multiple integers. We use product here since all ints must match to get encoding of 1.
            encoded_tensors.append(torch.prod(torch.eq(x[:, i : i + 1, :], conditions), dim=2))

        return torch.cat(encoded_tensors, dim=1).float()

Ancestors

  • hummingbird.ml.operator_converters._base_operator.BaseOperator
  • abc.ABC
  • torch.nn.modules.module.Module

Methods

def forward(self, x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Expand source code
def forward(self, x):
    encoded_tensors = []
    for i in range(self.num_columns):
        # First we fetch the condition for the particular column.
        conditions = self.condition_tensors[self.categories_idx[i] : self.categories_idx[i + 1], :].view(
            1, -1, self.max_word_length // 4
        )
        # Differently than the numeric case where eq is enough, here we need to aggregate per object (dim = 2)
        # because objects can span multiple integers. We use product here since all ints must match to get encoding of 1.
        encoded_tensors.append(torch.prod(torch.eq(x[:, i : i + 1, :], conditions), dim=2))

    return torch.cat(encoded_tensors, dim=1).float()