-
Notifications
You must be signed in to change notification settings - Fork 3.2k
/
Copy pathop-resolve-rules.ts
129 lines (126 loc) · 5.53 KB
/
op-resolve-rules.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import {argMax, argMin, parseArgMinMaxAttributes} from './ops/argminmax';
import {attention} from './ops/attention';
import {batchNorm} from './ops/batch-norm';
import {biasAdd} from './ops/bias-add';
import {biasSplitGelu} from './ops/bias-split-gelu';
import * as binaryOps from './ops/binary-op';
import {concat, parseConcatAttributes} from './ops/concat';
import {conv, parseConvAttributes} from './ops/conv';
import {convTranspose, parseConvTransposeAttributes} from './ops/conv-transpose';
import {cumsum, parseCumSumAttributes} from './ops/cumsum';
import {einsum, parseEinsumAttributes} from './ops/einsum';
import {expand} from './ops/expand';
import {gather, parseGatherAttributes} from './ops/gather';
import {gatherElements, parseGatherElementsAttributes} from './ops/gather-elements';
import {gemm, parseGemmAttributes} from './ops/gemm';
import {instanceNorm} from './ops/instance-norm';
import {layerNorm} from './ops/layer-norm';
import {matMul} from './ops/matmul';
import {multiHeadAttention, parseMultiHeadAttentionAttributes} from './ops/multi-head-attentiion';
import {pad} from './ops/pad';
import * as pool from './ops/pool';
import {range} from './ops/range';
import {reduceL1, reduceL2, reduceLogSum, reduceLogSumExp, reduceMax, reduceMean, reduceMin, reduceProd, reduceSum, reduceSumSquare} from './ops/reduce';
import {parseResizeAttributes, resize} from './ops/resize';
import {skipLayerNorm} from './ops/skip-layer-norm';
import {parseSliceAttributes, slice} from './ops/slice';
import {parseSoftmaxAttributes, softmax} from './ops/softmax';
import {parseSplitAttributes, split} from './ops/split';
import {tile} from './ops/tile';
import {parseTransposeAttributes, transpose} from './ops/transpose';
import * as unaryOps from './ops/unary-op';
import {where} from './ops/where';
import {ComputeContext} from './types';
export type RunFunction = (context: ComputeContext, attribute?: unknown) => void;
export type ParseAttributeFunction = (attributeRaw: unknown) => unknown;
export type OperatorImplementation = [RunFunction]|[RunFunction, ParseAttributeFunction];
export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new Map([
['Abs', [unaryOps.abs]],
['Acos', [unaryOps.acos]],
['Acosh', [unaryOps.acosh]],
['Add', [binaryOps.add]],
['ArgMax', [argMax, parseArgMinMaxAttributes]],
['ArgMin', [argMin, parseArgMinMaxAttributes]],
['Asin', [unaryOps.asin]],
['Asinh', [unaryOps.asinh]],
['Atan', [unaryOps.atan]],
['Atanh', [unaryOps.atanh]],
['Attention', [attention]],
// TODO: support new attributes for AveragePool-10
['AveragePool', [pool.averagePool, pool.parseAveragePoolAttributes]],
['BatchNormalization', [batchNorm]],
['BiasAdd', [biasAdd]],
['BiasSplitGelu', [biasSplitGelu]],
['Cast', [unaryOps.cast, unaryOps.parseCastAttributes]],
['Ceil', [unaryOps.ceil]],
['Clip', [unaryOps.clip]],
['Concat', [concat, parseConcatAttributes]],
['Conv', [conv, parseConvAttributes]],
['ConvTranspose', [convTranspose, parseConvTransposeAttributes]],
['Cos', [unaryOps.cos]],
['Cosh', [unaryOps.cosh]],
['CumSum', [cumsum, parseCumSumAttributes]],
['Div', [binaryOps.div]],
['Einsum', [einsum, parseEinsumAttributes]],
['Elu', [unaryOps.elu, unaryOps.parseAlphaAttributes]],
['Equal', [binaryOps.equal]],
['Erf', [unaryOps.erf]],
['Exp', [unaryOps.exp]],
['Expand', [expand]],
['Floor', [unaryOps.floor]],
['FusedConv', [conv, parseConvAttributes]],
['Gather', [gather, parseGatherAttributes]],
['GatherElements', [gatherElements, parseGatherElementsAttributes]],
['Gelu', [unaryOps.gelu]],
['Gemm', [gemm, parseGemmAttributes]],
['GlobalAveragePool', [pool.globalAveragePool, pool.parseGlobalAveragePoolAttributes]],
['GlobalMaxPool', [pool.globalMaxPool, pool.parseGlobalMaxPoolAttributes]],
['Greater', [binaryOps.greater]],
['GreaterOrEqual', [binaryOps.greaterOrEqual]],
['InstanceNormalization', [instanceNorm]],
['LayerNormalization', [layerNorm]],
['LeakyRelu', [unaryOps.leakyRelu, unaryOps.parseAlphaAttributes]],
['Less', [binaryOps.less]],
['LessOrEqual', [binaryOps.lessOrEqual]],
['Log', [unaryOps.log]],
['MatMul', [matMul]],
// TODO: support new attributes for MaxPool-8 and MaxPool-10
['MaxPool', [pool.maxPool, pool.parseMaxPoolAttributes]],
['Mul', [binaryOps.mul]],
['MultiHeadAttention', [multiHeadAttention, parseMultiHeadAttentionAttributes]],
['Neg', [unaryOps.neg]],
['Not', [unaryOps.not]],
['Pad', [pad]],
['Pow', [binaryOps.pow]],
['Range', [range]],
['Reciprocal', [unaryOps.reciprocal]],
['ReduceMin', [reduceMin]],
['ReduceMean', [reduceMean]],
['ReduceMax', [reduceMax]],
['ReduceSum', [reduceSum]],
['ReduceProd', [reduceProd]],
['ReduceL1', [reduceL1]],
['ReduceL2', [reduceL2]],
['ReduceLogSum', [reduceLogSum]],
['ReduceLogSumExp', [reduceLogSumExp]],
['ReduceSumSquare', [reduceSumSquare]],
['Relu', [unaryOps.relu]],
['Resize', [resize, parseResizeAttributes]],
['Sigmoid', [unaryOps.sigmoid]],
['Sin', [unaryOps.sin]],
['Sinh', [unaryOps.sinh]],
['Slice', [slice, parseSliceAttributes]],
['SkipLayerNormalization', [skipLayerNorm]],
['Split', [split, parseSplitAttributes]],
['Sqrt', [unaryOps.sqrt]],
['Softmax', [softmax, parseSoftmaxAttributes]],
['Sub', [binaryOps.sub]],
['Tan', [unaryOps.tan]],
['Tanh', [unaryOps.tanh]],
['ThresholdedRelu', [unaryOps.thresholdedRelu, unaryOps.parseAlphaAttributes]],
['Tile', [tile]],
['Transpose', [transpose, parseTransposeAttributes]],
['Where', [where]],
]);