Skip to content

Commit e67841f

Browse files
pk-ghouseroad
authored andcommitted
Add Scatter op to ONNX (onnx#1517)
* Adding Scatter op to ONNX * fixing build onnx#6598 break * fixing build onnx#6601 break * Addressing PR Feedback onnx#1 * fixing build onnx#6601 break * removing renamed/redundant test folders from commit * updating testcoverage.md * Addressing PR Feedback onnx#2 * updating testcoverage file * Update onnx/defs/tensor/defs.cc Co-Authored-By: pk-g <[email protected]> * resolving merge conflicts * resolving merge conflicts onnx#2 * Applying PR Feedback * applying PR feedback and resolving merge conflicts
1 parent fe8e76b commit e67841f

File tree

17 files changed

+440
-1
lines changed

17 files changed

+440
-1
lines changed

docs/Changelog.md

+74
Original file line numberDiff line numberDiff line change
@@ -9263,6 +9263,80 @@ This version of the operator has been available since version 9 of the default O
92639263
<dd>All Tensor types</dd>
92649264
</dl>
92659265

9266+
### <a name="Scatter-9"></a>**Scatter-9**</a>
9267+
9268+
Given `data`, `updates` and `indices` input tensors of rank r >= 1, write the values provided by `updates`
9269+
into the first input, `data`, along `axis` dimension of `data` (by default outer-most one as axis=0) at corresponding `indices`.
9270+
For each entry in `updates`, the target index in `data` is specified by corresponding entry in `indices`
9271+
for dimension = axis, and index in source for dimension != axis. For instance, in a 2-D tensor case,
9272+
data[indices[i][j]][j] = updates[i][j] if axis = 0, or data[i][indices[i][j]] = updates[i][j] if axis = 1,
9273+
where i and j are loop counters from 0 up to the respective size in `updates` - 1.
9274+
9275+
Example 1:
9276+
data = [
9277+
[0.0, 0.0, 0.0],
9278+
[0.0, 0.0, 0.0],
9279+
[0.0, 0.0, 0.0],
9280+
]
9281+
indices = [
9282+
[1, 0, 2],
9283+
[0, 2, 1],
9284+
]
9285+
updates = [
9286+
[1.0, 1.1, 1.2],
9287+
[2.0, 2.1, 2.2],
9288+
]
9289+
output = [
9290+
[2.0, 1.1, 0.0]
9291+
[1.0, 0.0, 2.2]
9292+
[0.0, 2.1, 1.2]
9293+
]
9294+
9295+
Example 2:
9296+
data = [[1.0, 2.0, 3.0, 4.0, 5.0]]
9297+
indices = [[1, 3]]
9298+
updates = [[1.1, 2.1]]
9299+
axis = 1
9300+
output = [[1.0, 1.1, 3.0, 2.1, 5.0]]
9301+
9302+
#### Version
9303+
9304+
This version of the operator has been available since version 9 of the default ONNX operator set.
9305+
9306+
#### Attributes
9307+
9308+
<dl>
9309+
<dt><tt>axis</tt> : int (default is 0)</dt>
9310+
<dd>Which axis to scatter on. Negative value means counting dimensions from the back. Accepted range in [-r, r-1]</dd>
9311+
</dl>
9312+
9313+
#### Inputs
9314+
9315+
<dl>
9316+
<dt><tt>data</tt> : T</dt>
9317+
<dd>Tensor of rank r >= 1.</dd>
9318+
<dt><tt>indices</tt> : Tind</dt>
9319+
<dd>Tensor of int32/int64 indices, of r >= 1 (same rank as input).</dd>
9320+
<dt><tt>updates</tt> : T</dt>
9321+
<dd>Tensor of rank r >=1 (same rank and shape as indices)</dd>
9322+
</dl>
9323+
9324+
#### Outputs
9325+
9326+
<dl>
9327+
<dt><tt>output</tt> : T</dt>
9328+
<dd>Tensor of rank r >= 1 (same rank as input).</dd>
9329+
</dl>
9330+
9331+
#### Type Constraints
9332+
9333+
<dl>
9334+
<dt><tt>T</tt> : tensor(uint8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(int8), tensor(int16), tensor(int32), tensor(int64), tensor(float16), tensor(float), tensor(double), tensor(string), tensor(bool), tensor(complex64), tensor(complex128)</dt>
9335+
<dd>Input and output types can be of any tensor type.</dd>
9336+
<dt><tt>Tind</tt> : tensor(int32), tensor(int64)</dt>
9337+
<dd>Constrain indices to integer types</dd>
9338+
</dl>
9339+
92669340
### <a name="Sign-9"></a>**Sign-9**</a>
92679341

92689342
Calculate the sign of the given input tensor element-wise.

docs/Operators.md

+127
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@
9696
* <a href="#Relu">Relu</a>
9797
* <a href="#Reshape">Reshape</a>
9898
* <a href="#Scan">Scan</a>
99+
* <a href="#Scatter">Scatter</a>
99100
* <a href="#Selu">Selu</a>
100101
* <a href="#Shape">Shape</a>
101102
* <a href="#Sigmoid">Sigmoid</a>
@@ -9879,6 +9880,132 @@ expect(node, inputs=[initial, x], outputs=[y, z],
98799880
</details>
98809881

98819882

9883+
### <a name="Scatter"></a><a name="scatter">**Scatter**</a>
9884+
9885+
Given `data`, `updates` and `indices` input tensors of rank r >= 1, write the values provided by `updates`
9886+
into the first input, `data`, along `axis` dimension of `data` (by default outer-most one as axis=0) at corresponding `indices`.
9887+
For each entry in `updates`, the target index in `data` is specified by corresponding entry in `indices`
9888+
for dimension = axis, and index in source for dimension != axis. For instance, in a 2-D tensor case,
9889+
data[indices[i][j]][j] = updates[i][j] if axis = 0, or data[i][indices[i][j]] = updates[i][j] if axis = 1,
9890+
where i and j are loop counters from 0 up to the respective size in `updates` - 1.
9891+
9892+
Example 1:
9893+
data = [
9894+
[0.0, 0.0, 0.0],
9895+
[0.0, 0.0, 0.0],
9896+
[0.0, 0.0, 0.0],
9897+
]
9898+
indices = [
9899+
[1, 0, 2],
9900+
[0, 2, 1],
9901+
]
9902+
updates = [
9903+
[1.0, 1.1, 1.2],
9904+
[2.0, 2.1, 2.2],
9905+
]
9906+
output = [
9907+
[2.0, 1.1, 0.0]
9908+
[1.0, 0.0, 2.2]
9909+
[0.0, 2.1, 1.2]
9910+
]
9911+
9912+
Example 2:
9913+
data = [[1.0, 2.0, 3.0, 4.0, 5.0]]
9914+
indices = [[1, 3]]
9915+
updates = [[1.1, 2.1]]
9916+
axis = 1
9917+
output = [[1.0, 1.1, 3.0, 2.1, 5.0]]
9918+
9919+
#### Version
9920+
9921+
This version of the operator has been available since version 9 of the default ONNX operator set.
9922+
9923+
#### Attributes
9924+
9925+
<dl>
9926+
<dt><tt>axis</tt> : int (default is 0)</dt>
9927+
<dd>Which axis to scatter on. Negative value means counting dimensions from the back. Accepted range in [-r, r-1]</dd>
9928+
</dl>
9929+
9930+
#### Inputs
9931+
9932+
<dl>
9933+
<dt><tt>data</tt> : T</dt>
9934+
<dd>Tensor of rank r >= 1.</dd>
9935+
<dt><tt>indices</tt> : Tind</dt>
9936+
<dd>Tensor of int32/int64 indices, of r >= 1 (same rank as input).</dd>
9937+
<dt><tt>updates</tt> : T</dt>
9938+
<dd>Tensor of rank r >=1 (same rank and shape as indices)</dd>
9939+
</dl>
9940+
9941+
#### Outputs
9942+
9943+
<dl>
9944+
<dt><tt>output</tt> : T</dt>
9945+
<dd>Tensor of rank r >= 1 (same rank as input).</dd>
9946+
</dl>
9947+
9948+
#### Type Constraints
9949+
9950+
<dl>
9951+
<dt><tt>T</tt> : tensor(uint8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(int8), tensor(int16), tensor(int32), tensor(int64), tensor(float16), tensor(float), tensor(double), tensor(string), tensor(bool), tensor(complex64), tensor(complex128)</dt>
9952+
<dd>Input and output types can be of any tensor type.</dd>
9953+
<dt><tt>Tind</tt> : tensor(int32), tensor(int64)</dt>
9954+
<dd>Constrain indices to integer types</dd>
9955+
</dl>
9956+
9957+
9958+
#### Examples
9959+
9960+
<details>
9961+
<summary>scatter_with_axis</summary>
9962+
9963+
```python
9964+
node = onnx.helper.make_node(
9965+
'Scatter',
9966+
inputs=['data', 'indices', 'updates'],
9967+
outputs=['y'],
9968+
axis=1,
9969+
)
9970+
data = np.array([[1.0, 2.0, 3.0, 4.0, 5.0]], dtype=np.float32)
9971+
indices = np.array([[1, 3]], dtype=np.int64)
9972+
updates = np.array([[1.1, 2.1]], dtype=np.float32)
9973+
9974+
y = np.array([[1.0, 1.1, 3.0, 2.1, 5.0]], dtype=np.float32)
9975+
9976+
expect(node, inputs=[data, indices, updates], outputs=[y],
9977+
name='test_scatter_with_axis')
9978+
```
9979+
9980+
</details>
9981+
9982+
9983+
<details>
9984+
<summary>scatter_without_axis</summary>
9985+
9986+
```python
9987+
node = onnx.helper.make_node(
9988+
'Scatter',
9989+
inputs=['data', 'indices', 'updates'],
9990+
outputs=['y'],
9991+
)
9992+
data = np.zeros((3, 3), dtype=np.float32)
9993+
indices = np.array([[1, 0, 2], [0, 2, 1]], dtype=np.int64)
9994+
updates = np.array([[1.0, 1.1, 1.2], [2.0, 2.1, 2.2]], dtype=np.float32)
9995+
9996+
y = np.array([
9997+
[2.0, 1.1, 0.0],
9998+
[1.0, 0.0, 2.2],
9999+
[0.0, 2.1, 1.2]
10000+
], dtype=np.float32)
10001+
10002+
expect(node, inputs=[data, indices, updates], outputs=[y],
10003+
name='test_scatter_without_axis')
10004+
```
10005+
10006+
</details>
10007+
10008+
988210009
### <a name="Selu"></a><a name="selu">**Selu**</a>
988310010

988410011
Selu takes one input data (Tensor<T>) and produces one output data

docs/TestCoverage.md

+50-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
* [Overall Test Coverage](#overall-test-coverage)
66
# Node Test Coverage
77
## Summary
8-
Node tests have covered 105/112 (93.75%, 5 generators excluded) common operators.
8+
Node tests have covered 106/113 (93.81%, 5 generators excluded) common operators.
99

1010
Node tests have covered 2/12 (16.67%, 0 generators excluded) experimental operators.
1111

@@ -5184,6 +5184,55 @@ expect(node, inputs=[initial, x], outputs=[y, z],
51845184
</details>
51855185

51865186

5187+
### Scatter
5188+
There are 2 test cases, listed as following:
5189+
<details>
5190+
<summary>scatter_with_axis</summary>
5191+
5192+
```python
5193+
node = onnx.helper.make_node(
5194+
'Scatter',
5195+
inputs=['data', 'indices', 'updates'],
5196+
outputs=['y'],
5197+
axis=1,
5198+
)
5199+
data = np.array([[1.0, 2.0, 3.0, 4.0, 5.0]], dtype=np.float32)
5200+
indices = np.array([[1, 3]], dtype=np.int64)
5201+
updates = np.array([[1.1, 2.1]], dtype=np.float32)
5202+
5203+
y = np.array([[1.0, 1.1, 3.0, 2.1, 5.0]], dtype=np.float32)
5204+
5205+
expect(node, inputs=[data, indices, updates], outputs=[y],
5206+
name='test_scatter_with_axis')
5207+
```
5208+
5209+
</details>
5210+
<details>
5211+
<summary>scatter_without_axis</summary>
5212+
5213+
```python
5214+
node = onnx.helper.make_node(
5215+
'Scatter',
5216+
inputs=['data', 'indices', 'updates'],
5217+
outputs=['y'],
5218+
)
5219+
data = np.zeros((3, 3), dtype=np.float32)
5220+
indices = np.array([[1, 0, 2], [0, 2, 1]], dtype=np.int64)
5221+
updates = np.array([[1.0, 1.1, 1.2], [2.0, 2.1, 2.2]], dtype=np.float32)
5222+
5223+
y = np.array([
5224+
[2.0, 1.1, 0.0],
5225+
[1.0, 0.0, 2.2],
5226+
[0.0, 2.1, 1.2]
5227+
], dtype=np.float32)
5228+
5229+
expect(node, inputs=[data, indices, updates], outputs=[y],
5230+
name='test_scatter_without_axis')
5231+
```
5232+
5233+
</details>
5234+
5235+
51875236
### Selu
51885237
There are 2 test cases, listed as following:
51895238
<details>
+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from __future__ import absolute_import
2+
from __future__ import division
3+
from __future__ import print_function
4+
from __future__ import unicode_literals
5+
6+
import numpy as np # type: ignore
7+
8+
import onnx
9+
from ..base import Base
10+
from . import expect
11+
12+
13+
class Scatter(Base):
14+
15+
@staticmethod
16+
def export_scatter_without_axis(): # type: () -> None
17+
node = onnx.helper.make_node(
18+
'Scatter',
19+
inputs=['data', 'indices', 'updates'],
20+
outputs=['y'],
21+
)
22+
data = np.zeros((3, 3), dtype=np.float32)
23+
indices = np.array([[1, 0, 2], [0, 2, 1]], dtype=np.int64)
24+
updates = np.array([[1.0, 1.1, 1.2], [2.0, 2.1, 2.2]], dtype=np.float32)
25+
26+
y = np.array([
27+
[2.0, 1.1, 0.0],
28+
[1.0, 0.0, 2.2],
29+
[0.0, 2.1, 1.2]
30+
], dtype=np.float32)
31+
32+
expect(node, inputs=[data, indices, updates], outputs=[y],
33+
name='test_scatter_without_axis')
34+
35+
@staticmethod
36+
def export_scatter_with_axis(): # type: () -> None
37+
node = onnx.helper.make_node(
38+
'Scatter',
39+
inputs=['data', 'indices', 'updates'],
40+
outputs=['y'],
41+
axis=1,
42+
)
43+
data = np.array([[1.0, 2.0, 3.0, 4.0, 5.0]], dtype=np.float32)
44+
indices = np.array([[1, 3]], dtype=np.int64)
45+
updates = np.array([[1.1, 2.1]], dtype=np.float32)
46+
47+
y = np.array([[1.0, 1.1, 3.0, 2.1, 5.0]], dtype=np.float32)
48+
49+
expect(node, inputs=[data, indices, updates], outputs=[y],
50+
name='test_scatter_with_axis')
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
 backend-test:�
2+
1
3+
data
4+
indices
5+
updatesy"Scatter*
6+
axis�test_scatter_with_axisZ
7+
data
8+

9+

10+
Z
11+
indices
12+

13+

14+
Z
15+
updates
16+

17+

18+
b
19+
y
20+

21+

22+
B
Binary file not shown.
Binary file not shown.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
BupdatesJ�̌?ff@
Binary file not shown.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
 backend-test:�
2+
$
3+
data
4+
indices
5+
updatesy"Scattertest_scatter_without_axisZ
6+
data
7+

8+

9+
Z
10+
indices
11+

12+

13+
Z
14+
updates
15+

16+

17+
b
18+
y
19+

20+

21+
B
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 commit comments

Comments
 (0)