Skip to content

Commit 65dbd5c

Browse files
Revert "[Inductor] Inplacing with Donated Buffer (pytorch#140113)"
This reverts commit eecc8e3. Reverted pytorch#140113 on behalf of https://github.com/BoyuanFeng due to break test_donated_buffer_inplace internally since donated_buffer = False if is_fbcode() else True ([comment](pytorch#140113 (comment)))
1 parent 869d629 commit 65dbd5c

File tree

7 files changed

+17
-180
lines changed

7 files changed

+17
-180
lines changed

test/inductor/test_torchinductor.py

-62
Original file line numberDiff line numberDiff line change
@@ -5198,31 +5198,6 @@ def test_layer_norm(self):
51985198
if self.device != "cpu":
51995199
assertGeneratedKernelCountEqual(self, 1)
52005200

5201-
def test_matmul_layer_norm(self):
5202-
batch_size = 32
5203-
seq_length = 50
5204-
hidden_size = 256
5205-
5206-
inp = torch.randn(
5207-
batch_size,
5208-
seq_length,
5209-
hidden_size,
5210-
requires_grad=True,
5211-
device=self.device,
5212-
)
5213-
weight = torch.randn(
5214-
hidden_size, hidden_size, requires_grad=True, device=self.device
5215-
)
5216-
5217-
layer_norm = torch.nn.LayerNorm(hidden_size, device=self.device)
5218-
5219-
def foo(inp, weight):
5220-
matmul_output = inp @ weight
5221-
final_output = layer_norm(matmul_output)
5222-
return final_output
5223-
5224-
self.common(foo, (inp, weight), check_lowp=False)
5225-
52265201
def test_transpose_add(self):
52275202
def fn(a, b):
52285203
return a.t() + b
@@ -12880,43 +12855,6 @@ def fn(inp, weight):
1288012855
self.assertTrue(len(re.findall(r"in_out_ptr\d+", code)) > 0)
1288112856
self.assertEqual(fn_opt(*inps), fn(*inps))
1288212857

12883-
def test_donated_buffer_inplace(self):
12884-
batch_size = 32
12885-
seq_length = 50
12886-
hidden_size = 256
12887-
12888-
inp = torch.randn(
12889-
batch_size,
12890-
seq_length,
12891-
hidden_size,
12892-
requires_grad=True,
12893-
device=self.device,
12894-
)
12895-
weight = torch.randn(
12896-
hidden_size, hidden_size, requires_grad=True, device=self.device
12897-
)
12898-
12899-
layer_norm = torch.nn.LayerNorm(hidden_size, device=self.device)
12900-
12901-
def fn(inp, weight):
12902-
matmul_output = inp @ weight
12903-
final_output = layer_norm(matmul_output)
12904-
return final_output
12905-
12906-
fn_opt = torch.compile(fn)
12907-
12908-
def wrapper(inp, weight):
12909-
return fn_opt(inp, weight).sum().backward()
12910-
12911-
_, code = run_and_get_code(wrapper, inp, weight)
12912-
12913-
if config.cpp_wrapper:
12914-
# when using cpp_wrapper, backward triton code is in code[2]
12915-
self.assertTrue("in_out_ptr" in code[2])
12916-
else:
12917-
# when not using cpp_wrapper, backward triton code is in code[1]
12918-
self.assertTrue("in_out_ptr" in code[1])
12919-
1292012858
class RNNTest(TestCase):
1292112859
device_type = GPU_TYPE
1292212860

torch/_inductor/codegen/wrapper.py

+2-11
Original file line numberDiff line numberDiff line change
@@ -2120,11 +2120,7 @@ def codegen_deferred_allocation(self, name, layout):
21202120
def codegen_allocation(self, buffer: ir.Buffer):
21212121
name = buffer.get_name()
21222122

2123-
if (
2124-
name in V.graph.removed_buffers
2125-
or name in self.allocated
2126-
or isinstance(buffer, ir.DonatedBuffer)
2127-
):
2123+
if name in V.graph.removed_buffers or name in self.allocated:
21282124
return
21292125
self.allocated.add(name)
21302126
if isinstance(
@@ -2178,12 +2174,7 @@ def can_reuse(self, input_buffer, output_buffer=None):
21782174
name = input_buffer.get_name()
21792175
return not (
21802176
name in V.graph.removed_buffers
2181-
or (
2182-
name in V.graph.graph_inputs
2183-
and not isinstance(
2184-
V.graph.graph_inputs_original[name], ir.DonatedBuffer
2185-
)
2186-
)
2177+
or name in V.graph.graph_inputs
21872178
or name in V.graph.constants
21882179
or name in V.graph.torchbind_constants
21892180
or name in V.graph.never_reuse_buffers

torch/_inductor/cudagraph_trees.py

+2-23
Original file line numberDiff line numberDiff line change
@@ -832,20 +832,6 @@ def __init__(
832832
if isinstance(t, torch.Tensor) and self._is_cuda_graph_recorded_tensor(t)
833833
]
834834

835-
# (depth, offset) of live tensors which are alias of previous graph outputs
836-
self.live_cudagraph_managed_path_refs: InputList[Optional[PathOutputIndex]] = [
837-
(
838-
self._is_alias_of_live_recorded_tensor(t)
839-
if isinstance(t, torch.Tensor)
840-
else None
841-
)
842-
for t in inputs
843-
]
844-
845-
# when replay, preserve the liveness of an input if it AliasesPriorGraphOutput
846-
# and also aliases an output of the current CUDAGraphNode
847-
self.preserved_aliased_inputs: InputList[bool] = [False] * len(inputs)
848-
849835
self.static_input_idxs: List[int] = list(
850836
set(wrapped_function.static_input_idxs) | set(self.cudagraph_managed_idxs)
851837
)
@@ -1052,11 +1038,11 @@ def run(self, new_inputs: List[InputType]) -> OutputType:
10521038
self.check_static_inputs_are_stable(new_inputs)
10531039

10541040
self._copy_inputs_and_remove_from_src(self.reconstructed_inputs, new_inputs)
1041+
new_inputs.clear()
10551042

10561043
self.run_graph()
10571044

10581045
outputs = self.reconstruct_outputs()
1059-
new_inputs.clear()
10601046

10611047
if config.triton.fast_path_cudagraph_asserts:
10621048
self.debug_check_invariants_after_invocation()
@@ -1275,12 +1261,6 @@ def _add_first_outputs(
12751261
path_ref = self._is_alias_of_live_recorded_tensor(o)
12761262
if path_ref is not None:
12771263
self._mark_prior_graph_output_as_aliased(path_ref)
1278-
1279-
for idx, inp_path_ref in enumerate(
1280-
self.live_cudagraph_managed_path_refs
1281-
):
1282-
if path_ref == inp_path_ref:
1283-
self.preserved_aliased_inputs[idx] = True
12841264
self.output_storage_alias.append(AliasesPriorGraphOutput(path_ref))
12851265
continue
12861266

@@ -1687,8 +1667,7 @@ def check_invariants(
16871667
# this invocation. it is too late to check after we've replayed the graph,
16881668
# because we would have already written over their memory.
16891669
for idx in self.cudagraph_managed_idxs:
1690-
if not self.preserved_aliased_inputs[idx]:
1691-
inputs[idx] = None # type: ignore[call-overload]
1670+
inputs[idx] = None # type: ignore[call-overload]
16921671

16931672
torch._check(
16941673
self._check_liveness(

torch/_inductor/graph.py

+6-28
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@
7474
)
7575
from .ir import (
7676
Constant,
77-
DonatedBuffer,
7877
FixedLayout,
7978
get_device_type,
8079
InputBuffer,
@@ -104,7 +103,6 @@
104103
convert_shape_to_inductor,
105104
gather_origins,
106105
get_cloned_parameter_buffer_name,
107-
get_donated_idxs,
108106
get_sympy_Expr_dtype,
109107
is_same_tensor,
110108
maybe_get_suppress_shape_guards_ctx,
@@ -488,11 +486,6 @@ def __init__(
488486
# state used by for Kernel.workspace
489487
self.workspace_id = itertools.count()
490488

491-
# track the current placeholder index that we are processing
492-
self.placeholder_idx = -1
493-
494-
self.bw_donated_idxs = get_donated_idxs()
495-
496489
def has_feature(
497490
self,
498491
device: Union[torch._inductor.ir.IRNode, device, None],
@@ -970,7 +963,6 @@ def constant_name(self, name: str, device_override: Optional[torch.device]) -> s
970963
def placeholder(
971964
self, target: str, args: Tuple[object], kwargs: Dict[str, object] # type: ignore[override]
972965
) -> Union[Expr, TensorBox, None]:
973-
self.placeholder_idx += 1
974966
example = super().placeholder(target, args, kwargs) # type: ignore[arg-type]
975967
target = self.qualify_name(target)
976968
if isinstance(example, SymTypes):
@@ -1001,27 +993,13 @@ def placeholder(
1001993
sizes, strides = self.static_sizes_strides(example)
1002994
else:
1003995
sizes, strides = self.symbolic_sizes_strides(example) # type: ignore[assignment]
1004-
1005-
if (
1006-
self.is_backward
1007-
and self.bw_donated_idxs
1008-
and self.placeholder_idx in self.bw_donated_idxs
1009-
):
1010-
tensor = TensorBox.create(
1011-
DonatedBuffer(
1012-
name=target,
1013-
layout=FixedLayout(example.device, example.dtype, sizes, strides),
1014-
)
996+
# TODO(jansel): handle input aliasing
997+
tensor = TensorBox.create(
998+
InputBuffer(
999+
name=target,
1000+
layout=FixedLayout(example.device, example.dtype, sizes, strides),
10151001
)
1016-
else:
1017-
# TODO(jansel): handle input aliasing
1018-
tensor = TensorBox.create(
1019-
InputBuffer(
1020-
name=target,
1021-
layout=FixedLayout(example.device, example.dtype, sizes, strides),
1022-
)
1023-
)
1024-
1002+
)
10251003
self.graph_inputs[target] = tensor
10261004
self.graph_input_names.append(target)
10271005
self.graph_inputs_original[target] = tensor.data.data

torch/_inductor/ir.py

-10
Original file line numberDiff line numberDiff line change
@@ -3832,16 +3832,6 @@ def num_reads(self) -> int:
38323832
return 1
38333833

38343834

3835-
class DonatedBuffer(InputBuffer):
3836-
"""
3837-
Represents a donated buffer which is a saved tensor that is not alias to any
3838-
fwd inputs, fwd user outputs, and bwd outputs. We generally cannot inplace
3839-
reuse the input tensor memory during backward since it might be used in another
3840-
function. However, donated buffer can be inplace reused during backward
3841-
to save memory.
3842-
"""
3843-
3844-
38453835
class ConstantBuffer(InputBuffer):
38463836
override_device: Optional[torch.device] = None
38473837

torch/_inductor/scheduler.py

+7-39
Original file line numberDiff line numberDiff line change
@@ -125,16 +125,10 @@ def allocate(self) -> None:
125125
hasattr(V.kernel, "args")
126126
and self.get_name() in V.kernel.inplace_update_buffers
127127
):
128-
input_buffer: Union[ir.DonatedBuffer, ir.Buffer]
129-
input_buffer_name = V.kernel.inplace_update_buffers[self.get_name()]
130-
if input_buffer_name in self.scheduler.name_to_donated_buffer:
131-
input_buffer = self.scheduler.name_to_donated_buffer[
132-
input_buffer_name
133-
].node
134-
else:
135-
input_buffer = self.scheduler.name_to_buf[input_buffer_name].node
136128
V.graph.wrapper_code.codegen_inplace_reuse(
137-
input_buffer,
129+
self.scheduler.name_to_buf[
130+
V.kernel.inplace_update_buffers[self.get_name()]
131+
].node,
138132
self.node,
139133
)
140134
else:
@@ -169,11 +163,6 @@ def get_mutations(self) -> Sequence[str]:
169163
return self.node.get_mutation_names()
170164

171165

172-
@dataclasses.dataclass
173-
class SchedulerDonatedBuffer(SchedulerBuffer):
174-
defining_op: Optional[BaseSchedulerNode] = None # type: ignore[assignment]
175-
176-
177166
class BaseSchedulerNode:
178167
group: Tuple[torch.device, Tuple[Tuple[sympy.Expr, ...], ...]]
179168
read_writes: dependencies.ReadWrites
@@ -453,12 +442,9 @@ def decide_inplace_update(self) -> None:
453442
continue
454443

455444
for read in self.read_writes.reads:
456-
input_buf: Optional[Union[SchedulerBuffer, SchedulerDonatedBuffer]]
457-
if read.name in self.scheduler.name_to_donated_buffer:
458-
input_buf = self.scheduler.name_to_donated_buffer[read.name]
459-
else:
460-
input_buf = self.scheduler.name_to_buf.get(read.name)
461-
445+
input_buf: Optional[SchedulerBuffer] = self.scheduler.name_to_buf.get(
446+
read.name
447+
)
462448
if (
463449
input_buf
464450
and V.graph.wrapper_code.can_reuse(input_buf, self)
@@ -484,8 +470,7 @@ def decide_inplace_update(self) -> None:
484470
),
485471
)
486472
and not (
487-
input_buf.defining_op
488-
and isinstance(
473+
isinstance(
489474
input_buf.defining_op.node,
490475
(ir.FallbackKernel, ir.MultiOutput),
491476
)
@@ -1816,9 +1801,6 @@ def _init(self, nodes: List[ir.Operation]) -> None:
18161801
for node in self.nodes:
18171802
node.prune_deps()
18181803

1819-
self.name_to_donated_buffer: Dict[
1820-
str, SchedulerDonatedBuffer
1821-
] = self.get_donated_buffers()
18221804
self.name_to_node: Dict[str, BaseSchedulerNode] = {
18231805
n.get_name(): n for n in self.nodes
18241806
}
@@ -1902,17 +1884,6 @@ def _init(self, nodes: List[ir.Operation]) -> None:
19021884
}
19031885
)
19041886

1905-
def get_donated_buffers(self) -> Dict[str, SchedulerDonatedBuffer]:
1906-
name_to_donated_buf = {}
1907-
for name in V.graph.graph_inputs_original:
1908-
if isinstance(V.graph.graph_inputs_original[name], ir.DonatedBuffer):
1909-
name_to_donated_buf[name] = SchedulerDonatedBuffer(
1910-
self,
1911-
V.graph.graph_inputs_original[name],
1912-
defining_op=None,
1913-
)
1914-
return name_to_donated_buf
1915-
19161887
@property
19171888
def current_device(self) -> Optional[torch.device]:
19181889
return V.graph.current_device
@@ -2189,9 +2160,6 @@ def add_user(
21892160
for buf in node.get_outputs():
21902161
buf.set_users(name_to_users[buf.get_name()].items)
21912162

2192-
for name in self.name_to_donated_buffer:
2193-
self.name_to_donated_buffer[name].set_users(name_to_users[name].items)
2194-
21952163
def dead_node_elimination(self) -> None:
21962164
"""
21972165
Remove any nodes without users

torch/_inductor/utils.py

-7
Original file line numberDiff line numberDiff line change
@@ -2200,10 +2200,3 @@ def wrap(cls: _T) -> _T:
22002200
if cls is None:
22012201
return wrap
22022202
return wrap(cls)
2203-
2204-
2205-
def get_donated_idxs() -> Optional[List[int]]:
2206-
tracing_context = torch._guards.TracingContext.try_get()
2207-
if tracing_context is not None and tracing_context.fw_metadata:
2208-
return tracing_context.fw_metadata.bw_donated_idxs
2209-
return None

0 commit comments

Comments
 (0)