Module hummingbird.ml.operator_converters.gbdt

Converters for Sklearn's GradientBoosting models.

Expand source code
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------

"""
Converters for Sklearn's GradientBoosting models.
"""

import warnings

import numpy as np
from onnxconverter_common.registration import register_converter

from . import constants
from ._gbdt_commons import convert_gbdt_classifier_common
from ._tree_commons import get_parameters_for_sklearn_common, get_parameters_for_tree_trav_sklearn


def convert_sklearn_gbdt_classifier(operator, device, extra_config):
    """
    Converter for `sklearn.ensemble.GradientBoostingClassifier`.

    Args:
        operator: An operator wrapping a `sklearn.ensemble.GradientBoostingClassifier` model
        device: String defining the type of device the converted operator should be run on
        extra_config: Extra configuration used to select the best conversion strategy

    Returns:
        A PyTorch model
    """
    assert operator is not None

    # Get tree information out of the operator.
    tree_infos = operator.raw_operator.estimators_
    n_features = operator.raw_operator.n_features_
    learning_rate = operator.raw_operator.learning_rate
    classes = operator.raw_operator.classes_.tolist()
    n_classes = len(classes)

    # Analyze classes.
    if not all(isinstance(c, int) for c in classes):
        raise RuntimeError("GBDT Classifier translation only supports integer class labels.")
    if n_classes == 2:
        n_classes -= 1

    # Reshape the tree_infos to a more generic format.
    tree_infos = [tree_infos[i][j] for j in range(n_classes) for i in range(len(tree_infos))]

    # Get the value for Alpha.
    if operator.raw_operator.init == "zero":
        alpha = [[0.0]]
    elif operator.raw_operator.init is None:
        if n_classes == 1:
            alpha = [[np.log(operator.raw_operator.init_.class_prior_[1] / (1 - operator.raw_operator.init_.class_prior_[1]))]]
        else:
            alpha = [[np.log(operator.raw_operator.init_.class_prior_[i]) for i in range(n_classes)]]
    else:
        raise RuntimeError("Custom initializers for GBDT are not yet supported in Hummingbird.")

    extra_config[constants.ALPHA] = alpha
    extra_config[constants.LEARNING_RATE] = learning_rate
    # For sklearn models we need to massage the parameters a bit before generating the parameters for tree_trav.
    extra_config[constants.GET_PARAMETERS_FOR_TREE_TRAVERSAL] = get_parameters_for_tree_trav_sklearn

    return convert_gbdt_classifier_common(
        tree_infos, get_parameters_for_sklearn_common, n_features, n_classes, classes, extra_config
    )


# Register the converter.
register_converter("SklearnGradientBoostingClassifier", convert_sklearn_gbdt_classifier)

Functions

def convert_sklearn_gbdt_classifier(operator, device, extra_config)

Converter for sklearn.ensemble.GradientBoostingClassifier.

Args

operator
An operator wrapping a sklearn.ensemble.GradientBoostingClassifier model
device
String defining the type of device the converted operator should be run on
extra_config
Extra configuration used to select the best conversion strategy

Returns

A PyTorch model
 
Expand source code
def convert_sklearn_gbdt_classifier(operator, device, extra_config):
    """
    Converter for `sklearn.ensemble.GradientBoostingClassifier`.

    Args:
        operator: An operator wrapping a `sklearn.ensemble.GradientBoostingClassifier` model
        device: String defining the type of device the converted operator should be run on
        extra_config: Extra configuration used to select the best conversion strategy

    Returns:
        A PyTorch model
    """
    assert operator is not None

    # Get tree information out of the operator.
    tree_infos = operator.raw_operator.estimators_
    n_features = operator.raw_operator.n_features_
    learning_rate = operator.raw_operator.learning_rate
    classes = operator.raw_operator.classes_.tolist()
    n_classes = len(classes)

    # Analyze classes.
    if not all(isinstance(c, int) for c in classes):
        raise RuntimeError("GBDT Classifier translation only supports integer class labels.")
    if n_classes == 2:
        n_classes -= 1

    # Reshape the tree_infos to a more generic format.
    tree_infos = [tree_infos[i][j] for j in range(n_classes) for i in range(len(tree_infos))]

    # Get the value for Alpha.
    if operator.raw_operator.init == "zero":
        alpha = [[0.0]]
    elif operator.raw_operator.init is None:
        if n_classes == 1:
            alpha = [[np.log(operator.raw_operator.init_.class_prior_[1] / (1 - operator.raw_operator.init_.class_prior_[1]))]]
        else:
            alpha = [[np.log(operator.raw_operator.init_.class_prior_[i]) for i in range(n_classes)]]
    else:
        raise RuntimeError("Custom initializers for GBDT are not yet supported in Hummingbird.")

    extra_config[constants.ALPHA] = alpha
    extra_config[constants.LEARNING_RATE] = learning_rate
    # For sklearn models we need to massage the parameters a bit before generating the parameters for tree_trav.
    extra_config[constants.GET_PARAMETERS_FOR_TREE_TRAVERSAL] = get_parameters_for_tree_trav_sklearn

    return convert_gbdt_classifier_common(
        tree_infos, get_parameters_for_sklearn_common, n_features, n_classes, classes, extra_config
    )