Program Listing for File trtorch.h

Return to documentation for file ( cpp/api/include/trtorch/trtorch.h )

/*
 * Copyright (c) NVIDIA Corporation.
 * All rights reserved.
 *
 * This library is licensed under the BSD-style license found in the
 * LICENSE file in the root directory of this source tree.
 */

#pragma once

#include <cuda_runtime.h>
#include <iostream>
#include <memory>
#include <set>
#include <string>
#include <vector>

// Just include the .h?
#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace torch {
namespace jit {
struct Graph;
struct Module;
} // namespace jit
} // namespace torch

namespace c10 {
enum class DeviceType : int8_t;
enum class ScalarType : int8_t;
template <class>
class ArrayRef;
} // namespace c10

namespace nvinfer1 {
class IInt8Calibrator;
}
#endif // DOXYGEN_SHOULD_SKIP_THIS

#include "trtorch/macros.h"
namespace trtorch {
struct TRTORCH_API CompileSpec {
  class TRTORCH_API DataType {
   public:
    enum Value : int8_t {
      kFloat,
      kHalf,
      kChar,
      kInt,
      kBool,
      kUnknown
    };

    DataType() = default;
    constexpr DataType(Value t) : value(t) {}
    DataType(c10::ScalarType t);
    operator Value() const {
      return value;
    }
    explicit operator bool() = delete;
    constexpr bool operator==(DataType other) const {
      return value == other.value;
    }
    constexpr bool operator==(DataType::Value other) const {
      return value == other;
    }
    constexpr bool operator!=(DataType other) const {
      return value != other.value;
    }
    constexpr bool operator!=(DataType::Value other) const {
      return value != other;
    }

   private:
    friend std::ostream& operator<<(std::ostream& os, const DataType& dtype);
    Value value;
  };

  /*
   * Setting data structure for Target device
   */
  struct Device {
    class DeviceType {
     public:
      enum Value : int8_t {
        kGPU,
        kDLA,
      };

      DeviceType() = default;
      constexpr DeviceType(Value t) : value(t) {}
      DeviceType(c10::DeviceType t);
      operator Value() const {
        return value;
      }
      explicit operator bool() = delete;
      constexpr bool operator==(DeviceType other) const {
        return value == other.value;
      }
      constexpr bool operator!=(DeviceType other) const {
        return value != other.value;
      }

     private:
      Value value;
    };

    DeviceType device_type;

    /*
     * Target gpu id
     */
    int64_t gpu_id;

    /*
     * When using DLA core on NVIDIA AGX platforms gpu_id should be set as Xavier device
     */
    int64_t dla_core;

    bool allow_gpu_fallback;

    Device() : device_type(DeviceType::kGPU), gpu_id(0), dla_core(0), allow_gpu_fallback(false) {}
  };

  enum class EngineCapability : int8_t {
    kSTANDARD,
    kSAFETY,
    kDLA_STANDALONE,
  };

  class TRTORCH_API TensorFormat {
   public:
    enum Value : int8_t {
      kContiguous,
      kChannelsLast,
      kUnknown,
    };

    TensorFormat() = default;
    constexpr TensorFormat(Value t) : value(t) {}
    TensorFormat(at::MemoryFormat t);
    operator Value() const {
      return value;
    }
    explicit operator bool() = delete;
    constexpr bool operator==(TensorFormat other) const {
      return value == other.value;
    }
    constexpr bool operator==(TensorFormat::Value other) const {
      return value == other;
    }
    constexpr bool operator!=(TensorFormat other) const {
      return value != other.value;
    }
    constexpr bool operator!=(TensorFormat::Value other) const {
      return value != other;
    }

   private:
    friend std::ostream& operator<<(std::ostream& os, const TensorFormat& format);
    Value value;
  };

  struct TRTORCH_API Input {
    std::vector<int64_t> min_shape;
    std::vector<int64_t> opt_shape;
    std::vector<int64_t> max_shape;
    std::vector<int64_t> shape;
    DataType dtype;
    TensorFormat format;

    Input(std::vector<int64_t> shape, TensorFormat format = TensorFormat::kContiguous);

    Input(std::vector<int64_t> shape, DataType dtype, TensorFormat format = TensorFormat::kContiguous);

    Input(c10::ArrayRef<int64_t> shape, TensorFormat format = TensorFormat::kContiguous);

    Input(c10::ArrayRef<int64_t> shape, DataType dtype, TensorFormat format = TensorFormat::kContiguous);

    Input(
        std::vector<int64_t> min_shape,
        std::vector<int64_t> opt_shape,
        std::vector<int64_t> max_shape,
        TensorFormat format = TensorFormat::kContiguous);

    Input(
        std::vector<int64_t> min_shape,
        std::vector<int64_t> opt_shape,
        std::vector<int64_t> max_shape,
        DataType dtype,
        TensorFormat format = TensorFormat::kContiguous);

    Input(
        c10::ArrayRef<int64_t> min_shape,
        c10::ArrayRef<int64_t> opt_shape,
        c10::ArrayRef<int64_t> max_shape,
        TensorFormat format = TensorFormat::kContiguous);

    Input(
        c10::ArrayRef<int64_t> min_shape,
        c10::ArrayRef<int64_t> opt_shape,
        c10::ArrayRef<int64_t> max_shape,
        DataType dtype,
        TensorFormat format = TensorFormat::kContiguous);

    bool get_explicit_set_dtype() {
      return explicit_set_dtype;
    }

   private:
    friend std::ostream& operator<<(std::ostream& os, const Input& input);
    bool input_is_dynamic;
    bool explicit_set_dtype;
  };

  struct TRTORCH_API InputRange {
    std::vector<int64_t> min;
    std::vector<int64_t> opt;
    std::vector<int64_t> max;
    [[deprecated("trtorch::CompileSpec::InputRange is being deprecated in favor of trtorch::CompileSpec::Input. trtorch::CompileSpec::InputRange will be removed in TRTorch v0.5.0")]] InputRange(
        std::vector<int64_t> opt);
    [[deprecated("trtorch::CompileSpec::InputRange is being deprecated in favor of trtorch::CompileSpec::Input. trtorch::CompileSpec::InputRange will be removed in TRTorch v0.5.0")]] InputRange(
        c10::ArrayRef<int64_t> opt);
    [[deprecated("trtorch::CompileSpec::InputRange is being deprecated in favor of trtorch::CompileSpec::Input. trtorch::CompileSpec::InputRange will be removed in TRTorch v0.5.0")]] InputRange(
        std::vector<int64_t> min,
        std::vector<int64_t> opt,
        std::vector<int64_t> max);
    [[deprecated("trtorch::CompileSpec::InputRange is being deprecated in favor of trtorch::CompileSpec::Input. trtorch::CompileSpec::InputRange will be removed in TRTorch v0.5.0")]] InputRange(
        c10::ArrayRef<int64_t> min,
        c10::ArrayRef<int64_t> opt,
        c10::ArrayRef<int64_t> max);
  };

  struct TRTORCH_API TorchFallback {
    bool enabled = false;

    uint64_t min_block_size = 1;

    std::vector<std::string> forced_fallback_ops;

    std::vector<std::string> forced_fallback_modules;

    TorchFallback() = default;

    TorchFallback(bool enabled) : enabled(enabled) {}

    TorchFallback(bool enabled, uint64_t min_size) : enabled(enabled), min_block_size(min_size) {}
  };

  [[deprecated("trtorch::CompileSpec::CompileSpec(std::vector<InputRange> input_ranges) is being deprecated in favor of trtorch::CompileSpec::CompileSpec(std::vector<Input> inputs). Please use CompileSpec(std::vector<Input> inputs). trtorch::CompileSpec::CompileSpec(std::vector<InputRange> input_ranges) will be removed in TRTorch v0.5.0")]] CompileSpec(
      std::vector<InputRange> input_ranges)
      : input_ranges(std::move(input_ranges)) {}
  CompileSpec(std::vector<std::vector<int64_t>> fixed_sizes);

  CompileSpec(std::vector<c10::ArrayRef<int64_t>> fixed_sizes);

  CompileSpec(std::vector<Input> inputs) : inputs(std::move(inputs)) {}

  // Defaults should reflect TensorRT defaults for BuilderConfig

  std::vector<Input> inputs;

  [[deprecated(
      "trtorch::CompileSpec::input_ranges is being deprecated in favor of trtorch::CompileSpec::inputs. trtorch::CompileSpec::input_ranges will be removed in TRTorch v0.5.0")]] std::
      vector<InputRange>
          input_ranges;

  [[deprecated(
      "trtorch::CompileSpec::op_precision is being deprecated in favor of trtorch::CompileSpec::enabled_precisions, a set of all enabled precisions to use during compilation, trtorch::CompileSpec::op_precision will be removed in TRTorch v0.5.0")]] DataType
      op_precision = DataType::kFloat;

  std::set<DataType> enabled_precisions = {DataType::kFloat};

  bool disable_tf32 = false;

  bool sparse_weights = false;

  bool refit = false;

  bool debug = false;

  bool truncate_long_and_double = false;

  bool strict_types = false;

  Device device;

  TorchFallback torch_fallback;

  EngineCapability capability = EngineCapability::kSTANDARD;

  uint64_t num_min_timing_iters = 2;
  uint64_t num_avg_timing_iters = 1;

  uint64_t workspace_size = 0;

  uint64_t max_batch_size = 0;

  nvinfer1::IInt8Calibrator* ptq_calibrator = nullptr;
};

TRTORCH_API std::string get_build_info();

TRTORCH_API void dump_build_info();

TRTORCH_API bool CheckMethodOperatorSupport(const torch::jit::Module& module, std::string method_name);

TRTORCH_API torch::jit::Module CompileGraph(const torch::jit::Module& module, CompileSpec info);

TRTORCH_API std::string ConvertGraphToTRTEngine(
    const torch::jit::Module& module,
    std::string method_name,
    CompileSpec info);

TRTORCH_API torch::jit::Module EmbedEngineInNewModule(const std::string& engine, CompileSpec::Device device);

TRTORCH_API void set_device(const int gpu_id);

} // namespace trtorch