Skip to content

Commit fa7d6d9

Browse files
committed
feat(aten::masked_fill): In progress implementation of masked_fill
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 6aaba3b commit fa7d6d9

File tree

8 files changed

+310
-1
lines changed

8 files changed

+310
-1
lines changed

Diff for: core/conversion/converters/converter_util.cpp

+21
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,27 @@ nvinfer1::ILayer* add_elementwise(
122122
return ele;
123123
}
124124

125+
nvinfer1::ITensor* castITensor(ConversionCtx* ctx, nvinfer1::ITensor* tensor, nvinfer1::DataType dtype) {
126+
if (tensor->getType() != dtype) {
127+
std::ostringstream tensor_id;
128+
tensor_id << reinterpret_cast<int*>(tensor);
129+
130+
auto id_layer = ctx->net->addIdentity(*tensor);
131+
TRTORCH_CHECK(id_layer, "Unable to create identity layer for ITensor: " << tensor_id.str());
132+
auto casted_tensor = id_layer->getOutput(0);
133+
casted_tensor->setType(dtype);
134+
135+
LOG_DEBUG(ctx->logger, "Casting ITensor " << tensor_id.str() << " from " << tensor->getType() << " to " << dtype);
136+
137+
std::stringstream ss;
138+
ss << "[Cast ITensor " << tensor_id.str() << " from " << tensor->getType() << " to " << dtype << "]";
139+
id_layer->setName(ss.str().c_str());
140+
return casted_tensor;
141+
} else {
142+
return tensor;
143+
}
144+
}
145+
125146
} // namespace converters
126147
} // namespace conversion
127148
} // namespace core

Diff for: core/conversion/converters/converter_util.h

+3
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ nvinfer1::ILayer* add_elementwise(
4242
nvinfer1::ITensor* other,
4343
const std::string& name);
4444

45+
// If an ITensor is of a type not dtype, add an Identity layer to cast it to dtype
46+
nvinfer1::ITensor* castITensor(ConversionCtx* ctx, nvinfer1::ITensor* tensor, nvinfer1::DataType dtype);
47+
4548
} // namespace converters
4649
} // namespace conversion
4750
} // namespace core

Diff for: core/conversion/converters/impl/select.cpp

+23-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "NvInfer.h"
44
#include "c10/util/intrusive_ptr.h"
55
#include "core/conversion/converters/converters.h"
6+
#include "core/conversion/converters/converter_util.h"
67
#include "core/conversion/tensorcontainer/TensorContainer.h"
78
#include "core/util/prelude.h"
89
#include "torch/torch.h"
@@ -247,7 +248,28 @@ auto select_registrations TRTORCH_UNUSED =
247248
add_split(ctx, n, args, true);
248249
LOG_DEBUG("Converted split op into a list of IValues");
249250
return true;
250-
}});
251+
}})
252+
.pattern({
253+
"aten::masked_fill.Scalar(Tensor self, Tensor mask, Scalar value) -> (Tensor)",
254+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
255+
auto self = args[0].ITensorOrFreeze(ctx);
256+
LOG_DEBUG(args[1].unwrapToTensor());
257+
auto mask = castITensor(ctx, args[1].ITensorOrFreeze(ctx), nvinfer1::DataType::kBOOL);
258+
auto val = args[2].unwrapToScalar().to<float>();
259+
LOG_DEBUG(torch::full(util::toVec(self->getDimensions()), val));
260+
auto val_t = tensor_to_const(ctx, torch::full(util::toVec(self->getDimensions()), val));
261+
262+
TRTORCH_CHECK(util::broadcastable(self->getDimensions(), mask->getDimensions(), /*multidirectional=*/false), "Self and mask tensors are not broadcastable");
263+
264+
auto new_layer = ctx->net->addSelect(*mask, *self, *val_t);
265+
TRTORCH_CHECK(new_layer, "Unable to create layer for aten::masked_fill");
266+
267+
new_layer->setName(util::node_info(n).c_str());
268+
269+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0));
270+
LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
271+
return true;
272+
}});
251273

252274
} // namespace
253275
} // namespace impl

Diff for: core/conversion/evaluators/aten.cpp

+16
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "torch/torch.h"
99

1010
#include "core/conversion/evaluators/eval_macros.h"
11+
#include "core/conversion/evaluators/eval_util.h"
1112
#include "core/conversion/evaluators/evaluators.h"
1213

1314
namespace trtorch {
@@ -566,6 +567,21 @@ auto aten_registrations TRTORCH_UNUSED =
566567
return {};
567568
},
568569
EvalOptions()})
570+
.evaluator({c10::Symbol::fromQualString("aten::tensor"),
571+
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
572+
auto data = args.at(n->input(0)).IValue();
573+
auto dtype = args.at(n->input(1)).IValue();
574+
auto device = args.at(n->input(2)).IValue();
575+
auto tensor = createTensorFromList(*data, *dtype, *device);
576+
LOG_DEBUG(tensor);
577+
if (tensor.dtype() == at::kByte) {
578+
return tensor.to(at::kInt);
579+
}
580+
return tensor;
581+
},
582+
EvalOptions().validSchemas({
583+
"aten::tensor(t[] data, *, int? dtype=None, Device? device=None, bool requires_grad=False) -> (Tensor)"
584+
})})
569585
.evaluator({c10::Symbol::fromQualString("aten::arange"),
570586
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
571587
int input_size = n->inputs().size();

Diff for: core/conversion/evaluators/eval_util.cpp

+201
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
#include "ATen/InitialTensorOptions.h"
12
#include "ATen/core/List.h"
23
#include "ATen/core/functional.h"
34
#include "ATen/core/ivalue.h"
5+
#include "ATen/core/jit_type.h"
6+
#include "c10/util/irange.h"
47
#include "core/util/prelude.h"
58

69
namespace trtorch {
@@ -91,6 +94,204 @@ c10::optional<torch::jit::IValue> toIValue(const torch::jit::Value* v) {
9194
}
9295
}
9396

97+
void checkListInputType(const c10::TypePtr& elem_type, bool empty_list) {
98+
if (!elem_type->isSubtypeOf(c10::NumberType::get()) &&
99+
elem_type != c10::BoolType::get()) {
100+
std::stringstream error;
101+
error << "Input must be of ints, floats, or bools, "
102+
<< "got " << elem_type->repr_str();
103+
// special case empty list torch.tensor([])
104+
if (elem_type->isSubtypeOf(c10::TensorType::get())) {
105+
if (empty_list) {
106+
error << "\nEmpty lists default to List[Tensor]. Add a variable "
107+
"annotation to the assignment to create an empty list "
108+
"of another type (torch.jit.annotate(List[T, []]) where T "
109+
"is the type of elements in the list for Python 2)";
110+
}
111+
}
112+
TRTORCH_THROW_ERROR(error.str());
113+
}
114+
}
115+
116+
void checkSequenceSize(int64_t n, int64_t dim, int64_t seq_size) {
117+
if (seq_size != n) {
118+
TRTORCH_THROW_ERROR(
119+
"Expected sequence of length "
120+
<< n
121+
<< " at dim "
122+
<< dim
123+
<< " (got "
124+
<< seq_size
125+
<< ")");
126+
}
127+
}
128+
129+
130+
131+
template <typename DTYPE>
132+
void storeLastDimension(
133+
char* data,
134+
const std::vector<int64_t>& sizes,
135+
const c10::ArrayRef<int64_t>& strides,
136+
int64_t dim,
137+
int elementSize,
138+
at::ArrayRef<torch::jit::IValue> obj) {
139+
auto n = sizes[dim];
140+
auto seq_size = obj.size();
141+
checkSequenceSize(n, dim, seq_size);
142+
for (const auto i : c10::irange(n)) {
143+
*(DTYPE*)data = obj[i].to<DTYPE>();
144+
data += strides[dim] * elementSize;
145+
}
146+
}
147+
148+
149+
void storeLastDimensionFloat(
150+
char* data,
151+
const std::vector<int64_t>& sizes,
152+
const c10::ArrayRef<int64_t>& strides,
153+
int64_t dim,
154+
int elementSize,
155+
at::ArrayRef<torch::jit::IValue> obj) {
156+
auto n = sizes[dim];
157+
auto seq_size = obj.size();
158+
checkSequenceSize(n, dim, seq_size);
159+
for (int64_t i = 0; i < n; i++) {
160+
*(float*)data = static_cast<float>(obj[i].to<double>());
161+
data += strides[dim] * elementSize;
162+
}
163+
}
164+
165+
void storeLastDimensionHalf(
166+
char* data,
167+
const std::vector<int64_t>& sizes,
168+
const c10::ArrayRef<int64_t>& strides,
169+
int64_t dim,
170+
int elementSize,
171+
at::ArrayRef<torch::jit::IValue> obj) {
172+
auto n = sizes[dim];
173+
auto seq_size = obj.size();
174+
checkSequenceSize(n, dim, seq_size);
175+
for (int64_t i = 0; i < n; i++) {
176+
*(at::Half*)data = at::convert<at::Half, double>(obj[i].to<double>());
177+
data += strides[dim] * elementSize;
178+
}
179+
}
180+
181+
void recursiveStore(
182+
char* data,
183+
const std::vector<int64_t>& sizes,
184+
const c10::ArrayRef<int64_t>& strides,
185+
int64_t dim,
186+
int tenElementSize,
187+
const torch::jit::IValue& obj) {
188+
auto ndim = sizes.size();
189+
auto n = sizes[dim];
190+
auto seq = obj.toListRef();
191+
checkSequenceSize(n, dim, seq.size());
192+
if (dim + 1 < static_cast<long>(ndim)) {
193+
for (const auto i : c10::irange(n)) {
194+
recursiveStore(data, sizes, strides, dim + 1, tenElementSize, seq[i]);
195+
data += strides[dim] * tenElementSize;
196+
}
197+
} else {
198+
if (obj.isIntList()) {
199+
storeLastDimension<int64_t>(
200+
data, sizes, strides, dim, tenElementSize, seq);
201+
} else if (obj.isBoolList()) {
202+
storeLastDimension<bool>(data, sizes, strides, dim, tenElementSize, seq);
203+
} else if (obj.isDoubleList()) {
204+
if (tenElementSize ==
205+
static_cast<int>(elementSize(at::ScalarType::Double))) {
206+
storeLastDimension<double>(
207+
data, sizes, strides, dim, tenElementSize, seq);
208+
} else if (
209+
tenElementSize ==
210+
static_cast<int>(elementSize(at::ScalarType::Float))) {
211+
storeLastDimensionFloat(data, sizes, strides, dim, tenElementSize, seq);
212+
} else if (
213+
tenElementSize ==
214+
static_cast<int>(elementSize(at::ScalarType::Half))) {
215+
storeLastDimensionHalf(data, sizes, strides, dim, tenElementSize, seq);
216+
} else {
217+
TORCH_INTERNAL_ASSERT(false);
218+
}
219+
} else {
220+
TORCH_INTERNAL_ASSERT(false);
221+
}
222+
}
223+
}
224+
225+
at::Tensor castTensorTo(
226+
at::Tensor self,
227+
const torch::jit::IValue& dtype,
228+
const torch::jit::IValue& device) {
229+
at::ScalarType scalar_type =
230+
dtype.isNone() ? self.scalar_type() : dtype.toScalarType();
231+
c10::Device dev = device.isNone() ? self.device() : device.toDevice();
232+
if (scalar_type != self.scalar_type() || dev != self.device()) {
233+
self = self.to(dev, scalar_type);
234+
}
235+
return self;
236+
}
237+
238+
std::vector<int64_t> compute_sizes(const torch::jit::IValue& seq) {
239+
std::vector<int64_t> sizes;
240+
auto seq_recur = seq.toList();
241+
while (true) {
242+
sizes.push_back(seq_recur.size());
243+
if (seq_recur.size() == 0 || !seq_recur.get(0).isList()) {
244+
break;
245+
}
246+
seq_recur = seq_recur.get(0).toList();
247+
}
248+
return sizes;
249+
}
250+
251+
at::Tensor createTensorFromList(const torch::jit::IValue& data, const torch::jit::IValue& dtype, const torch::jit::IValue& device) {
252+
auto elem_type = data.type();
253+
while (auto list_type = elem_type->cast<c10::ListType>()) {
254+
elem_type = list_type->getElementType();
255+
}
256+
auto sizes = compute_sizes(data);
257+
checkListInputType(elem_type, sizes.size() == 1 && sizes[0] == 0);
258+
at::ScalarType initial_scalar_type = c10::scalarTypeFromJitType(elem_type);
259+
if (initial_scalar_type == at::ScalarType::Double) {
260+
initial_scalar_type = at::typeMetaToScalarType(c10::get_default_dtype());
261+
}
262+
263+
auto tensor =
264+
at::empty(sizes, at::initialTensorOptions().dtype(initial_scalar_type));
265+
266+
if (tensor.numel() != 0) {
267+
recursiveStore(
268+
(char*)tensor.data_ptr(),
269+
sizes,
270+
tensor.strides(),
271+
0,
272+
tensor.element_size(),
273+
data);
274+
}
275+
276+
tensor = castTensorTo(tensor, dtype, device);
277+
auto default_type = at::typeMetaToScalarType(at::get_default_dtype());
278+
279+
if (dtype.isNone() && tensor.scalar_type() != default_type &&
280+
tensor.numel() == 0) {
281+
LOG_WARNING(
282+
"Creating a tensor from an empty "
283+
<< elem_type->repr_str()
284+
<< "list will create a tensor of default floating point type (currently "
285+
<< default_type
286+
<< ") in python but a tensor of type "
287+
<< elem_type->repr_str()
288+
<< " in torchscript.\n"
289+
<< "Pass in a dtype argument to ensure consistent behavior");
290+
}
291+
292+
return tensor;
293+
}
294+
94295
} // namespace evaluators
95296
} // namespace conversion
96297
} // namespace core

Diff for: core/conversion/evaluators/eval_util.h

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ namespace conversion {
88
namespace evaluators {
99

1010
c10::optional<torch::jit::IValue> toIValue(const torch::jit::Value* v);
11+
at::Tensor createTensorFromList(const torch::jit::IValue& data, const torch::jit::IValue& dtype, const torch::jit::IValue& device);
1112

1213
} // namespace evaluators
1314
} // namespace conversion

Diff for: core/util/trt_util.h

+2
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ inline std::ostream& operator<<(std::ostream& stream, const nvinfer1::DataType&
3131
return stream << "Int8";
3232
case nvinfer1::DataType::kINT32:
3333
return stream << "Int32";
34+
case nvinfer1::DataType::kBOOL:
35+
return stream << "Bool";
3436
default:
3537
return stream << "Unknown Data Type";
3638
}

Diff for: tests/core/conversion/converters/test_select.cpp

+43
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
#include "gtest/gtest.h"
55
#include "tests/util/util.h"
66
#include "torch/csrc/jit/ir/irparser.h"
7+
#include "core/lowering/passes/passes.h"
8+
79

810
TEST(Converters, ATenSelectIntConvertsCorrectly) {
911
const auto graph = R"IR(
@@ -398,3 +400,44 @@ TEST(Converters, ATenSplitAndAddConvertsCorrectly) {
398400
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[i], trt, 2e-6));
399401
}
400402
}
403+
404+
TEST(Converters, ATenMaskedFillZerosConvertsCorrectly) {
405+
const auto graph = R"IR(
406+
graph(%x.1 : Tensor):
407+
%44 : Device = prim::Constant[value="cuda"]()
408+
%8 : bool = prim::Constant[value=0]()
409+
%7 : None = prim::Constant()
410+
%1 : int = prim::Constant[value=0]() # bert.py:5:26
411+
%2 : int = prim::Constant[value=1]() # bert.py:5:32
412+
%33 : int = prim::Constant[value=2]() # bert.py:6:31
413+
%3 : int[] = prim::ListConstruct(%1, %1, %2)
414+
%4 : int[] = prim::ListConstruct(%2, %2, %1)
415+
%5 : int[][] = prim::ListConstruct(%3, %4)
416+
%5 : int[][][] = prim::ListConstruct(%5)
417+
%9 : Tensor = aten::tensor(%5, %1, %7, %8) # bert.py:5:11
418+
%mask.1 : Tensor = aten::to(%9, %44, %7, %8, %8) # bert.py:5:11
419+
%mask.2 : Tensor = trt::const(%mask.1)
420+
%34 : Tensor = aten::masked_fill(%x.1, %mask.1, %33) # bert.py:6:11
421+
return (%34, %mask.2))IR";
422+
423+
auto g = std::make_shared<torch::jit::Graph>();
424+
425+
torch::jit::parseIR(graph, &*g);
426+
427+
auto in = at::zeros({1, 2, 3}, {at::kCUDA});
428+
429+
430+
auto jit_in = at::clone(in);
431+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
432+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
433+
434+
auto trt_in = at::clone(in);
435+
trtorch::core::lowering::passes::RemoveNOPs(g);
436+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
437+
438+
std::cout << jit_results[0] << trt_results[0].reshape_as(jit_results[0]) << std::endl;
439+
440+
std::cout << trt_results[1].reshape_as(jit_results[0]) << std::endl;
441+
442+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
443+
}

0 commit comments

Comments
 (0)