diff --git a/torchsummary/torchsummary.py b/torchsummary/torchsummary.py index 1ed065f..359435d 100644 --- a/torchsummary/torchsummary.py +++ b/torchsummary/torchsummary.py @@ -45,10 +45,7 @@ def hook(module, input, output): params += torch.prod(torch.LongTensor(list(module.bias.size()))) summary[m_key]["nb_params"] = params - if ( - not isinstance(module, nn.Sequential) - and not isinstance(module, nn.ModuleList) - ): + if not any(module.children()): hooks.append(module.register_forward_hook(hook)) # multiple inputs to the network