Skip to content

Commit ff908a0

Browse files
committed
Add testdata for Gather with multiple outputs
1 parent ff9268d commit ff908a0

File tree

4 files changed

+17
-0
lines changed

4 files changed

+17
-0
lines changed
Binary file not shown.
Binary file not shown.

testdata/dnn/onnx/generate_onnx_models.py

+17
Original file line numberDiff line numberDiff line change
@@ -977,3 +977,20 @@ def forward(self, x, skip=None):
977977
x = Variable(torch.rand(1, 2, 2, 2))
978978
model = ResizeConv(2, 0, 2)
979979
save_data_and_model("resize_opset11_torch1.6", x, model, 11)
980+
981+
class GatherMultiOutput(nn.Module):
982+
def __init__(self, in_dim = 2):
983+
super(GatherMultiOutput, self).__init__()
984+
self.in_dim = in_dim
985+
self.lin_inp = nn.Linear(in_dim, 2, bias=False)
986+
def forward(self, x):
987+
x_projected = self.lin_inp(x).long()
988+
x_gather = x_projected[:,0,:]
989+
x_float1 = x_gather.float()
990+
x_float2 = x_gather.float()
991+
x_float3 = x_gather.float()
992+
return x_float1+x_float2+x_float3
993+
994+
x = Variable(torch.zeros([1, 2, 2]))
995+
model = GatherMultiOutput()
996+
save_data_and_model("gather_multi_output", x, model)
432 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)