Program Listing for File ptq.h

Return to documentation for file ( cpp/api/include/trtorch/ptq.h )

/*
 * Copyright (c) NVIDIA Corporation.
 * All rights reserved.
 *
 * This library is licensed under the BSD-style license found in the
 * LICENSE file in the root directory of this source tree.
 */
#pragma once

#include <string>
#include <vector>
#include <memory>
#include <iostream>
#include <fstream>
#include <iterator>
#include <sstream>

#include "torch/torch.h"
#include "trtorch/logging.h"
#include "NvInfer.h"

#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace nvinfer1 {
class IInt8Calibrator;
class IInt8EntropyCalibrator2;
}

namespace trtorch {
namespace ptq {
bool get_batch_impl(void* bindings[], const char* names[], int nbBindings, torch::Tensor& data);
}
}
#endif //DOXYGEN_SHOULD_SKIP_THIS

namespace trtorch {
namespace ptq {

template<typename Algorithm, typename DataLoaderUniquePtr>
class Int8Calibrator : Algorithm {
    using DataLoader = typename DataLoaderUniquePtr::element_type;
    using Batch = typename DataLoader::super::BatchType;
public:
    Int8Calibrator(DataLoaderUniquePtr dataloader, const std::string& cache_file_path, bool use_cache)
      : dataloader_(dataloader.get()), cache_file_path_(cache_file_path), use_cache_(use_cache) {
          for (auto batch : *dataloader_) {
            batched_data_.push_back(batch.data);
          }
          it_ = batched_data_.begin();
      }

    int getBatchSize() const override {
        // HACK: TRTorch only uses explict batch sizing, INT8 Calibrator does not
        // work when reporting the batch size here and having explicity batching.
        // So we just report batch size 1 (warnings will still be printed out).
        return 1;
        //return static_cast<int>(dataloader_->options().batch_size);
    }

    bool getBatch(void* bindings[], const char* names[], int nbBindings) override {
        if (it_ != batched_data_.end()) {
            auto status = get_batch_impl(bindings, names, nbBindings, *it_);
            it_ = ++it_;
            return status;
        } else {
            // Reset iterator if incase calibrator is going to be used again
            it_ = batched_data_.begin();
            return false;
        }
    }

    const void* readCalibrationCache(size_t& length) override {
        if (use_cache_) {
            std::stringstream ss;
            ss << "Reading Calibration Cache from " << cache_file_path_;
            logging::log(logging::Level::kINFO, ss.str());

            cache_.clear();
            std::ifstream input(cache_file_path_, std::ios::binary);
            input >> std::noskipws;
            if (input.good()) {
                std::copy(std::istream_iterator<char>(input), std::istream_iterator<char>(),
                    std::back_inserter(cache_));
                logging::log(logging::Level::kDEBUG, "Cache read");
            }
            length = cache_.size();
            return length ? cache_.data() : nullptr;
        }
        return nullptr;
    }

    void writeCalibrationCache(const void* cache, size_t length) override {
        std::ofstream cache_file(cache_file_path_, std::ios::binary);
        cache_file.write(reinterpret_cast<const char*>(cache), length);
        std::stringstream ss;
        ss << "Saved Calibration Cache to " << cache_file_path_;
        logging::log(logging::Level::kINFO, ss.str());
    }

    operator nvinfer1::IInt8Calibrator* () {
        return reinterpret_cast<nvinfer1::IInt8Calibrator*>(this);
    }

private:
    DataLoader* dataloader_;
    const std::string& cache_file_path_;
    size_t cache_size_ = 0;
    bool use_cache_;
    std::vector<char> cache_;
    std::vector<torch::Tensor> batched_data_;
    std::vector<torch::Tensor>::iterator it_;

};

template<typename Algorithm>
class Int8CacheCalibrator : Algorithm {
public:
    Int8CacheCalibrator(const std::string& cache_file_path)
      : cache_file_path_(cache_file_path) {}

    int getBatchSize() const override {
        // HACK: TRTorch only uses explict batch sizing, INT8 Calibrator does not
        // work when reporting the batch size here and having explicity batching.
        // So we just report batch size 1 (warnings will still be printed out).
        return 1;
    }

    bool getBatch(void* bindings[], const char* names[], int nbBindings) override {
        return false;
    }

    const void* readCalibrationCache(size_t& length) override {
        std::stringstream ss;
        ss << "Reading Calibration Cache from " << cache_file_path_;
        logging::log(logging::Level::kINFO, ss.str());

        cache_.clear();
        std::ifstream input(cache_file_path_, std::ios::binary);
        input >> std::noskipws;
        if (input.good()) {
            std::copy(std::istream_iterator<char>(input), std::istream_iterator<char>(),
                std::back_inserter(cache_));
            logging::log(logging::Level::kDEBUG, "Cache read");
        }
        length = cache_.size();
        return length ? cache_.data() : nullptr;
    }


    void writeCalibrationCache(const void* cache, size_t length) override {
        std::ofstream cache_file(cache_file_path_, std::ios::binary);
        cache_file.write(reinterpret_cast<const char*>(cache), length);
        std::stringstream ss;
        ss << "Saved Calibration Cache to " << cache_file_path_;
        logging::log(logging::Level::kINFO, ss.str());
    }

    operator nvinfer1::IInt8Calibrator* () {
        return reinterpret_cast<nvinfer1::IInt8Calibrator*>(this);
    }

private:
    const std::string& cache_file_path_;
    size_t cache_size_ = 0;
    std::vector<char> cache_;
};

template<typename Algorithm = nvinfer1::IInt8EntropyCalibrator2, typename DataLoader>
TRTORCH_API inline Int8Calibrator<Algorithm, DataLoader> make_int8_calibrator(DataLoader dataloader, const std::string& cache_file_path, bool use_cache) {
    return Int8Calibrator<Algorithm, DataLoader>(std::move(dataloader), cache_file_path, use_cache);
}

template<typename Algorithm = nvinfer1::IInt8EntropyCalibrator2>
TRTORCH_API inline Int8CacheCalibrator<Algorithm> make_int8_cache_calibrator(const std::string& cache_file_path) {
    return Int8CacheCalibrator<Algorithm>(cache_file_path);
}

} // namespace ptq
} // namespace trtorch