Skip to content

Exact vector search #2519

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Mar 14, 2024
Merged
96 changes: 96 additions & 0 deletions ydb/library/yql/udfs/common/knn/knn-enumerator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
#pragma once

#include <ydb/library/yql/public/udf/udf_helpers.h>

#include <util/generic/buffer.h>
#include <util/stream/format.h>

using namespace NYql;
using namespace NYql::NUdf;

template <typename TCallback>
void EnumerateVector(const TUnboxedValuePod vector, TCallback&& callback) {
const auto elements = vector.GetElements();
if (elements) {
const auto size = vector.GetListLength();

for (ui32 i = 0; i < size; ++i) {
callback(elements[i].Get<float>());
}
} else {
TUnboxedValue value;
const auto it = vector.GetListIterator();
while (it.Next(value)) {
callback(value.Get<float>());
}
}
}

template <typename TCallback>
bool EnumerateVectors(const TUnboxedValuePod vector1, const TUnboxedValuePod vector2, TCallback&& callback) {

auto enumerateBothSized = [&callback] (const TUnboxedValuePod vector1, const TUnboxedValue* elements1, const TUnboxedValuePod vector2, const TUnboxedValue* elements2) {
const auto size1 = vector1.GetListLength();
const auto size2 = vector2.GetListLength();

// Length mismatch
if (size1 != size2)
return false;

for (ui32 i = 0; i < size1; ++i) {
callback(elements1[i].Get<float>(), elements2[i].Get<float>());
}

return true;
};

auto enumerateOneSized = [&callback] (const TUnboxedValuePod vector1, const TUnboxedValue* elements1, const TUnboxedValuePod vector2) {
const auto size = vector1.GetListLength();
ui32 idx = 0;
TUnboxedValue value;
const auto it = vector2.GetListIterator();

while (it.Next(value)) {
callback(elements1[idx++].Get<float>(), value.Get<float>());
}

// Length mismatch
if (it.Next(value) || idx != size)
return false;

return true;
};

auto enumerateNoSized = [&callback] (const TUnboxedValuePod vector1, const TUnboxedValuePod vector2) {
TUnboxedValue value1, value2;
const auto it1 = vector1.GetListIterator();
const auto it2 = vector2.GetListIterator();
for (; it1.Next(value1) && it2.Next(value2);) {
callback(value1.Get<float>(), value2.Get<float>());
}

// Length mismatch
if (it1.Next(value1) || it2.Next(value2))
return false;

return true;
};

const auto elements1 = vector1.GetElements();
const auto elements2 = vector2.GetElements();
if (elements1 && elements2) {
if (!enumerateBothSized(vector1, elements1, vector2, elements2))
return false;
} else if (elements1) {
if (!enumerateOneSized(vector1, elements1, vector2))
return false;
} else if (elements2) {
if (!enumerateOneSized(vector2, elements2, vector1))
return false;
} else {
if (!enumerateNoSized(vector1, vector2))
return false;
}

return true;
}
93 changes: 93 additions & 0 deletions ydb/library/yql/udfs/common/knn/knn-serializer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
#pragma once

#include "knn-enumerator.h"

#include <ydb/library/yql/public/udf/udf_helpers.h>

#include <util/generic/buffer.h>
#include <util/stream/format.h>

using namespace NYql;
using namespace NYql::NUdf;

enum EFormat : ui8 {
FloatVector = 1
};


class TFloatVectorSerializer {
public:
static TUnboxedValue Serialize(const IValueBuilder* valueBuilder, const TUnboxedValue x) {
auto serialize = [&x] (IOutputStream& outStream) {
const EFormat format = EFormat::FloatVector;
outStream.Write(&format, 1);
EnumerateVector(x, [&outStream] (float element) { outStream.Write(&element, sizeof(float)); });
};

if (x.HasFastListLength()) {
auto str = valueBuilder->NewStringNotFilled(sizeof(ui8) + x.GetListLength() * sizeof(float));
auto strRef = str.AsStringRef();
TMemoryOutput memoryOutput(strRef.Data(), strRef.Size());

serialize(memoryOutput);
return str;
} else {
TString str;
TStringOutput stringOutput(str);

serialize(stringOutput);
return valueBuilder->NewString(str);
}
}

static TUnboxedValue Deserialize(const IValueBuilder *valueBuilder, const TStringRef& str) {
//skip format byte, it was already read
const char* buf = str.Data() + 1;
const size_t len = str.Size() - 1;

if (len % sizeof(float) != 0)
return {};

const ui32 count = len / sizeof(float);

TUnboxedValue* items = nullptr;
auto res = valueBuilder->NewArray(count, items);

TMemoryInput inStr(buf, len);
for (ui32 i = 0; i < count; ++i) {
float element;
if (inStr.Read(&element, sizeof(float)) != sizeof(float))
return {};
*items++ = TUnboxedValuePod{element};
}

return res.Release();
}
};


class TSerializerFacade {
public:
static TUnboxedValue Serialize(EFormat format, const IValueBuilder* valueBuilder, const TUnboxedValue x) {
switch (format) {
case EFormat::FloatVector:
return TFloatVectorSerializer::Serialize(valueBuilder, x);
default:
return {};
}
}

static TUnboxedValue Deserialize(const IValueBuilder *valueBuilder, const TStringRef& str) {
if (str.Size() == 0)
return {};

ui8 formatByte = str.Data()[0];
switch (formatByte) {
case EFormat::FloatVector:
return TFloatVectorSerializer::Deserialize(valueBuilder, str);
default:
return {};
}
}
};

117 changes: 117 additions & 0 deletions ydb/library/yql/udfs/common/knn/knn.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
#include "knn-enumerator.h"
#include "knn-serializer.h"

#include <ydb/library/yql/public/udf/udf_helpers.h>

#include <util/generic/buffer.h>
#include <util/stream/format.h>

using namespace NYql;
using namespace NYql::NUdf;


SIMPLE_STRICT_UDF(TToBinaryString, char*(TAutoMap<TListType<float>>)) {
const TUnboxedValuePod x = args[0];
const EFormat format = EFormat::FloatVector; // will be taken from args in future

return TSerializerFacade::Serialize(format, valueBuilder, x);
}

SIMPLE_STRICT_UDF(TFromBinaryString, TOptional<TListType<float>>(const char*)) {
TStringRef str = args[0].AsStringRef();

return TSerializerFacade::Deserialize(valueBuilder, str);
}


std::optional<float> InnerProductSimilarity(const TUnboxedValuePod vector1, const TUnboxedValuePod vector2) {
float ret = 0;

if (!EnumerateVectors(vector1, vector2, [&ret](float el1, float el2) { ret += el1 * el2;}))
return {};

return ret;
}

std::optional<float> CosineSimilarity(const TUnboxedValuePod vector1, const TUnboxedValuePod vector2) {
float len1 = 0;
float len2 = 0;
float innerProduct = 0;

if (!EnumerateVectors(vector1, vector2, [&](float el1, float el2) {
innerProduct += el1 * el2;
len1 += el1 * el1;
len2 += el2 * el2;
}))
return {};

len1 = sqrt(len1);
len2 = sqrt(len2);

float cosine = innerProduct / len1 / len2;

return cosine;
}

std::optional<float> EuclideanDistance(const TUnboxedValuePod vector1, const TUnboxedValuePod vector2) {
float ret = 0;

if (!EnumerateVectors(vector1, vector2, [&ret](float el1, float el2) { ret += (el1 - el2) * (el1 - el2);}))
return {};

ret = sqrtf(ret);

return ret;
}

SIMPLE_STRICT_UDF(TInnerProductSimilarity, TOptional<float>(TAutoMap<TListType<float>>, TAutoMap<TListType<float>>)) {
Y_UNUSED(valueBuilder);

auto innerProduct = InnerProductSimilarity(args[0], args[1]);
if (!innerProduct)
return {};

return TUnboxedValuePod{innerProduct.value()};
}

SIMPLE_STRICT_UDF(TCosineSimilarity, TOptional<float>(TAutoMap<TListType<float>>, TAutoMap<TListType<float>>)) {
Y_UNUSED(valueBuilder);

auto cosine = CosineSimilarity(args[0], args[1]);
if (!cosine)
return {};

return TUnboxedValuePod{cosine.value()};
}

SIMPLE_STRICT_UDF(TCosineDistance, TOptional<float>(TAutoMap<TListType<float>>, TAutoMap<TListType<float>>)) {
Y_UNUSED(valueBuilder);

auto cosine = CosineSimilarity(args[0], args[1]);
if (!cosine)
return {};

return TUnboxedValuePod{1 - cosine.value()};
}

SIMPLE_STRICT_UDF(TEuclideanDistance, TOptional<float>(TAutoMap<TListType<float>>, TAutoMap<TListType<float>>)) {
Y_UNUSED(valueBuilder);

auto distance = EuclideanDistance(args[0], args[1]);
if (!distance)
return {};

return TUnboxedValuePod{distance.value()};
}

SIMPLE_MODULE(TKnnModule,
TFromBinaryString,
TToBinaryString,
TInnerProductSimilarity,
TCosineSimilarity,
TCosineDistance,
TEuclideanDistance
)

REGISTER_MODULES(TKnnModule)

37 changes: 37 additions & 0 deletions ydb/library/yql/udfs/common/knn/test/canondata/result.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
{
"test.test[CosineDistance]": [
{
"uri": "file://test.test_CosineDistance_/results.txt"
}
],
"test.test[CosineSimilarity]": [
{
"uri": "file://test.test_CosineSimilarity_/results.txt"
}
],
"test.test[DeserializationError]": [
{
"uri": "file://test.test_DeserializationError_/results.txt"
}
],
"test.test[EuclideanDistance]": [
{
"uri": "file://test.test_EuclideanDistance_/results.txt"
}
],
"test.test[InnerProductSimilarity]": [
{
"uri": "file://test.test_InnerProductSimilarity_/results.txt"
}
],
"test.test[LazyListSerialization]": [
{
"uri": "file://test.test_LazyListSerialization_/results.txt"
}
],
"test.test[ListSerialization]": [
{
"uri": "file://test.test_ListSerialization_/results.txt"
}
]
}
Loading