-
Notifications
You must be signed in to change notification settings - Fork 3.1k
/
Copy pathsplit.ts
135 lines (124 loc) · 5.65 KB
/
split.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, ProgramInfo, ProgramUniform, TensorInfo} from '../types';
import {createTensorShapeVariables, getElementAt, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common';
export interface SplitAttributes extends AttributeWithCacheKey {
readonly axis: number;
readonly numOutputs: number;
readonly splitSizes: number[];
}
const validateInputs = (inputs: readonly TensorView[]): void => {
if (!inputs || inputs.length < 1) {
throw new Error('too few inputs');
}
};
const createSplitAttributesFromInputs =
(inputs: readonly TensorView[], attributes: SplitAttributes): SplitAttributes => {
const splitSizes: number[] = [];
let numOutputs: number = attributes.numOutputs;
if (inputs[1].dims[0] > 0) {
inputs[1].getBigInt64Array().forEach(v => splitSizes.push(Number(v)));
numOutputs = splitSizes.length;
}
return createAttributeWithCacheKey({numOutputs, axis: attributes.axis, splitSizes});
};
const calculateOutputIndexImpl = (numberOfTensors: number): string => `
fn calculateOutputIndex(index: u32) -> u32 {
for (var i: u32 = 0u; i < ${numberOfTensors}u; i += 1u ) {
if (index < ${getElementAt('uniforms.size_in_split_axis', 'i', numberOfTensors)}) {
return i;
}
}
return ${numberOfTensors}u;
}`;
const writeBufferDataImpl = (outputs: readonly IndicesHelper[]) => {
const numberOfTensors = outputs.length;
const codeLines: string[] = [];
for (let i = 0; i < numberOfTensors; ++i) {
const returnSnippet = outputs[i].setByIndices('indices', 'input[global_idx]');
if (numberOfTensors === 1) {
codeLines.push(returnSnippet);
} else if (i === 0) {
codeLines.push(`if (output_number == ${i}u) { ${returnSnippet} }`);
} else if (i === numberOfTensors - 1) {
codeLines.push(`else { ${returnSnippet} }`);
} else {
codeLines.push(`else if (output_number == ${i}) { ${returnSnippet} }`);
}
}
return `
fn writeBufferData(output_number: u32, indices: ${outputs[0].type.indices}, global_idx: u32) {
${codeLines.join('\n')}
}`;
};
const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: SplitAttributes): ProgramInfo => {
const inputShape = inputs[0].dims;
const inputSize = ShapeUtil.size(inputShape);
const dataType = inputs[0].dataType;
const axis = ShapeUtil.normalizeAxis(attributes.axis, inputShape.length);
const outputs = new Array<IndicesHelper>(attributes.numOutputs);
const input = inputVariable('input', dataType, inputShape);
const sizeInSplitAxis = new Array<number>(attributes.numOutputs);
const outputsTensorInfo: TensorInfo[] = [];
const outputShapes: number[][] = [];
let previousSum = 0;
const programUniforms: ProgramUniform[] = [{type: DataType.uint32, data: inputSize}];
for (let i = 0; i < attributes.numOutputs; i++) {
previousSum += attributes.splitSizes[i];
sizeInSplitAxis[i] = previousSum;
const outputShape = inputShape.slice();
outputShape[attributes.axis] = attributes.splitSizes[i];
outputShapes.push(outputShape);
outputs[i] = outputVariable(`output${i}`, dataType, outputShape);
outputsTensorInfo.push({dims: outputShapes[i], dataType: inputs[0].dataType});
}
programUniforms.push(
{type: DataType.uint32, data: sizeInSplitAxis}, ...createTensorShapeVariables(inputShape, ...outputShapes));
const getShaderSource = (shaderHelper: ShaderHelper) => `
${
shaderHelper.registerUniform('input_size', 'u32')
.registerUniform('size_in_split_axis', 'u32', sizeInSplitAxis.length)
.declareVariables(input, ...outputs)}
${calculateOutputIndexImpl(sizeInSplitAxis.length)}
${writeBufferDataImpl(outputs)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.input_size')}
var indices = ${input.offsetToIndices('global_idx')};
var index = ${input.indicesGet('indices', axis)};
let output_number = calculateOutputIndex(index);
if (output_number != 0) {
index -= ${getElementAt('uniforms.size_in_split_axis', 'output_number - 1u', sizeInSplitAxis.length)};
${input.indicesSet('indices', axis, 'index')};
}
writeBufferData(output_number, indices, global_idx);
}`;
return {
name: 'Split',
shaderCache: {hint: `${attributes.cacheKey};${inputShape}`, inputDependencies: ['rank']},
getShaderSource,
getRunData: () => ({
outputs: outputsTensorInfo,
dispatchGroup: {x: Math.ceil(inputSize / 64 /* workgroup size */)},
programUniforms
})
};
};
export const split = (context: ComputeContext, attributes: SplitAttributes): void => {
validateInputs(context.inputs);
const updatedAttributes =
context.inputs.length === 1 ? attributes : createSplitAttributesFromInputs(context.inputs, attributes);
context.compute(createSplitProgramInfo(context.inputs, updatedAttributes), {inputs: [0]});
};
export const parseSplitAttributes = (attributes: Record<string, unknown>): SplitAttributes => {
const axis = attributes.axis as number;
const splitSizes: number[] = attributes.splitSizes as number[];
const numOutputs = attributes.numOutputs as number < 0 ? splitSizes.length : attributes.numOutputs as number;
if (numOutputs !== splitSizes.length) {
throw new Error('numOutputs and splitSizes lengh must be equal');
}
return createAttributeWithCacheKey({axis, numOutputs, splitSizes});
};