Skip to content

Commit 6c025c6

Browse files
Merge branch 'main' into rl-2352
2 parents fc6afea + a668406 commit 6c025c6

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

Diff for: advanced_source/static_quantization_tutorial.rst

+5-4
Original file line numberDiff line numberDiff line change
@@ -206,14 +206,15 @@ Note: this code is taken from
206206
207207
# Fuse Conv+BN and Conv+BN+Relu modules prior to quantization
208208
# This operation does not change the numerics
209-
def fuse_model(self):
209+
def fuse_model(self, is_qat=False):
210+
fuse_modules = torch.ao.quantization.fuse_modules_qat if is_qat else torch.ao.quantization.fuse_modules
210211
for m in self.modules():
211212
if type(m) == ConvBNReLU:
212-
torch.ao.quantization.fuse_modules(m, ['0', '1', '2'], inplace=True)
213+
fuse_modules(m, ['0', '1', '2'], inplace=True)
213214
if type(m) == InvertedResidual:
214215
for idx in range(len(m.conv)):
215216
if type(m.conv[idx]) == nn.Conv2d:
216-
torch.ao.quantization.fuse_modules(m.conv, [str(idx), str(idx + 1)], inplace=True)
217+
fuse_modules(m.conv, [str(idx), str(idx + 1)], inplace=True)
217218
218219
2. Helper functions
219220
-------------------
@@ -533,7 +534,7 @@ We fuse modules as before
533534
.. code:: python
534535
535536
qat_model = load_model(saved_model_dir + float_model_file)
536-
qat_model.fuse_model()
537+
qat_model.fuse_model(is_qat=True)
537538
538539
optimizer = torch.optim.SGD(qat_model.parameters(), lr = 0.0001)
539540
# The old 'fbgemm' is still available but 'x86' is the recommended default.

0 commit comments

Comments
 (0)