Program Listing for File trtorch.h

Return to documentation for file ( cpp/api/include/trtorch/trtorch.h )

/*
 * Copyright (c) NVIDIA Corporation.
 * All rights reserved.
 *
 * This library is licensed under the BSD-style license found in the
 * LICENSE file in the root directory of this source tree.
 */

#pragma once

#include <string>
#include <vector>
#include <memory>

// Just include the .h?
#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace torch {
namespace jit {
struct Graph;
struct Module;
} // namespace jit
} // namespace torch

namespace c10 {
enum class DeviceType : int16_t;
enum class ScalarType : int8_t;
template <class>
class ArrayRef;
}

namespace nvinfer1 {
class IInt8Calibrator;
}
#endif //DOXYGEN_SHOULD_SKIP_THIS

#include "trtorch/macros.h"
namespace trtorch {
struct TRTORCH_API CompileSpec {
    struct TRTORCH_API InputRange {
        std::vector<int64_t> min;
        std::vector<int64_t> opt;
        std::vector<int64_t> max;
        InputRange(std::vector<int64_t> opt);
        InputRange(c10::ArrayRef<int64_t> opt);
        InputRange(std::vector<int64_t> min, std::vector<int64_t> opt, std::vector<int64_t> max);
        InputRange(c10::ArrayRef<int64_t> min, c10::ArrayRef<int64_t> opt, c10::ArrayRef<int64_t> max);
    };

    class TRTORCH_API DataType {
    public:
        enum Value : int8_t {
            kFloat,
            kHalf,
            kChar,
        };

        DataType() = default;
        constexpr DataType(Value t) : value(t) {}
        DataType(c10::ScalarType t);
        operator Value() const  { return value; }
        explicit operator bool() = delete;
        constexpr bool operator==(DataType other) const { return value == other.value; }
        constexpr bool operator==(DataType::Value other) const { return value == other; }
        constexpr bool operator!=(DataType other) const { return value != other.value; }
        constexpr bool operator!=(DataType::Value other) const { return value != other; }
    private:
        Value value;
    };

    class DeviceType {
    public:
        enum Value : int8_t {
            kGPU,
            kDLA,
        };

        DeviceType() = default;
        constexpr DeviceType(Value t) : value(t) {}
        DeviceType(c10::DeviceType t);
        operator Value() const { return value; }
        explicit operator bool() = delete;
        constexpr bool operator==(DeviceType other) const { return value == other.value; }
        constexpr bool operator!=(DeviceType other) const { return value != other.value; }
    private:
        Value value;
    };

    enum class EngineCapability : int8_t {
        kDEFAULT,
        kSAFE_GPU,
        kSAFE_DLA,
    };

    CompileSpec(std::vector<InputRange> input_ranges)
        : input_ranges(std::move(input_ranges)) {}
    CompileSpec(std::vector<std::vector<int64_t>> fixed_sizes);
    CompileSpec(std::vector<c10::ArrayRef<int64_t>> fixed_sizes);

    // Defaults should reflect TensorRT defaults for BuilderConfig

    std::vector<InputRange> input_ranges;

    DataType op_precision = DataType::kFloat;

    bool refit = false;

    bool debug = false;

    bool strict_types = false;

    bool allow_gpu_fallback = true;

    DeviceType device = DeviceType::kGPU;

    EngineCapability capability = EngineCapability::kDEFAULT;

    uint64_t num_min_timing_iters = 2;
    uint64_t num_avg_timing_iters = 1;

    uint64_t workspace_size = 0;

    uint64_t max_batch_size = 0;

    nvinfer1::IInt8Calibrator* ptq_calibrator = nullptr;
};

TRTORCH_API std::string get_build_info();


TRTORCH_API void dump_build_info();

TRTORCH_API bool CheckMethodOperatorSupport(const torch::jit::Module& module, std::string method_name);

TRTORCH_API torch::jit::Module CompileGraph(const torch::jit::Module& module, CompileSpec info);

TRTORCH_API std::string ConvertGraphToTRTEngine(const torch::jit::Module& module, std::string method_name, CompileSpec info);
} // namespace trtorch