Program Listing for File trtorch.h ¶
↰
Return to documentation for file
(
cpp/api/include/trtorch/trtorch.h
)
/*
* Copyright (c) NVIDIA Corporation.
* All rights reserved.
*
* This library is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <string>
#include <vector>
#include <memory>
// Just include the .h?
#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace torch {
namespace jit {
struct Graph;
struct Module;
} // namespace jit
} // namespace torch
namespace c10 {
enum class DeviceType : int16_t;
enum class ScalarType : int8_t;
template <class>
class ArrayRef;
}
namespace nvinfer1 {
class IInt8EntropyCalibrator2;
}
#endif //DOXYGEN_SHOULD_SKIP_THIS
#include "trtorch/macros.h"
#include "trtorch/logging.h"
#include "trtorch/ptq.h"
namespace trtorch {
struct TRTORCH_API ExtraInfo {
struct TRTORCH_API InputRange {
std::vector<int64_t> min;
std::vector<int64_t> opt;
std::vector<int64_t> max;
InputRange(std::vector<int64_t> opt);
InputRange(c10::ArrayRef<int64_t> opt);
InputRange(std::vector<int64_t> min, std::vector<int64_t> opt, std::vector<int64_t> max);
InputRange(c10::ArrayRef<int64_t> min, c10::ArrayRef<int64_t> opt, c10::ArrayRef<int64_t> max);
};
class DataType {
public:
enum Value : int8_t {
kFloat,
kHalf,
kChar,
};
DataType() = default;
constexpr DataType(Value t) : value(t) {}
DataType(c10::ScalarType t);
operator Value() const { return value; }
explicit operator bool() = delete;
constexpr bool operator==(DataType other) const { return value == other.value; }
constexpr bool operator!=(DataType other) const { return value != other.value; }
private:
Value value;
};
class DeviceType {
public:
enum Value : int8_t {
kGPU,
kDLA,
};
DeviceType() = default;
constexpr DeviceType(Value t) : value(t) {}
DeviceType(c10::DeviceType t);
operator Value() const { return value; }
explicit operator bool() = delete;
constexpr bool operator==(DeviceType other) const { return value == other.value; }
constexpr bool operator!=(DeviceType other) const { return value != other.value; }
private:
Value value;
};
enum class EngineCapability : int8_t {
kDEFAULT,
kSAFE_GPU,
kSAFE_DLA,
};
ExtraInfo(std::vector<InputRange> input_ranges)
: input_ranges(std::move(input_ranges)) {}
ExtraInfo(std::vector<std::vector<int64_t>> fixed_sizes);
ExtraInfo(std::vector<c10::ArrayRef<int64_t>> fixed_sizes);
// Defaults should reflect TensorRT defaults for BuilderConfig
std::vector<InputRange> input_ranges;
DataType op_precision = DataType::kFloat;
bool refit = false;
bool debug = false;
bool strict_types = false;
bool allow_gpu_fallback = true;
DeviceType device = DeviceType::kGPU;
EngineCapability capability = EngineCapability::kDEFAULT;
uint64_t num_min_timing_iters = 2;
uint64_t num_avg_timing_iters = 1;
uint64_t workspace_size = 0;
uint64_t max_batch_size = 0;
nvinfer1::IInt8Calibrator* ptq_calibrator = nullptr;
};
TRTORCH_API std::string get_build_info();
TRTORCH_API void dump_build_info();
TRTORCH_API bool CheckMethodOperatorSupport(const torch::jit::Module& module, std::string method_name);
TRTORCH_API torch::jit::Module CompileGraph(const torch::jit::Module& module, ExtraInfo info);
TRTORCH_API std::string ConvertGraphToTRTEngine(const torch::jit::Module& module, std::string method_name, ExtraInfo info);
namespace ptq {
template<typename Algorithm = nvinfer1::IInt8EntropyCalibrator2, typename DataLoader>
TRTORCH_API inline Int8Calibrator<Algorithm, DataLoader> make_int8_calibrator(DataLoader dataloader, const std::string& cache_file_path, bool use_cache) {
return Int8Calibrator<Algorithm, DataLoader>(std::move(dataloader), cache_file_path, use_cache);
}
template<typename Algorithm = nvinfer1::IInt8EntropyCalibrator2>
TRTORCH_API inline Int8CacheCalibrator<Algorithm> make_int8_cache_calibrator(const std::string& cache_file_path) {
return Int8CacheCalibrator<Algorithm>(cache_file_path);
}
} // namespace ptq
} // namespace trtorch