|
| 1 | +#include "mkql_builtins_impl.h" // Y_IGNORE |
| 2 | + |
| 3 | +namespace NKikimr { |
| 4 | +namespace NMiniKQL { |
| 5 | + |
| 6 | +template <typename T> |
| 7 | +arrow::compute::InputType GetPrimitiveInputArrowType(bool tz) { |
| 8 | + return arrow::compute::InputType(AddTzType(tz, GetPrimitiveDataType<T>()), arrow::ValueDescr::ANY); |
| 9 | +} |
| 10 | + |
| 11 | +template <typename T> |
| 12 | +arrow::compute::OutputType GetPrimitiveOutputArrowType(bool tz) { |
| 13 | + return arrow::compute::OutputType(AddTzType(tz, GetPrimitiveDataType<T>())); |
| 14 | +} |
| 15 | + |
| 16 | +template arrow::compute::InputType GetPrimitiveInputArrowType<bool>(bool tz); |
| 17 | +template arrow::compute::InputType GetPrimitiveInputArrowType<i8>(bool tz); |
| 18 | +template arrow::compute::InputType GetPrimitiveInputArrowType<ui8>(bool tz); |
| 19 | +template arrow::compute::InputType GetPrimitiveInputArrowType<i16>(bool tz); |
| 20 | +template arrow::compute::InputType GetPrimitiveInputArrowType<ui16>(bool tz); |
| 21 | +template arrow::compute::InputType GetPrimitiveInputArrowType<i32>(bool tz); |
| 22 | +template arrow::compute::InputType GetPrimitiveInputArrowType<ui32>(bool tz); |
| 23 | +template arrow::compute::InputType GetPrimitiveInputArrowType<i64>(bool tz); |
| 24 | +template arrow::compute::InputType GetPrimitiveInputArrowType<ui64>(bool tz); |
| 25 | +template arrow::compute::InputType GetPrimitiveInputArrowType<float>(bool tz); |
| 26 | +template arrow::compute::InputType GetPrimitiveInputArrowType<double>(bool tz); |
| 27 | +template arrow::compute::InputType GetPrimitiveInputArrowType<char*>(bool tz); |
| 28 | +template arrow::compute::InputType GetPrimitiveInputArrowType<NYql::NUdf::TUtf8>(bool tz); |
| 29 | + |
| 30 | +template arrow::compute::OutputType GetPrimitiveOutputArrowType<bool>(bool tz); |
| 31 | +template arrow::compute::OutputType GetPrimitiveOutputArrowType<i8>(bool tz); |
| 32 | +template arrow::compute::OutputType GetPrimitiveOutputArrowType<ui8>(bool tz); |
| 33 | +template arrow::compute::OutputType GetPrimitiveOutputArrowType<i16>(bool tz); |
| 34 | +template arrow::compute::OutputType GetPrimitiveOutputArrowType<ui16>(bool tz); |
| 35 | +template arrow::compute::OutputType GetPrimitiveOutputArrowType<i32>(bool tz); |
| 36 | +template arrow::compute::OutputType GetPrimitiveOutputArrowType<ui32>(bool tz); |
| 37 | +template arrow::compute::OutputType GetPrimitiveOutputArrowType<i64>(bool tz); |
| 38 | +template arrow::compute::OutputType GetPrimitiveOutputArrowType<ui64>(bool tz); |
| 39 | +template arrow::compute::OutputType GetPrimitiveOutputArrowType<float>(bool tz); |
| 40 | +template arrow::compute::OutputType GetPrimitiveOutputArrowType<double>(bool tz); |
| 41 | +template arrow::compute::OutputType GetPrimitiveOutputArrowType<char*>(bool tz); |
| 42 | +template arrow::compute::OutputType GetPrimitiveOutputArrowType<NYql::NUdf::TUtf8>(bool tz); |
| 43 | + |
| 44 | +arrow::compute::InputType GetPrimitiveInputArrowType(NUdf::EDataSlot slot) { |
| 45 | + switch (slot) { |
| 46 | + case NUdf::EDataSlot::Bool: return GetPrimitiveInputArrowType<bool>(); |
| 47 | + case NUdf::EDataSlot::Int8: return GetPrimitiveInputArrowType<i8>(); |
| 48 | + case NUdf::EDataSlot::Uint8: return GetPrimitiveInputArrowType<ui8>(); |
| 49 | + case NUdf::EDataSlot::Int16: return GetPrimitiveInputArrowType<i16>(); |
| 50 | + case NUdf::EDataSlot::Uint16: return GetPrimitiveInputArrowType<ui16>(); |
| 51 | + case NUdf::EDataSlot::Int32: return GetPrimitiveInputArrowType<i32>(); |
| 52 | + case NUdf::EDataSlot::Uint32: return GetPrimitiveInputArrowType<ui32>(); |
| 53 | + case NUdf::EDataSlot::Int64: return GetPrimitiveInputArrowType<i64>(); |
| 54 | + case NUdf::EDataSlot::Uint64: return GetPrimitiveInputArrowType<ui64>(); |
| 55 | + case NUdf::EDataSlot::Float: return GetPrimitiveInputArrowType<float>(); |
| 56 | + case NUdf::EDataSlot::Double: return GetPrimitiveInputArrowType<double>(); |
| 57 | + case NUdf::EDataSlot::String: return GetPrimitiveInputArrowType<char*>(); |
| 58 | + case NUdf::EDataSlot::Utf8: return GetPrimitiveInputArrowType<NYql::NUdf::TUtf8>(); |
| 59 | + case NUdf::EDataSlot::Date: return GetPrimitiveInputArrowType<ui16>(); |
| 60 | + case NUdf::EDataSlot::TzDate: return GetPrimitiveInputArrowType<ui16>(true); |
| 61 | + case NUdf::EDataSlot::Datetime: return GetPrimitiveInputArrowType<ui32>(); |
| 62 | + case NUdf::EDataSlot::TzDatetime: return GetPrimitiveInputArrowType<ui32>(true); |
| 63 | + case NUdf::EDataSlot::Timestamp: return GetPrimitiveInputArrowType<ui64>(); |
| 64 | + case NUdf::EDataSlot::TzTimestamp: return GetPrimitiveInputArrowType<ui64>(true); |
| 65 | + case NUdf::EDataSlot::Interval: return GetPrimitiveInputArrowType<i64>(); |
| 66 | + case NUdf::EDataSlot::Date32: return GetPrimitiveInputArrowType<i32>(); |
| 67 | + case NUdf::EDataSlot::TzDate32: return GetPrimitiveInputArrowType<i32>(true); |
| 68 | + case NUdf::EDataSlot::Datetime64: return GetPrimitiveInputArrowType<i64>(); |
| 69 | + case NUdf::EDataSlot::TzDatetime64: return GetPrimitiveInputArrowType<i64>(true); |
| 70 | + case NUdf::EDataSlot::Timestamp64: return GetPrimitiveInputArrowType<i64>(); |
| 71 | + case NUdf::EDataSlot::TzTimestamp64: return GetPrimitiveInputArrowType<i64>(true); |
| 72 | + case NUdf::EDataSlot::Interval64: return GetPrimitiveInputArrowType<i64>(); |
| 73 | + default: |
| 74 | + ythrow yexception() << "Unexpected data slot: " << slot; |
| 75 | + } |
| 76 | +} |
| 77 | + |
| 78 | +arrow::compute::OutputType GetPrimitiveOutputArrowType(NUdf::EDataSlot slot) { |
| 79 | + switch (slot) { |
| 80 | + case NUdf::EDataSlot::Bool: return GetPrimitiveOutputArrowType<bool>(); |
| 81 | + case NUdf::EDataSlot::Int8: return GetPrimitiveOutputArrowType<i8>(); |
| 82 | + case NUdf::EDataSlot::Uint8: return GetPrimitiveOutputArrowType<ui8>(); |
| 83 | + case NUdf::EDataSlot::Int16: return GetPrimitiveOutputArrowType<i16>(); |
| 84 | + case NUdf::EDataSlot::Uint16: return GetPrimitiveOutputArrowType<ui16>(); |
| 85 | + case NUdf::EDataSlot::Int32: return GetPrimitiveOutputArrowType<i32>(); |
| 86 | + case NUdf::EDataSlot::Uint32: return GetPrimitiveOutputArrowType<ui32>(); |
| 87 | + case NUdf::EDataSlot::Int64: return GetPrimitiveOutputArrowType<i64>(); |
| 88 | + case NUdf::EDataSlot::Uint64: return GetPrimitiveOutputArrowType<ui64>(); |
| 89 | + case NUdf::EDataSlot::Float: return GetPrimitiveOutputArrowType<float>(); |
| 90 | + case NUdf::EDataSlot::Double: return GetPrimitiveOutputArrowType<double>(); |
| 91 | + case NUdf::EDataSlot::String: return GetPrimitiveOutputArrowType<char*>(); |
| 92 | + case NUdf::EDataSlot::Utf8: return GetPrimitiveOutputArrowType<NYql::NUdf::TUtf8>(); |
| 93 | + case NUdf::EDataSlot::Date: return GetPrimitiveOutputArrowType<ui16>(); |
| 94 | + case NUdf::EDataSlot::TzDate: return GetPrimitiveOutputArrowType<ui16>(true); |
| 95 | + case NUdf::EDataSlot::Datetime: return GetPrimitiveOutputArrowType<ui32>(); |
| 96 | + case NUdf::EDataSlot::TzDatetime: return GetPrimitiveOutputArrowType<ui32>(true); |
| 97 | + case NUdf::EDataSlot::Timestamp: return GetPrimitiveOutputArrowType<ui64>(); |
| 98 | + case NUdf::EDataSlot::TzTimestamp: return GetPrimitiveOutputArrowType<ui64>(true); |
| 99 | + case NUdf::EDataSlot::Interval: return GetPrimitiveOutputArrowType<i64>(); |
| 100 | + case NUdf::EDataSlot::Date32: return GetPrimitiveOutputArrowType<i32>(); |
| 101 | + case NUdf::EDataSlot::TzDate32: return GetPrimitiveOutputArrowType<i32>(true); |
| 102 | + case NUdf::EDataSlot::Datetime64: return GetPrimitiveOutputArrowType<i64>(); |
| 103 | + case NUdf::EDataSlot::TzDatetime64: return GetPrimitiveOutputArrowType<i64>(true); |
| 104 | + case NUdf::EDataSlot::Timestamp64: return GetPrimitiveOutputArrowType<i64>(); |
| 105 | + case NUdf::EDataSlot::TzTimestamp64: return GetPrimitiveOutputArrowType<i64>(true); |
| 106 | + case NUdf::EDataSlot::Interval64: return GetPrimitiveOutputArrowType<i64>(); |
| 107 | + default: |
| 108 | + ythrow yexception() << "Unexpected data slot: " << slot; |
| 109 | + } |
| 110 | +} |
| 111 | + |
| 112 | +std::shared_ptr<arrow::DataType> AddTzType(bool addTz, const std::shared_ptr<arrow::DataType>& type) { |
| 113 | + if (!addTz) { |
| 114 | + return type; |
| 115 | + } |
| 116 | + |
| 117 | + std::vector<std::shared_ptr<arrow::Field>> fields { |
| 118 | + std::make_shared<arrow::Field>("datetime", type, false), |
| 119 | + std::make_shared<arrow::Field>("timezoneId", arrow::uint16(), false) |
| 120 | + }; |
| 121 | + |
| 122 | + return std::make_shared<arrow::StructType>(fields); |
| 123 | +} |
| 124 | + |
| 125 | +std::shared_ptr<arrow::DataType> AddTzType(EPropagateTz propagateTz, const std::shared_ptr<arrow::DataType>& type) { |
| 126 | + return AddTzType(propagateTz != EPropagateTz::None, type); |
| 127 | +} |
| 128 | + |
| 129 | +std::shared_ptr<arrow::Scalar> ExtractTz(bool isTz, const std::shared_ptr<arrow::Scalar>& value) { |
| 130 | + if (!isTz) { |
| 131 | + return value; |
| 132 | + } |
| 133 | + |
| 134 | + const auto& structScalar = arrow::internal::checked_cast<const arrow::StructScalar&>(*value); |
| 135 | + return structScalar.value[0]; |
| 136 | +} |
| 137 | + |
| 138 | +std::shared_ptr<arrow::ArrayData> ExtractTz(bool isTz, const std::shared_ptr<arrow::ArrayData>& value) { |
| 139 | + if (!isTz) { |
| 140 | + return value; |
| 141 | + } |
| 142 | + |
| 143 | + return value->child_data[0]; |
| 144 | +} |
| 145 | + |
| 146 | +std::shared_ptr<arrow::Scalar> WithTz(bool propagateTz, const std::shared_ptr<arrow::Scalar>& input, |
| 147 | + const std::shared_ptr<arrow::Scalar>& value) { |
| 148 | + if (!propagateTz) { |
| 149 | + return value; |
| 150 | + } |
| 151 | + |
| 152 | + const auto& structScalar = arrow::internal::checked_cast<const arrow::StructScalar&>(*input); |
| 153 | + auto tzId = structScalar.value[1]; |
| 154 | + return std::make_shared<arrow::StructScalar>(arrow::StructScalar::ValueType{value, tzId}, input->type); |
| 155 | +} |
| 156 | + |
| 157 | +std::shared_ptr<arrow::Scalar> WithTz(EPropagateTz propagateTz, |
| 158 | + const std::shared_ptr<arrow::Scalar>& input1, |
| 159 | + const std::shared_ptr<arrow::Scalar>& input2, |
| 160 | + const std::shared_ptr<arrow::Scalar>& value) { |
| 161 | + if (propagateTz == EPropagateTz::None) { |
| 162 | + return value; |
| 163 | + } |
| 164 | + |
| 165 | + const auto& structScalar = arrow::internal::checked_cast<const arrow::StructScalar&>(propagateTz == EPropagateTz::FromLeft ? *input1 : *input2); |
| 166 | + const auto tzId = structScalar.value[1]; |
| 167 | + return std::make_shared<arrow::StructScalar>(arrow::StructScalar::ValueType{value,tzId}, propagateTz == EPropagateTz::FromLeft ? input1->type : input2->type); |
| 168 | +} |
| 169 | + |
| 170 | +std::shared_ptr<arrow::ArrayData> CopyTzImpl(const std::shared_ptr<arrow::ArrayData>& res, bool propagateTz, |
| 171 | + const std::shared_ptr<arrow::ArrayData>& input, arrow::MemoryPool* pool, |
| 172 | + size_t sizeOf, const std::shared_ptr<arrow::DataType>& outputType) { |
| 173 | + if (!propagateTz) { |
| 174 | + return res; |
| 175 | + } |
| 176 | + |
| 177 | + Y_ENSURE(res->child_data.empty()); |
| 178 | + std::shared_ptr<arrow::Buffer> buffer(NUdf::AllocateResizableBuffer(sizeOf * res->length, pool)); |
| 179 | + res->child_data.push_back(arrow::ArrayData::Make(outputType, res->length, { nullptr, buffer })); |
| 180 | + res->child_data.push_back(input->child_data[1]); |
| 181 | + return res->child_data[0]; |
| 182 | +} |
| 183 | + |
| 184 | +std::shared_ptr<arrow::ArrayData> CopyTzImpl(const std::shared_ptr<arrow::ArrayData>& res, EPropagateTz propagateTz, |
| 185 | + const std::shared_ptr<arrow::ArrayData>& input1, |
| 186 | + const std::shared_ptr<arrow::Scalar>& input2, |
| 187 | + arrow::MemoryPool* pool, |
| 188 | + size_t sizeOf, const std::shared_ptr<arrow::DataType>& outputType) { |
| 189 | + if (propagateTz == EPropagateTz::None) { |
| 190 | + return res; |
| 191 | + } |
| 192 | + |
| 193 | + Y_ENSURE(res->child_data.empty()); |
| 194 | + std::shared_ptr<arrow::Buffer> buffer(NUdf::AllocateResizableBuffer(sizeOf * res->length, pool)); |
| 195 | + res->child_data.push_back(arrow::ArrayData::Make(outputType, res->length, { nullptr, buffer })); |
| 196 | + if (propagateTz == EPropagateTz::FromLeft) { |
| 197 | + res->child_data.push_back(input1->child_data[1]); |
| 198 | + } else { |
| 199 | + const auto& structScalar = arrow::internal::checked_cast<const arrow::StructScalar&>(*input2); |
| 200 | + auto tzId = ARROW_RESULT(arrow::MakeArrayFromScalar(*structScalar.value[1], res->length, pool))->data(); |
| 201 | + res->child_data.push_back(tzId); |
| 202 | + } |
| 203 | + |
| 204 | + return res->child_data[0]; |
| 205 | +} |
| 206 | + |
| 207 | +std::shared_ptr<arrow::ArrayData> CopyTzImpl(const std::shared_ptr<arrow::ArrayData>& res, EPropagateTz propagateTz, |
| 208 | + const std::shared_ptr<arrow::Scalar>& input1, |
| 209 | + const std::shared_ptr<arrow::ArrayData>& input2, |
| 210 | + arrow::MemoryPool* pool, |
| 211 | + size_t sizeOf, const std::shared_ptr<arrow::DataType>& outputType) { |
| 212 | + if (propagateTz == EPropagateTz::None) { |
| 213 | + return res; |
| 214 | + } |
| 215 | + |
| 216 | + Y_ENSURE(res->child_data.empty()); |
| 217 | + std::shared_ptr<arrow::Buffer> buffer(NUdf::AllocateResizableBuffer(sizeOf * res->length, pool)); |
| 218 | + res->child_data.push_back(arrow::ArrayData::Make(outputType, res->length, { nullptr, buffer })); |
| 219 | + if (propagateTz == EPropagateTz::FromLeft) { |
| 220 | + const auto& structScalar = arrow::internal::checked_cast<const arrow::StructScalar&>(*input1); |
| 221 | + auto tzId = ARROW_RESULT(arrow::MakeArrayFromScalar(*structScalar.value[1], res->length, pool))->data(); |
| 222 | + res->child_data.push_back(tzId); |
| 223 | + } else { |
| 224 | + res->child_data.push_back(input2->child_data[1]); |
| 225 | + } |
| 226 | + |
| 227 | + return res->child_data[0]; |
| 228 | +} |
| 229 | + |
| 230 | +std::shared_ptr<arrow::ArrayData> CopyTzImpl(const std::shared_ptr<arrow::ArrayData>& res, EPropagateTz propagateTz, |
| 231 | + const std::shared_ptr<arrow::ArrayData>& input1, |
| 232 | + const std::shared_ptr<arrow::ArrayData>& input2, |
| 233 | + arrow::MemoryPool* pool, |
| 234 | + size_t sizeOf, const std::shared_ptr<arrow::DataType>& outputType) { |
| 235 | + if (propagateTz == EPropagateTz::None) { |
| 236 | + return res; |
| 237 | + } |
| 238 | + |
| 239 | + Y_ENSURE(res->child_data.empty()); |
| 240 | + std::shared_ptr<arrow::Buffer> buffer(NUdf::AllocateResizableBuffer(sizeOf * res->length, pool)); |
| 241 | + res->child_data.push_back(arrow::ArrayData::Make(outputType, res->length, { nullptr, buffer })); |
| 242 | + if (propagateTz == EPropagateTz::FromLeft) { |
| 243 | + res->child_data.push_back(input1->child_data[1]); |
| 244 | + } else { |
| 245 | + res->child_data.push_back(input2->child_data[1]); |
| 246 | + } |
| 247 | + |
| 248 | + return res->child_data[0]; |
| 249 | +} |
| 250 | + |
| 251 | +TPlainKernel::TPlainKernel(const TKernelFamily& family, const std::vector<NUdf::TDataTypeId>& argTypes, |
| 252 | + NUdf::TDataTypeId returnType, std::unique_ptr<arrow::compute::ScalarKernel>&& arrowKernel, |
| 253 | + TKernel::ENullMode nullMode) |
| 254 | + : TKernel(family, argTypes, returnType, nullMode) |
| 255 | + , ArrowKernel(std::move(arrowKernel)) |
| 256 | +{ |
| 257 | +} |
| 258 | + |
| 259 | +const arrow::compute::ScalarKernel& TPlainKernel::GetArrowKernel() const { |
| 260 | + return *ArrowKernel; |
| 261 | +} |
| 262 | + |
| 263 | +void AddUnaryKernelImpl(TKernelFamilyBase& owner, NUdf::EDataSlot arg1, NUdf::EDataSlot res, |
| 264 | + TStatelessArrayKernelExec exec, TKernel::ENullMode nullMode) { |
| 265 | + auto type1 = NUdf::GetDataTypeInfo(arg1).TypeId; |
| 266 | + auto returnType = NUdf::GetDataTypeInfo(res).TypeId; |
| 267 | + std::vector<NUdf::TDataTypeId> argTypes({ type1 }); |
| 268 | + |
| 269 | + auto k = std::make_unique<arrow::compute::ScalarKernel>(std::vector<arrow::compute::InputType>{ |
| 270 | + GetPrimitiveInputArrowType(arg1) |
| 271 | + }, GetPrimitiveOutputArrowType(res), exec); |
| 272 | + |
| 273 | + switch (nullMode) { |
| 274 | + case TKernel::ENullMode::Default: |
| 275 | + k->null_handling = arrow::compute::NullHandling::INTERSECTION; |
| 276 | + break; |
| 277 | + case TKernel::ENullMode::AlwaysNull: |
| 278 | + k->null_handling = arrow::compute::NullHandling::COMPUTED_PREALLOCATE; |
| 279 | + break; |
| 280 | + case TKernel::ENullMode::AlwaysNotNull: |
| 281 | + k->null_handling = arrow::compute::NullHandling::OUTPUT_NOT_NULL; |
| 282 | + break; |
| 283 | + } |
| 284 | + |
| 285 | + owner.Adopt(argTypes, returnType, std::make_unique<TPlainKernel>(owner, argTypes, returnType, std::move(k), nullMode)); |
| 286 | +} |
| 287 | + |
| 288 | +void AddBinaryKernelImpl(TKernelFamilyBase& owner, NUdf::EDataSlot arg1, NUdf::EDataSlot arg2, NUdf::EDataSlot res, |
| 289 | + TStatelessArrayKernelExec exec, TKernel::ENullMode nullMode) { |
| 290 | + auto type1 = NUdf::GetDataTypeInfo(arg1).TypeId; |
| 291 | + auto type2 = NUdf::GetDataTypeInfo(arg2).TypeId; |
| 292 | + auto returnType = NUdf::GetDataTypeInfo(res).TypeId; |
| 293 | + std::vector<NUdf::TDataTypeId> argTypes({ type1, type2 }); |
| 294 | + |
| 295 | + auto k = std::make_unique<arrow::compute::ScalarKernel>(std::vector<arrow::compute::InputType>{ |
| 296 | + GetPrimitiveInputArrowType(arg1), GetPrimitiveInputArrowType(arg2) |
| 297 | + }, GetPrimitiveOutputArrowType(res), exec); |
| 298 | + |
| 299 | + switch (nullMode) { |
| 300 | + case TKernel::ENullMode::Default: |
| 301 | + k->null_handling = arrow::compute::NullHandling::INTERSECTION; |
| 302 | + break; |
| 303 | + case TKernel::ENullMode::AlwaysNull: |
| 304 | + k->null_handling = arrow::compute::NullHandling::COMPUTED_PREALLOCATE; |
| 305 | + break; |
| 306 | + case TKernel::ENullMode::AlwaysNotNull: |
| 307 | + k->null_handling = arrow::compute::NullHandling::OUTPUT_NOT_NULL; |
| 308 | + break; |
| 309 | + } |
| 310 | + |
| 311 | + owner.Adopt(argTypes, returnType, std::make_unique<TPlainKernel>(owner, argTypes, returnType, std::move(k), nullMode)); |
| 312 | +} |
| 313 | + |
| 314 | +} // namespace NMiniKQL |
| 315 | +} // namespace NKikimr |
0 commit comments