@@ -206,14 +206,15 @@ Note: this code is taken from
206
206
207
207
# Fuse Conv+BN and Conv+BN+Relu modules prior to quantization
208
208
# This operation does not change the numerics
209
- def fuse_model (self ):
209
+ def fuse_model (self , is_qat = False ):
210
+ fuse_modules = torch.ao.quantization.fuse_modules_qat if is_qat else torch.ao.quantization.fuse_modules
210
211
for m in self .modules():
211
212
if type (m) == ConvBNReLU:
212
- torch.ao.quantization. fuse_modules(m, [' 0' , ' 1' , ' 2' ], inplace = True )
213
+ fuse_modules(m, [' 0' , ' 1' , ' 2' ], inplace = True )
213
214
if type (m) == InvertedResidual:
214
215
for idx in range (len (m.conv)):
215
216
if type (m.conv[idx]) == nn.Conv2d:
216
- torch.ao.quantization. fuse_modules(m.conv, [str (idx), str (idx + 1 )], inplace = True )
217
+ fuse_modules(m.conv, [str (idx), str (idx + 1 )], inplace = True )
217
218
218
219
2. Helper functions
219
220
-------------------
@@ -533,7 +534,7 @@ We fuse modules as before
533
534
.. code :: python
534
535
535
536
qat_model = load_model(saved_model_dir + float_model_file)
536
- qat_model.fuse_model()
537
+ qat_model.fuse_model(is_qat = True )
537
538
538
539
optimizer = torch.optim.SGD(qat_model.parameters(), lr = 0.0001 )
539
540
# The old 'fbgemm' is still available but 'x86' is the recommended default.
0 commit comments