Skip to content

Commit 8d22bdd

Browse files
committed
fix(//cpp/api): Better inital condition for the dataloader iterator to
address datarace issue Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 5c0d737 commit 8d22bdd

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

Diff for: cpp/api/include/trtorch/ptq.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class Int8Calibrator : Algorithm {
2727
using Batch = typename DataLoader::super::BatchType;
2828
public:
2929
Int8Calibrator(DataLoaderUniquePtr dataloader, const std::string& cache_file_path, bool use_cache)
30-
: dataloader_(dataloader.get()), it_(dataloader_->begin()), cache_file_path_(cache_file_path), use_cache_(use_cache) {}
30+
: dataloader_(dataloader.get()), it_(dataloader_->end()), cache_file_path_(cache_file_path), use_cache_(use_cache) {}
3131

3232
int getBatchSize() const override {
3333
// HACK: TRTorch only uses explict batch sizing, INT8 Calibrator does not

Diff for: cpp/ptq/main.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,13 @@ int main(int argc, const char* argv[]) {
121121
auto execution_timer = timers::PreciseCPUTimer();
122122
auto images = (*(*eval_dataloader).begin()).data.to(torch::kCUDA);
123123

124+
execution_timer.start();
125+
mod.forward({images});
126+
execution_timer.stop();
127+
std::cout << "Latency of JIT model FP32 (Batch Size 32): " << execution_timer.milliseconds() << "ms" << std::endl;
128+
129+
execution_timer.reset();
130+
124131
execution_timer.start();
125132
trt_mod.forward({images});
126133
execution_timer.stop();

0 commit comments

Comments
 (0)