Skip to content

Add decimal comparison kernels #8271

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ void RegisterEquals(TKernelFamilyMap& kernelFamilyMap) {

AddNumericComparisonKernels<TEqualsOp>(*family);
AddDateComparisonKernels<TDiffDateEqualsOp>(*family);
AddDecimalComparisonKernels<TDecimalEquals>(*family);
RegisterStringKernelEquals(*family);

kernelFamilyMap["Equals"] = std::move(family);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ void RegisterGreater(TKernelFamilyMap& kernelFamilyMap) {

AddNumericComparisonKernels<TGreaterOp>(*family);
AddDateComparisonKernels<TDiffDateGreaterOp>(*family);
AddDecimalComparisonKernels<TDecimalGreater>(*family);
RegisterStringKernelGreater(*family);

kernelFamilyMap["Greater"] = std::move(family);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ void RegisterGreaterOrEqual(TKernelFamilyMap& kernelFamilyMap) {

AddNumericComparisonKernels<TGreaterOrEqualOp>(*family);
AddDateComparisonKernels<TDiffDateGreaterOrEqualOp>(*family);
AddDecimalComparisonKernels<TDecimalGreaterOrEqual>(*family);
RegisterStringKernelGreaterOrEqual(*family);

kernelFamilyMap["GreaterOrEqual"] = std::move(family);
Expand Down
21 changes: 9 additions & 12 deletions ydb/library/yql/minikql/invoke_builtins/mkql_builtins_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,21 +296,18 @@ std::shared_ptr<arrow::compute::ScalarKernel> TDecimalKernel::MakeArrowKernel(co

MKQL_ENSURE(*dataType1->GetDataSlot() == NUdf::EDataSlot::Decimal, "Require decimal");
MKQL_ENSURE(*dataType2->GetDataSlot() == NUdf::EDataSlot::Decimal, "Require decimal");
MKQL_ENSURE(*dataResultType->GetDataSlot() == NUdf::EDataSlot::Decimal, "Require decimal");

auto decimalType1 = static_cast<TDataDecimalType*>(dataType1);
auto decimalType2 = static_cast<TDataDecimalType*>(dataType2);
auto decimalResultType = static_cast<TDataDecimalType*>(dataResultType);

MKQL_ENSURE(decimalType1->GetParams() == decimalType2->GetParams(), "Require same precision/scale");
MKQL_ENSURE(decimalType1->GetParams() == decimalResultType->GetParams(), "Require same precision/scale");

ui8 precision = decimalType1->GetParams().first;
MKQL_ENSURE(precision >= 1&& precision <= 35, TStringBuilder() << "Wrong precision: " << (int)precision);

auto k = std::make_shared<arrow::compute::ScalarKernel>(std::vector<arrow::compute::InputType>{
GetPrimitiveInputArrowType(NUdf::EDataSlot::Decimal), GetPrimitiveInputArrowType(NUdf::EDataSlot::Decimal)
}, GetPrimitiveOutputArrowType(NUdf::EDataSlot::Decimal), Exec);
}, GetPrimitiveOutputArrowType(*dataResultType->GetDataSlot()), Exec);
k->null_handling = arrow::compute::NullHandling::INTERSECTION;
k->init = [precision](arrow::compute::KernelContext*, const arrow::compute::KernelInitArgs&) {
auto state = std::make_unique<TDecimalKernel::TKernelState>();
Expand Down Expand Up @@ -713,7 +710,8 @@ arrow::Status ExecDecimalScalarArrayOptImpl(const arrow::compute::ExecBatch& bat

arrow::Status ExecDecimalScalarScalarOptImpl(arrow::compute::KernelContext* kernelCtx,
const arrow::compute::ExecBatch& batch, arrow::Datum* res,
TPrimitiveDataTypeGetter typeGetter, TUntypedBinaryScalarOptFuncPtr func) {
TPrimitiveDataTypeGetter typeGetter, TPrimitiveDataScalarGetterWithMemPool scalarGetter,
TUntypedBinaryScalarOptFuncPtr func) {
MKQL_ENSURE(batch.values.size() == 2, "Expected 2 args");
const auto& arg1 = batch.values[0];
const auto& arg2 = batch.values[1];
Expand All @@ -722,9 +720,9 @@ arrow::Status ExecDecimalScalarScalarOptImpl(arrow::compute::KernelContext* kern
} else {
const auto val1Ptr = GetStringScalarValue(*arg1.scalar());
const auto val2Ptr = GetStringScalarValue(*arg2.scalar());
std::shared_ptr<arrow::Buffer> buffer(ARROW_RESULT(arrow::AllocateBuffer(16, kernelCtx->memory_pool())));
auto resDatum = arrow::Datum(std::make_shared<TPrimitiveDataType<NYql::NDecimal::TInt128>::TScalarResult>(buffer));
if (!func(val1Ptr.data(), val2Ptr.data(), buffer->mutable_data())) {
void* resMem;
auto resDatum = scalarGetter(&resMem, kernelCtx->memory_pool());
if (!func(val1Ptr.data(), val2Ptr.data(), resMem)) {
*res = arrow::MakeNullScalar(typeGetter());
} else {
*res = resDatum.scalar();
Expand All @@ -736,7 +734,7 @@ arrow::Status ExecDecimalScalarScalarOptImpl(arrow::compute::KernelContext* kern

arrow::Status ExecDecimalBinaryOptImpl(arrow::compute::KernelContext* kernelCtx,
const arrow::compute::ExecBatch& batch, arrow::Datum* res,
TPrimitiveDataTypeGetter typeGetter,
TPrimitiveDataTypeGetter typeGetter, TPrimitiveDataScalarGetterWithMemPool scalarGetter,
size_t outputSizeOf,
TUntypedBinaryScalarOptFuncPtr scalarScalarFunc,
TUntypedBinaryArrayOptFuncPtr scalarArrayFunc,
Expand All @@ -747,7 +745,7 @@ arrow::Status ExecDecimalBinaryOptImpl(arrow::compute::KernelContext* kernelCtx,
const auto& arg2 = batch.values[1];
if (arg1.is_scalar()) {
if (arg2.is_scalar()) {
return ExecDecimalScalarScalarOptImpl(kernelCtx, batch, res, typeGetter, scalarScalarFunc);
return ExecDecimalScalarScalarOptImpl(kernelCtx, batch, res, typeGetter, scalarGetter, scalarScalarFunc);
} else {
return ExecDecimalScalarArrayOptImpl(batch, res, scalarArrayFunc);
}
Expand All @@ -769,8 +767,7 @@ arrow::Status ExecDecimalScalarImpl(arrow::compute::KernelContext* kernelCtx,
const auto valPtr = GetPrimitiveScalarValuePtr(*arg.scalar());
std::shared_ptr<arrow::Buffer> buffer(ARROW_RESULT(arrow::AllocateBuffer(16, kernelCtx->memory_pool())));
auto resDatum = arrow::Datum(std::make_shared<TPrimitiveDataType<NYql::NDecimal::TInt128>::TScalarResult>(buffer));
const auto resPtr = GetPrimitiveScalarValueMutablePtr(*resDatum.scalar());
func(valPtr, resPtr);
func(valPtr, buffer->mutable_data());
*res = resDatum.scalar();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -942,6 +942,7 @@ inline std::shared_ptr<arrow::ArrayData> CopyTz(const std::shared_ptr<arrow::Arr

using TPrimitiveDataTypeGetter = std::shared_ptr<arrow::DataType>(*)();
using TPrimitiveDataScalarGetter= arrow::Datum(*)();
using TPrimitiveDataScalarGetterWithMemPool = arrow::Datum(*)(void** result, arrow::MemoryPool*);
using TUntypedBinaryScalarFuncPtr = void(*)(const void*, const void*, void*);
using TUntypedBinaryArrayFuncPtr = void(*)(const void*, const void*, void*, int64_t length, int64_t offset1, int64_t offset2);
using TUntypedBinaryScalarOptFuncPtr = bool(*)(const void*, const void*, void*);
Expand Down Expand Up @@ -1023,7 +1024,7 @@ arrow::Status ExecBinaryOptImpl(arrow::compute::KernelContext* kernelCtx,

arrow::Status ExecDecimalBinaryOptImpl(arrow::compute::KernelContext* kernelCtx,
const arrow::compute::ExecBatch& batch, arrow::Datum* res,
TPrimitiveDataTypeGetter typeGetter,
TPrimitiveDataTypeGetter typeGetter, TPrimitiveDataScalarGetterWithMemPool scalarGetter,
size_t outputSizeOf,
TUntypedBinaryScalarOptFuncPtr scalarScalarFunc,
TUntypedBinaryArrayOptFuncPtr scalarArrayFunc,
Expand All @@ -1032,7 +1033,7 @@ arrow::Status ExecDecimalBinaryOptImpl(arrow::compute::KernelContext* kernelCtx,

arrow::Status ExecDecimalUnaryImpl(arrow::compute::KernelContext* kernelCtx,
const arrow::compute::ExecBatch& batch, arrow::Datum* res,
TPrimitiveDataTypeGetter typeGetter,
TPrimitiveDataTypeGetter typeGetter,
TUntypedUnaryScalarFuncPtr scalarFunc, TUntypedUnaryArrayFuncPtr arrayFunc);

template<typename TInput1, bool Tz1, typename TInput2, bool Tz2, typename TOutput, EPropagateTz PropagateTz,
Expand Down Expand Up @@ -1664,6 +1665,12 @@ struct TDecimalKernelExecs
using TInput2 = NYql::NDecimal::TInt128;
using TOutput = NYql::NDecimal::TInt128;

static arrow::Datum ScalarGetter(void** result, arrow::MemoryPool* memory_pool) {
std::shared_ptr<arrow::Buffer> buffer(ARROW_RESULT(arrow::AllocateBuffer(16, memory_pool)));
*result = buffer->mutable_data();
return arrow::Datum(std::make_shared<TPrimitiveDataType<NYql::NDecimal::TInt128>::TScalarResult>(buffer));
}

template<ui8 precision>
static arrow::Status ExecImpl(arrow::compute::KernelContext* kernelCtx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) {
auto scalarScalarFunc = &TBinaryKernelOptExecsImpl<TInput1, TInput2, TOutput, TFuncInstance<precision>>::ScalarScalarCoreOpt;
Expand All @@ -1672,7 +1679,7 @@ struct TDecimalKernelExecs
auto arrayArrayFunc = &TBinaryKernelOptExecsImpl<TInput1, TInput2, TOutput, TFuncInstance<precision>>::ArrayArrayCoreOpt;

return ExecDecimalBinaryOptImpl(kernelCtx, batch, res,
&GetPrimitiveDataType<TOutput>,
&GetPrimitiveDataType<TOutput>, &ScalarGetter,
sizeof(TOutput),
(TUntypedBinaryScalarOptFuncPtr)scalarScalarFunc,
(TUntypedBinaryArrayOptFuncPtr)scalarArrayFunc,
Expand Down Expand Up @@ -1709,5 +1716,46 @@ void AddBinaryDecimalKernels(TKernelFamilyBase& owner) {
owner.Adopt(argTypes, returnType, std::move(kernel));
}

template<class TFuncInstance>
struct TDecimalComparisonKernelExecs
{
using TInput1 = NYql::NDecimal::TInt128;
using TInput2 = NYql::NDecimal::TInt128;
using TOutput = bool;

static arrow::Datum ScalarGetter(void** resMem, arrow::MemoryPool*) {
auto result = MakeDefaultScalarDatum<TOutput>();
*resMem = GetPrimitiveScalarValueMutablePtr(*result.scalar());
return result;
}

static arrow::Status Exec(arrow::compute::KernelContext* kernelCtx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) {
auto scalarScalarFunc = &TBinaryKernelOptExecsImpl<TInput1, TInput2, TOutput, TFuncInstance>::ScalarScalarCoreOpt;
auto scalarArrayFunc = &TBinaryKernelOptExecsImpl<TInput1, TInput2, TOutput, TFuncInstance>::ScalarArrayCoreOpt;
auto arrayScalarFunc = &TBinaryKernelOptExecsImpl<TInput1, TInput2, TOutput, TFuncInstance>::ArrayScalarCoreOpt;
auto arrayArrayFunc = &TBinaryKernelOptExecsImpl<TInput1, TInput2, TOutput, TFuncInstance>::ArrayArrayCoreOpt;

return ExecDecimalBinaryOptImpl(kernelCtx, batch, res,
&GetPrimitiveDataType<TOutput>, &ScalarGetter,
sizeof(TOutput),
(TUntypedBinaryScalarOptFuncPtr)scalarScalarFunc,
(TUntypedBinaryArrayOptFuncPtr)scalarArrayFunc,
(TUntypedBinaryArrayOptFuncPtr)arrayScalarFunc,
(TUntypedBinaryArrayOptFuncPtr)arrayArrayFunc);
}
};

template<class TFunc>
void AddDecimalComparisonKernels(TKernelFamilyBase& owner) {
auto type1 = NUdf::GetDataTypeInfo(NUdf::EDataSlot::Decimal).TypeId;
auto type2 = type1;
auto returnType = NUdf::GetDataTypeInfo(NUdf::EDataSlot::Bool).TypeId;
std::vector<NUdf::TDataTypeId> argTypes({ type1, type2 });

using Execs = TDecimalComparisonKernelExecs<TFunc>;
auto kernel = std::make_unique<TDecimalKernel>(owner, argTypes, returnType, &Execs::Exec, TKernel::ENullMode::Default);
owner.Adopt(argTypes, returnType, std::move(kernel));
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ void RegisterLess(TKernelFamilyMap& kernelFamilyMap) {

AddNumericComparisonKernels<TLessOp>(*family);
AddDateComparisonKernels<TDiffDateLessOp>(*family);
AddDecimalComparisonKernels<TDecimalLess>(*family);
RegisterStringKernelLess(*family);

kernelFamilyMap["Less"] = std::move(family);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ void RegisterLessOrEqual(TKernelFamilyMap& kernelFamilyMap) {

AddNumericComparisonKernels<TLessOrEqualOp>(*family);
AddDateComparisonKernels<TDiffDateLessOrEqualOp>(*family);
AddDecimalComparisonKernels<TDecimalLessOrEqual>(*family);
RegisterStringKernelLessOrEqual(*family);

kernelFamilyMap["LessOrEqual"] = std::move(family);
Expand Down
22 changes: 22 additions & 0 deletions ydb/library/yql/tests/sql/dq_file/part18/canondata/result.json
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,28 @@
}
],
"test.test[blocks-date_sub_scalar--Results]": [],
"test.test[blocks-decimal_comparison--Analyze]": [
{
"checksum": "84e84352daef0d01b5fc6884ec3ebf29",
"size": 3719,
"uri": "https://{canondata_backend}/1924537/081e3ea5ef34a4fe33a8e971e47d53ea3a5151a4/resource.tar.gz#test.test_blocks-decimal_comparison--Analyze_/plan.txt"
}
],
"test.test[blocks-decimal_comparison--Debug]": [
{
"checksum": "8277613af5703692784868526580bfe1",
"size": 2214,
"uri": "https://{canondata_backend}/1924537/081e3ea5ef34a4fe33a8e971e47d53ea3a5151a4/resource.tar.gz#test.test_blocks-decimal_comparison--Debug_/opt.yql_patched"
}
],
"test.test[blocks-decimal_comparison--Plan]": [
{
"checksum": "84e84352daef0d01b5fc6884ec3ebf29",
"size": 3719,
"uri": "https://{canondata_backend}/1924537/081e3ea5ef34a4fe33a8e971e47d53ea3a5151a4/resource.tar.gz#test.test_blocks-decimal_comparison--Plan_/plan.txt"
}
],
"test.test[blocks-decimal_comparison--Results]": [],
"test.test[blocks-filter_partial_expr--Analyze]": [
{
"checksum": "e8f201d2a8a9bec0b35c5b959be1c92a",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,20 @@
"uri": "https://{canondata_backend}/1920236/24fc177ad15ffb34261b5dbf2e1c5c96c4ad7faf/resource.tar.gz#test.test_blocks-date_sub_interval_scalar--Plan_/plan.txt"
}
],
"test.test[blocks-decimal_comparison--Debug]": [
{
"checksum": "d9a0988969c48defdd31a7b6259e79f0",
"size": 3807,
"uri": "https://{canondata_backend}/1784826/2b2e157f3cee0db3cfdb7f3fc538c5f30f0b593e/resource.tar.gz#test.test_blocks-decimal_comparison--Debug_/opt.yql_patched"
}
],
"test.test[blocks-decimal_comparison--Plan]": [
{
"checksum": "e65c86ba6e55033fe5f47711b65bba6d",
"size": 4093,
"uri": "https://{canondata_backend}/1784826/2b2e157f3cee0db3cfdb7f3fc538c5f30f0b593e/resource.tar.gz#test.test_blocks-decimal_comparison--Plan_/plan.txt"
}
],
"test.test[blocks-distinct_pure_all--Debug]": [
{
"checksum": "a368ba0efe31dc0503d90694682eb845",
Expand Down
14 changes: 14 additions & 0 deletions ydb/library/yql/tests/sql/sql2yql/canondata/result.json
Original file line number Diff line number Diff line change
Expand Up @@ -3779,6 +3779,13 @@
"uri": "https://{canondata_backend}/1942100/4770669c24007543908dd55606255f269883b26e/resource.tar.gz#test_sql2yql.test_blocks-date_top_sort_/sql.yql"
}
],
"test_sql2yql.test[blocks-decimal_comparison]": [
{
"checksum": "abee33c1efc8dcd301cc6154c36cc219",
"size": 4058,
"uri": "https://{canondata_backend}/1773845/16ab8549a56b6d73fdf2b214a79ffc536fd46be5/resource.tar.gz#test_sql2yql.test_blocks-decimal_comparison_/sql.yql"
}
],
"test_sql2yql.test[blocks-decimal_op_decimal]": [
{
"checksum": "7b42147f7c8462d1ffa32e88ba91d961",
Expand Down Expand Up @@ -23316,6 +23323,13 @@
"uri": "https://{canondata_backend}/1942100/4770669c24007543908dd55606255f269883b26e/resource.tar.gz#test_sql_format.test_blocks-date_top_sort_/formatted.sql"
}
],
"test_sql_format.test[blocks-decimal_comparison]": [
{
"checksum": "83c1f5e627f7e182044701e4143fd101",
"size": 630,
"uri": "https://{canondata_backend}/1773845/16ab8549a56b6d73fdf2b214a79ffc536fd46be5/resource.tar.gz#test_sql_format.test_blocks-decimal_comparison_/formatted.sql"
}
],
"test_sql_format.test[blocks-decimal_op_decimal]": [
{
"checksum": "078fae170f284e83ac9c9c64fc31c788",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
in Input input_decimal.txt
22 changes: 22 additions & 0 deletions ydb/library/yql/tests/sql/suites/blocks/decimal_comparison.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
USE plato;

SELECT
cs_ext_list_price == cs_ext_tax,
cs_ext_list_price < cs_ext_tax,
cs_ext_list_price <= cs_ext_tax,
cs_ext_list_price > cs_ext_tax,
cs_ext_list_price >= cs_ext_tax,

cs_ext_tax == decimal("26.91", 7, 2),
cs_ext_tax < decimal("26.91", 7, 2),
cs_ext_tax <= decimal("26.91", 7, 2),
cs_ext_tax > decimal("26.91", 7, 2),
cs_ext_tax >= decimal("26.91", 7, 2),

decimal("26.91", 7, 2) == cs_ext_tax,
decimal("26.91", 7, 2) < cs_ext_tax,
decimal("26.91", 7, 2) <= cs_ext_tax,
decimal("26.91", 7, 2) > cs_ext_tax,
decimal("26.91", 7, 2) >= cs_ext_tax,
FROM Input;

Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,34 @@
"uri": "https://{canondata_backend}/1937429/45b62dc8690b9f212c440cd9ce32dca44fce59af/resource.tar.gz#test.test_blocks-date_sub_scalar--Results_/results.txt"
}
],
"test.test[blocks-decimal_comparison--Debug]": [
{
"checksum": "9627ad92732c696c533b45bb2706e012",
"size": 2878,
"uri": "https://{canondata_backend}/1777230/4826a2658ee004e3149897383e92334eabb6d44d/resource.tar.gz#test.test_blocks-decimal_comparison--Debug_/opt.yql"
}
],
"test.test[blocks-decimal_comparison--Peephole]": [
{
"checksum": "ed741ff1d0a63c734018572cb802403a",
"size": 3134,
"uri": "https://{canondata_backend}/1777230/4826a2658ee004e3149897383e92334eabb6d44d/resource.tar.gz#test.test_blocks-decimal_comparison--Peephole_/opt.yql"
}
],
"test.test[blocks-decimal_comparison--Plan]": [
{
"checksum": "4578f9a1b54bf43aa6ebe0c94d50b4c1",
"size": 4059,
"uri": "https://{canondata_backend}/1777230/4826a2658ee004e3149897383e92334eabb6d44d/resource.tar.gz#test.test_blocks-decimal_comparison--Plan_/plan.txt"
}
],
"test.test[blocks-decimal_comparison--Results]": [
{
"checksum": "4db899cc84d90660be094653923f5e4b",
"size": 20076,
"uri": "https://{canondata_backend}/1777230/4826a2658ee004e3149897383e92334eabb6d44d/resource.tar.gz#test.test_blocks-decimal_comparison--Results_/results.txt"
}
],
"test.test[blocks-filter_partial_expr--Debug]": [
{
"checksum": "074405dc3a47b47ba5b6d4f37ed07c8e",
Expand Down
Loading