Skip to content

Commit a8e693f

Browse files
committed
fix undefined attr issue (#1783)
1 parent 5eb8c8e commit a8e693f

File tree

2 files changed

+28
-26
lines changed

2 files changed

+28
-26
lines changed

core/conversion/evaluators/eval_util.cpp

+26-25
Original file line numberDiff line numberDiff line change
@@ -70,72 +70,73 @@ c10::optional<torch::jit::IValue> toIValue(const torch::jit::Value* v) {
7070
}
7171
const torch::jit::Node* node = v->node();
7272
const c10::TypePtr& type = v->type();
73+
74+
c10::Symbol attr_value = c10::Symbol::fromDomainAndUnqualString(c10::attr::value.domainString(), "value");
75+
7376
if (type->isSubtypeOf(c10::TensorType::get())) {
74-
return node->t(c10::attr::value);
77+
return node->t(attr_value);
7578
} else if (type->isSubtypeOf(c10::BoolType::get())) {
76-
return (bool)node->i(c10::attr::value);
77-
} else if (
78-
type->isSubtypeOf(c10::NumberType::get()) && node->kindOf(c10::attr::value) == torch::jit::AttributeKind::i) {
79-
return node->i(c10::attr::value);
80-
} else if (
81-
type->isSubtypeOf(c10::NumberType::get()) && node->kindOf(c10::attr::value) == torch::jit::AttributeKind::f) {
82-
return node->f(c10::attr::value);
79+
return (bool)node->i(attr_value);
80+
} else if (type->isSubtypeOf(c10::NumberType::get()) && node->kindOf(attr_value) == torch::jit::AttributeKind::i) {
81+
return node->i(attr_value);
82+
} else if (type->isSubtypeOf(c10::NumberType::get()) && node->kindOf(attr_value) == torch::jit::AttributeKind::f) {
83+
return node->f(attr_value);
8384
} else if (type->isSubtypeOf(c10::ListType::ofInts())) {
8485
try {
85-
const auto& is = node->is(c10::attr::value);
86+
const auto& is = node->is(attr_value);
8687
return is;
8788
} catch (const std::exception& ex) {
88-
const auto& ival = node->ival(c10::attr::value);
89+
const auto& ival = node->ival(attr_value);
8990
return ival;
9091
}
9192
} else if (type->isSubtypeOf(c10::ListType::ofFloats())) {
9293
try {
93-
const auto& fs = node->fs(c10::attr::value);
94+
const auto& fs = node->fs(attr_value);
9495
return fs;
9596
} catch (const std::exception& ex) {
96-
const auto& ival = node->ival(c10::attr::value);
97+
const auto& ival = node->ival(attr_value);
9798
return ival;
9899
}
99100
} else if (type->isSubtypeOf(c10::ListType::ofBools())) {
100-
const auto bs = c10::fmap<bool>(node->is(c10::attr::value));
101+
const auto bs = c10::fmap<bool>(node->is(attr_value));
101102
return bs;
102103
} else if (type->isSubtypeOf(c10::ListType::ofTensors())) {
103104
try {
104-
const auto& ts = node->ts(c10::attr::value);
105+
const auto& ts = node->ts(attr_value);
105106
return ts;
106107
} catch (const std::exception& ex) {
107-
const auto& ival = node->ival(c10::attr::value);
108+
const auto& ival = node->ival(attr_value);
108109
return ival;
109110
}
110111
} else if (type->isSubtypeOf(c10::ListType::ofStrings())) {
111112
try {
112-
const auto& ss = node->ss(c10::attr::value);
113+
const auto& ss = node->ss(attr_value);
113114
auto vals = c10::impl::GenericList(c10::StringType::get());
114115
for (const auto& str : ss) {
115116
vals.push_back(str);
116117
}
117118
return vals;
118119
} catch (const std::exception& ex) {
119-
const auto& ival = node->ival(c10::attr::value);
120+
const auto& ival = node->ival(attr_value);
120121
return ival;
121122
}
122-
} else if (type->cast<c10::ListType>() && node->kindOf(c10::attr::value) == torch::jit::AttributeKind::ival) {
123-
const auto& list = node->ival(c10::attr::value);
123+
} else if (type->cast<c10::ListType>() && node->kindOf(attr_value) == torch::jit::AttributeKind::ival) {
124+
const auto& list = node->ival(attr_value);
124125
TORCHTRT_ASSERT(list.isList(), "Is not a list");
125126
return list;
126-
} else if (type->cast<c10::DictType>() && node->kindOf(c10::attr::value) == torch::jit::AttributeKind::ival) {
127-
const auto& dict = node->ival(c10::attr::value);
127+
} else if (type->cast<c10::DictType>() && node->kindOf(attr_value) == torch::jit::AttributeKind::ival) {
128+
const auto& dict = node->ival(attr_value);
128129
TORCHTRT_ASSERT(dict.isGenericDict(), "Is not a dict");
129130
return dict;
130-
} else if (type->cast<c10::TupleType>() && node->kindOf(c10::attr::value) == torch::jit::AttributeKind::ival) {
131-
const auto& tup = node->ival(c10::attr::value);
131+
} else if (type->cast<c10::TupleType>() && node->kindOf(attr_value) == torch::jit::AttributeKind::ival) {
132+
const auto& tup = node->ival(attr_value);
132133
TORCHTRT_ASSERT(tup.isTuple(), "Is not a tuple");
133134
return tup;
134135
} else if (type == c10::StringType::get()) {
135-
const auto& s = node->s(c10::attr::value);
136+
const auto& s = node->s(attr_value);
136137
return s;
137138
} else if (type == c10::DeviceObjType::get()) {
138-
auto d = c10::Device(node->s(c10::attr::value));
139+
auto d = c10::Device(node->s(attr_value));
139140
return d;
140141
} else if (node->mustBeNone()) {
141142
return torch::jit::IValue();

core/ir/ir.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,8 @@ c10::optional<at::ScalarType> get_value_first_calc_dtype_opt(torch::jit::Block*
160160
LOG_GRAPH("Input outputs a Tensor");
161161
if (in->node()->kind() == torch::jit::prim::Constant) {
162162
LOG_GRAPH("Input is a constant");
163-
auto const_val = in->node()->t(c10::attr::value);
163+
auto const_val =
164+
in->node()->t(c10::Symbol::fromDomainAndUnqualString(c10::attr::value.domainString(), "value"));
164165
LOG_GRAPH("Found that constant tensor has type: " << const_val.scalar_type());
165166
dtype = {const_val.scalar_type()};
166167
goto exit_first_calc_dtype;

0 commit comments

Comments
 (0)