From 060dc6ab4432a6ff432801cad1ef7f299c286ca2 Mon Sep 17 00:00:00 2001 From: BJ Hargrave Date: Tue, 7 Nov 2023 18:09:48 +0000 Subject: [PATCH] Fix static_quantization_tutorial error in qat_model We need to use the qat variant of the fuse_modules method. After this fix, the tutorial runs to completion on a linux x86 system. Fixes https://github.com/pytorch/tutorials/issues/1269 Signed-off-by: BJ Hargrave --- advanced_source/static_quantization_tutorial.rst | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/advanced_source/static_quantization_tutorial.rst b/advanced_source/static_quantization_tutorial.rst index 6f4118079be..7b0df08362a 100644 --- a/advanced_source/static_quantization_tutorial.rst +++ b/advanced_source/static_quantization_tutorial.rst @@ -206,14 +206,15 @@ Note: this code is taken from # Fuse Conv+BN and Conv+BN+Relu modules prior to quantization # This operation does not change the numerics - def fuse_model(self): + def fuse_model(self, is_qat=False): + fuse_modules = torch.ao.quantization.fuse_modules_qat if is_qat else torch.ao.quantization.fuse_modules for m in self.modules(): if type(m) == ConvBNReLU: - torch.ao.quantization.fuse_modules(m, ['0', '1', '2'], inplace=True) + fuse_modules(m, ['0', '1', '2'], inplace=True) if type(m) == InvertedResidual: for idx in range(len(m.conv)): if type(m.conv[idx]) == nn.Conv2d: - torch.ao.quantization.fuse_modules(m.conv, [str(idx), str(idx + 1)], inplace=True) + fuse_modules(m.conv, [str(idx), str(idx + 1)], inplace=True) 2. Helper functions ------------------- @@ -533,7 +534,7 @@ We fuse modules as before .. code:: python qat_model = load_model(saved_model_dir + float_model_file) - qat_model.fuse_model() + qat_model.fuse_model(is_qat=True) optimizer = torch.optim.SGD(qat_model.parameters(), lr = 0.0001) # The old 'fbgemm' is still available but 'x86' is the recommended default.