Skip to content

Commit cd1db55

Browse files
yushangdipytorchmergebot
authored andcommitted
Fix tensor_constant name collision in aot_export_module (pytorch#151123)
Summary: When we have an exported program that looks like this: ``` ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, b__tensor_constant0: "f32[1]", ... c_lifted_tensor_0: "i64[925]", …. , tupleized_input_0_0: "f32[10, 2139]", clone: "i64[925]" = torch.ops.aten.clone.default(c_lifted_tensor_0); c_lifted_tensor_0 = None index_select: "f32[10, 925]" = torch.ops.aten.index_select.default(tupleized_input_0_0, 1, clone); clone = None ``` The graph after `aot_export_module` could have a name collision, notice that `_tensor_constant0` arg of `clone` is different from the `_tensor_constant0` in the input module . ``` def forward(self): arg9_1: "f32[10, 2139]" _tensor_constant0: "f32[1]" = self._tensor_constant0 # this should be int64, conflicted with the original _tensor_constant0, had a clone on this constant before lifting index: "f32[10, 925]" = torch.ops.aten.index.Tensor(arg9_1, [None, _tensor_constant0]); _tensor_constant0 = None ``` This caused the `tensors used as indices must binary, int...` aoti error on PT2I dashboard because later we used `clone` as index. We had this error because we created a new `_tensor_constant0` at [here](https://github.com/pytorch/pytorch/blob/main/torch/fx/_symbolic_trace.py#L403-L412), and the new `_tensor_constant0` overrides the original `_tensor_constant0` on the input Module in `_unlift_graph`. The `arg` for `clone` is created at `create_proxy` in `proxy.py`. To fix this, we do a graph pass before we unlift the graph inputs to avoid name collision Test Plan: ``` buck run fbcode//mode/dev-nosan //caffe2/test/inductor:torchbind -- -r aot_compile_constant_folding buck2 run mode/dev-nosan caffe2/test/inductor:test_aot_inductor -- -r aoti_constant_tensor_name_collision ``` Differential Revision: D72761937 Pull Request resolved: pytorch#151123 Approved by: https://github.com/tugsbayasgalan, https://github.com/jingsh
1 parent bf92c98 commit cd1db55

File tree

2 files changed

+108
-4
lines changed

2 files changed

+108
-4
lines changed

test/inductor/test_aot_inductor.py

+42
Original file line numberDiff line numberDiff line change
@@ -1200,6 +1200,48 @@ def forward(self, x):
12001200
example_inputs = (torch.ones(4, 4, device=self.device),)
12011201
self.check_model(Foo(self.device), example_inputs)
12021202

1203+
def test_aoti_constant_tensor_name_collision(self):
1204+
class SubModule(torch.nn.Module):
1205+
def __init__(self, device):
1206+
super().__init__()
1207+
self.register_buffer(
1208+
"_tensor_constant1",
1209+
torch.ones(1, device=device, dtype=torch.float32),
1210+
persistent=True,
1211+
)
1212+
1213+
def forward(self, x):
1214+
return self.linear(x)
1215+
1216+
class Foo(torch.nn.Module):
1217+
def __init__(self, user_float_feature_idx, device):
1218+
super().__init__()
1219+
self.user_float_feature_idx = user_float_feature_idx
1220+
self.register_buffer(
1221+
"_tensor_constant0",
1222+
torch.ones(1, device=device, dtype=torch.float32),
1223+
persistent=True,
1224+
)
1225+
self.sub_mod = SubModule(device)
1226+
1227+
def forward(self, x):
1228+
return (
1229+
torch.index_select(
1230+
x, 1, torch.tensor(self.user_float_feature_idx, device=x.device)
1231+
),
1232+
self._tensor_constant0,
1233+
self.sub_mod._tensor_constant1,
1234+
)
1235+
1236+
example_inputs = (torch.ones(4, 4, device=self.device),)
1237+
user_float_feature_idx = [1]
1238+
# we have to have run_decomposition first to trigger the name collision
1239+
ep = torch.export.export(
1240+
Foo(user_float_feature_idx, self.device), example_inputs, strict=False
1241+
).run_decompositions()
1242+
gm = ep.module()
1243+
self.check_model(gm, example_inputs)
1244+
12031245
def test_large_grid(self):
12041246
if self.device != GPU_TYPE:
12051247
raise unittest.SkipTest("requires GPU")

torch/_inductor/compile_fx.py

+66-4
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from contextlib import AbstractContextManager
1717
from inspect import currentframe
1818
from itertools import count
19+
from operator import attrgetter
1920
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union
2021
from typing_extensions import Never, override, ParamSpec, Protocol, TypedDict, Unpack
2122
from unittest import mock
@@ -81,6 +82,7 @@
8182
should_use_remote_fx_graph_cache,
8283
tensor_is_aligned,
8384
)
85+
from torch._library.fake_class_registry import FakeScriptObject
8486
from torch._logging import trace_structured
8587
from torch._utils_internal import compile_time_strobelight_meta
8688
from torch.fx import GraphModule
@@ -246,11 +248,62 @@ def _warn_tf32_disabled() -> None:
246248
)
247249

248250

251+
def _resolve_name_collision(mod: GraphModule, gm: GraphModule) -> None:
252+
"""
253+
In aot_export_module (make_fx), we create get_attr nodes with name prefix
254+
"_tensor_constant" and "_torchbind_obj". See Tracer.create_arg() in
255+
torch/fx/_symbolic_trace.py
256+
257+
However, this might result in name collision if the original mod already
258+
has a different buffer with the same name.
259+
260+
We resolve this potential name collision here by changing the target name
261+
with a new number post fix.
262+
"""
263+
264+
def find_smallest_i(graph: fx.Graph, prefix: str) -> int:
265+
i = 0
266+
for node in graph.nodes:
267+
if node.op == "get_attr" and node.target.startswith(prefix):
268+
i = max(i, int(node.target.split(prefix)[-1]))
269+
return i + 1
270+
271+
for node in gm.graph.nodes:
272+
if node.op == "get_attr":
273+
target_name = node.target
274+
if not target_name.startswith(
275+
"_tensor_constant"
276+
) and not target_name.startswith("_torchbind_obj"):
277+
continue
278+
279+
if not hasattr(mod, target_name):
280+
continue
281+
gm_target = attrgetter(target_name)(gm)
282+
model_target = attrgetter(target_name)(mod)
283+
if (
284+
torch.equal(gm_target, model_target)
285+
and gm_target.dtype == model_target.dtype
286+
):
287+
continue
288+
289+
prefix = (
290+
"_tensor_constant"
291+
if target_name.startswith("_tensor_constant")
292+
else "_torchbind_obj"
293+
)
294+
new_id = find_smallest_i(gm.graph, prefix)
295+
new_target_name = f"{prefix}{new_id}"
296+
node.target = new_target_name
297+
setattr(gm, new_target_name, gm_target)
298+
299+
249300
def _unlift_graph(
250301
mod: GraphModule, gm: GraphModule, graph_signature: GraphSignature
251302
) -> GraphModule:
252303
from torch.export.unflatten import _assign_attr, _AttrKind
253304

305+
_resolve_name_collision(mod, gm)
306+
254307
state_dict: dict[str, Union[torch.nn.parameter.Parameter, torch.Tensor]] = {}
255308
for name, param in mod.named_parameters(remove_duplicate=False):
256309
state_dict[name] = param
@@ -1138,13 +1191,14 @@ def log_graph_runnable() -> str:
11381191
if aot_mode and config.aot_inductor.use_runtime_constant_folding:
11391192
# torchbind objects have name that starts with _torchbind_obj
11401193
# See caffe2/torch/fx/_symbolic_trace.py?lines=406
1141-
# We don't use node.meta["val"] because we don't typically
1142-
# attach meta["val"] for get_attr nodes.
11431194
const_gm, const_output_index = split_const_gm(
11441195
gm,
11451196
skip_folding_node_fn=lambda node: node.op == "get_attr"
11461197
and isinstance(node.target, str)
1147-
and node.target.startswith("_torchbind_obj"),
1198+
and (
1199+
node.target.startswith("_torchbind_obj")
1200+
or isinstance(node.meta.get("val", None), FakeScriptObject)
1201+
),
11481202
)
11491203

11501204
const_graph = GraphLowering(
@@ -2161,11 +2215,19 @@ def bw_compiler(
21612215
# this will go away.
21622216
for node in gm.graph.nodes:
21632217
if node.op == "get_attr" and "val" not in node.meta:
2164-
target = getattr(gm, node.target)
2218+
target = attrgetter(node.target)(gm)
21652219
if isinstance(target, torch.Tensor):
21662220
node.meta["val"] = fake_mode.from_tensor(
21672221
target, static_shapes=True
21682222
)
2223+
elif isinstance(target, torch.ScriptObject):
2224+
node.meta["val"] = (
2225+
torch._library.fake_class_registry.maybe_to_fake_obj(
2226+
fake_mode, target
2227+
)
2228+
)
2229+
elif isinstance(target, FakeScriptObject):
2230+
node.meta["val"] = target
21692231

21702232
unlifted_gm = _unlift_graph(model_, gm, graph_signature)
21712233
if "dynamo_flat_name_to_original_fqn" in model_.meta:

0 commit comments

Comments
 (0)