Skip to content

Commit c7123e3

Browse files
authored
Reduce templates (invoke kernels) (#6097)
1 parent 3158d24 commit c7123e3

File tree

4 files changed

+399
-195
lines changed

4 files changed

+399
-195
lines changed
Lines changed: 315 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,315 @@
1+
#include "mkql_builtins_impl.h" // Y_IGNORE
2+
3+
namespace NKikimr {
4+
namespace NMiniKQL {
5+
6+
template <typename T>
7+
arrow::compute::InputType GetPrimitiveInputArrowType(bool tz) {
8+
return arrow::compute::InputType(AddTzType(tz, GetPrimitiveDataType<T>()), arrow::ValueDescr::ANY);
9+
}
10+
11+
template <typename T>
12+
arrow::compute::OutputType GetPrimitiveOutputArrowType(bool tz) {
13+
return arrow::compute::OutputType(AddTzType(tz, GetPrimitiveDataType<T>()));
14+
}
15+
16+
template arrow::compute::InputType GetPrimitiveInputArrowType<bool>(bool tz);
17+
template arrow::compute::InputType GetPrimitiveInputArrowType<i8>(bool tz);
18+
template arrow::compute::InputType GetPrimitiveInputArrowType<ui8>(bool tz);
19+
template arrow::compute::InputType GetPrimitiveInputArrowType<i16>(bool tz);
20+
template arrow::compute::InputType GetPrimitiveInputArrowType<ui16>(bool tz);
21+
template arrow::compute::InputType GetPrimitiveInputArrowType<i32>(bool tz);
22+
template arrow::compute::InputType GetPrimitiveInputArrowType<ui32>(bool tz);
23+
template arrow::compute::InputType GetPrimitiveInputArrowType<i64>(bool tz);
24+
template arrow::compute::InputType GetPrimitiveInputArrowType<ui64>(bool tz);
25+
template arrow::compute::InputType GetPrimitiveInputArrowType<float>(bool tz);
26+
template arrow::compute::InputType GetPrimitiveInputArrowType<double>(bool tz);
27+
template arrow::compute::InputType GetPrimitiveInputArrowType<char*>(bool tz);
28+
template arrow::compute::InputType GetPrimitiveInputArrowType<NYql::NUdf::TUtf8>(bool tz);
29+
30+
template arrow::compute::OutputType GetPrimitiveOutputArrowType<bool>(bool tz);
31+
template arrow::compute::OutputType GetPrimitiveOutputArrowType<i8>(bool tz);
32+
template arrow::compute::OutputType GetPrimitiveOutputArrowType<ui8>(bool tz);
33+
template arrow::compute::OutputType GetPrimitiveOutputArrowType<i16>(bool tz);
34+
template arrow::compute::OutputType GetPrimitiveOutputArrowType<ui16>(bool tz);
35+
template arrow::compute::OutputType GetPrimitiveOutputArrowType<i32>(bool tz);
36+
template arrow::compute::OutputType GetPrimitiveOutputArrowType<ui32>(bool tz);
37+
template arrow::compute::OutputType GetPrimitiveOutputArrowType<i64>(bool tz);
38+
template arrow::compute::OutputType GetPrimitiveOutputArrowType<ui64>(bool tz);
39+
template arrow::compute::OutputType GetPrimitiveOutputArrowType<float>(bool tz);
40+
template arrow::compute::OutputType GetPrimitiveOutputArrowType<double>(bool tz);
41+
template arrow::compute::OutputType GetPrimitiveOutputArrowType<char*>(bool tz);
42+
template arrow::compute::OutputType GetPrimitiveOutputArrowType<NYql::NUdf::TUtf8>(bool tz);
43+
44+
arrow::compute::InputType GetPrimitiveInputArrowType(NUdf::EDataSlot slot) {
45+
switch (slot) {
46+
case NUdf::EDataSlot::Bool: return GetPrimitiveInputArrowType<bool>();
47+
case NUdf::EDataSlot::Int8: return GetPrimitiveInputArrowType<i8>();
48+
case NUdf::EDataSlot::Uint8: return GetPrimitiveInputArrowType<ui8>();
49+
case NUdf::EDataSlot::Int16: return GetPrimitiveInputArrowType<i16>();
50+
case NUdf::EDataSlot::Uint16: return GetPrimitiveInputArrowType<ui16>();
51+
case NUdf::EDataSlot::Int32: return GetPrimitiveInputArrowType<i32>();
52+
case NUdf::EDataSlot::Uint32: return GetPrimitiveInputArrowType<ui32>();
53+
case NUdf::EDataSlot::Int64: return GetPrimitiveInputArrowType<i64>();
54+
case NUdf::EDataSlot::Uint64: return GetPrimitiveInputArrowType<ui64>();
55+
case NUdf::EDataSlot::Float: return GetPrimitiveInputArrowType<float>();
56+
case NUdf::EDataSlot::Double: return GetPrimitiveInputArrowType<double>();
57+
case NUdf::EDataSlot::String: return GetPrimitiveInputArrowType<char*>();
58+
case NUdf::EDataSlot::Utf8: return GetPrimitiveInputArrowType<NYql::NUdf::TUtf8>();
59+
case NUdf::EDataSlot::Date: return GetPrimitiveInputArrowType<ui16>();
60+
case NUdf::EDataSlot::TzDate: return GetPrimitiveInputArrowType<ui16>(true);
61+
case NUdf::EDataSlot::Datetime: return GetPrimitiveInputArrowType<ui32>();
62+
case NUdf::EDataSlot::TzDatetime: return GetPrimitiveInputArrowType<ui32>(true);
63+
case NUdf::EDataSlot::Timestamp: return GetPrimitiveInputArrowType<ui64>();
64+
case NUdf::EDataSlot::TzTimestamp: return GetPrimitiveInputArrowType<ui64>(true);
65+
case NUdf::EDataSlot::Interval: return GetPrimitiveInputArrowType<i64>();
66+
case NUdf::EDataSlot::Date32: return GetPrimitiveInputArrowType<i32>();
67+
case NUdf::EDataSlot::TzDate32: return GetPrimitiveInputArrowType<i32>(true);
68+
case NUdf::EDataSlot::Datetime64: return GetPrimitiveInputArrowType<i64>();
69+
case NUdf::EDataSlot::TzDatetime64: return GetPrimitiveInputArrowType<i64>(true);
70+
case NUdf::EDataSlot::Timestamp64: return GetPrimitiveInputArrowType<i64>();
71+
case NUdf::EDataSlot::TzTimestamp64: return GetPrimitiveInputArrowType<i64>(true);
72+
case NUdf::EDataSlot::Interval64: return GetPrimitiveInputArrowType<i64>();
73+
default:
74+
ythrow yexception() << "Unexpected data slot: " << slot;
75+
}
76+
}
77+
78+
arrow::compute::OutputType GetPrimitiveOutputArrowType(NUdf::EDataSlot slot) {
79+
switch (slot) {
80+
case NUdf::EDataSlot::Bool: return GetPrimitiveOutputArrowType<bool>();
81+
case NUdf::EDataSlot::Int8: return GetPrimitiveOutputArrowType<i8>();
82+
case NUdf::EDataSlot::Uint8: return GetPrimitiveOutputArrowType<ui8>();
83+
case NUdf::EDataSlot::Int16: return GetPrimitiveOutputArrowType<i16>();
84+
case NUdf::EDataSlot::Uint16: return GetPrimitiveOutputArrowType<ui16>();
85+
case NUdf::EDataSlot::Int32: return GetPrimitiveOutputArrowType<i32>();
86+
case NUdf::EDataSlot::Uint32: return GetPrimitiveOutputArrowType<ui32>();
87+
case NUdf::EDataSlot::Int64: return GetPrimitiveOutputArrowType<i64>();
88+
case NUdf::EDataSlot::Uint64: return GetPrimitiveOutputArrowType<ui64>();
89+
case NUdf::EDataSlot::Float: return GetPrimitiveOutputArrowType<float>();
90+
case NUdf::EDataSlot::Double: return GetPrimitiveOutputArrowType<double>();
91+
case NUdf::EDataSlot::String: return GetPrimitiveOutputArrowType<char*>();
92+
case NUdf::EDataSlot::Utf8: return GetPrimitiveOutputArrowType<NYql::NUdf::TUtf8>();
93+
case NUdf::EDataSlot::Date: return GetPrimitiveOutputArrowType<ui16>();
94+
case NUdf::EDataSlot::TzDate: return GetPrimitiveOutputArrowType<ui16>(true);
95+
case NUdf::EDataSlot::Datetime: return GetPrimitiveOutputArrowType<ui32>();
96+
case NUdf::EDataSlot::TzDatetime: return GetPrimitiveOutputArrowType<ui32>(true);
97+
case NUdf::EDataSlot::Timestamp: return GetPrimitiveOutputArrowType<ui64>();
98+
case NUdf::EDataSlot::TzTimestamp: return GetPrimitiveOutputArrowType<ui64>(true);
99+
case NUdf::EDataSlot::Interval: return GetPrimitiveOutputArrowType<i64>();
100+
case NUdf::EDataSlot::Date32: return GetPrimitiveOutputArrowType<i32>();
101+
case NUdf::EDataSlot::TzDate32: return GetPrimitiveOutputArrowType<i32>(true);
102+
case NUdf::EDataSlot::Datetime64: return GetPrimitiveOutputArrowType<i64>();
103+
case NUdf::EDataSlot::TzDatetime64: return GetPrimitiveOutputArrowType<i64>(true);
104+
case NUdf::EDataSlot::Timestamp64: return GetPrimitiveOutputArrowType<i64>();
105+
case NUdf::EDataSlot::TzTimestamp64: return GetPrimitiveOutputArrowType<i64>(true);
106+
case NUdf::EDataSlot::Interval64: return GetPrimitiveOutputArrowType<i64>();
107+
default:
108+
ythrow yexception() << "Unexpected data slot: " << slot;
109+
}
110+
}
111+
112+
std::shared_ptr<arrow::DataType> AddTzType(bool addTz, const std::shared_ptr<arrow::DataType>& type) {
113+
if (!addTz) {
114+
return type;
115+
}
116+
117+
std::vector<std::shared_ptr<arrow::Field>> fields {
118+
std::make_shared<arrow::Field>("datetime", type, false),
119+
std::make_shared<arrow::Field>("timezoneId", arrow::uint16(), false)
120+
};
121+
122+
return std::make_shared<arrow::StructType>(fields);
123+
}
124+
125+
std::shared_ptr<arrow::DataType> AddTzType(EPropagateTz propagateTz, const std::shared_ptr<arrow::DataType>& type) {
126+
return AddTzType(propagateTz != EPropagateTz::None, type);
127+
}
128+
129+
std::shared_ptr<arrow::Scalar> ExtractTz(bool isTz, const std::shared_ptr<arrow::Scalar>& value) {
130+
if (!isTz) {
131+
return value;
132+
}
133+
134+
const auto& structScalar = arrow::internal::checked_cast<const arrow::StructScalar&>(*value);
135+
return structScalar.value[0];
136+
}
137+
138+
std::shared_ptr<arrow::ArrayData> ExtractTz(bool isTz, const std::shared_ptr<arrow::ArrayData>& value) {
139+
if (!isTz) {
140+
return value;
141+
}
142+
143+
return value->child_data[0];
144+
}
145+
146+
std::shared_ptr<arrow::Scalar> WithTz(bool propagateTz, const std::shared_ptr<arrow::Scalar>& input,
147+
const std::shared_ptr<arrow::Scalar>& value) {
148+
if (!propagateTz) {
149+
return value;
150+
}
151+
152+
const auto& structScalar = arrow::internal::checked_cast<const arrow::StructScalar&>(*input);
153+
auto tzId = structScalar.value[1];
154+
return std::make_shared<arrow::StructScalar>(arrow::StructScalar::ValueType{value, tzId}, input->type);
155+
}
156+
157+
std::shared_ptr<arrow::Scalar> WithTz(EPropagateTz propagateTz,
158+
const std::shared_ptr<arrow::Scalar>& input1,
159+
const std::shared_ptr<arrow::Scalar>& input2,
160+
const std::shared_ptr<arrow::Scalar>& value) {
161+
if (propagateTz == EPropagateTz::None) {
162+
return value;
163+
}
164+
165+
const auto& structScalar = arrow::internal::checked_cast<const arrow::StructScalar&>(propagateTz == EPropagateTz::FromLeft ? *input1 : *input2);
166+
const auto tzId = structScalar.value[1];
167+
return std::make_shared<arrow::StructScalar>(arrow::StructScalar::ValueType{value,tzId}, propagateTz == EPropagateTz::FromLeft ? input1->type : input2->type);
168+
}
169+
170+
std::shared_ptr<arrow::ArrayData> CopyTzImpl(const std::shared_ptr<arrow::ArrayData>& res, bool propagateTz,
171+
const std::shared_ptr<arrow::ArrayData>& input, arrow::MemoryPool* pool,
172+
size_t sizeOf, const std::shared_ptr<arrow::DataType>& outputType) {
173+
if (!propagateTz) {
174+
return res;
175+
}
176+
177+
Y_ENSURE(res->child_data.empty());
178+
std::shared_ptr<arrow::Buffer> buffer(NUdf::AllocateResizableBuffer(sizeOf * res->length, pool));
179+
res->child_data.push_back(arrow::ArrayData::Make(outputType, res->length, { nullptr, buffer }));
180+
res->child_data.push_back(input->child_data[1]);
181+
return res->child_data[0];
182+
}
183+
184+
std::shared_ptr<arrow::ArrayData> CopyTzImpl(const std::shared_ptr<arrow::ArrayData>& res, EPropagateTz propagateTz,
185+
const std::shared_ptr<arrow::ArrayData>& input1,
186+
const std::shared_ptr<arrow::Scalar>& input2,
187+
arrow::MemoryPool* pool,
188+
size_t sizeOf, const std::shared_ptr<arrow::DataType>& outputType) {
189+
if (propagateTz == EPropagateTz::None) {
190+
return res;
191+
}
192+
193+
Y_ENSURE(res->child_data.empty());
194+
std::shared_ptr<arrow::Buffer> buffer(NUdf::AllocateResizableBuffer(sizeOf * res->length, pool));
195+
res->child_data.push_back(arrow::ArrayData::Make(outputType, res->length, { nullptr, buffer }));
196+
if (propagateTz == EPropagateTz::FromLeft) {
197+
res->child_data.push_back(input1->child_data[1]);
198+
} else {
199+
const auto& structScalar = arrow::internal::checked_cast<const arrow::StructScalar&>(*input2);
200+
auto tzId = ARROW_RESULT(arrow::MakeArrayFromScalar(*structScalar.value[1], res->length, pool))->data();
201+
res->child_data.push_back(tzId);
202+
}
203+
204+
return res->child_data[0];
205+
}
206+
207+
std::shared_ptr<arrow::ArrayData> CopyTzImpl(const std::shared_ptr<arrow::ArrayData>& res, EPropagateTz propagateTz,
208+
const std::shared_ptr<arrow::Scalar>& input1,
209+
const std::shared_ptr<arrow::ArrayData>& input2,
210+
arrow::MemoryPool* pool,
211+
size_t sizeOf, const std::shared_ptr<arrow::DataType>& outputType) {
212+
if (propagateTz == EPropagateTz::None) {
213+
return res;
214+
}
215+
216+
Y_ENSURE(res->child_data.empty());
217+
std::shared_ptr<arrow::Buffer> buffer(NUdf::AllocateResizableBuffer(sizeOf * res->length, pool));
218+
res->child_data.push_back(arrow::ArrayData::Make(outputType, res->length, { nullptr, buffer }));
219+
if (propagateTz == EPropagateTz::FromLeft) {
220+
const auto& structScalar = arrow::internal::checked_cast<const arrow::StructScalar&>(*input1);
221+
auto tzId = ARROW_RESULT(arrow::MakeArrayFromScalar(*structScalar.value[1], res->length, pool))->data();
222+
res->child_data.push_back(tzId);
223+
} else {
224+
res->child_data.push_back(input2->child_data[1]);
225+
}
226+
227+
return res->child_data[0];
228+
}
229+
230+
std::shared_ptr<arrow::ArrayData> CopyTzImpl(const std::shared_ptr<arrow::ArrayData>& res, EPropagateTz propagateTz,
231+
const std::shared_ptr<arrow::ArrayData>& input1,
232+
const std::shared_ptr<arrow::ArrayData>& input2,
233+
arrow::MemoryPool* pool,
234+
size_t sizeOf, const std::shared_ptr<arrow::DataType>& outputType) {
235+
if (propagateTz == EPropagateTz::None) {
236+
return res;
237+
}
238+
239+
Y_ENSURE(res->child_data.empty());
240+
std::shared_ptr<arrow::Buffer> buffer(NUdf::AllocateResizableBuffer(sizeOf * res->length, pool));
241+
res->child_data.push_back(arrow::ArrayData::Make(outputType, res->length, { nullptr, buffer }));
242+
if (propagateTz == EPropagateTz::FromLeft) {
243+
res->child_data.push_back(input1->child_data[1]);
244+
} else {
245+
res->child_data.push_back(input2->child_data[1]);
246+
}
247+
248+
return res->child_data[0];
249+
}
250+
251+
TPlainKernel::TPlainKernel(const TKernelFamily& family, const std::vector<NUdf::TDataTypeId>& argTypes,
252+
NUdf::TDataTypeId returnType, std::unique_ptr<arrow::compute::ScalarKernel>&& arrowKernel,
253+
TKernel::ENullMode nullMode)
254+
: TKernel(family, argTypes, returnType, nullMode)
255+
, ArrowKernel(std::move(arrowKernel))
256+
{
257+
}
258+
259+
const arrow::compute::ScalarKernel& TPlainKernel::GetArrowKernel() const {
260+
return *ArrowKernel;
261+
}
262+
263+
void AddUnaryKernelImpl(TKernelFamilyBase& owner, NUdf::EDataSlot arg1, NUdf::EDataSlot res,
264+
TStatelessArrayKernelExec exec, TKernel::ENullMode nullMode) {
265+
auto type1 = NUdf::GetDataTypeInfo(arg1).TypeId;
266+
auto returnType = NUdf::GetDataTypeInfo(res).TypeId;
267+
std::vector<NUdf::TDataTypeId> argTypes({ type1 });
268+
269+
auto k = std::make_unique<arrow::compute::ScalarKernel>(std::vector<arrow::compute::InputType>{
270+
GetPrimitiveInputArrowType(arg1)
271+
}, GetPrimitiveOutputArrowType(res), exec);
272+
273+
switch (nullMode) {
274+
case TKernel::ENullMode::Default:
275+
k->null_handling = arrow::compute::NullHandling::INTERSECTION;
276+
break;
277+
case TKernel::ENullMode::AlwaysNull:
278+
k->null_handling = arrow::compute::NullHandling::COMPUTED_PREALLOCATE;
279+
break;
280+
case TKernel::ENullMode::AlwaysNotNull:
281+
k->null_handling = arrow::compute::NullHandling::OUTPUT_NOT_NULL;
282+
break;
283+
}
284+
285+
owner.Adopt(argTypes, returnType, std::make_unique<TPlainKernel>(owner, argTypes, returnType, std::move(k), nullMode));
286+
}
287+
288+
void AddBinaryKernelImpl(TKernelFamilyBase& owner, NUdf::EDataSlot arg1, NUdf::EDataSlot arg2, NUdf::EDataSlot res,
289+
TStatelessArrayKernelExec exec, TKernel::ENullMode nullMode) {
290+
auto type1 = NUdf::GetDataTypeInfo(arg1).TypeId;
291+
auto type2 = NUdf::GetDataTypeInfo(arg2).TypeId;
292+
auto returnType = NUdf::GetDataTypeInfo(res).TypeId;
293+
std::vector<NUdf::TDataTypeId> argTypes({ type1, type2 });
294+
295+
auto k = std::make_unique<arrow::compute::ScalarKernel>(std::vector<arrow::compute::InputType>{
296+
GetPrimitiveInputArrowType(arg1), GetPrimitiveInputArrowType(arg2)
297+
}, GetPrimitiveOutputArrowType(res), exec);
298+
299+
switch (nullMode) {
300+
case TKernel::ENullMode::Default:
301+
k->null_handling = arrow::compute::NullHandling::INTERSECTION;
302+
break;
303+
case TKernel::ENullMode::AlwaysNull:
304+
k->null_handling = arrow::compute::NullHandling::COMPUTED_PREALLOCATE;
305+
break;
306+
case TKernel::ENullMode::AlwaysNotNull:
307+
k->null_handling = arrow::compute::NullHandling::OUTPUT_NOT_NULL;
308+
break;
309+
}
310+
311+
owner.Adopt(argTypes, returnType, std::make_unique<TPlainKernel>(owner, argTypes, returnType, std::move(k), nullMode));
312+
}
313+
314+
} // namespace NMiniKQL
315+
} // namespace NKikimr

0 commit comments

Comments
 (0)