@@ -70,72 +70,73 @@ c10::optional<torch::jit::IValue> toIValue(const torch::jit::Value* v) {
70
70
}
71
71
const torch::jit::Node* node = v->node ();
72
72
const c10::TypePtr& type = v->type ();
73
+
74
+ c10::Symbol attr_value = c10::Symbol::fromDomainAndUnqualString (c10::attr::value.domainString (), " value" );
75
+
73
76
if (type->isSubtypeOf (c10::TensorType::get ())) {
74
- return node->t (c10::attr::value );
77
+ return node->t (attr_value );
75
78
} else if (type->isSubtypeOf (c10::BoolType::get ())) {
76
- return (bool )node->i (c10::attr::value);
77
- } else if (
78
- type->isSubtypeOf (c10::NumberType::get ()) && node->kindOf (c10::attr::value) == torch::jit::AttributeKind::i) {
79
- return node->i (c10::attr::value);
80
- } else if (
81
- type->isSubtypeOf (c10::NumberType::get ()) && node->kindOf (c10::attr::value) == torch::jit::AttributeKind::f) {
82
- return node->f (c10::attr::value);
79
+ return (bool )node->i (attr_value);
80
+ } else if (type->isSubtypeOf (c10::NumberType::get ()) && node->kindOf (attr_value) == torch::jit::AttributeKind::i) {
81
+ return node->i (attr_value);
82
+ } else if (type->isSubtypeOf (c10::NumberType::get ()) && node->kindOf (attr_value) == torch::jit::AttributeKind::f) {
83
+ return node->f (attr_value);
83
84
} else if (type->isSubtypeOf (c10::ListType::ofInts ())) {
84
85
try {
85
- const auto & is = node->is (c10::attr::value );
86
+ const auto & is = node->is (attr_value );
86
87
return is;
87
88
} catch (const std::exception & ex) {
88
- const auto & ival = node->ival (c10::attr::value );
89
+ const auto & ival = node->ival (attr_value );
89
90
return ival;
90
91
}
91
92
} else if (type->isSubtypeOf (c10::ListType::ofFloats ())) {
92
93
try {
93
- const auto & fs = node->fs (c10::attr::value );
94
+ const auto & fs = node->fs (attr_value );
94
95
return fs;
95
96
} catch (const std::exception & ex) {
96
- const auto & ival = node->ival (c10::attr::value );
97
+ const auto & ival = node->ival (attr_value );
97
98
return ival;
98
99
}
99
100
} else if (type->isSubtypeOf (c10::ListType::ofBools ())) {
100
- const auto bs = c10::fmap<bool >(node->is (c10::attr::value ));
101
+ const auto bs = c10::fmap<bool >(node->is (attr_value ));
101
102
return bs;
102
103
} else if (type->isSubtypeOf (c10::ListType::ofTensors ())) {
103
104
try {
104
- const auto & ts = node->ts (c10::attr::value );
105
+ const auto & ts = node->ts (attr_value );
105
106
return ts;
106
107
} catch (const std::exception & ex) {
107
- const auto & ival = node->ival (c10::attr::value );
108
+ const auto & ival = node->ival (attr_value );
108
109
return ival;
109
110
}
110
111
} else if (type->isSubtypeOf (c10::ListType::ofStrings ())) {
111
112
try {
112
- const auto & ss = node->ss (c10::attr::value );
113
+ const auto & ss = node->ss (attr_value );
113
114
auto vals = c10::impl::GenericList (c10::StringType::get ());
114
115
for (const auto & str : ss) {
115
116
vals.push_back (str);
116
117
}
117
118
return vals;
118
119
} catch (const std::exception & ex) {
119
- const auto & ival = node->ival (c10::attr::value );
120
+ const auto & ival = node->ival (attr_value );
120
121
return ival;
121
122
}
122
- } else if (type->cast <c10::ListType>() && node->kindOf (c10::attr::value ) == torch::jit::AttributeKind::ival) {
123
- const auto & list = node->ival (c10::attr::value );
123
+ } else if (type->cast <c10::ListType>() && node->kindOf (attr_value ) == torch::jit::AttributeKind::ival) {
124
+ const auto & list = node->ival (attr_value );
124
125
TORCHTRT_ASSERT (list.isList (), " Is not a list" );
125
126
return list;
126
- } else if (type->cast <c10::DictType>() && node->kindOf (c10::attr::value ) == torch::jit::AttributeKind::ival) {
127
- const auto & dict = node->ival (c10::attr::value );
127
+ } else if (type->cast <c10::DictType>() && node->kindOf (attr_value ) == torch::jit::AttributeKind::ival) {
128
+ const auto & dict = node->ival (attr_value );
128
129
TORCHTRT_ASSERT (dict.isGenericDict (), " Is not a dict" );
129
130
return dict;
130
- } else if (type->cast <c10::TupleType>() && node->kindOf (c10::attr::value ) == torch::jit::AttributeKind::ival) {
131
- const auto & tup = node->ival (c10::attr::value );
131
+ } else if (type->cast <c10::TupleType>() && node->kindOf (attr_value ) == torch::jit::AttributeKind::ival) {
132
+ const auto & tup = node->ival (attr_value );
132
133
TORCHTRT_ASSERT (tup.isTuple (), " Is not a tuple" );
133
134
return tup;
134
135
} else if (type == c10::StringType::get ()) {
135
- const auto & s = node->s (c10::attr::value );
136
+ const auto & s = node->s (attr_value );
136
137
return s;
137
138
} else if (type == c10::DeviceObjType::get ()) {
138
- auto d = c10::Device (node->s (c10::attr::value ));
139
+ auto d = c10::Device (node->s (attr_value ));
139
140
return d;
140
141
} else if (node->mustBeNone ()) {
141
142
return torch::jit::IValue ();
0 commit comments