forked from codeplaysoftware/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcommon.py
2598 lines (2223 loc) · 90.4 KB
/
common.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
from __future__ import annotations
import contextlib
import dataclasses
import enum
import functools
import itertools
import logging
import math
import operator
import re
import typing
from enum import auto, Enum
from itertools import chain
from typing import (
Any,
Callable,
cast,
ClassVar,
Generic,
NamedTuple,
Optional,
TYPE_CHECKING,
Union,
)
from typing_extensions import TypeVar
import sympy
import torch
import torch.fx
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
from torch.utils import _pytree as pytree
from torch.utils._ordered_set import OrderedSet
from torch.utils._sympy.numbers import int_oo
from torch.utils._sympy.printers import PythonPrinter as _PythonPrinter
from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges
from .. import config, metrics
from ..dtype_propagation import DtypePropagationOpsHandler
from ..ops_handler import BasicMathOpsMixin, DefaultHandler
from ..utils import (
boolean_ops,
DeferredLineBase,
generate_assert,
IndentedBuffer,
ir_dataclass,
ScopedDict,
sympy_dot,
sympy_index_symbol,
sympy_subs,
triton_type,
unique,
)
from ..virtualized import ops, OpsHandler, OpsValue, ReductionType, StoreMode, V
if TYPE_CHECKING:
from collections.abc import Iterator, MutableMapping, Sequence
from ..ir import Buffer, ChoiceCaller, FixedLayout, IRNode
from ..loop_body import LoopBody
from ..scheduler import BaseScheduling, Scheduler, SchedulerNode
from .wrapper import PythonWrapperCodegen
_T = TypeVar("_T")
SchedulingConstructor = Callable[[Optional[Scheduler]], BaseScheduling]
WrapperConstructor = type[PythonWrapperCodegen]
SymbolLike = Union[str, sympy.Symbol]
# OpVarT should really be Union[CSEVariable, str], however this
# causes typing errors in subclasses (defined in other files).
OpVarT = str
schedule_log = torch._logging.getArtifactLogger(__name__, "schedule")
log = logging.getLogger(__name__)
def data_type_logger(msg: str) -> None:
if schedule_log.isEnabledFor(logging.DEBUG):
schedule_log.debug("Data type propagation: %s", msg)
class WorkspaceZeroMode(enum.Enum):
UNINITIALIZED = 0
ZERO_ON_CALL = 1 # kernel may leave workspace dirty
ZERO_PER_GRAPH = 2 # must be re-zeroed by kernel
@staticmethod
def combine(a: WorkspaceZeroMode, b: WorkspaceZeroMode) -> WorkspaceZeroMode:
if a == b or b == WorkspaceZeroMode.UNINITIALIZED:
return a
if a == WorkspaceZeroMode.UNINITIALIZED:
return b
raise NotImplementedError(f"WorkspaceZeroMode.combine({a!r}, {b!r})")
@staticmethod
def from_bool(zero_fill: bool) -> WorkspaceZeroMode:
if zero_fill:
return WorkspaceZeroMode.ZERO_ON_CALL
return WorkspaceZeroMode.UNINITIALIZED
@ir_dataclass(frozen=True)
class WorkspaceArg:
"""A temporary buffer used for a single kernel, then discarded.
Not registered as a traditional buffer since there are no users,
so it would be dead code eliminated.
Args:
nbytes: The size of the buffer in bytes.
zero_fill: Whether the buffer should be initialized to zero.
"""
count: sympy.Expr
zero_mode: WorkspaceZeroMode
device: torch.device
outer_name: str
inner_name: str = "ws_ptr"
dtype: torch.dtype = torch.uint8
@staticmethod
def unique_name(prefix: str = "workspace_") -> str:
return f"{prefix}{next(V.graph.workspace_id)}"
@staticmethod
def can_join(a: WorkspaceArg, b: WorkspaceArg) -> bool:
return (
a.inner_name == b.inner_name and a.dtype == b.dtype and a.device == b.device
)
@staticmethod
def join(a: WorkspaceArg, b: WorkspaceArg) -> WorkspaceArg:
return WorkspaceArg(
count=a.count + b.count,
zero_mode=WorkspaceZeroMode.combine(a.zero_mode, b.zero_mode),
dtype=a.dtype,
device=a.device,
inner_name=a.inner_name,
outer_name=a.outer_name,
)
@staticmethod
def maximum(a: WorkspaceArg, b: WorkspaceArg) -> WorkspaceArg:
assert (
a.dtype == b.dtype and a.device == b.device and a.inner_name == b.inner_name
)
return WorkspaceArg(
count=sympy.Max(a.count, b.count),
zero_mode=WorkspaceZeroMode.combine(a.zero_mode, b.zero_mode),
dtype=a.dtype,
device=a.device,
inner_name=a.inner_name,
outer_name=a.outer_name,
)
# These methods let WorkspaceArg pretend it is a buffer to reuse allocation code
def get_device(self) -> torch.device:
return self.device
get_device_or_error = get_device
def get_dtype(self) -> torch.dtype:
return self.dtype
def get_layout(self) -> FixedLayout:
from ..ir import FixedLayout
return FixedLayout(
device=self.device,
dtype=self.dtype,
size=[self.count],
stride=[1],
)
@property
def layout(self) -> FixedLayout:
return self.get_layout()
get_output_spec = get_layout
maybe_get_output_spec = get_layout
maybe_get_layout = get_layout
def get_size(self) -> list[sympy.Expr]:
return [self.count]
def get_stride(self) -> list[sympy.Expr]:
return [sympy.S.One]
def get_name(self) -> str:
return self.outer_name
def get_inputs_that_alias_output(self) -> list[str]:
return []
@dataclasses.dataclass
class TensorArg:
name: str
buffer: str
dtype: torch.dtype
offset: sympy.Expr = sympy.S.Zero # c++ only
alias_of: Optional[str] = None # halide only
@dataclasses.dataclass
class SizeArg:
name: str
expr: sympy.Expr
@property
def alias_of(self) -> Optional[str]:
return None
@dataclasses.dataclass
class ConstexprArg:
name: str
@dataclasses.dataclass
class TMADescriptorArg:
name: str
@dataclasses.dataclass
class DeviceCodegen:
scheduling: SchedulingConstructor
wrapper_codegen: WrapperConstructor
cpp_wrapper_codegen: Optional[WrapperConstructor] = None
KernelArgType = Union[WorkspaceArg, TensorArg, SizeArg, TMADescriptorArg, ConstexprArg]
device_codegens: dict[str, DeviceCodegen] = {}
class DeviceOpOverrides:
def import_get_raw_stream_as(self, name: str) -> str:
raise NotImplementedError
def set_device(self, device_idx: int) -> str:
raise NotImplementedError
def synchronize(self) -> str:
raise NotImplementedError
def device_guard(self, device_idx: int) -> str:
raise NotImplementedError
def cpp_device_guard(self) -> str:
raise NotImplementedError
def cpp_aoti_device_guard(self) -> str:
raise NotImplementedError
def cpp_stream_guard(self) -> str:
raise NotImplementedError
def cpp_aoti_stream_guard(self) -> str:
raise NotImplementedError
def cpp_getStreamFromExternal(self) -> str:
raise NotImplementedError
def kernel_header(self) -> str:
raise NotImplementedError
def kernel_driver(self) -> str:
raise NotImplementedError
def cpp_stream_type(self) -> str:
raise NotImplementedError
def aoti_get_stream(self) -> str:
raise NotImplementedError
def cpp_kernel_type(self) -> str:
raise NotImplementedError
def cpp_device_ptr(self) -> str:
raise NotImplementedError
def tma_descriptor_helpers(self) -> str:
raise NotImplementedError
def cpp_global_scratch(self, idx: int) -> Optional[tuple[str, str]]:
# optionally return (scratch definition, arg name)
raise NotImplementedError
device_op_overrides_dict: dict[str, DeviceOpOverrides] = {}
# The code generated by Inductor consists of two main parts: kernel code and wrapper code.
# For any new backend looking to integrate with Inductor, customization of these two main
# parts are necessary to generate its specific code.
#
# Kernel code generation is determined by different Scheduling. Consequently, a new
# backend needs to provide a custom Scheduling for its unique kernel code generation. Currently,
# CppScheduling and TritonScheduling serve the C++/OpenMP and Triton backends, respectively.
#
# For the Wrapper, Inductor provides a PythonWrapperCodegen class to generate the Python wrapper code
# that bridges kernels. This allows out-of-tree backends to inherit from PythonWrapperCodegen,
# and override specific member functions to create backend-specific Python wrapper code.
#
# Other classes, such as CppKernel and TritonKernel, used for code generation, typically form part
# of the logic for either Scheduling or PythonWrapperCodegen. So the Scheduling and PythonWrapperCodegen interfaces
# provide flexibility to the backend. A backend can choose to implement these classes from scratch,
# or reuse them by extending and overriding as necessary. And Inductor provides the registration API,
# register_backend_for_device, to equip a new backend at runtime.
#
# Intel has developed a new backend on top of Triton to support Intel GPUs, leveraging these interfaces.
# This backend can be used as a reference:
# https://github.com/intel/intel-extension-for-pytorch/blob/5dcc9d57e5422cf295e1a1ee97896d6b6a554a85/intel_extension_for_pytorch/_inductor/__init__.py#L9
def register_backend_for_device(
device: str,
device_scheduling: SchedulingConstructor,
device_wrapper_codegen: WrapperConstructor,
device_cpp_wrapper_codegen: Optional[WrapperConstructor] = None,
) -> None:
device_codegens[device] = DeviceCodegen(
device_scheduling, device_wrapper_codegen, device_cpp_wrapper_codegen
)
class BackendFeature(Enum):
FOREACH = auto()
BUCKETIZE = auto()
INPLACE_BUFFERS = auto()
MASKED_SCATTER_WITH_INDEX = auto()
SCAN = auto()
SORT = auto()
TUPLE_REDUCTION = auto()
PREFER_STORE_LOOP_ORDER = auto()
TRITON_TEMPLATES = auto()
REDUCE_TO_SINGLE_ELEMENT = auto()
def get_backend_features(
device: Union[torch.device, str, None],
) -> OrderedSet[BackendFeature]:
if device is None:
return OrderedSet()
init_backend_registration()
if isinstance(device, torch.device):
device_type = device.type
else:
assert isinstance(device, str)
device_type = device
device = torch.device(device_type)
scheduling_ctor = get_scheduling_for_device(device_type)
assert scheduling_ctor
scheduling = scheduling_ctor(None)
return scheduling.get_backend_features(device)
def has_backend_feature(
device: Union[torch.device, str, None], feature: BackendFeature
) -> bool:
"""See also V.graph.has_feature"""
assert isinstance(feature, BackendFeature)
return feature in get_backend_features(device)
def get_scheduling_for_device(device: str) -> Optional[SchedulingConstructor]:
return device_codegens[device].scheduling if device in device_codegens else None
def get_wrapper_codegen_for_device(
device: str, cpp_wrapper: bool = False
) -> Optional[WrapperConstructor]:
if device in device_codegens:
wrapper_codegen_obj: DeviceCodegen = device_codegens[device]
return (
wrapper_codegen_obj.cpp_wrapper_codegen
if cpp_wrapper
else wrapper_codegen_obj.wrapper_codegen
)
return None
@functools.lru_cache(None)
def init_backend_registration() -> None:
from .cpp import CppScheduling
from .cpp_wrapper_cpu import CppWrapperCpu
from .cpp_wrapper_cpu_array_ref import CppWrapperCpuArrayRef
from .cpp_wrapper_gpu import CppWrapperGpu
from .cuda_combined_scheduling import CUDACombinedScheduling
from .halide import HalideScheduling
from .mps import MetalScheduling
from .triton import TritonScheduling
from .wrapper import PythonWrapperCodegen
from .xpu_combined_scheduling import SYCLCombinedScheduling
if get_scheduling_for_device("cpu") is None:
cpu_backends = {
"cpp": CppScheduling,
"halide": HalideScheduling,
"triton": TritonScheduling,
}
register_backend_for_device(
"cpu",
lambda scheduling: cpu_backends[config.cpu_backend](scheduling),
PythonWrapperCodegen,
CppWrapperCpuArrayRef
if config.aot_inductor.allow_stack_allocation
else CppWrapperCpu,
)
if get_scheduling_for_device("cuda") is None:
# CUDACombinedScheduling combines Triton and CUDA C++ scheduling for CUDA devices via delegation
cuda_backends = {
"triton": CUDACombinedScheduling,
"halide": HalideScheduling,
}
register_backend_for_device(
"cuda",
lambda scheduling: cuda_backends[config.cuda_backend](scheduling),
PythonWrapperCodegen,
CppWrapperGpu,
)
if get_scheduling_for_device("xpu") is None:
# SYCLCombinedScheduling combines Triton and SYCL C++ scheduling for XPU devices via delegation
register_backend_for_device(
"xpu",
SYCLCombinedScheduling,
PythonWrapperCodegen,
CppWrapperGpu,
)
if get_scheduling_for_device("mps") is None:
register_backend_for_device(
"mps",
MetalScheduling,
PythonWrapperCodegen,
CppWrapperGpu,
)
private_backend = torch._C._get_privateuse1_backend_name()
if (
private_backend != "privateuseone"
and get_scheduling_for_device(private_backend) is None
):
from torch.utils.backend_registration import _get_custom_mod_func
try:
device_scheduling = _get_custom_mod_func("Scheduling")
wrapper_codegen = _get_custom_mod_func("PythonWrapperCodegen")
cpp_wrapper_codegen = _get_custom_mod_func("CppWrapperCodegen")
if device_scheduling and wrapper_codegen and cpp_wrapper_codegen:
register_backend_for_device(
private_backend,
device_scheduling,
wrapper_codegen,
cpp_wrapper_codegen,
)
except RuntimeError:
pass
def index_prevent_reordering(
index: Sequence[sympy.Expr],
index_vars: Sequence[sympy.Expr],
sizes: Sequence[sympy.Expr],
) -> list[sympy.Expr]:
from ..ir import FlexibleLayout
# added contiguous index prevents reordering
return [*index, sympy_dot(index_vars, FlexibleLayout.contiguous_strides(sizes))]
def register_device_op_overrides(
device: str, device_op_overrides: DeviceOpOverrides
) -> None:
device_op_overrides_dict[device] = device_op_overrides
def get_device_op_overrides(device: str) -> DeviceOpOverrides:
assert isinstance(device, str)
if not device_op_overrides_dict:
from . import cpu_device_op_overrides, mps_device_op_overrides # noqa: F401
from .cuda import device_op_overrides # noqa: F401
from .xpu import device_op_overrides as xpu_op_overrides # noqa: F401
return device_op_overrides_dict[device]
DTYPE_TO_COMPUTATION_DTYPE: dict[torch.dtype, torch.dtype] = {
torch.bfloat16: torch.float,
torch.float16: torch.float,
**{
dtype: dtype
for dtype in [
torch.bool,
torch.float32,
torch.float64,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.uint8,
torch.uint16,
torch.uint32,
torch.uint64,
]
},
}
def deduce_output_dtype_by_name(
op_name: str,
*args: Any,
**kwargs: Any,
) -> Optional[torch.dtype]:
"""
Given op name and a list of input dtypes, deduce the output dtype
"""
if op_name in boolean_ops():
return torch.bool
elif op_name in (
"to_dtype",
"index_expr",
):
return kwargs["dtype"] if "dtype" in kwargs else args[-1]
elif op_name in (
"rand",
"randn",
):
return torch.float
elif op_name in (
"get_index",
"randint64",
"load_seed",
):
return torch.int64
elif op_name == "reduction":
return kwargs["dtype"] if "dtype" in kwargs else args[1]
elif op_name == "constant":
return kwargs["dtype"] if "dtype" in kwargs else args[-1]
elif op_name in (
"load",
"store",
"store_reduction",
):
buf_name = args[1]
return V.graph.get_dtype(buf_name) # type: ignore[arg-type]
elif op_name == "to_dtype_bitcast":
return kwargs["dtype"] if "dtype" in kwargs else args[-2]
return None
class DataTypePropagation:
def __init__(self, body: LoopBody) -> None:
self.body = body
self.graphs: dict[Union[Callable[..., Any], str], Any] = {
"root": body.root_block.graph
}
for k, v in body.subblocks.items():
self.graphs[k] = v.graph
def deduce_node_dtype_by_inputs(self, node: torch.fx.Node) -> Optional[torch.dtype]:
inputs = node.all_input_nodes
input_nodes = [
n for n in inputs if isinstance(n, torch.fx.Node) and n.op != "placeholder"
]
if len(input_nodes) == 0:
return None
all_input_nodes_propagated = all(
OptimizationContext.key in n.meta
and n.meta[OptimizationContext.key].dtype is not None
for n in input_nodes
)
if not all_input_nodes_propagated:
return None
return functools.reduce(
torch.promote_types,
[n.meta[OptimizationContext.key].dtype for n in input_nodes],
)
def deduce_node_dtype_by_subgraph(self, node: torch.fx.Node) -> torch.dtype:
sub_graph = self.graphs[node.target]
dtype = self.propagate_graph(sub_graph)
assert dtype
return dtype
def deduce_node_dtype(self, node: torch.fx.Node) -> Optional[torch.dtype]:
if node.op == "placeholder":
return None
if node.target == "output" and len(node.args) != 1:
# we can infer output node if it only have 1 arg
return None
if node.target == operator.getitem:
return self.deduce_node_dtype(node.args[0]) # type: ignore[arg-type]
assert isinstance(node.target, str)
if node.target.startswith("masked_subblock"):
return self.deduce_node_dtype_by_subgraph(node)
if (
output_dtype := deduce_output_dtype_by_name(
node.target,
*node.args,
**node.kwargs,
)
) is not None:
return output_dtype
return self.deduce_node_dtype_by_inputs(node)
def propagate_graph(self, graph: torch.fx.Graph) -> Optional[torch.dtype]:
assert graph.nodes
graph_dtype: Optional[torch.dtype] = None
# For masked_subblock, we use output's dtype to represent
# the dtype of this subgraph. For other cases, graph_dtype
# might be None
for node in graph.nodes:
if OptimizationContext.key in node.meta:
opt_ctx = node.meta[OptimizationContext.key]
else:
opt_ctx = OptimizationContext()
opt_ctx.dtype = self.deduce_node_dtype(node)
node.meta[OptimizationContext.key] = opt_ctx
if node.target == "output":
graph_dtype = opt_ctx.dtype
return graph_dtype
def propagate(self) -> Optional[torch.dtype]:
return self.propagate_graph(self.graphs["root"])
@classmethod
def propagate_loopbody(cls, body: LoopBody) -> Optional[torch.dtype]:
return cls(body).propagate()
@classmethod
def propagate_scheduler_node(cls, node: SchedulerNode) -> Optional[torch.dtype]:
from ..loop_body import LoopBody
from ..scheduler import SchedulerNode
assert isinstance(node, SchedulerNode)
assert isinstance(node._body, LoopBody)
return DataTypePropagation.propagate_loopbody(node._body)
class PythonPrinter(_PythonPrinter):
def doprint(
self, expr: sympy.Expr, *, simplify: bool = True, p: bool = True
) -> str:
# TODO: why are people passing strings to the printer here :think:
if simplify and isinstance(expr, sympy.Expr) and hasattr(V.graph, "sizevars"):
expr = V.graph.sizevars.simplify(expr)
return super().doprint(expr)
class OpDecompositions:
"""
Decomposes inductor ops
"""
@staticmethod
def identity(value: OpVarT) -> OpVarT:
# used to trigger cse
return value
@staticmethod
def reciprocal(x: OpVarT) -> OpVarT:
return ops.truediv(ops.constant(1, torch.int32), x)
@staticmethod
def square(x: OpVarT) -> OpVarT:
return ops.mul(x, x)
@staticmethod
def erfc(x: OpVarT) -> OpVarT:
return ops.sub(ops.constant(1, torch.float32), ops.erf(x))
@staticmethod
def erfcx(x: OpVarT) -> OpVarT:
return ops.mul(ops.exp(ops.square(x)), ops.erfc(x))
@staticmethod
def expm1(x: OpVarT) -> OpVarT:
return ops.sub(ops.exp(x), ops.constant(1, torch.float32))
@staticmethod
def log10(x: OpVarT) -> OpVarT:
return ops.mul(ops.log(x), ops.constant(1 / math.log(10), torch.float32))
@staticmethod
def log2(x: OpVarT) -> OpVarT:
return ops.mul(ops.log(x), ops.constant(1 / math.log(2), torch.float32))
@staticmethod
def exp2(x: OpVarT) -> OpVarT:
return ops.exp(ops.mul(x, ops.constant(math.log(2), torch.float32)))
@staticmethod
def log1p(x: OpVarT) -> OpVarT:
return ops.log(ops.add(x, ops.constant(1, torch.int32)))
@staticmethod
def sigmoid(x: OpVarT) -> OpVarT:
one = ops.constant(1, torch.int32)
return ops.truediv(one, ops.add(one, ops.exp(ops.neg(x))))
@staticmethod
def relu(x: OpVarT) -> OpVarT:
return ops.maximum(x, ops.constant(0, torch.int32))
@staticmethod
def fma(x: OpVarT, y: OpVarT, z: OpVarT) -> OpVarT:
# for backends that don't override this (halide)
return ops.add(ops.mul(x, y), z)
@staticmethod
def floor_to_int(a: OpVarT, dtype: torch.dtype) -> OpVarT:
return ops.to_dtype(ops.floor(a), dtype)
@staticmethod
def ceil_to_int(a: OpVarT, dtype: torch.dtype) -> OpVarT:
return ops.to_dtype(ops.ceil(a), dtype)
@staticmethod
def trunc_to_int(a: OpVarT, dtype: torch.dtype) -> OpVarT:
return ops.to_dtype(ops.trunc(a), dtype)
@staticmethod
def remainder(a: OpVarT, b: OpVarT) -> OpVarT:
r = ops.mod(a, b)
cond = ops.and_(
ops.ne(r, ops.constant(0, torch.int32)),
ops.ne(ops.signbit(r), ops.signbit(b)),
)
return ops.where(cond, ops.add(r, b), r)
@staticmethod
def round_to_int(a: OpVarT, dtype: torch.dtype) -> OpVarT:
return ops.to_dtype(ops.round(a), dtype)
_RE_PAREN_NOT_NEEDED = re.compile(r"[a-z0-9_.]+|\([^)]*\)|", flags=re.IGNORECASE)
def _all_in_parens(string: str) -> bool:
if string[0] != "(" or len(string) < 2:
return False
count = 1
for i, char in enumerate(string[1:]):
if char == "(":
count += 1
elif char == ")":
count -= 1
if count == 0 and i != len(string) - 2:
return False
assert count == 0
return True
class OpOverrides(BasicMathOpsMixin, OpDecompositions, OpsHandler[Any]):
@staticmethod
def paren(string: OpVarT) -> OpVarT:
if (
isinstance(string, CSEVariable)
or _RE_PAREN_NOT_NEEDED.fullmatch(string)
or _all_in_parens(string)
):
# don't put extra parens for strings that are already wrapped in parens
return string
return f"({string})"
@staticmethod
def constant(value: Union[bool, float, int], dtype: torch.dtype) -> OpVarT:
return repr(value)
@staticmethod
def libdevice_sigmoid(x: OpVarT) -> OpVarT:
one = ops.constant(1, torch.int32)
return ops.truediv(one, ops.add(one, ops.libdevice_exp(ops.neg(x))))
@staticmethod
def libdevice_abs(x: OpVarT) -> OpVarT:
return ops.abs(x)
@staticmethod
def libdevice_sqrt(x: OpVarT) -> OpVarT:
return ops.sqrt(x)
@staticmethod
def libdevice_cos(x: OpVarT) -> OpVarT:
return ops.cos(x)
@staticmethod
def libdevice_sin(x: OpVarT) -> OpVarT:
return ops.sin(x)
@staticmethod
def libdevice_log(x: OpVarT) -> OpVarT:
return ops.log(x)
@staticmethod
def libdevice_exp(x: OpVarT) -> OpVarT:
return ops.exp(x)
@staticmethod
def bitwise_not(x: OpVarT) -> OpVarT:
return f"~{OpOverrides.paren(x)}"
@staticmethod
def logical_not(a: OpVarT) -> OpVarT:
return f"{OpOverrides.paren(a)} == 0"
@staticmethod
def bitwise_and(x: OpVarT, y: OpVarT) -> OpVarT:
return f"{OpOverrides.paren(x)} & {OpOverrides.paren(y)}"
@staticmethod
def bitwise_or(x: OpVarT, y: OpVarT) -> OpVarT:
return f"{OpOverrides.paren(x)} | {OpOverrides.paren(y)}"
@staticmethod
def bitwise_xor(x: OpVarT, y: OpVarT) -> OpVarT:
return f"{OpOverrides.paren(x)} ^ {OpOverrides.paren(y)}"
@staticmethod
def bitwise_left_shift(x: OpVarT, y: OpVarT) -> OpVarT:
return f"{OpOverrides.paren(x)} << {OpOverrides.paren(y)}"
@staticmethod
def bitwise_right_shift(x: OpVarT, y: OpVarT) -> OpVarT:
return f"{OpOverrides.paren(x)} >> {OpOverrides.paren(y)}"
@staticmethod
def int_truediv(a: OpVarT, b: OpVarT) -> OpVarT:
# TODO: this is wrong
# TODO: an easy bandaid is to generate runtime asserts that it's
# <= 2**53, which is when this equation is correct
return ops.truediv(a, b)
@staticmethod
def load_seed(name: str, offset: OpVarT) -> OpVarT:
return ops.load(name, sympy.Integer(offset))
def indirect_indexing(
self,
var: OpVarT,
size: Union[sympy.Expr, int],
check: bool = True,
wrap_neg: bool = True,
) -> sympy.Symbol:
return sympy_index_symbol(str(var))
def check_bounds(
self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool
) -> None:
raise NotImplementedError(
f"{type(self).__name__}: check_bounds should be handled by CSEProxy"
)
def load(self, name: str, index: sympy.Expr) -> OpVarT:
raise NotImplementedError(
f"{type(self).__name__}: load should be handled by CSEProxy"
)
def store(
self, name: str, index: sympy.Expr, value: OpVarT, mode: StoreMode = None
) -> None:
raise NotImplementedError(
f"{type(self).__name__}: store should be handled by CSEProxy"
)
def store_reduction(self, name: str, index: sympy.Expr, value: OpVarT) -> None:
raise NotImplementedError(
f"{type(self).__name__}: store_reduction should be handled by CSEProxy"
)
def reduction(
self,
dtype: torch.dtype,
src_dtype: torch.dtype,
reduction_type: ReductionType,
value: Union[OpVarT, tuple[OpVarT, ...]],
) -> Union[OpVarT, tuple[OpVarT, ...]]:
raise NotImplementedError(
f"{type(self).__name__}: reduction should be handled by CSEProxy"
)
def scan(
self,
dtypes: tuple[torch.dtype, ...],
combine_fn: Callable[
[tuple[OpVarT, ...], tuple[OpVarT, ...]],
tuple[OpVarT, ...],
],
values: tuple[OpVarT, ...],
) -> tuple[OpVarT, ...]:
raise NotImplementedError(
f"{type(self).__name__}: scan should be handled by CSEProxy"
)
def sort(
self,
dtypes: tuple[torch.dtype, ...],
values: tuple[OpVarT, ...],
stable: bool,
descending: bool,
) -> tuple[OpVarT, ...]:
raise NotImplementedError(
f"{type(self).__name__}: sort should be handled by CSEProxy"
)
def bucketize(
self,
values: OpVarT,
boundaries: tuple[str, sympy.Expr, sympy.Expr, sympy.Expr],
boundary_indices: OpVarT,
indexing_dtype: torch.dtype,
right: bool,
sorter: Optional[tuple[str, sympy.Expr]] = None,
sorter_indices: Optional[OpVarT] = None,
) -> OpVarT:
raise NotImplementedError(
f"{type(self).__name__}: bucketize should be handled by CSEProxy"
)
def halide_clamp(self, value: OpVarT, size: sympy.Expr, check: bool) -> OpVarT:
raise NotImplementedError(
f"{type(self).__name__}: halide_clamp only implemented for Halide backend"
)
def inline_asm_elementwise(
self,
*inputs: OpVarT,
asm: str,
constraints: Optional[str] = None,
dtype: torch.dtype = torch.float32,
is_pure: bool = True,
pack: int = 1,
) -> OpVarT:
raise NotImplementedError(
f"{type(self).__name__}: inline_asm_elementwise only implemented for Triton backend"
)
def output(self, *args: OpVarT) -> None:
raise AssertionError(
f"{type(self).__name__}: ops.output should not appear at codegen time"
)
def placeholder(self, index: int) -> OpVarT:
raise AssertionError(
f"{type(self).__name__}: ops.placeholder should not appear at codegen time"
)
@staticmethod
def _unimplemented(name: str) -> Callable[..., OpVarT]:
def unimplemented(self: OpOverrides, *args: Any, **kwargs: Any) -> OpVarT:
raise NotImplementedError(
f"{type(self).__name__} does not implement ops.{name}"
)
unimplemented.__name__ = name
unimplemented.is_unimplemented = True # type: ignore[attr-defined]
return unimplemented
@classmethod
def _is_unimplemented(cls, name: str) -> bool:
fn = getattr(cls, name, None)
default_fn = getattr(OpsHandler, name, None)
return not fn or fn == default_fn or getattr(fn, "is_unimplemented", False)
@classmethod
def _initialize_pointwise_overrides(cls, target: str) -> None:
assert target in ("triton", "cpp", "cppvec", "halide", "mps"), target
for funcname, data in pointwise_overrides_data.items():
impl = getattr(data, target)
if impl is None:
if cls._is_unimplemented(funcname):
setattr(cls, funcname, cls._unimplemented(funcname))
else:
assert funcname not in cls.__dict__, (
f"multiple definitions of {funcname} on {cls.__name__}"
)
impl.__name__ = funcname
setattr(cls, funcname, staticmethod(impl))
@dataclasses.dataclass
class OverridesData:
name: str