|
16 | 16 | from contextlib import AbstractContextManager
|
17 | 17 | from inspect import currentframe
|
18 | 18 | from itertools import count
|
| 19 | +from operator import attrgetter |
19 | 20 | from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union
|
20 | 21 | from typing_extensions import Never, override, ParamSpec, Protocol, TypedDict, Unpack
|
21 | 22 | from unittest import mock
|
|
81 | 82 | should_use_remote_fx_graph_cache,
|
82 | 83 | tensor_is_aligned,
|
83 | 84 | )
|
| 85 | +from torch._library.fake_class_registry import FakeScriptObject |
84 | 86 | from torch._logging import trace_structured
|
85 | 87 | from torch._utils_internal import compile_time_strobelight_meta
|
86 | 88 | from torch.fx import GraphModule
|
@@ -246,11 +248,62 @@ def _warn_tf32_disabled() -> None:
|
246 | 248 | )
|
247 | 249 |
|
248 | 250 |
|
| 251 | +def _resolve_name_collision(mod: GraphModule, gm: GraphModule) -> None: |
| 252 | + """ |
| 253 | + In aot_export_module (make_fx), we create get_attr nodes with name prefix |
| 254 | + "_tensor_constant" and "_torchbind_obj". See Tracer.create_arg() in |
| 255 | + torch/fx/_symbolic_trace.py |
| 256 | +
|
| 257 | + However, this might result in name collision if the original mod already |
| 258 | + has a different buffer with the same name. |
| 259 | +
|
| 260 | + We resolve this potential name collision here by changing the target name |
| 261 | + with a new number post fix. |
| 262 | + """ |
| 263 | + |
| 264 | + def find_smallest_i(graph: fx.Graph, prefix: str) -> int: |
| 265 | + i = 0 |
| 266 | + for node in graph.nodes: |
| 267 | + if node.op == "get_attr" and node.target.startswith(prefix): |
| 268 | + i = max(i, int(node.target.split(prefix)[-1])) |
| 269 | + return i + 1 |
| 270 | + |
| 271 | + for node in gm.graph.nodes: |
| 272 | + if node.op == "get_attr": |
| 273 | + target_name = node.target |
| 274 | + if not target_name.startswith( |
| 275 | + "_tensor_constant" |
| 276 | + ) and not target_name.startswith("_torchbind_obj"): |
| 277 | + continue |
| 278 | + |
| 279 | + if not hasattr(mod, target_name): |
| 280 | + continue |
| 281 | + gm_target = attrgetter(target_name)(gm) |
| 282 | + model_target = attrgetter(target_name)(mod) |
| 283 | + if ( |
| 284 | + torch.equal(gm_target, model_target) |
| 285 | + and gm_target.dtype == model_target.dtype |
| 286 | + ): |
| 287 | + continue |
| 288 | + |
| 289 | + prefix = ( |
| 290 | + "_tensor_constant" |
| 291 | + if target_name.startswith("_tensor_constant") |
| 292 | + else "_torchbind_obj" |
| 293 | + ) |
| 294 | + new_id = find_smallest_i(gm.graph, prefix) |
| 295 | + new_target_name = f"{prefix}{new_id}" |
| 296 | + node.target = new_target_name |
| 297 | + setattr(gm, new_target_name, gm_target) |
| 298 | + |
| 299 | + |
249 | 300 | def _unlift_graph(
|
250 | 301 | mod: GraphModule, gm: GraphModule, graph_signature: GraphSignature
|
251 | 302 | ) -> GraphModule:
|
252 | 303 | from torch.export.unflatten import _assign_attr, _AttrKind
|
253 | 304 |
|
| 305 | + _resolve_name_collision(mod, gm) |
| 306 | + |
254 | 307 | state_dict: dict[str, Union[torch.nn.parameter.Parameter, torch.Tensor]] = {}
|
255 | 308 | for name, param in mod.named_parameters(remove_duplicate=False):
|
256 | 309 | state_dict[name] = param
|
@@ -1138,13 +1191,14 @@ def log_graph_runnable() -> str:
|
1138 | 1191 | if aot_mode and config.aot_inductor.use_runtime_constant_folding:
|
1139 | 1192 | # torchbind objects have name that starts with _torchbind_obj
|
1140 | 1193 | # See caffe2/torch/fx/_symbolic_trace.py?lines=406
|
1141 |
| - # We don't use node.meta["val"] because we don't typically |
1142 |
| - # attach meta["val"] for get_attr nodes. |
1143 | 1194 | const_gm, const_output_index = split_const_gm(
|
1144 | 1195 | gm,
|
1145 | 1196 | skip_folding_node_fn=lambda node: node.op == "get_attr"
|
1146 | 1197 | and isinstance(node.target, str)
|
1147 |
| - and node.target.startswith("_torchbind_obj"), |
| 1198 | + and ( |
| 1199 | + node.target.startswith("_torchbind_obj") |
| 1200 | + or isinstance(node.meta.get("val", None), FakeScriptObject) |
| 1201 | + ), |
1148 | 1202 | )
|
1149 | 1203 |
|
1150 | 1204 | const_graph = GraphLowering(
|
@@ -2161,11 +2215,19 @@ def bw_compiler(
|
2161 | 2215 | # this will go away.
|
2162 | 2216 | for node in gm.graph.nodes:
|
2163 | 2217 | if node.op == "get_attr" and "val" not in node.meta:
|
2164 |
| - target = getattr(gm, node.target) |
| 2218 | + target = attrgetter(node.target)(gm) |
2165 | 2219 | if isinstance(target, torch.Tensor):
|
2166 | 2220 | node.meta["val"] = fake_mode.from_tensor(
|
2167 | 2221 | target, static_shapes=True
|
2168 | 2222 | )
|
| 2223 | + elif isinstance(target, torch.ScriptObject): |
| 2224 | + node.meta["val"] = ( |
| 2225 | + torch._library.fake_class_registry.maybe_to_fake_obj( |
| 2226 | + fake_mode, target |
| 2227 | + ) |
| 2228 | + ) |
| 2229 | + elif isinstance(target, FakeScriptObject): |
| 2230 | + node.meta["val"] = target |
2169 | 2231 |
|
2170 | 2232 | unlifted_gm = _unlift_graph(model_, gm, graph_signature)
|
2171 | 2233 | if "dynamo_flat_name_to_original_fqn" in model_.meta:
|
|
0 commit comments