1
+ #include < bitset>
1
2
#include " core/util/prelude.h"
2
3
#include " core/conversion/converters/converters.h"
3
4
@@ -22,25 +23,36 @@ auto reduced_registrations = RegisterNodeConversionPatterns()
22
23
TRTORCH_CHECK (mean_layer, " Unable to create mean layer from node: " << *n);
23
24
24
25
mean_layer->setName (util::node_info (n).c_str ());
25
- ctx->AssociateValueAndTensor (n->outputs ()[0 ], mean_layer->getOutput (0 ));
26
+ auto out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], mean_layer->getOutput (0 ));
27
+
28
+ LOG_DEBUG (" Output shape: " << out_tensor->getDimensions ());
26
29
return true ;
27
30
}
28
31
}).pattern({
29
- " aten::mean.dim(Tensor self, int[1 ] dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)" ,
32
+ " aten::mean.dim(Tensor self, int[] dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)" ,
30
33
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
31
34
auto in_tensor = args[0 ].ITensor ();
32
- auto dim = args[1 ].unwrapToIntList ();
33
- auto keepdim = args[ 2 ]. unwrapToBool ();
35
+ auto dims = args[1 ].unwrapToIntList ();
36
+ LOG_DEBUG ( " Dim to reduce: " << util::toDims (dims)); // Some abuse of toDim but just for debug info
34
37
35
- uint32_t axis_mask = 1 << dim[0 ];
38
+ uint32_t axis_mask = 0 ;
39
+ for (int d = 0 ; d < dims.size (); d++) {
40
+ axis_mask |= 1 << dims[d];
41
+ }
42
+ LOG_DEBUG (" Axis Mask" << std::bitset<32 >(axis_mask));
43
+
44
+ auto keepdim = args[2 ].unwrapToBool ();
45
+ LOG_DEBUG (" Keep dims :" << keepdim);
36
46
37
47
LOG_WARNING (" Mean converter disregards dtype" );
38
48
auto mean_layer = ctx->net ->addReduce (*in_tensor, nvinfer1::ReduceOperation::kAVG , axis_mask, keepdim);
39
49
40
50
TRTORCH_CHECK (mean_layer, " Unable to create mean layer from node: " << *n);
41
51
42
52
mean_layer->setName (util::node_info (n).c_str ());
43
- ctx->AssociateValueAndTensor (n->outputs ()[0 ], mean_layer->getOutput (0 ));
53
+ auto out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], mean_layer->getOutput (0 ));
54
+
55
+ LOG_DEBUG (" Output shape: " << out_tensor->getDimensions ());
44
56
return true ;
45
57
}
46
58
});
0 commit comments