trtorch

Functions

trtorch. set_device ( gpu_id )
trtorch. compile ( module : torch.jit._script.ScriptModule , compile_spec : Any ) → torch.jit._script.ScriptModule

Compile a TorchScript module for NVIDIA GPUs using TensorRT

Takes a existing TorchScript module and a set of settings to configure the compiler and will convert methods to JIT Graphs which call equivalent TensorRT engines

Converts specifically the forward method of a TorchScript Module

Parameters
  • module ( torch.jit.ScriptModule ) – Source module, a result of tracing or scripting a PyTorch torch.nn.Module

  • compile_spec ( dict ) –

    Compilation settings including operating precision, target device, etc. One key is required which is input_shapes , describing the input sizes or ranges for inputs to the graph. All other keys are optional

    compile_spec = {
        "input_shapes": [
            (1, 3, 224, 224), # Static input shape for input #1
            {
                "min": (1, 3, 224, 224),
                "opt": (1, 3, 512, 512),
                "max": (1, 3, 1024, 1024)
            } # Dynamic input shape for input #2
        ],
        "device": {
            "device_type": torch.device("cuda"), # Type of device to run engine on (for DLA use trtorch.DeviceType.DLA)
            "gpu_id": 0, # Target gpu id to run engine (Use Xavier as gpu id for DLA)
            "dla_core": 0, # (DLA only) Target dla core id to run engine
            "allow_gpu_fallback": false, # (DLA only) Allow layers unsupported on DLA to run on GPU
        },
        "op_precision": torch.half, # Operating precision set to FP16
        "refit": false, # enable refit
        "debug": false, # enable debuggable engine
        "strict_types": false, # kernels should strictly run in operating precision
        "capability": trtorch.EngineCapability.DEFAULT, # Restrict kernel selection to safe gpu kernels or safe dla kernels
        "num_min_timing_iters": 2, # Number of minimization timing iterations used to select kernels
        "num_avg_timing_iters": 1, # Number of averaging timing iterations used to select kernels
        "workspace_size": 0, # Maximum size of workspace given to TensorRT
        "max_batch_size": 0, # Maximum batch size (must be >= 1 to be set, 0 means not set)
    }
    

    Input Sizes can be specified as torch sizes, tuples or lists. Op precisions can be specified using torch datatypes or trtorch datatypes and you can use either torch devices or the trtorch device type enum to select device type.

Returns

Compiled TorchScript Module, when run it will execute via TensorRT

Return type

torch.jit.ScriptModule

trtorch. convert_method_to_trt_engine ( module : torch.jit._script.ScriptModule , method_name : str , compile_spec : Any ) → str

Convert a TorchScript module method to a serialized TensorRT engine

Converts a specified method of a module to a serialized TensorRT engine given a dictionary of conversion settings

Parameters
  • module ( torch.jit.ScriptModule ) – Source module, a result of tracing or scripting a PyTorch torch.nn.Module

  • method_name ( str ) – Name of method to convert

  • compile_spec ( dict ) –

    Compilation settings including operating precision, target device, etc. One key is required which is input_shapes , describing the input sizes or ranges for inputs to the graph. All other keys are optional

    CompileSpec = {
        "input_shapes": [
            (1, 3, 224, 224), # Static input shape for input #1
            {
                "min": (1, 3, 224, 224),
                "opt": (1, 3, 512, 512),
                "max": (1, 3, 1024, 1024)
            } # Dynamic input shape for input #2
        ],
        "device": {
            "device_type": torch.device("cuda"), # Type of device to run engine on (for DLA use trtorch.DeviceType.DLA)
            "gpu_id": 0, # Target gpu id to run engine (Use Xavier as gpu id for DLA)
            "dla_core": 0, # (DLA only) Target dla core id to run engine
            "allow_gpu_fallback": false, # (DLA only) Allow layers unsupported on DLA to run on GPU
        },
        "op_precision": torch.half, # Operating precision set to FP16
        "disable_tf32": False, # Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas
        "refit": false, # enable refit
        "debug": false, # enable debuggable engine
        "strict_types": false, # kernels should strictly run in operating precision
        "capability": trtorch.EngineCapability.DEFAULT, # Restrict kernel selection to safe gpu kernels or safe dla kernels
        "num_min_timing_iters": 2, # Number of minimization timing iterations used to select kernels
        "num_avg_timing_iters": 1, # Number of averaging timing iterations used to select kernels
        "workspace_size": 0, # Maximum size of workspace given to TensorRT
        "max_batch_size": 0, # Maximum batch size (must be >= 1 to be set, 0 means not set)
    }
    

    Input Sizes can be specified as torch sizes, tuples or lists. Op precisions can be specified using torch datatypes or trtorch datatypes and you can use either torch devices or the trtorch device type enum to select device type.

Returns

Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs

Return type

bytes

trtorch. check_method_op_support ( module : torch.jit._script.ScriptModule , method_name : str ) → bool

Checks to see if a method is fully supported by TRTorch

Checks if a method of a TorchScript module can be compiled by TRTorch, if not, a list of operators that are not supported are printed out and the function returns false, else true.

Parameters
  • module ( torch.jit.ScriptModule ) – Source module, a result of tracing or scripting a PyTorch torch.nn.Module

  • method_name ( str ) – Name of method to check

Returns

True if supported Method

Return type

bool

trtorch. get_build_info ( ) → str

Returns a string containing the build information of TRTorch distribution

Returns

String containing the build information for TRTorch distribution

Return type

str

trtorch. dump_build_info ( )

Prints build information about the TRTorch distribution to stdout

trtorch. TensorRTCompileSpec ( compile_spec : Dict [ str , Any ] ) → <torch._C.ScriptClass object at 0x7efed6e50f10>

Utility to create a formated spec dictionary for using the PyTorch TensorRT backend

Parameters

compile_spec ( dict ) –

Compilation settings including operating precision, target device, etc. One key is required which is input_shapes , describing the input sizes or ranges for inputs to the graph. All other keys are optional. Entries for each method to be compiled.

CompileSpec = {
    "forward" : trtorch.TensorRTCompileSpec({
        "input_shapes": [
            (1, 3, 224, 224), # Static input shape for input #1
            {
                "min": (1, 3, 224, 224),
                "opt": (1, 3, 512, 512),
                "max": (1, 3, 1024, 1024)
            } # Dynamic input shape for input #2
        ],
        "device": {
            "device_type": torch.device("cuda"), # Type of device to run engine on (for DLA use trtorch.DeviceType.DLA)
            "gpu_id": 0, # Target gpu id to run engine (Use Xavier as gpu id for DLA)
            "dla_core": 0, # (DLA only) Target dla core id to run engine
            "allow_gpu_fallback": false, # (DLA only) Allow layers unsupported on DLA to run on GPU
        },
        "op_precision": torch.half, # Operating precision set to FP16
        "disable_tf32": False, # Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas
        "refit": False, # enable refit
        "debug": False, # enable debuggable engine
        "strict_types": False, # kernels should strictly run in operating precision
        "capability": trtorch.EngineCapability.DEFAULT, # Restrict kernel selection to safe gpu kernels or safe dla kernels
        "num_min_timing_iters": 2, # Number of minimization timing iterations used to select kernels
        "num_avg_timing_iters": 1, # Number of averaging timing iterations used to select kernels
        "workspace_size": 0, # Maximum size of workspace given to TensorRT
        "max_batch_size": 0, # Maximum batch size (must be >= 1 to be set, 0 means not set)
    })
}

Input Sizes can be specified as torch sizes, tuples or lists. Op precisions can be specified using torch datatypes or trtorch datatypes and you can use either torch devices or the trtorch device type enum to select device type.

Returns

List of methods and formated spec objects to be provided to torch._C._jit_to_tensorrt

Return type

torch.classes.tensorrt.CompileSpec

Enums

class trtorch. dtype

Enum to specifiy operating precision for engine execution

Members:

float : 32 bit floating point number

float32 : 32 bit floating point number

half : 16 bit floating point number

float16 : 16 bit floating point number

int8 : 8 bit integer number

class trtorch. DeviceType

Enum to specify device kinds to build TensorRT engines for

Members:

GPU : Specify using GPU to execute TensorRT Engine

DLA : Specify using DLA to execute TensorRT Engine (Jetson Only)

class trtorch. EngineCapability

Enum to specify engine capability settings (selections of kernels to meet safety requirements)

Members:

safe_gpu : Use safety GPU kernels only

safe_dla : Use safety DLA kernels only

default : Use default behavior

Submodules