Skip to content

Fast dot product in Knn UDF #3188

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 35 additions & 8 deletions ydb/library/yql/udfs/common/knn/knn-serializer.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@

#include <ydb/library/yql/public/udf/udf_helpers.h>

#include <util/generic/array_ref.h>
#include <util/generic/buffer.h>
#include <util/stream/format.h>

using namespace NYql;
using namespace NYql::NUdf;

enum EFormat : ui8 {
enum EFormat : ui32 {
FloatVector = 1
};

Expand All @@ -20,12 +21,12 @@ class TFloatVectorSerializer {
static TUnboxedValue Serialize(const IValueBuilder* valueBuilder, const TUnboxedValue x) {
auto serialize = [&x] (IOutputStream& outStream) {
const EFormat format = EFormat::FloatVector;
outStream.Write(&format, 1);
outStream.Write(&format, sizeof(ui32));
EnumerateVector(x, [&outStream] (float element) { outStream.Write(&element, sizeof(float)); });
};

if (x.HasFastListLength()) {
auto str = valueBuilder->NewStringNotFilled(sizeof(ui8) + x.GetListLength() * sizeof(float));
auto str = valueBuilder->NewStringNotFilled(sizeof(ui32) + x.GetListLength() * sizeof(float));
auto strRef = str.AsStringRef();
TMemoryOutput memoryOutput(strRef.Data(), strRef.Size());

Expand All @@ -41,9 +42,9 @@ class TFloatVectorSerializer {
}

static TUnboxedValue Deserialize(const IValueBuilder *valueBuilder, const TStringRef& str) {
//skip format byte, it was already read
const char* buf = str.Data() + 1;
const size_t len = str.Size() - 1;
//skip format header, it was already read
const char* buf = str.Data() + sizeof(ui32);
const size_t len = str.Size() - sizeof(ui32);

if (len % sizeof(float) != 0)
return {};
Expand All @@ -63,6 +64,19 @@ class TFloatVectorSerializer {

return res.Release();
}

static const TArrayRef<const float> GetArray(const TStringRef& str) {
//skip format header, it was already read
const char* buf = str.Data() + sizeof(ui32);
const size_t len = str.Size() - sizeof(ui32);

if (len % sizeof(float) != 0)
return {};

const ui32 count = len / sizeof(float);

return MakeArrayRef(reinterpret_cast<const float*>(buf), count);
}
};


Expand All @@ -81,13 +95,26 @@ class TSerializerFacade {
if (str.Size() == 0)
return {};

ui8 formatByte = str.Data()[0];
switch (formatByte) {
const ui32* format = reinterpret_cast<const ui32*>(str.Data());
switch (*format) {
case EFormat::FloatVector:
return TFloatVectorSerializer::Deserialize(valueBuilder, str);
default:
return {};
}
}

static const TArrayRef<const float> GetArray(const TStringRef& str) {
if (str.Size() == 0)
return {};

const ui32* format = reinterpret_cast<const ui32*>(str.Data());
switch (*format) {
case EFormat::FloatVector:
return TFloatVectorSerializer::GetArray(str);
default:
return {};
}
}
};

90 changes: 25 additions & 65 deletions ydb/library/yql/udfs/common/knn/knn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include <ydb/library/yql/public/udf/udf_helpers.h>

#include <library/cpp/dot_product/dot_product.h>
#include <util/generic/buffer.h>
#include <util/stream/format.h>

Expand All @@ -23,94 +24,53 @@ SIMPLE_STRICT_UDF(TFromBinaryString, TOptional<TListType<float>>(TAutoMap<const
return TSerializerFacade::Deserialize(valueBuilder, str);
}

SIMPLE_STRICT_UDF(TInnerProductSimilarity, TOptional<float>(TAutoMap<const char*>, TAutoMap<const char*>)) {
Y_UNUSED(valueBuilder);

std::optional<float> InnerProductSimilarity(const TUnboxedValuePod vector1, const TUnboxedValuePod vector2) {
float ret = 0;

if (!EnumerateVectors(vector1, vector2, [&ret](float el1, float el2) { ret += el1 * el2;}))
return {};

return ret;
}

std::optional<float> CosineSimilarity(const TUnboxedValuePod vector1, const TUnboxedValuePod vector2) {
float len1 = 0;
float len2 = 0;
float innerProduct = 0;

if (!EnumerateVectors(vector1, vector2, [&](float el1, float el2) {
innerProduct += el1 * el2;
len1 += el1 * el1;
len2 += el2 * el2;
}))
return {};

len1 = sqrt(len1);
len2 = sqrt(len2);

float cosine = innerProduct / len1 / len2;

return cosine;
}

std::optional<float> EuclideanDistance(const TUnboxedValuePod vector1, const TUnboxedValuePod vector2) {
float ret = 0;
const TArrayRef<const float> vector1 = TSerializerFacade::GetArray(args[0].AsStringRef());
const TArrayRef<const float> vector2 = TSerializerFacade::GetArray(args[1].AsStringRef());

if (!EnumerateVectors(vector1, vector2, [&ret](float el1, float el2) { ret += (el1 - el2) * (el1 - el2);}))
if (vector1.size() != vector2.size() || vector1.empty() || vector2.empty())
return {};

ret = sqrtf(ret);

return ret;
const float dotProduct = DotProduct(vector1.data(), vector2.data(), vector1.size());
return TUnboxedValuePod{dotProduct};
}

SIMPLE_STRICT_UDF(TInnerProductSimilarity, TOptional<float>(TAutoMap<TListType<float>>, TAutoMap<TListType<float>>)) {
SIMPLE_STRICT_UDF(TCosineSimilarity, TOptional<float>(TAutoMap<const char*>, TAutoMap<const char*>)) {
Y_UNUSED(valueBuilder);

auto innerProduct = InnerProductSimilarity(args[0], args[1]);
if (!innerProduct)
return {};

return TUnboxedValuePod{innerProduct.value()};
}
const TArrayRef<const float> vector1 = TSerializerFacade::GetArray(args[0].AsStringRef());
const TArrayRef<const float> vector2 = TSerializerFacade::GetArray(args[1].AsStringRef());

SIMPLE_STRICT_UDF(TCosineSimilarity, TOptional<float>(TAutoMap<TListType<float>>, TAutoMap<TListType<float>>)) {
Y_UNUSED(valueBuilder);

auto cosine = CosineSimilarity(args[0], args[1]);
if (!cosine)
return {};
if (vector1.size() != vector2.size() || vector1.empty() || vector2.empty())
return {};

return TUnboxedValuePod{cosine.value()};
const auto [ll, lr, rr] = TriWayDotProduct(vector1.data(), vector2.data(), vector1.size());
const float cosine = lr / std::sqrt(ll * rr);
return TUnboxedValuePod{cosine};
}

SIMPLE_STRICT_UDF(TCosineDistance, TOptional<float>(TAutoMap<TListType<float>>, TAutoMap<TListType<float>>)) {
SIMPLE_STRICT_UDF(TCosineDistance, TOptional<float>(TAutoMap<const char*>, TAutoMap<const char*>)) {
Y_UNUSED(valueBuilder);

auto cosine = CosineSimilarity(args[0], args[1]);
if (!cosine)
return {};

return TUnboxedValuePod{1 - cosine.value()};
}
const TArrayRef<const float> vector1 = TSerializerFacade::GetArray(args[0].AsStringRef());
const TArrayRef<const float> vector2 = TSerializerFacade::GetArray(args[1].AsStringRef());

SIMPLE_STRICT_UDF(TEuclideanDistance, TOptional<float>(TAutoMap<TListType<float>>, TAutoMap<TListType<float>>)) {
Y_UNUSED(valueBuilder);

auto distance = EuclideanDistance(args[0], args[1]);
if (!distance)
return {};
if (vector1.size() != vector2.size() || vector1.empty() || vector2.empty())
return {};

return TUnboxedValuePod{distance.value()};
const auto [ll, lr, rr] = TriWayDotProduct(vector1.data(), vector2.data(), vector1.size());
const float cosine = lr / std::sqrt(ll * rr);
return TUnboxedValuePod{1 - cosine};
}

SIMPLE_MODULE(TKnnModule,
TFromBinaryString,
TToBinaryString,
TInnerProductSimilarity,
TCosineSimilarity,
TCosineDistance,
TEuclideanDistance
TCosineDistance
)

REGISTER_MODULES(TKnnModule)
Expand Down
5 changes: 0 additions & 5 deletions ydb/library/yql/udfs/common/knn/test/canondata/result.json
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,6 @@
"uri": "file://test.test_DeserializationError_/results.txt"
}
],
"test.test[EuclideanDistance]": [
{
"uri": "file://test.test_EuclideanDistance_/results.txt"
}
],
"test.test[InnerProductSimilarity]": [
{
"uri": "file://test.test_InnerProductSimilarity_/results.txt"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
"Data" = [
[
[
"0.025368154"
"0.025368094"
]
]
]
Expand Down Expand Up @@ -54,131 +54,7 @@
"Data" = [
[
[
"0.025368154"
]
]
]
}
]
};
{
"Write" = [
{
"Type" = [
"ListType";
[
"StructType";
[
[
"column0";
[
"OptionalType";
[
"DataType";
"Float"
]
]
]
]
]
];
"Data" = [
[
[
"0.025368154"
]
]
]
}
]
};
{
"Write" = [
{
"Type" = [
"ListType";
[
"StructType";
[
[
"column0";
[
"OptionalType";
[
"DataType";
"Float"
]
]
]
]
]
];
"Data" = [
[
[
"0.025368154"
]
]
]
}
]
};
{
"Write" = [
{
"Type" = [
"ListType";
[
"StructType";
[
[
"column0";
[
"OptionalType";
[
"DataType";
"Float"
]
]
]
]
]
];
"Data" = [
[
[
"0.025368154"
]
]
]
}
]
};
{
"Write" = [
{
"Type" = [
"ListType";
[
"StructType";
[
[
"column0";
[
"OptionalType";
[
"DataType";
"Float"
]
]
]
]
]
];
"Data" = [
[
[
"0.000000059604645"
"0"
]
]
]
Expand Down
Loading