@@ -125,16 +125,10 @@ def allocate(self) -> None:
125
125
hasattr (V .kernel , "args" )
126
126
and self .get_name () in V .kernel .inplace_update_buffers
127
127
):
128
- input_buffer : Union [ir .DonatedBuffer , ir .Buffer ]
129
- input_buffer_name = V .kernel .inplace_update_buffers [self .get_name ()]
130
- if input_buffer_name in self .scheduler .name_to_donated_buffer :
131
- input_buffer = self .scheduler .name_to_donated_buffer [
132
- input_buffer_name
133
- ].node
134
- else :
135
- input_buffer = self .scheduler .name_to_buf [input_buffer_name ].node
136
128
V .graph .wrapper_code .codegen_inplace_reuse (
137
- input_buffer ,
129
+ self .scheduler .name_to_buf [
130
+ V .kernel .inplace_update_buffers [self .get_name ()]
131
+ ].node ,
138
132
self .node ,
139
133
)
140
134
else :
@@ -169,11 +163,6 @@ def get_mutations(self) -> Sequence[str]:
169
163
return self .node .get_mutation_names ()
170
164
171
165
172
- @dataclasses .dataclass
173
- class SchedulerDonatedBuffer (SchedulerBuffer ):
174
- defining_op : Optional [BaseSchedulerNode ] = None # type: ignore[assignment]
175
-
176
-
177
166
class BaseSchedulerNode :
178
167
group : Tuple [torch .device , Tuple [Tuple [sympy .Expr , ...], ...]]
179
168
read_writes : dependencies .ReadWrites
@@ -453,12 +442,9 @@ def decide_inplace_update(self) -> None:
453
442
continue
454
443
455
444
for read in self .read_writes .reads :
456
- input_buf : Optional [Union [SchedulerBuffer , SchedulerDonatedBuffer ]]
457
- if read .name in self .scheduler .name_to_donated_buffer :
458
- input_buf = self .scheduler .name_to_donated_buffer [read .name ]
459
- else :
460
- input_buf = self .scheduler .name_to_buf .get (read .name )
461
-
445
+ input_buf : Optional [SchedulerBuffer ] = self .scheduler .name_to_buf .get (
446
+ read .name
447
+ )
462
448
if (
463
449
input_buf
464
450
and V .graph .wrapper_code .can_reuse (input_buf , self )
@@ -484,8 +470,7 @@ def decide_inplace_update(self) -> None:
484
470
),
485
471
)
486
472
and not (
487
- input_buf .defining_op
488
- and isinstance (
473
+ isinstance (
489
474
input_buf .defining_op .node ,
490
475
(ir .FallbackKernel , ir .MultiOutput ),
491
476
)
@@ -1816,9 +1801,6 @@ def _init(self, nodes: List[ir.Operation]) -> None:
1816
1801
for node in self .nodes :
1817
1802
node .prune_deps ()
1818
1803
1819
- self .name_to_donated_buffer : Dict [
1820
- str , SchedulerDonatedBuffer
1821
- ] = self .get_donated_buffers ()
1822
1804
self .name_to_node : Dict [str , BaseSchedulerNode ] = {
1823
1805
n .get_name (): n for n in self .nodes
1824
1806
}
@@ -1902,17 +1884,6 @@ def _init(self, nodes: List[ir.Operation]) -> None:
1902
1884
}
1903
1885
)
1904
1886
1905
- def get_donated_buffers (self ) -> Dict [str , SchedulerDonatedBuffer ]:
1906
- name_to_donated_buf = {}
1907
- for name in V .graph .graph_inputs_original :
1908
- if isinstance (V .graph .graph_inputs_original [name ], ir .DonatedBuffer ):
1909
- name_to_donated_buf [name ] = SchedulerDonatedBuffer (
1910
- self ,
1911
- V .graph .graph_inputs_original [name ],
1912
- defining_op = None ,
1913
- )
1914
- return name_to_donated_buf
1915
-
1916
1887
@property
1917
1888
def current_device (self ) -> Optional [torch .device ]:
1918
1889
return V .graph .current_device
@@ -2189,9 +2160,6 @@ def add_user(
2189
2160
for buf in node .get_outputs ():
2190
2161
buf .set_users (name_to_users [buf .get_name ()].items )
2191
2162
2192
- for name in self .name_to_donated_buffer :
2193
- self .name_to_donated_buffer [name ].set_users (name_to_users [name ].items )
2194
-
2195
2163
def dead_node_elimination (self ) -> None :
2196
2164
"""
2197
2165
Remove any nodes without users
0 commit comments