@@ -942,6 +942,7 @@ inline std::shared_ptr<arrow::ArrayData> CopyTz(const std::shared_ptr<arrow::Arr
942
942
943
943
using TPrimitiveDataTypeGetter = std::shared_ptr<arrow::DataType>(*)();
944
944
using TPrimitiveDataScalarGetter= arrow::Datum(*)();
945
+ using TPrimitiveDataScalarGetterWithMemPool = arrow::Datum(*)(void** result, arrow::MemoryPool*);
945
946
using TUntypedBinaryScalarFuncPtr = void(*)(const void*, const void*, void*);
946
947
using TUntypedBinaryArrayFuncPtr = void(*)(const void*, const void*, void*, int64_t length, int64_t offset1, int64_t offset2);
947
948
using TUntypedBinaryScalarOptFuncPtr = bool(*)(const void*, const void*, void*);
@@ -1023,7 +1024,7 @@ arrow::Status ExecBinaryOptImpl(arrow::compute::KernelContext* kernelCtx,
1023
1024
1024
1025
arrow::Status ExecDecimalBinaryOptImpl(arrow::compute::KernelContext* kernelCtx,
1025
1026
const arrow::compute::ExecBatch& batch, arrow::Datum* res,
1026
- TPrimitiveDataTypeGetter typeGetter,
1027
+ TPrimitiveDataTypeGetter typeGetter, TPrimitiveDataScalarGetterWithMemPool scalarGetter,
1027
1028
size_t outputSizeOf,
1028
1029
TUntypedBinaryScalarOptFuncPtr scalarScalarFunc,
1029
1030
TUntypedBinaryArrayOptFuncPtr scalarArrayFunc,
@@ -1032,7 +1033,7 @@ arrow::Status ExecDecimalBinaryOptImpl(arrow::compute::KernelContext* kernelCtx,
1032
1033
1033
1034
arrow::Status ExecDecimalUnaryImpl(arrow::compute::KernelContext* kernelCtx,
1034
1035
const arrow::compute::ExecBatch& batch, arrow::Datum* res,
1035
- TPrimitiveDataTypeGetter typeGetter,
1036
+ TPrimitiveDataTypeGetter typeGetter,
1036
1037
TUntypedUnaryScalarFuncPtr scalarFunc, TUntypedUnaryArrayFuncPtr arrayFunc);
1037
1038
1038
1039
template<typename TInput1, bool Tz1, typename TInput2, bool Tz2, typename TOutput, EPropagateTz PropagateTz,
@@ -1664,6 +1665,12 @@ struct TDecimalKernelExecs
1664
1665
using TInput2 = NYql::NDecimal::TInt128;
1665
1666
using TOutput = NYql::NDecimal::TInt128;
1666
1667
1668
+ static arrow::Datum ScalarGetter(void** result, arrow::MemoryPool* memory_pool) {
1669
+ std::shared_ptr<arrow::Buffer> buffer(ARROW_RESULT(arrow::AllocateBuffer(16, memory_pool)));
1670
+ *result = buffer->mutable_data();
1671
+ return arrow::Datum(std::make_shared<TPrimitiveDataType<NYql::NDecimal::TInt128>::TScalarResult>(buffer));
1672
+ }
1673
+
1667
1674
template<ui8 precision>
1668
1675
static arrow::Status ExecImpl(arrow::compute::KernelContext* kernelCtx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) {
1669
1676
auto scalarScalarFunc = &TBinaryKernelOptExecsImpl<TInput1, TInput2, TOutput, TFuncInstance<precision>>::ScalarScalarCoreOpt;
@@ -1672,7 +1679,7 @@ struct TDecimalKernelExecs
1672
1679
auto arrayArrayFunc = &TBinaryKernelOptExecsImpl<TInput1, TInput2, TOutput, TFuncInstance<precision>>::ArrayArrayCoreOpt;
1673
1680
1674
1681
return ExecDecimalBinaryOptImpl(kernelCtx, batch, res,
1675
- &GetPrimitiveDataType<TOutput>,
1682
+ &GetPrimitiveDataType<TOutput>, &ScalarGetter,
1676
1683
sizeof(TOutput),
1677
1684
(TUntypedBinaryScalarOptFuncPtr)scalarScalarFunc,
1678
1685
(TUntypedBinaryArrayOptFuncPtr)scalarArrayFunc,
@@ -1709,5 +1716,46 @@ void AddBinaryDecimalKernels(TKernelFamilyBase& owner) {
1709
1716
owner.Adopt(argTypes, returnType, std::move(kernel));
1710
1717
}
1711
1718
1719
+ template<class TFuncInstance>
1720
+ struct TDecimalComparisonKernelExecs
1721
+ {
1722
+ using TInput1 = NYql::NDecimal::TInt128;
1723
+ using TInput2 = NYql::NDecimal::TInt128;
1724
+ using TOutput = bool;
1725
+
1726
+ static arrow::Datum ScalarGetter(void** resMem, arrow::MemoryPool*) {
1727
+ auto result = MakeDefaultScalarDatum<TOutput>();
1728
+ *resMem = GetPrimitiveScalarValueMutablePtr(*result.scalar());
1729
+ return result;
1730
+ }
1731
+
1732
+ static arrow::Status Exec(arrow::compute::KernelContext* kernelCtx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) {
1733
+ auto scalarScalarFunc = &TBinaryKernelOptExecsImpl<TInput1, TInput2, TOutput, TFuncInstance>::ScalarScalarCoreOpt;
1734
+ auto scalarArrayFunc = &TBinaryKernelOptExecsImpl<TInput1, TInput2, TOutput, TFuncInstance>::ScalarArrayCoreOpt;
1735
+ auto arrayScalarFunc = &TBinaryKernelOptExecsImpl<TInput1, TInput2, TOutput, TFuncInstance>::ArrayScalarCoreOpt;
1736
+ auto arrayArrayFunc = &TBinaryKernelOptExecsImpl<TInput1, TInput2, TOutput, TFuncInstance>::ArrayArrayCoreOpt;
1737
+
1738
+ return ExecDecimalBinaryOptImpl(kernelCtx, batch, res,
1739
+ &GetPrimitiveDataType<TOutput>, &ScalarGetter,
1740
+ sizeof(TOutput),
1741
+ (TUntypedBinaryScalarOptFuncPtr)scalarScalarFunc,
1742
+ (TUntypedBinaryArrayOptFuncPtr)scalarArrayFunc,
1743
+ (TUntypedBinaryArrayOptFuncPtr)arrayScalarFunc,
1744
+ (TUntypedBinaryArrayOptFuncPtr)arrayArrayFunc);
1745
+ }
1746
+ };
1747
+
1748
+ template<class TFunc>
1749
+ void AddDecimalComparisonKernels(TKernelFamilyBase& owner) {
1750
+ auto type1 = NUdf::GetDataTypeInfo(NUdf::EDataSlot::Decimal).TypeId;
1751
+ auto type2 = type1;
1752
+ auto returnType = NUdf::GetDataTypeInfo(NUdf::EDataSlot::Bool).TypeId;
1753
+ std::vector<NUdf::TDataTypeId> argTypes({ type1, type2 });
1754
+
1755
+ using Execs = TDecimalComparisonKernelExecs<TFunc>;
1756
+ auto kernel = std::make_unique<TDecimalKernel>(owner, argTypes, returnType, &Execs::Exec, TKernel::ENullMode::Default);
1757
+ owner.Adopt(argTypes, returnType, std::move(kernel));
1758
+ }
1759
+
1712
1760
}
1713
1761
}
0 commit comments