Program Listing for File ptq.h ¶
↰
Return to documentation for file
(
cpp/api/include/trtorch/ptq.h
)
#pragma once
#include <string>
#include <vector>
#include <memory>
#include <iostream>
#include <sstream>
#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace nvinfer1 {
class IInt8Calibrator;
class IInt8EntropyCalibrator2;
}
namespace torch {
namespace data {
template<typename Example>
class Iterator;
}
}
#endif //DOXYGEN_SHOULD_SKIP_THIS
namespace trtorch {
namespace ptq {
template<typename Algorithm, typename DataLoaderUniquePtr>
class Int8Calibrator : Algorithm {
using DataLoader = typename DataLoaderUniquePtr::element_type;
using Batch = typename DataLoader::super::BatchType;
public:
Int8Calibrator(DataLoaderUniquePtr dataloader, const std::string& cache_file_path, bool use_cache)
: dataloader_(dataloader.get()), it_(dataloader_->end()), cache_file_path_(cache_file_path), use_cache_(use_cache) {}
int getBatchSize() const override {
// HACK: TRTorch only uses explict batch sizing, INT8 Calibrator does not
// work when reporting the batch size here and having explicity batching.
// So we just report batch size 1 (warnings will still be printed out).
return 1;
//return static_cast<int>(dataloader_->options().batch_size);
}
bool getBatch(void* bindings[], const char* names[], int nbBindings) override {
// HACK: doesnt seem like the first try in the initializer list works
if (! it_created_) {
it_ = dataloader_->begin();
it_created_ = true;
}
if (it_ == dataloader_->end()) {
return false;
}
auto batch = *it_;
for (int i = 0; i < nbBindings; i++) {
auto data = batch.data;
data = data.to(at::kCUDA).contiguous();
bindings[i] = data.data_ptr();
}
it_ = ++it_;
return true;
}
const void* readCalibrationCache(size_t& length) override {
if (use_cache_) {
std::stringstream ss;
ss << "Reading Calibration Cache from " << cache_file_path_;
logging::log(logging::Level::kINFO, ss.str());
cache_.clear();
std::ifstream cache_file(cache_file_path_, std::ios::binary);
cache_file >> std::noskipws;
if (cache_file.good()) {
std::copy(std::istream_iterator<char>(cache_file),
std::istream_iterator<char>(),
std::back_inserter(cache_));
ss << "Cache read";
logging::log(logging::Level::kDEBUG, ss.str());
}
cache_size_ = cache_.size();
return cache_size_ ? cache_.data() : nullptr;
}
return nullptr;
}
void writeCalibrationCache(const void* cache, size_t length) override {
std::ofstream cache_file(cache_file_path_, std::ios::binary);
cache_file.write(reinterpret_cast<const char*>(cache), length);
std::stringstream ss;
ss << "Saved Calibration Cache to " << cache_file_path_;
logging::log(logging::Level::kINFO, ss.str());
}
operator nvinfer1::IInt8Calibrator* () {
return reinterpret_cast<nvinfer1::IInt8Calibrator*>(this);
}
private:
DataLoader* dataloader_;
torch::data::Iterator<Batch> it_;
const std::string& cache_file_path_;
size_t cache_size_ = 0;
bool use_cache_;
std::vector<char> cache_;
bool it_created_ = false;
};
template<typename Algorithm>
class Int8CacheCalibrator : Algorithm {
public:
Int8CacheCalibrator(const std::string& cache_file_path)
: cache_file_path_(cache_file_path) {}
int getBatchSize() const override {
// HACK: TRTorch only uses explict batch sizing, INT8 Calibrator does not
// work when reporting the batch size here and having explicity batching.
// So we just report batch size 1 (warnings will still be printed out).
return 1;
}
bool getBatch(void* bindings[], const char* names[], int nbBindings) override {
return false;
}
const void* readCalibrationCache(size_t& length) override {
std::stringstream ss;
ss << "Reading Calibration Cache from " << cache_file_path_;
logging::log(logging::Level::kINFO, ss.str());
cache_.clear();
std::ifstream cache_file;
cache_file.open(cache_file_path_, std::ios::in | std::ios::binary);
cache_file.unsetf(std::ios::skipws);
cache_file.seekg(0, std::ios::beg);
cache_.reserve(cache_file.tellg());
cache_file.seekg(0, std::ios::beg);
if (cache_file.good()) {
std::cout << "Trying to read cache" << std::endl;
std::copy(std::istreambuf_iterator<char>(cache_file),
std::istreambuf_iterator<char>(),
std::back_inserter(cache_));
ss << "Cache read";
logging::log(logging::Level::kDEBUG, ss.str());
}
cache_size_ = cache_.size();
return cache_size_ ? cache_.data() : nullptr;
}
void writeCalibrationCache(const void* cache, size_t length) override {
std::ofstream cache_file(cache_file_path_, std::ios::binary);
cache_file.write(reinterpret_cast<const char*>(cache), length);
std::stringstream ss;
ss << "Saved Calibration Cache to " << cache_file_path_;
logging::log(logging::Level::kINFO, ss.str());
}
operator nvinfer1::IInt8Calibrator* () {
return reinterpret_cast<nvinfer1::IInt8Calibrator*>(this);
}
private:
const std::string& cache_file_path_;
size_t cache_size_ = 0;
std::vector<char> cache_;
};
} // namespace ptq
} // namespace trtorch