Module hummingbird._utils

Collection of utility functions used throughout Hummingbird.

Expand source code
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------

"""
Collection of utility functions used throughout Hummingbird.
"""

from distutils.version import LooseVersion
import warnings

from .exceptions import ConstantError


def torch_installed():
    """
    Checks that *PyTorch* is available.
    """
    try:
        import torch

        return True
    except ImportError:
        return False


def sklearn_installed():
    """
    Checks that *Sklearn* is available.
    """
    try:
        import sklearn

        return True
    except ImportError:
        return False


def lightgbm_installed():
    """
    Checks that *LightGBM* is available.
    """
    try:
        import lightgbm

        return True
    except ImportError:
        return False


def xgboost_installed():
    """
    Checks that *XGBoost* is available.
    """
    try:
        import xgboost
    except ImportError:
        return False
    from xgboost.core import _LIB

    try:
        _LIB.XGBoosterDumpModelEx
    except AttributeError:
        # The version is not recent enough even though it is version 0.6.
        # You need to install xgboost from github and not from pypi.
        return False
    from xgboost import __version__

    vers = LooseVersion(__version__)
    allowed_min = LooseVersion("0.70")
    allowed_max = LooseVersion("0.90")
    if vers < allowed_min or vers > allowed_max:
        warnings.warn("The converter works for xgboost >= 0.7 and <= 0.9. Different versions might not.")
    return True


class _Constants(object):
    """
    Class enabling the proper definition of constants.
    """

    def __init__(self, constants, other_constants=None):
        for constant in dir(constants):
            if constant.isupper():
                setattr(self, constant, getattr(constants, constant))
        for constant in dir(other_constants):
            if constant.isupper():
                setattr(self, constant, getattr(other_constants, constant))

    def __setattr__(self, name, value):
        if name in self.__dict__:
            raise ConstantError("Overwriting a constant is not allowed {}".format(name))
        self.__dict__[name] = value

Functions

def lightgbm_installed()

Checks that LightGBM is available.

Expand source code
def lightgbm_installed():
    """
    Checks that *LightGBM* is available.
    """
    try:
        import lightgbm

        return True
    except ImportError:
        return False
def sklearn_installed()

Checks that Sklearn is available.

Expand source code
def sklearn_installed():
    """
    Checks that *Sklearn* is available.
    """
    try:
        import sklearn

        return True
    except ImportError:
        return False
def torch_installed()

Checks that PyTorch is available.

Expand source code
def torch_installed():
    """
    Checks that *PyTorch* is available.
    """
    try:
        import torch

        return True
    except ImportError:
        return False
def xgboost_installed()

Checks that XGBoost is available.

Expand source code
def xgboost_installed():
    """
    Checks that *XGBoost* is available.
    """
    try:
        import xgboost
    except ImportError:
        return False
    from xgboost.core import _LIB

    try:
        _LIB.XGBoosterDumpModelEx
    except AttributeError:
        # The version is not recent enough even though it is version 0.6.
        # You need to install xgboost from github and not from pypi.
        return False
    from xgboost import __version__

    vers = LooseVersion(__version__)
    allowed_min = LooseVersion("0.70")
    allowed_max = LooseVersion("0.90")
    if vers < allowed_min or vers > allowed_max:
        warnings.warn("The converter works for xgboost >= 0.7 and <= 0.9. Different versions might not.")
    return True