From ae2de4591254002e511aaef0d6aba5b695187ef8 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 6 Oct 2021 16:12:47 +0100 Subject: [PATCH 1/9] Adding first implementation of the util. --- torchvision/prototype/ops/__init__.py | 0 torchvision/prototype/ops/_utils.py | 79 +++++++++++++++++++++++++++ 2 files changed, 79 insertions(+) create mode 100644 torchvision/prototype/ops/__init__.py create mode 100644 torchvision/prototype/ops/_utils.py diff --git a/torchvision/prototype/ops/__init__.py b/torchvision/prototype/ops/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/torchvision/prototype/ops/_utils.py b/torchvision/prototype/ops/_utils.py new file mode 100644 index 00000000000..9f00f00dbbe --- /dev/null +++ b/torchvision/prototype/ops/_utils.py @@ -0,0 +1,79 @@ +import copy +import operator +import torch +import warnings + +from typing import Callable, Tuple + + +# TODO: create a util to undo the change +# TODO: if shouldn't have a _regularized_shotrcut when is has a downsample + + +class RegularizedShortcut(torch.nn.Module): + def __init__(self, regularizer_layer: Callable[..., torch.nn.Module]): + super().__init__() + self._regularizer = regularizer_layer() + + def forward(self, input, result): + return input + self._regularizer(result) + + +def add_regularized_shortcut( + model: torch.nn.Module, + block_types: Tuple[type, ...], + regularizer_layer: Callable[..., torch.nn.Module], + inplace: bool = True +) -> torch.nn.Module: + if not inplace: + model = copy.deepcopy(model) + + ATTR_NAME = "_regularized_shotrcut" + tracer = torch.fx.Tracer() + changed = False + for m in model.modules(): + if isinstance(m, block_types): + # Add the Layer directly on submodule prior tracing + # workaround due to https://github.com/pytorch/pytorch/issues/66197 + m.add_module(ATTR_NAME, RegularizedShortcut(regularizer_layer)) + + graph = tracer.trace(m) + patterns = {operator.add, torch.add, "add"} + + input = None + for node in graph.nodes: + if node.op == 'call_function': + if node.target in patterns and len(node.args) == 2 and input in node.args: + with graph.inserting_after(node): + # Always put the shortcut value first + args = node.args if node.args[0] == input else node.args[::-1] + node.replace_all_uses_with(graph.call_module(ATTR_NAME, args)) + graph.erase_node(node) + changed = True + break + elif node.op == "placeholder": + input = node + + graph.lint() + if not changed: + warnings.warn("No shortcut was detected. Please ensure you have provided the correct `block_types` parameter " + "for this model.") + + return model + + +if __name__ == "__main__": + from torchvision.models.resnet import resnet18, resnet50, BasicBlock, Bottleneck, load_state_dict_from_url + from torchvision.ops.stochastic_depth import StochasticDepth + from functools import partial + + regularizer_layer = partial(StochasticDepth, p=0.1, mode="row") + model = resnet50() + model = add_regularized_shortcut(model, (BasicBlock, Bottleneck), regularizer_layer) + # print(model) + out = model(torch.randn((7, 3, 224, 224))) + print(out) + + # state_dict = load_state_dict_from_url("https://download.pytorch.org/models/resnet50-0676ba61.pth") + # model.load_state_dict(state_dict) + From e0b6aa41f481caf9f26accc48ab167d878d04535 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 6 Oct 2021 17:12:09 +0100 Subject: [PATCH 2/9] Update the model graph. --- torchvision/prototype/ops/_utils.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/torchvision/prototype/ops/_utils.py b/torchvision/prototype/ops/_utils.py index 9f00f00dbbe..c20e72fad4c 100644 --- a/torchvision/prototype/ops/_utils.py +++ b/torchvision/prototype/ops/_utils.py @@ -3,6 +3,7 @@ import torch import warnings +from torch import fx from typing import Callable, Tuple @@ -29,9 +30,9 @@ def add_regularized_shortcut( model = copy.deepcopy(model) ATTR_NAME = "_regularized_shotrcut" - tracer = torch.fx.Tracer() - changed = False - for m in model.modules(): + tracer = fx.Tracer() + modifications = {} + for name, m in model.named_modules(): if isinstance(m, block_types): # Add the Layer directly on submodule prior tracing # workaround due to https://github.com/pytorch/pytorch/issues/66197 @@ -49,13 +50,20 @@ def add_regularized_shortcut( args = node.args if node.args[0] == input else node.args[::-1] node.replace_all_uses_with(graph.call_module(ATTR_NAME, args)) graph.erase_node(node) - changed = True - break + graph.lint() + modifications[name] = graph elif node.op == "placeholder": input = node - graph.lint() - if not changed: + if modifications: + # Update the model by overwriting its modules + for name, graph in modifications.items(): + parent_name, child_name = name.rsplit(".", 1) + parent = model.get_submodule(parent_name) + previous_child = parent.get_submodule(child_name) + new_child = fx.GraphModule(previous_child, graph) + parent.register_module(child_name, new_child) + else: warnings.warn("No shortcut was detected. Please ensure you have provided the correct `block_types` parameter " "for this model.") @@ -77,3 +85,5 @@ def add_regularized_shortcut( # state_dict = load_state_dict_from_url("https://download.pytorch.org/models/resnet50-0676ba61.pth") # model.load_state_dict(state_dict) + fx.symbolic_trace(model).graph.print_tabular() + From b2faf3f73b9bc937abd2d3f96e8c6c1b0fb806e9 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 6 Oct 2021 18:13:32 +0100 Subject: [PATCH 3/9] Adding delete method. --- torchvision/prototype/ops/_utils.py | 49 +++++++++++++++++++++++------ 1 file changed, 40 insertions(+), 9 deletions(-) diff --git a/torchvision/prototype/ops/_utils.py b/torchvision/prototype/ops/_utils.py index c20e72fad4c..3148df07f3c 100644 --- a/torchvision/prototype/ops/_utils.py +++ b/torchvision/prototype/ops/_utils.py @@ -4,11 +4,11 @@ import warnings from torch import fx +from torchvision.models.feature_extraction import LeafModuleAwareTracer from typing import Callable, Tuple -# TODO: create a util to undo the change -# TODO: if shouldn't have a _regularized_shotrcut when is has a downsample +_MODULE_NAME = "_regularized_shotrcut" class RegularizedShortcut(torch.nn.Module): @@ -29,14 +29,13 @@ def add_regularized_shortcut( if not inplace: model = copy.deepcopy(model) - ATTR_NAME = "_regularized_shotrcut" tracer = fx.Tracer() modifications = {} for name, m in model.named_modules(): if isinstance(m, block_types): # Add the Layer directly on submodule prior tracing # workaround due to https://github.com/pytorch/pytorch/issues/66197 - m.add_module(ATTR_NAME, RegularizedShortcut(regularizer_layer)) + m.add_module(_MODULE_NAME, RegularizedShortcut(regularizer_layer)) graph = tracer.trace(m) patterns = {operator.add, torch.add, "add"} @@ -48,7 +47,7 @@ def add_regularized_shortcut( with graph.inserting_after(node): # Always put the shortcut value first args = node.args if node.args[0] == input else node.args[::-1] - node.replace_all_uses_with(graph.call_module(ATTR_NAME, args)) + node.replace_all_uses_with(graph.call_module(_MODULE_NAME, args)) graph.erase_node(node) graph.lint() modifications[name] = graph @@ -70,20 +69,52 @@ def add_regularized_shortcut( return model +def del_regularized_shortcut( + model: torch.nn.Module, + inplace: bool = True +) -> torch.nn.Module: + if not inplace: + model = copy.deepcopy(model) + + tracer = LeafModuleAwareTracer(leaf_modules=[RegularizedShortcut]) + graph = tracer.trace(model) + for node in graph.nodes: + if node.op == "call_module" and node.target.rsplit(".", 1)[-1] == _MODULE_NAME: + with graph.inserting_before(node): + new_node = graph.call_function(operator.add, node.args) + node.replace_all_uses_with(new_node) + graph.erase_node(node) + + return fx.GraphModule(model, graph) + + if __name__ == "__main__": from torchvision.models.resnet import resnet18, resnet50, BasicBlock, Bottleneck, load_state_dict_from_url from torchvision.ops.stochastic_depth import StochasticDepth from functools import partial - regularizer_layer = partial(StochasticDepth, p=0.1, mode="row") + out = [] + batch = torch.randn((7, 3, 224, 224)) + + print("Before") model = resnet50() + out.append(model(batch)) + fx.symbolic_trace(model).graph.print_tabular() + + print("After addition") + regularizer_layer = partial(StochasticDepth, p=0.0, mode="row") model = add_regularized_shortcut(model, (BasicBlock, Bottleneck), regularizer_layer) + fx.symbolic_trace(model).graph.print_tabular() # print(model) - out = model(torch.randn((7, 3, 224, 224))) - print(out) - + out.append(model(batch)) # state_dict = load_state_dict_from_url("https://download.pytorch.org/models/resnet50-0676ba61.pth") # model.load_state_dict(state_dict) + print("After deletion") + model = del_regularized_shortcut(model) fx.symbolic_trace(model).graph.print_tabular() + out.append(model(batch)) + + for v in out[1:]: + torch.testing.assert_allclose(out[0], v) From 74b005c767051749630f79e59b50496aff24ec12 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 6 Oct 2021 18:21:15 +0100 Subject: [PATCH 4/9] Fixing linter. --- torchvision/prototype/ops/_utils.py | 31 ++++++++++++++--------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/torchvision/prototype/ops/_utils.py b/torchvision/prototype/ops/_utils.py index 3148df07f3c..f2576af12d5 100644 --- a/torchvision/prototype/ops/_utils.py +++ b/torchvision/prototype/ops/_utils.py @@ -1,11 +1,11 @@ import copy import operator -import torch import warnings +from typing import Callable, Tuple +import torch from torch import fx from torchvision.models.feature_extraction import LeafModuleAwareTracer -from typing import Callable, Tuple _MODULE_NAME = "_regularized_shotrcut" @@ -21,10 +21,10 @@ def forward(self, input, result): def add_regularized_shortcut( - model: torch.nn.Module, - block_types: Tuple[type, ...], - regularizer_layer: Callable[..., torch.nn.Module], - inplace: bool = True + model: torch.nn.Module, + block_types: Tuple[type, ...], + regularizer_layer: Callable[..., torch.nn.Module], + inplace: bool = True, ) -> torch.nn.Module: if not inplace: model = copy.deepcopy(model) @@ -42,7 +42,7 @@ def add_regularized_shortcut( input = None for node in graph.nodes: - if node.op == 'call_function': + if node.op == "call_function": if node.target in patterns and len(node.args) == 2 and input in node.args: with graph.inserting_after(node): # Always put the shortcut value first @@ -63,16 +63,15 @@ def add_regularized_shortcut( new_child = fx.GraphModule(previous_child, graph) parent.register_module(child_name, new_child) else: - warnings.warn("No shortcut was detected. Please ensure you have provided the correct `block_types` parameter " - "for this model.") + warnings.warn( + "No shortcut was detected. Please ensure you have provided the correct `block_types` parameter " + "for this model." + ) return model -def del_regularized_shortcut( - model: torch.nn.Module, - inplace: bool = True -) -> torch.nn.Module: +def del_regularized_shortcut(model: torch.nn.Module, inplace: bool = True) -> torch.nn.Module: if not inplace: model = copy.deepcopy(model) @@ -89,10 +88,11 @@ def del_regularized_shortcut( if __name__ == "__main__": - from torchvision.models.resnet import resnet18, resnet50, BasicBlock, Bottleneck, load_state_dict_from_url - from torchvision.ops.stochastic_depth import StochasticDepth from functools import partial + from torchvision.models.resnet import resnet50, BasicBlock, Bottleneck + from torchvision.ops.stochastic_depth import StochasticDepth + out = [] batch = torch.randn((7, 3, 224, 224)) @@ -117,4 +117,3 @@ def del_regularized_shortcut( for v in out[1:]: torch.testing.assert_allclose(out[0], v) - From 568445923278051b99d66dd945227c16a3ea5814 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 6 Oct 2021 18:34:04 +0100 Subject: [PATCH 5/9] Fixing types. --- torchvision/prototype/ops/_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/prototype/ops/_utils.py b/torchvision/prototype/ops/_utils.py index f2576af12d5..0af72fa40c9 100644 --- a/torchvision/prototype/ops/_utils.py +++ b/torchvision/prototype/ops/_utils.py @@ -1,7 +1,7 @@ import copy import operator import warnings -from typing import Callable, Tuple +from typing import Callable, Tuple, Union import torch from torch import fx @@ -22,7 +22,7 @@ def forward(self, input, result): def add_regularized_shortcut( model: torch.nn.Module, - block_types: Tuple[type, ...], + block_types: Union[type, Tuple[type, ...]], regularizer_layer: Callable[..., torch.nn.Module], inplace: bool = True, ) -> torch.nn.Module: From 3ecf8aa7bdf84bc194cf0dd404619e6c660ca2f4 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 6 Oct 2021 20:10:24 +0100 Subject: [PATCH 6/9] Restoring break and moving lint. --- torchvision/prototype/ops/_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchvision/prototype/ops/_utils.py b/torchvision/prototype/ops/_utils.py index 0af72fa40c9..70fc26092e6 100644 --- a/torchvision/prototype/ops/_utils.py +++ b/torchvision/prototype/ops/_utils.py @@ -49,14 +49,15 @@ def add_regularized_shortcut( args = node.args if node.args[0] == input else node.args[::-1] node.replace_all_uses_with(graph.call_module(_MODULE_NAME, args)) graph.erase_node(node) - graph.lint() modifications[name] = graph + break elif node.op == "placeholder": input = node if modifications: # Update the model by overwriting its modules for name, graph in modifications.items(): + graph.lint() parent_name, child_name = name.rsplit(".", 1) parent = model.get_submodule(parent_name) previous_child = parent.get_submodule(child_name) From 51bf33fe4713df32b59606384355ab7bd9505934 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 8 Oct 2021 17:15:16 +0100 Subject: [PATCH 7/9] Allow delete to remove custom ops. --- torchvision/prototype/ops/_utils.py | 44 ++++++++++++++++++++--------- 1 file changed, 31 insertions(+), 13 deletions(-) diff --git a/torchvision/prototype/ops/_utils.py b/torchvision/prototype/ops/_utils.py index 70fc26092e6..428cf17f719 100644 --- a/torchvision/prototype/ops/_utils.py +++ b/torchvision/prototype/ops/_utils.py @@ -1,14 +1,14 @@ import copy import operator import warnings -from typing import Callable, Tuple, Union +from typing import Callable, Optional, Tuple, Union import torch from torch import fx from torchvision.models.feature_extraction import LeafModuleAwareTracer -_MODULE_NAME = "_regularized_shotrcut" +# TODO: Investigate what happens in the scenario of y = x + f1(x) + f2(x). class RegularizedShortcut(torch.nn.Module): @@ -29,13 +29,14 @@ def add_regularized_shortcut( if not inplace: model = copy.deepcopy(model) + reg_name = RegularizedShortcut.__name__.lower() tracer = fx.Tracer() modifications = {} for name, m in model.named_modules(): if isinstance(m, block_types): # Add the Layer directly on submodule prior tracing # workaround due to https://github.com/pytorch/pytorch/issues/66197 - m.add_module(_MODULE_NAME, RegularizedShortcut(regularizer_layer)) + m.add_module(reg_name, RegularizedShortcut(regularizer_layer)) graph = tracer.trace(m) patterns = {operator.add, torch.add, "add"} @@ -47,7 +48,7 @@ def add_regularized_shortcut( with graph.inserting_after(node): # Always put the shortcut value first args = node.args if node.args[0] == input else node.args[::-1] - node.replace_all_uses_with(graph.call_module(_MODULE_NAME, args)) + node.replace_all_uses_with(graph.call_module(reg_name, args)) graph.erase_node(node) modifications[name] = graph break @@ -72,19 +73,33 @@ def add_regularized_shortcut( return model -def del_regularized_shortcut(model: torch.nn.Module, inplace: bool = True) -> torch.nn.Module: +def del_regularized_shortcut( + model: torch.nn.Module, + block_types: Union[type, Tuple[type, ...]] = RegularizedShortcut, + op: Optional[Callable] = operator.add, + inplace: bool = True, +) -> torch.nn.Module: + if isinstance(block_types, type): + block_types = (block_types,) if not inplace: model = copy.deepcopy(model) - tracer = LeafModuleAwareTracer(leaf_modules=[RegularizedShortcut]) + tracer = LeafModuleAwareTracer(leaf_modules=block_types) graph = tracer.trace(model) for node in graph.nodes: - if node.op == "call_module" and node.target.rsplit(".", 1)[-1] == _MODULE_NAME: - with graph.inserting_before(node): - new_node = graph.call_function(operator.add, node.args) - node.replace_all_uses_with(new_node) + if node.op == "call_module" and issubclass(model.get_submodule(node.target).__class__, block_types): + if op is not None: + with graph.inserting_before(node): + new_node = graph.call_function(op, node.args) + node.replace_all_uses_with(new_node) + else: + if len(node.args) == 1: + node.replace_all_uses_with(node.prev) + else: + raise ValueError("Can't eliminate an operator that receives more than 1 arguments.") graph.erase_node(node) + # BUG: When we reconstruct efficientnet, custom classes like MBConv are replaced with Module and lose their names. return fx.GraphModule(model, graph) @@ -99,7 +114,8 @@ def del_regularized_shortcut(model: torch.nn.Module, inplace: bool = True) -> to print("Before") model = resnet50() - out.append(model(batch)) + with torch.no_grad(): + out.append(model(batch)) fx.symbolic_trace(model).graph.print_tabular() print("After addition") @@ -107,14 +123,16 @@ def del_regularized_shortcut(model: torch.nn.Module, inplace: bool = True) -> to model = add_regularized_shortcut(model, (BasicBlock, Bottleneck), regularizer_layer) fx.symbolic_trace(model).graph.print_tabular() # print(model) - out.append(model(batch)) + with torch.no_grad(): + out.append(model(batch)) # state_dict = load_state_dict_from_url("https://download.pytorch.org/models/resnet50-0676ba61.pth") # model.load_state_dict(state_dict) print("After deletion") model = del_regularized_shortcut(model) fx.symbolic_trace(model).graph.print_tabular() - out.append(model(batch)) + with torch.no_grad(): + out.append(model(batch)) for v in out[1:]: torch.testing.assert_allclose(out[0], v) From c6fbe63153f6a3f50a43bd6336c0e6e5e0048dce Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 8 Oct 2021 20:26:35 +0100 Subject: [PATCH 8/9] Minor refactoring --- torchvision/prototype/ops/_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchvision/prototype/ops/_utils.py b/torchvision/prototype/ops/_utils.py index 428cf17f719..b9f23f36340 100644 --- a/torchvision/prototype/ops/_utils.py +++ b/torchvision/prototype/ops/_utils.py @@ -87,7 +87,9 @@ def del_regularized_shortcut( tracer = LeafModuleAwareTracer(leaf_modules=block_types) graph = tracer.trace(model) for node in graph.nodes: - if node.op == "call_module" and issubclass(model.get_submodule(node.target).__class__, block_types): + # The isinstance() won't work if the model has already been traced before because it loses + # the class info of submodules. See https://github.com/pytorch/pytorch/issues/66335 + if node.op == "call_module" and isinstance(model.get_submodule(node.target), block_types): if op is not None: with graph.inserting_before(node): new_node = graph.call_function(op, node.args) @@ -99,7 +101,6 @@ def del_regularized_shortcut( raise ValueError("Can't eliminate an operator that receives more than 1 arguments.") graph.erase_node(node) - # BUG: When we reconstruct efficientnet, custom classes like MBConv are replaced with Module and lose their names. return fx.GraphModule(model, graph) From da3a5e3a63fff61e5f9cb27227dd54f4e53da5b0 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 11 Oct 2021 16:03:47 +0100 Subject: [PATCH 9/9] Pass model names. --- torchvision/prototype/ops/_utils.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/torchvision/prototype/ops/_utils.py b/torchvision/prototype/ops/_utils.py index b9f23f36340..037d90fd304 100644 --- a/torchvision/prototype/ops/_utils.py +++ b/torchvision/prototype/ops/_utils.py @@ -45,6 +45,7 @@ def add_regularized_shortcut( for node in graph.nodes: if node.op == "call_function": if node.target in patterns and len(node.args) == 2 and input in node.args: + # TODO: ensure the arg2 has "input" as its ancestor with graph.inserting_after(node): # Always put the shortcut value first args = node.args if node.args[0] == input else node.args[::-1] @@ -62,7 +63,7 @@ def add_regularized_shortcut( parent_name, child_name = name.rsplit(".", 1) parent = model.get_submodule(parent_name) previous_child = parent.get_submodule(child_name) - new_child = fx.GraphModule(previous_child, graph) + new_child = fx.GraphModule(previous_child, graph, previous_child.__class__.__name__) parent.register_module(child_name, new_child) else: warnings.warn( @@ -101,7 +102,7 @@ def del_regularized_shortcut( raise ValueError("Can't eliminate an operator that receives more than 1 arguments.") graph.erase_node(node) - return fx.GraphModule(model, graph) + return fx.GraphModule(model, graph, model.__class__.__name__) if __name__ == "__main__": @@ -126,8 +127,6 @@ def del_regularized_shortcut( # print(model) with torch.no_grad(): out.append(model(batch)) - # state_dict = load_state_dict_from_url("https://download.pytorch.org/models/resnet50-0676ba61.pth") - # model.load_state_dict(state_dict) print("After deletion") model = del_regularized_shortcut(model)