Skip to content

Commit b244423

Browse files
committed
fix: Minor fixes to qat scripts
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent b7f6d8a commit b244423

File tree

3 files changed

+3
-3
lines changed

3 files changed

+3
-3
lines changed

Diff for: examples/int8/training/vgg16/main.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def main():
124124

125125
print("Test Loss: {:.5f} Test Acc: {:.2f}%".format(test_loss, 100 * test_acc))
126126

127-
if epoch % 10 == 9:
127+
if epoch % 10 == 9 or epoch==args.epochs-1:
128128
save_checkpoint(
129129
{
130130
'epoch': epoch + 1,

Diff for: examples/int8/training/vgg16/train_qat.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def main():
183183

184184
crit = nn.CrossEntropyLoss()
185185
opt = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
186-
import pdb; pdb.set_trace()
186+
187187
if args.start_from != 0:
188188
ckpt_file = args.ckpt_dir + '/ckpt_epoch' + str(args.start_from) + '.pth'
189189
print('Loading from checkpoint {}'.format(ckpt_file))

Diff for: py/trtorch/csrc/tensorrt_backend.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ c10::impl::GenericDict TensorRTBackend::compile(c10::IValue mod_val, c10::impl::
2727
mod = core::lowering::LowerModule(mod);
2828

2929
auto spec = c10::impl::toTypedDict<std::string, at::IValue>(method_compile_spec);
30-
lowering::LowerInfo lower_info;
30+
core::lowering::LowerInfo lower_info;
3131
for (auto it = spec.begin(), end = spec.end(); it != end; ++it) {
3232
const auto& method_name = it->key();
3333
auto method = mod.get_method(method_name);

0 commit comments

Comments
 (0)