Module hummingbird.operator_converters._gbdt_commons
Collections of classes and functions shared among GBDT converters.
Expand source code
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
"""
Collections of classes and functions shared among GBDT converters.
"""
import torch
import numpy as np
from . import constants
from ._tree_commons import get_tree_params_and_type, get_parameters_for_tree_trav_common, get_parameters_for_gemm_common
from ._tree_implementations import GEMMTreeImpl, TreeTraversalTreeImpl, PerfectTreeTraversalTreeImpl, TreeImpl
class GEMMGBDTImpl(GEMMTreeImpl):
"""
Class implementing the GEMM strategy (in PyTorch) for GBDT models.
"""
def __init__(self, net_parameters, n_features, classes=None, extra_config={}):
"""
Args:
net_parameters: The parameters defining the tree structure
n_features: The number of features input to the model
classes: The classes used for classification. None if implementing a regression model
extra_config: Extra configuration used to properly implement the source tree
"""
super(GEMMGBDTImpl, self).__init__(net_parameters, n_features, classes, 1)
self.n_gbdt_classes = 1
if constants.LEARNING_RATE in extra_config:
self.learning_rate = extra_config[constants.LEARNING_RATE]
if constants.ALPHA in extra_config:
self.alpha = torch.nn.Parameter(torch.FloatTensor(extra_config[constants.ALPHA]), requires_grad=False)
if classes is not None:
self.n_gbdt_classes = len(classes) if len(classes) > 2 else 1
if self.n_gbdt_classes == 1:
self.binary_classification = True
self.n_trees_per_class = len(net_parameters) // self.n_gbdt_classes
def aggregation(self, x):
return torch.squeeze(x).t().view(-1, self.n_gbdt_classes, self.n_trees_per_class).sum(2)
def calibration(self, x):
if self.binary_classification:
output = torch.sigmoid(x)
return torch.cat([1 - output, output], dim=1)
else:
return torch.softmax(x, dim=1)
class TreeTraversalGBDTImpl(TreeTraversalTreeImpl):
"""
Class implementing the Tree Traversal strategy in PyTorch.
"""
def __init__(self, net_parameters, max_detph, n_features, classes=None, extra_config={}):
"""
Args:
net_parameters: The parameters defining the tree structure
max_depth: The maximum tree-depth in the model
n_features: The number of features input to the model
classes: The classes used for classification. None if implementing a regression model
extra_config: Extra configuration used to properly implement the source tree
"""
super(TreeTraversalGBDTImpl, self).__init__(net_parameters, max_detph, n_features, classes, 1)
self.n_gbdt_classes = 1
if constants.LEARNING_RATE in extra_config:
self.learning_rate = extra_config[constants.LEARNING_RATE]
if constants.ALPHA in extra_config:
self.alpha = torch.nn.Parameter(torch.FloatTensor(extra_config[constants.ALPHA]), requires_grad=False)
if classes is not None:
self.n_gbdt_classes = len(classes) if len(classes) > 2 else 1
if self.n_gbdt_classes == 1:
self.binary_classification = True
self.n_trees_per_class = len(net_parameters) // self.n_gbdt_classes
def aggregation(self, x):
return x.view(-1, self.n_gbdt_classes, self.n_trees_per_class).sum(2)
def calibration(self, x):
if self.binary_classification:
output = torch.sigmoid(x)
return torch.cat([1 - output, output], dim=1)
else:
return torch.softmax(x, dim=1)
class PerfectTreeTraversalGBDTImpl(PerfectTreeTraversalTreeImpl):
"""
Class implementing the Perfect Tree Traversal strategy in PyTorch.
"""
def __init__(self, net_parameters, max_depth, n_features, classes=None, extra_config={}):
"""
Args:
net_parameters: The parameters defining the tree structure
max_depth: The maximum tree-depth in the model
n_features: The number of features input to the model
classes: The classes used for classification. None if implementing a regression model
extra_config: Extra configuration used to properly implement the source tree
"""
super(PerfectTreeTraversalGBDTImpl, self).__init__(net_parameters, max_depth, n_features, classes, 1)
self.n_gbdt_classes = 1
if constants.LEARNING_RATE in extra_config:
self.learning_rate = extra_config[constants.LEARNING_RATE]
if constants.ALPHA in extra_config:
self.alpha = torch.nn.Parameter(torch.FloatTensor(extra_config[constants.ALPHA]), requires_grad=False)
if classes is not None:
self.n_gbdt_classes = len(classes) if len(classes) > 2 else 1
if self.n_gbdt_classes == 1:
self.binary_classification = True
self.n_trees_per_class = len(net_parameters) // self.n_gbdt_classes
def aggregation(self, x):
return x.view(-1, self.n_gbdt_classes, self.n_trees_per_class).sum(2)
def calibration(self, x):
if self.binary_classification:
output = torch.sigmoid(x)
return torch.cat([1 - output, output], dim=1)
else:
return torch.softmax(x, dim=1)
def convert_gbdt_classifier_common(tree_infos, get_tree_parameters, n_features, n_classes, classes=None, extra_config={}):
"""
Common converter for GBDT classifiers.
Args:
tree_infos: The information representaing a tree (ensemble)
get_tree_parameters: A function specifying how to parse the tree_infos into parameters
n_features: The number of features input to the model
n_classes: How many classes are expected. 1 for regression tasks
classes: The classes used for classification. None if implementing a regression model
extra_config: Extra configuration used to properly implement the source tree
Returns:
A tree implementation in PyTorch
"""
assert tree_infos is not None
assert get_tree_parameters is not None
assert n_features is not None
assert n_classes is not None
# Rearrange classes and tree information.
if n_classes == 2:
n_classes -= 1
if classes is None:
classes = [i for i in range(n_classes)]
tree_infos = [tree_infos[i * n_classes + j] for j in range(n_classes) for i in range(len(tree_infos) // n_classes)]
return convert_gbdt_common(tree_infos, get_tree_parameters, n_features, classes, extra_config)
def convert_gbdt_common(tree_infos, get_tree_parameters, n_features, classes=None, extra_config={}):
"""
Common converter for GBDT models.
Args:
tree_infos: The information representaing a tree (ensemble)
get_tree_parameters: A function specifying how to parse the tree_infos into parameters
n_features: The number of features input to the model
classes: The classes used for classification. None if implementing a regression model
extra_config: Extra configuration used to properly implement the source tree
Returns:
A tree implementation in PyTorch
"""
assert tree_infos is not None
assert get_tree_parameters is not None
assert n_features is not None
tree_parameters, max_depth, tree_type = get_tree_params_and_type(tree_infos, get_tree_parameters, extra_config)
# Generate the tree implementation based on the selected strategy.
if tree_type == TreeImpl.gemm:
net_parameters = [
get_parameters_for_gemm_common(
tree_param.lefts, tree_param.rights, tree_param.features, tree_param.thresholds, tree_param.values, n_features
)
for tree_param in tree_parameters
]
return GEMMGBDTImpl(net_parameters, n_features, classes, extra_config)
# Some models require some additional massagging of the parameters before generating the tree_trav implementation.
get_parameters_for_tree_trav = get_parameters_for_tree_trav_common
if constants.GET_PARAMETERS_FOR_TREE_TRAVERSAL in extra_config:
get_parameters_for_tree_trav = extra_config[constants.GET_PARAMETERS_FOR_TREE_TRAVERSAL]
net_parameters = [
get_parameters_for_tree_trav(
tree_param.lefts, tree_param.rights, tree_param.features, tree_param.thresholds, tree_param.values
)
for tree_param in tree_parameters
]
if tree_type == TreeImpl.tree_trav:
return TreeTraversalGBDTImpl(net_parameters, max_depth, n_features, classes, extra_config)
else: # Remaining possible case: tree_type == TreeImpl.perf_tree_trav.
return PerfectTreeTraversalGBDTImpl(net_parameters, max_depth, n_features, classes, extra_config)
Functions
def convert_gbdt_classifier_common(tree_infos, get_tree_parameters, n_features, n_classes, classes=None, extra_config={})
-
Common converter for GBDT classifiers.
Args
tree_infos
- The information representaing a tree (ensemble)
get_tree_parameters
- A function specifying how to parse the tree_infos into parameters
n_features
- The number of features input to the model
n_classes
- How many classes are expected. 1 for regression tasks
classes
- The classes used for classification. None if implementing a regression model
extra_config
- Extra configuration used to properly implement the source tree
Returns
A tree implementation in PyTorch
Expand source code
def convert_gbdt_classifier_common(tree_infos, get_tree_parameters, n_features, n_classes, classes=None, extra_config={}): """ Common converter for GBDT classifiers. Args: tree_infos: The information representaing a tree (ensemble) get_tree_parameters: A function specifying how to parse the tree_infos into parameters n_features: The number of features input to the model n_classes: How many classes are expected. 1 for regression tasks classes: The classes used for classification. None if implementing a regression model extra_config: Extra configuration used to properly implement the source tree Returns: A tree implementation in PyTorch """ assert tree_infos is not None assert get_tree_parameters is not None assert n_features is not None assert n_classes is not None # Rearrange classes and tree information. if n_classes == 2: n_classes -= 1 if classes is None: classes = [i for i in range(n_classes)] tree_infos = [tree_infos[i * n_classes + j] for j in range(n_classes) for i in range(len(tree_infos) // n_classes)] return convert_gbdt_common(tree_infos, get_tree_parameters, n_features, classes, extra_config)
def convert_gbdt_common(tree_infos, get_tree_parameters, n_features, classes=None, extra_config={})
-
Common converter for GBDT models.
Args
tree_infos
- The information representaing a tree (ensemble)
get_tree_parameters
- A function specifying how to parse the tree_infos into parameters
n_features
- The number of features input to the model
classes
- The classes used for classification. None if implementing a regression model
extra_config
- Extra configuration used to properly implement the source tree
Returns
A tree implementation in PyTorch
Expand source code
def convert_gbdt_common(tree_infos, get_tree_parameters, n_features, classes=None, extra_config={}): """ Common converter for GBDT models. Args: tree_infos: The information representaing a tree (ensemble) get_tree_parameters: A function specifying how to parse the tree_infos into parameters n_features: The number of features input to the model classes: The classes used for classification. None if implementing a regression model extra_config: Extra configuration used to properly implement the source tree Returns: A tree implementation in PyTorch """ assert tree_infos is not None assert get_tree_parameters is not None assert n_features is not None tree_parameters, max_depth, tree_type = get_tree_params_and_type(tree_infos, get_tree_parameters, extra_config) # Generate the tree implementation based on the selected strategy. if tree_type == TreeImpl.gemm: net_parameters = [ get_parameters_for_gemm_common( tree_param.lefts, tree_param.rights, tree_param.features, tree_param.thresholds, tree_param.values, n_features ) for tree_param in tree_parameters ] return GEMMGBDTImpl(net_parameters, n_features, classes, extra_config) # Some models require some additional massagging of the parameters before generating the tree_trav implementation. get_parameters_for_tree_trav = get_parameters_for_tree_trav_common if constants.GET_PARAMETERS_FOR_TREE_TRAVERSAL in extra_config: get_parameters_for_tree_trav = extra_config[constants.GET_PARAMETERS_FOR_TREE_TRAVERSAL] net_parameters = [ get_parameters_for_tree_trav( tree_param.lefts, tree_param.rights, tree_param.features, tree_param.thresholds, tree_param.values ) for tree_param in tree_parameters ] if tree_type == TreeImpl.tree_trav: return TreeTraversalGBDTImpl(net_parameters, max_depth, n_features, classes, extra_config) else: # Remaining possible case: tree_type == TreeImpl.perf_tree_trav. return PerfectTreeTraversalGBDTImpl(net_parameters, max_depth, n_features, classes, extra_config)
Classes
class GEMMGBDTImpl (net_parameters, n_features, classes=None, extra_config={})
-
Class implementing the GEMM strategy (in PyTorch) for GBDT models.
Args
net_parameters
- The parameters defining the tree structure
n_features
- The number of features input to the model
classes
- The classes used for classification. None if implementing a regression model
extra_config
- Extra configuration used to properly implement the source tree
Expand source code
class GEMMGBDTImpl(GEMMTreeImpl): """ Class implementing the GEMM strategy (in PyTorch) for GBDT models. """ def __init__(self, net_parameters, n_features, classes=None, extra_config={}): """ Args: net_parameters: The parameters defining the tree structure n_features: The number of features input to the model classes: The classes used for classification. None if implementing a regression model extra_config: Extra configuration used to properly implement the source tree """ super(GEMMGBDTImpl, self).__init__(net_parameters, n_features, classes, 1) self.n_gbdt_classes = 1 if constants.LEARNING_RATE in extra_config: self.learning_rate = extra_config[constants.LEARNING_RATE] if constants.ALPHA in extra_config: self.alpha = torch.nn.Parameter(torch.FloatTensor(extra_config[constants.ALPHA]), requires_grad=False) if classes is not None: self.n_gbdt_classes = len(classes) if len(classes) > 2 else 1 if self.n_gbdt_classes == 1: self.binary_classification = True self.n_trees_per_class = len(net_parameters) // self.n_gbdt_classes def aggregation(self, x): return torch.squeeze(x).t().view(-1, self.n_gbdt_classes, self.n_trees_per_class).sum(2) def calibration(self, x): if self.binary_classification: output = torch.sigmoid(x) return torch.cat([1 - output, output], dim=1) else: return torch.softmax(x, dim=1)
Ancestors
- GEMMTreeImpl
- AbstractPyTorchTreeImpl
- AbstracTreeImpl
- abc.ABC
- torch.nn.modules.module.Module
Inherited members
class PerfectTreeTraversalGBDTImpl (net_parameters, max_depth, n_features, classes=None, extra_config={})
-
Class implementing the Perfect Tree Traversal strategy in PyTorch.
Args
net_parameters
- The parameters defining the tree structure
max_depth
- The maximum tree-depth in the model
n_features
- The number of features input to the model
classes
- The classes used for classification. None if implementing a regression model
extra_config
- Extra configuration used to properly implement the source tree
Expand source code
class PerfectTreeTraversalGBDTImpl(PerfectTreeTraversalTreeImpl): """ Class implementing the Perfect Tree Traversal strategy in PyTorch. """ def __init__(self, net_parameters, max_depth, n_features, classes=None, extra_config={}): """ Args: net_parameters: The parameters defining the tree structure max_depth: The maximum tree-depth in the model n_features: The number of features input to the model classes: The classes used for classification. None if implementing a regression model extra_config: Extra configuration used to properly implement the source tree """ super(PerfectTreeTraversalGBDTImpl, self).__init__(net_parameters, max_depth, n_features, classes, 1) self.n_gbdt_classes = 1 if constants.LEARNING_RATE in extra_config: self.learning_rate = extra_config[constants.LEARNING_RATE] if constants.ALPHA in extra_config: self.alpha = torch.nn.Parameter(torch.FloatTensor(extra_config[constants.ALPHA]), requires_grad=False) if classes is not None: self.n_gbdt_classes = len(classes) if len(classes) > 2 else 1 if self.n_gbdt_classes == 1: self.binary_classification = True self.n_trees_per_class = len(net_parameters) // self.n_gbdt_classes def aggregation(self, x): return x.view(-1, self.n_gbdt_classes, self.n_trees_per_class).sum(2) def calibration(self, x): if self.binary_classification: output = torch.sigmoid(x) return torch.cat([1 - output, output], dim=1) else: return torch.softmax(x, dim=1)
Ancestors
- PerfectTreeTraversalTreeImpl
- AbstractPyTorchTreeImpl
- AbstracTreeImpl
- abc.ABC
- torch.nn.modules.module.Module
Inherited members
class TreeTraversalGBDTImpl (net_parameters, max_detph, n_features, classes=None, extra_config={})
-
Class implementing the Tree Traversal strategy in PyTorch.
Args
net_parameters
- The parameters defining the tree structure
max_depth
- The maximum tree-depth in the model
n_features
- The number of features input to the model
classes
- The classes used for classification. None if implementing a regression model
extra_config
- Extra configuration used to properly implement the source tree
Expand source code
class TreeTraversalGBDTImpl(TreeTraversalTreeImpl): """ Class implementing the Tree Traversal strategy in PyTorch. """ def __init__(self, net_parameters, max_detph, n_features, classes=None, extra_config={}): """ Args: net_parameters: The parameters defining the tree structure max_depth: The maximum tree-depth in the model n_features: The number of features input to the model classes: The classes used for classification. None if implementing a regression model extra_config: Extra configuration used to properly implement the source tree """ super(TreeTraversalGBDTImpl, self).__init__(net_parameters, max_detph, n_features, classes, 1) self.n_gbdt_classes = 1 if constants.LEARNING_RATE in extra_config: self.learning_rate = extra_config[constants.LEARNING_RATE] if constants.ALPHA in extra_config: self.alpha = torch.nn.Parameter(torch.FloatTensor(extra_config[constants.ALPHA]), requires_grad=False) if classes is not None: self.n_gbdt_classes = len(classes) if len(classes) > 2 else 1 if self.n_gbdt_classes == 1: self.binary_classification = True self.n_trees_per_class = len(net_parameters) // self.n_gbdt_classes def aggregation(self, x): return x.view(-1, self.n_gbdt_classes, self.n_trees_per_class).sum(2) def calibration(self, x): if self.binary_classification: output = torch.sigmoid(x) return torch.cat([1 - output, output], dim=1) else: return torch.softmax(x, dim=1)
Ancestors
- TreeTraversalTreeImpl
- AbstractPyTorchTreeImpl
- AbstracTreeImpl
- abc.ABC
- torch.nn.modules.module.Module
Inherited members