Skip to content

Commit ad70664

Browse files
committed
Add testdata for Gather with multiple outputs
1 parent c3d6342 commit ad70664

File tree

4 files changed

+17
-0
lines changed

4 files changed

+17
-0
lines changed
Binary file not shown.
Binary file not shown.

testdata/dnn/onnx/generate_onnx_models.py

+17
Original file line numberDiff line numberDiff line change
@@ -986,3 +986,20 @@ def forward(self, x):
986986
x = Variable(torch.randn(1, 3, 2, 2))
987987
model = Scale()
988988
save_data_and_model("scale", x, model)
989+
990+
class GatherMultiOutput(nn.Module):
991+
def __init__(self, in_dim = 2):
992+
super(GatherMultiOutput, self).__init__()
993+
self.in_dim = in_dim
994+
self.lin_inp = nn.Linear(in_dim, 2, bias=False)
995+
def forward(self, x):
996+
x_projected = self.lin_inp(x).long()
997+
x_gather = x_projected[:,0,:]
998+
x_float1 = x_gather.float()
999+
x_float2 = x_gather.float()
1000+
x_float3 = x_gather.float()
1001+
return x_float1+x_float2+x_float3
1002+
1003+
x = Variable(torch.zeros([1, 2, 2]))
1004+
model = GatherMultiOutput()
1005+
save_data_and_model("gather_multi_output", x, model)
432 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)