Module hummingbird.ml.ir_converters.topology

Converters for topology IR are stored in this file.

Expand source code
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------

"""
Converters for topology IR are stored in this file.
"""
from onnxconverter_common.registration import get_converter

from ..exceptions import MissingConverter
from .._container import PyTorchBackendModelRegression, PyTorchBackendModelClassification, PyTorchBackendModelTransformer


def convert(topology, device=None, extra_config={}):
    """
    This function is used to convert a `onnxconverter_common.topology.Topology` object into a *PyTorch* model.

    Args:
        topology: The `onnxconverter_common.topology.Topology` object that will be converted into Pytorch
        device: Which device the translated model will be run on
        extra_config: Extra configurations to be used by individual operator converters

    Returns:
        A *PyTorch* model
    """
    assert topology is not None, "Cannot convert a Topology object of type None."

    operator_map = {}

    for operator in topology.topological_operator_iterator():
        try:
            converter = get_converter(operator.type)
            operator_map[operator.full_name] = converter(operator, device, extra_config)
        except ValueError:
            raise MissingConverter(
                "Unable to find converter for {} type {} with extra config: {}.".format(
                    operator.type, type(getattr(operator, "raw_model", None)), extra_config
                )
            )
        except Exception as e:
            raise e

    operators = list(topology.topological_operator_iterator())
    if operator_map[operators[-1].full_name].regression:
        # We are doing a regression task.
        pytorch_container = PyTorchBackendModelRegression
    elif operator_map[operators[-1].full_name].transformer:
        # We are just transforming the input data.
        pytorch_container = PyTorchBackendModelTransformer
    else:
        # We are doing a classification task.
        pytorch_container = PyTorchBackendModelClassification

    pytorch_model = pytorch_container(
        topology.raw_model.input_names, topology.raw_model.output_names, operator_map, operators, extra_config
    ).eval()

    if device is not None:
        pytorch_model = pytorch_model.to(device)
    return pytorch_model

Functions

def convert(topology, device=None, extra_config={})

This function is used to convert a onnxconverter_common.topology.Topology object into a PyTorch model.

Args

topology
The onnxconverter_common.topology.Topology object that will be converted into Pytorch
device
Which device the translated model will be run on
extra_config
Extra configurations to be used by individual operator converters

Returns

A PyTorch model

Expand source code
def convert(topology, device=None, extra_config={}):
    """
    This function is used to convert a `onnxconverter_common.topology.Topology` object into a *PyTorch* model.

    Args:
        topology: The `onnxconverter_common.topology.Topology` object that will be converted into Pytorch
        device: Which device the translated model will be run on
        extra_config: Extra configurations to be used by individual operator converters

    Returns:
        A *PyTorch* model
    """
    assert topology is not None, "Cannot convert a Topology object of type None."

    operator_map = {}

    for operator in topology.topological_operator_iterator():
        try:
            converter = get_converter(operator.type)
            operator_map[operator.full_name] = converter(operator, device, extra_config)
        except ValueError:
            raise MissingConverter(
                "Unable to find converter for {} type {} with extra config: {}.".format(
                    operator.type, type(getattr(operator, "raw_model", None)), extra_config
                )
            )
        except Exception as e:
            raise e

    operators = list(topology.topological_operator_iterator())
    if operator_map[operators[-1].full_name].regression:
        # We are doing a regression task.
        pytorch_container = PyTorchBackendModelRegression
    elif operator_map[operators[-1].full_name].transformer:
        # We are just transforming the input data.
        pytorch_container = PyTorchBackendModelTransformer
    else:
        # We are doing a classification task.
        pytorch_container = PyTorchBackendModelClassification

    pytorch_model = pytorch_container(
        topology.raw_model.input_names, topology.raw_model.output_names, operator_map, operators, extra_config
    ).eval()

    if device is not None:
        pytorch_model = pytorch_model.to(device)
    return pytorch_model