Skip to content

Commit 2cc3226

Browse files
committed
feat(//core/conversion/evaluators): Adds new applicability filters for
evaluators. This allows developers to specifically blacklist or whiteline specific cases of node kinds so that they will run on a subset of cases instead of any instance. This is important in the case of prim::Loop where we want to evaluate some loops and not others. This also lets us use function schemas to target node, for instance there is now a aten::mul.Tensor converter and an aten::mul.int evaluator. In Tensor cases the converter will be called, in int cases the evaluator will. We cannot switch to keying on function schema like we do for converters because some node kinds dont have a schema so we do schema white listing instead. This commit also adds the following evaluators: - aten::mul.int - aten::sub.int - aten::__round_to_zero_floordiv - aten::slice.t - aten::len.t - prim::min.self_int Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 2f394fb commit 2cc3226

File tree

4 files changed

+169
-13
lines changed

4 files changed

+169
-13
lines changed

Diff for: core/conversion/evaluators/NodeEvaluatorRegistry.cpp

+44-12
Original file line numberDiff line numberDiff line change
@@ -15,27 +15,59 @@ namespace core {
1515
namespace conversion {
1616
namespace evaluators {
1717
namespace {
18-
using EvaluatorLUT = std::unordered_map<torch::jit::NodeKind, NodeEvaluator>;
18+
using EvaluatorLUT = std::unordered_map<torch::jit::NodeKind, EvalRegistration>;
19+
20+
bool FindInVec(std::vector<c10::OperatorName>& names, c10::OperatorName target) {
21+
for (auto n : names) {
22+
if (n == target) {
23+
return true;
24+
}
25+
}
26+
return false;
27+
}
1928

2029
class NodeEvaluatorRegistry {
2130
public:
22-
void RegisterEvaluator(torch::jit::NodeKind node_kind, NodeEvaluator& evaluator) {
31+
void RegisterEvaluator(torch::jit::NodeKind node_kind, EvalRegistration eval_reg) {
2332
LOG_DEBUG("Registering evaluator for " << node_kind.toQualString());
24-
evaluator_lut_[node_kind] = std::move(evaluator);
33+
evaluator_lut_[node_kind] = std::move(eval_reg);
2534
}
2635

27-
NodeEvaluator GetEvaluator(const torch::jit::NodeKind node_kind) {
36+
NodeEvaluator FindEvaluator(const torch::jit::Node* n) {
37+
auto node_kind = n->kind();
2838
auto iter = evaluator_lut_.find(node_kind);
2939
if (iter == evaluator_lut_.end()) {
30-
LOG_ERROR("Requested evaluator for " << node_kind.toQualString() << ", but no such evaluator was found");
3140
return nullptr;
3241
}
33-
return iter->second;
42+
auto eval_reg = iter->second;
43+
if (eval_reg.options.use()) {
44+
for (auto o : n->outputs()) {
45+
if (eval_reg.options.blacklisted_output_types.find(o->type()) != eval_reg.options.blacklisted_output_types.end()) {
46+
return nullptr;
47+
}
48+
}
49+
50+
if (eval_reg.options.valid_schemas.size() != 0) {
51+
auto schema = n->maybeSchema();
52+
TRTORCH_CHECK(schema, "Evaluator for " << node_kind.toQualString() << "only runs on certain schemas, but schema for node is retrievable");
53+
if (!FindInVec(eval_reg.options.valid_schemas, schema->operator_name())) {
54+
return nullptr;
55+
}
56+
}
57+
}
58+
59+
return eval_reg.evaluator;
60+
}
61+
62+
NodeEvaluator GetEvaluator(const torch::jit::Node* n) {
63+
auto evaluator = FindEvaluator(n);
64+
TRTORCH_CHECK(evaluator, "Requested evaluator for " << n->kind().toQualString() << ", but no such evaluator was found");
65+
return evaluator;
3466
}
3567

3668
bool EvalAtConversionTime(const torch::jit::Node* n) {
37-
auto eval_at_conversion = evaluator_lut_.find(n->kind());
38-
if (eval_at_conversion == evaluator_lut_.end()) {
69+
auto evaluator = FindEvaluator(n);
70+
if (evaluator == nullptr) {
3971
return false;
4072
} else {
4173
return true;
@@ -58,16 +90,16 @@ bool shouldEvalAtConversionTime(const torch::jit::Node* n) {
5890
}
5991

6092
c10::optional<torch::jit::IValue> EvalNode(const torch::jit::Node* n, kwargs& args) {
61-
auto evaluator = get_evaluator_registry().GetEvaluator(n->kind());
93+
auto evaluator = get_evaluator_registry().GetEvaluator(n);
6294
return evaluator(n, args);
6395
}
6496

65-
void register_node_evaluator(torch::jit::NodeKind node_kind, NodeEvaluator evaluator) {
66-
get_evaluator_registry().RegisterEvaluator(node_kind, evaluator);
97+
void register_node_evaluator(torch::jit::NodeKind node_kind, EvalRegistration eval_reg) {
98+
get_evaluator_registry().RegisterEvaluator(node_kind, std::move(eval_reg));
6799
}
68100

69101
void register_node_evaluator(EvalRegistration r) {
70-
register_node_evaluator(r.kind, r.evaluator);
102+
register_node_evaluator(r.kind, std::move(r));
71103
}
72104

73105
RegisterNodeEvaluators&& RegisterNodeEvaluators::evaluator(EvalRegistration r) && {

Diff for: core/conversion/evaluators/aten.cpp

+73-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,15 @@ namespace conversion {
1515
namespace evaluators {
1616
namespace {
1717

18+
19+
int64_t normalizeIndex(int64_t idx, int64_t list_size) {
20+
if (idx < 0) {
21+
// Handle negative indexing
22+
idx = list_size + idx;
23+
}
24+
return idx;
25+
}
26+
1827
auto aten_registrations = RegisterNodeEvaluators()
1928
.evaluator({
2029
c10::Symbol::fromQualString("aten::zeros"),
@@ -25,9 +34,72 @@ auto aten_registrations = RegisterNodeEvaluators()
2534
.layout(torch::kStrided)
2635
.device(torch::kCUDA);
2736

28-
auto out_tensor = torch::zeros(args.at(&(n->output()[0])).unwrapToIntList().vec(), options);
37+
auto out_tensor = torch::zeros(args.at(&(n->input()[0])).unwrapToIntList().vec(), options);
2938
return out_tensor;
3039
}
40+
}).evaluator({
41+
c10::Symbol::fromQualString("aten::mul"),
42+
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
43+
auto a = args.at(&(n->input()[0])).unwrapToInt();
44+
auto b = args.at(&(n->input()[1])).unwrapToInt();
45+
return a * b;
46+
},
47+
EvalOptions().validSchemas({"aten::mul.int(int a, int b) -> (int)"})
48+
}).evaluator({
49+
c10::Symbol::fromQualString("aten::sub"),
50+
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
51+
auto a = args.at(&(n->input()[0])).unwrapToInt();
52+
auto b = args.at(&(n->input()[1])).unwrapToInt();
53+
return a - b;
54+
},
55+
EvalOptions().validSchemas({"aten::sub.int(int a, int b) -> (int)"})
56+
}).evaluator({
57+
c10::Symbol::fromQualString("aten::__round_to_zero_floordiv"),
58+
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
59+
auto a = args.at(&(n->input()[0])).unwrapToInt();
60+
auto b = args.at(&(n->input()[1])).unwrapToInt();
61+
return a / b;
62+
},
63+
EvalOptions().validSchemas({"aten::__round_to_zero_floordiv(int a, int b) -> (int)"})
64+
}).evaluator({
65+
c10::Symbol::fromQualString("aten::slice"),
66+
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
67+
c10::List<c10::IValue> list = args.at(&(n->input()[0])).IValue()->to<c10::List<c10::IValue>>();
68+
int64_t start = args.at(&(n->input()[0])).unwrapToInt();
69+
int64_t end = args.at(&(n->input()[0])).unwrapToInt();
70+
int64_t step = args.at(&(n->input()[0])).unwrapToInt();
71+
72+
const int64_t list_size = list.size();
73+
74+
// clamp start and end to the bounds of the list
75+
const auto normalized_start =
76+
std::max((int64_t)0, normalizeIndex(start, list_size));
77+
const auto normalized_end =
78+
std::min(list_size, normalizeIndex(end, list_size));
79+
80+
auto sliced_list = c10::impl::GenericList(list.elementType());
81+
if (normalized_end <= normalized_start) {
82+
// early exit if the slice is trivially empty
83+
return sliced_list;
84+
}
85+
86+
sliced_list.reserve(normalized_end - normalized_start);
87+
88+
for (auto i = normalized_start; i < normalized_end;) {
89+
sliced_list.push_back(list.get(i));
90+
i += step;
91+
}
92+
93+
return sliced_list;
94+
},
95+
EvalOptions().validSchemas({"aten::slice.t(t[] l, int start, int end=9223372036854775807, int step=1) -> (t[])"})
96+
}).evaluator({
97+
c10::Symbol::fromQualString("aten::len"),
98+
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
99+
c10::List<c10::IValue> list = args.at(&(n->input()[0])).IValue()->to<c10::List<c10::IValue>>();
100+
return static_cast<int64_t>(list.size());
101+
},
102+
EvalOptions().validSchemas({"aten::len.t(t[] a) -> (int)"})
31103
});
32104
}
33105
} // namespace evaluators

Diff for: core/conversion/evaluators/evaluators.h

+27
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,36 @@ inline bool constTypesOnly(kwargs& args) {
3535
// when writing evaluators
3636
typedef std::function<c10::optional<torch::jit::IValue>(const torch::jit::Node*, kwargs&)> NodeEvaluator;
3737

38+
struct EvalOptions {
39+
std::set<c10::TypePtr> blacklisted_output_types;
40+
std::vector<c10::OperatorName> valid_schemas;
41+
EvalOptions() = default;
42+
EvalOptions& blacklistOutputTypes(std::set<c10::TypePtr> types) {
43+
use_options = true;
44+
blacklisted_output_types = types;
45+
return *this;
46+
}
47+
EvalOptions& validSchemas(std::set<std::string> schemas) {
48+
use_options = true;
49+
for (auto s : schemas) {
50+
valid_schemas.push_back(torch::jit::parseSchema(s).operator_name());
51+
}
52+
return *this;
53+
}
54+
bool use() { return use_options; }
55+
private:
56+
bool use_options = false;
57+
};
58+
3859
struct EvalRegistration {
3960
torch::jit::NodeKind kind;
4061
NodeEvaluator evaluator;
62+
EvalOptions options;
63+
EvalRegistration() = default;
64+
EvalRegistration(torch::jit::NodeKind _kind, NodeEvaluator _evaluator)
65+
: kind(_kind), evaluator(_evaluator), options(EvalOptions()) {};
66+
EvalRegistration(torch::jit::NodeKind _kind, NodeEvaluator _evaluator, EvalOptions _options)
67+
: kind(_kind), evaluator(_evaluator), options(_options) {};
4168
};
4269

4370
c10::optional<torch::jit::IValue> EvalNode(const torch::jit::Node* n, kwargs& args);

Diff for: core/conversion/evaluators/prim.cpp

+25
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
#include <limits>
2+
13
#include "torch/csrc/jit/ir/ir.h"
24
#include "torch/csrc/jit/ir/constants.h"
35
#include "ATen/core/functional.h"
@@ -92,6 +94,29 @@ auto prim_registrations = RegisterNodeEvaluators()
9294
return c10::optional<torch::jit::IValue>(std::move(torch::jit::IValue(list)));
9395
}
9496
}
97+
}).evaluator({
98+
torch::jit::prim::Loop,
99+
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
100+
std::cout << *n << std::endl;
101+
102+
return {};
103+
},
104+
EvalOptions().blacklistOutputTypes({c10::TensorType::get()})
105+
}).evaluator({
106+
c10::Symbol::fromQualString("prim::min"),
107+
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
108+
auto a = args.at(&(n->input()[0])).unwrapToIntList();
109+
int64_t min = std::numeric_limits<int64_t>::max();
110+
111+
for (size_t i = 0; i < a.size(); i++) {
112+
if (a[i] < min) {
113+
min = i;
114+
}
115+
}
116+
117+
return min;
118+
},
119+
EvalOptions().validSchemas({"prim::min.self_int(int[] self) -> (int)"})
95120
});
96121
}
97122
} // namespace evaluators

0 commit comments

Comments
 (0)