Skip to content

Commit 2b011a8

Browse files
authored
YQ-2549 Checkpointing in match_recognize / merge to stable (#2426)
1 parent dd39aa2 commit 2b011a8

13 files changed

+1570
-33
lines changed

ydb/library/yql/minikql/comp_nodes/mkql_match_recognize.cpp

Lines changed: 196 additions & 23 deletions
Large diffs are not rendered by default.

ydb/library/yql/minikql/comp_nodes/mkql_match_recognize_list.h

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
#pragma once
2+
3+
#include "mkql_match_recognize_save_load.h"
4+
25
#include <ydb/library/yql/minikql/defs.h>
36
#include <ydb/library/yql/minikql/computation/mkql_computation_node_impl.h>
47
#include <ydb/library/yql/minikql/computation/mkql_computation_node_holders.h>
8+
#include <ydb/library/yql/minikql/comp_nodes/mkql_saveload.h>
59
#include <ydb/library/yql/public/udf/udf_value.h>
610
#include <unordered_map>
711

@@ -131,15 +135,37 @@ class TSparseList {
131135
}
132136
}
133137

138+
void Save(TOutputSerializer& serializer) const {
139+
serializer(Storage.size());
140+
for (const auto& [key, item]: Storage) {
141+
serializer(key, item.Value, item.LockCount);
142+
}
143+
}
144+
145+
void Load(TInputSerializer& serializer) {
146+
auto size = serializer.Read<TStorage::size_type>();
147+
Storage.reserve(size);
148+
for (size_t i = 0; i < size; ++i) {
149+
TStorage::key_type key;
150+
NUdf::TUnboxedValue row;
151+
decltype(TItem::LockCount) lockCount;
152+
serializer(key, row, lockCount);
153+
Storage.emplace(key, TItem{row, lockCount});
154+
}
155+
}
156+
134157
private:
135158
//TODO consider to replace hash table with contiguous chunks
136159
using TAllocator = TMKQLAllocator<std::pair<const size_t, TItem>, EMemorySubPool::Temporary>;
137-
std::unordered_map<
160+
161+
using TStorage = std::unordered_map<
138162
size_t,
139163
TItem,
140164
std::hash<size_t>,
141165
std::equal_to<size_t>,
142-
TAllocator> Storage;
166+
TAllocator>;
167+
168+
TStorage Storage;
143169
};
144170
using TContainerPtr = TContainer::TPtr;
145171

@@ -242,6 +268,14 @@ class TSparseList {
242268
ToIndex = -1;
243269
}
244270

271+
void Save(TOutputSerializer& serializer) const {
272+
serializer(Container, FromIndex, ToIndex);
273+
}
274+
275+
void Load(TInputSerializer& serializer) {
276+
serializer(Container, FromIndex, ToIndex);
277+
}
278+
245279
private:
246280
TRange(TContainerPtr container, size_t index)
247281
: Container(container)
@@ -297,6 +331,14 @@ class TSparseList {
297331
return Size() == 0;
298332
}
299333

334+
void Save(TOutputSerializer& serializer) const {
335+
serializer(Container, ListSize);
336+
}
337+
338+
void Load(TInputSerializer& serializer) {
339+
serializer(Container, ListSize);
340+
}
341+
300342
private:
301343
TContainerPtr Container = MakeIntrusive<TContainer>();
302344
size_t ListSize = 0; //impl: max index ever stored + 1

ydb/library/yql/minikql/comp_nodes/mkql_match_recognize_matched_vars.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ namespace NKikimr::NMiniKQL::NMatchRecognize {
88

99
template<class R>
1010
using TMatchedVar = std::vector<R, TMKQLAllocator<R>>;
11+
1112
template<class R>
1213
void Extend(TMatchedVar<R>& var, const R& r) {
1314
if (var.empty()) {
@@ -110,8 +111,7 @@ class TMatchedVarsValue : public TComputationValue<TMatchedVarsValue<R>> {
110111
: TComputationValue<TMatchedVarsValue>(memInfo)
111112
, HolderFactory(holderFactory)
112113
, Vars(vars)
113-
{
114-
}
114+
{}
115115

116116
NUdf::TUnboxedValue GetElement(ui32 index) const override {
117117
return HolderFactory.Create<TRangeList>(HolderFactory, Vars[index]);

ydb/library/yql/minikql/comp_nodes/mkql_match_recognize_measure_arg.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class TRowForMeasureValue: public TComputationValue<TRowForMeasureValue>
3232
, VarNames(varNames)
3333
, MatchNumber(matchNumber)
3434
{}
35+
3536
NUdf::TUnboxedValue GetElement(ui32 index) const override {
3637
switch(ColumnOrder[index].first) {
3738
case EMeasureInputDataSpecialColumns::Classifier: {

ydb/library/yql/minikql/comp_nodes/mkql_match_recognize_nfa.h

Lines changed: 132 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

33
#include "mkql_match_recognize_matched_vars.h"
4+
#include "mkql_match_recognize_save_load.h"
45
#include "../computation/mkql_computation_node_holders.h"
56
#include "../computation/mkql_computation_node_impl.h"
67
#include <ydb/library/yql/core/sql_types/match_recognize.h>
@@ -12,20 +13,38 @@ namespace NKikimr::NMiniKQL::NMatchRecognize {
1213
using namespace NYql::NMatchRecognize;
1314

1415
struct TVoidTransition {
16+
friend bool operator==(const TVoidTransition&, const TVoidTransition&) {
17+
return true;
18+
}
1519
};
1620
using TEpsilonTransition = size_t; //to
1721
using TEpsilonTransitions = std::vector<TEpsilonTransition, TMKQLAllocator<TEpsilonTransition>>;
1822
using TMatchedVarTransition = std::pair<std::pair<ui32, bool>, size_t>; //{{varIndex, saveState}, to}
1923
using TQuantityEnterTransition = size_t; //to
2024
using TQuantityExitTransition = std::pair<std::pair<ui64, ui64>, std::pair<size_t, size_t>>; //{{min, max}, {foFindMore, toMatched}}
21-
using TNfaTransition = std::variant<
25+
26+
template <typename... Ts>
27+
struct TVariantHelper {
28+
using TVariant = std::variant<Ts...>;
29+
using TTuple = std::tuple<Ts...>;
30+
31+
static std::variant<Ts...> getVariantByIndex(size_t i) {
32+
MKQL_ENSURE(i < sizeof...(Ts), "Wrong variant index");
33+
static std::variant<Ts...> table[] = { Ts{ }... };
34+
return table[i];
35+
}
36+
};
37+
38+
using TNfaTransitionHelper = TVariantHelper<
2239
TVoidTransition,
2340
TMatchedVarTransition,
2441
TEpsilonTransitions,
2542
TQuantityEnterTransition,
2643
TQuantityExitTransition
2744
>;
2845

46+
using TNfaTransition = TNfaTransitionHelper::TVariant;
47+
2948
struct TNfaTransitionDestinationVisitor {
3049
std::function<size_t(size_t)> callback;
3150

@@ -61,11 +80,42 @@ struct TNfaTransitionDestinationVisitor {
6180
};
6281

6382
struct TNfaTransitionGraph {
64-
std::vector<TNfaTransition, TMKQLAllocator<TNfaTransition>> Transitions;
83+
using TTransitions = std::vector<TNfaTransition, TMKQLAllocator<TNfaTransition>>;
84+
85+
TTransitions Transitions;
6586
size_t Input;
6687
size_t Output;
6788

6889
using TPtr = std::shared_ptr<TNfaTransitionGraph>;
90+
91+
template<class>
92+
inline constexpr static bool always_false_v = false;
93+
94+
void Save(TOutputSerializer& serializer) const {
95+
serializer(Transitions.size());
96+
for (ui64 i = 0; i < Transitions.size(); ++i) {
97+
serializer.Write(Transitions[i].index());
98+
std::visit(serializer, Transitions[i]);
99+
}
100+
serializer(Input, Output);
101+
}
102+
103+
void Load(TInputSerializer& serializer) {
104+
ui64 transitionSize = serializer.Read<TTransitions::size_type>();
105+
Transitions.resize(transitionSize);
106+
for (ui64 i = 0; i < transitionSize; ++i) {
107+
size_t index = serializer.Read<std::size_t>();
108+
Transitions[i] = TNfaTransitionHelper::getVariantByIndex(index);
109+
std::visit(serializer, Transitions[i]);
110+
}
111+
serializer(Input, Output);
112+
}
113+
114+
bool operator==(const TNfaTransitionGraph& other) {
115+
return Transitions == other.Transitions
116+
&& Input == other.Input
117+
&& Output == other.Output;
118+
}
69119
};
70120

71121
class TNfaTransitionGraphOptimizer {
@@ -78,6 +128,7 @@ class TNfaTransitionGraphOptimizer {
78128
EliminateSingleEpsilons();
79129
CollectGarbage();
80130
}
131+
81132
private:
82133
void EliminateEpsilonChains() {
83134
for (size_t node = 0; node != Graph->Transitions.size(); node++) {
@@ -250,14 +301,69 @@ class TNfaTransitionGraphBuilder {
250301
class TNfa {
251302
using TRange = TSparseList::TRange;
252303
using TMatchedVars = TMatchedVars<TRange>;
304+
305+
253306
struct TState {
307+
308+
TState() {}
309+
254310
TState(size_t index, const TMatchedVars& vars, std::stack<ui64, std::deque<ui64, TMKQLAllocator<ui64>>>&& quantifiers)
255311
: Index(index)
256312
, Vars(vars)
257313
, Quantifiers(quantifiers) {}
258-
const size_t Index;
314+
size_t Index;
259315
TMatchedVars Vars;
260-
std::stack<ui64, std::deque<ui64, TMKQLAllocator<ui64>>> Quantifiers; //get rid of this
316+
317+
using TQuantifiersStdStack = std::stack<
318+
ui64,
319+
std::deque<ui64, TMKQLAllocator<ui64>>>; //get rid of this
320+
321+
struct TQuantifiersStack: public TQuantifiersStdStack {
322+
template<typename...TArgs>
323+
TQuantifiersStack(TArgs... args) : TQuantifiersStdStack(args...) {}
324+
325+
auto begin() const { return c.begin(); }
326+
auto end() const { return c.end(); }
327+
auto clear() { return c.clear(); }
328+
};
329+
330+
TQuantifiersStack Quantifiers;
331+
332+
void Save(TOutputSerializer& serializer) const {
333+
serializer.Write(Index);
334+
serializer.Write(Vars.size());
335+
for (const auto& vector : Vars) {
336+
serializer.Write(vector.size());
337+
for (const auto& range : vector) {
338+
range.Save(serializer);
339+
}
340+
}
341+
serializer.Write(Quantifiers.size());
342+
for (ui64 qnt : Quantifiers) {
343+
serializer.Write(qnt);
344+
}
345+
}
346+
347+
void Load(TInputSerializer& serializer) {
348+
serializer.Read(Index);
349+
350+
auto varsSize = serializer.Read<TMatchedVars::size_type>();
351+
Vars.clear();
352+
Vars.resize(varsSize);
353+
for (auto& subvec: Vars) {
354+
ui64 vectorSize = serializer.Read<ui64>();
355+
subvec.resize(vectorSize);
356+
for (auto& item : subvec) {
357+
item.Load(serializer);
358+
}
359+
}
360+
Quantifiers.clear();
361+
auto quantifiersSize = serializer.Read<ui64>();
362+
for (size_t i = 0; i < quantifiersSize; ++i) {
363+
ui64 qnt = serializer.Read<ui64>();
364+
Quantifiers.push(qnt);
365+
}
366+
}
261367

262368
friend inline bool operator<(const TState& lhs, const TState& rhs) {
263369
return std::tie(lhs.Index, lhs.Quantifiers, lhs.Vars) < std::tie(rhs.Index, rhs.Quantifiers, rhs.Vars);
@@ -267,13 +373,14 @@ class TNfa {
267373
}
268374
};
269375
public:
376+
270377
TNfa(TNfaTransitionGraph::TPtr transitionGraph, IComputationExternalNode* matchedRangesArg, const TComputationNodePtrVector& defines)
271378
: TransitionGraph(transitionGraph)
272379
, MatchedRangesArg(matchedRangesArg)
273380
, Defines(defines) {
274381
}
275382

276-
void ProcessRow(TSparseList::TRange&& currentRowLock, TComputationContext& ctx) {
383+
void ProcessRow(TSparseList::TRange&& currentRowLock, TComputationContext& ctx) {
277384
ActiveStates.emplace(TransitionGraph->Input, TMatchedVars(Defines.size()), std::stack<ui64, std::deque<ui64, TMKQLAllocator<ui64>>>{});
278385
MakeEpsilonTransitions();
279386
std::set<TState, std::less<TState>, TMKQLAllocator<TState>> newStates;
@@ -329,6 +436,25 @@ class TNfa {
329436
return ActiveStates.size();
330437
}
331438

439+
void Save(TOutputSerializer& serializer) const {
440+
// TransitionGraph is not saved/loaded, passed in constructor.
441+
serializer.Write(ActiveStates.size());
442+
for (const auto& state : ActiveStates) {
443+
state.Save(serializer);
444+
}
445+
serializer.Write(EpsilonTransitionsLastRow);
446+
}
447+
448+
void Load(TInputSerializer& serializer) {
449+
auto stateSize = serializer.Read<ui64>();
450+
for (size_t i = 0; i < stateSize; ++i) {
451+
TState state;
452+
state.Load(serializer);
453+
ActiveStates.emplace(state);
454+
}
455+
serializer.Read(EpsilonTransitionsLastRow);
456+
}
457+
332458
private:
333459
//TODO (zverevgeny): Consider to change to std::vector for the sake of perf
334460
using TStateSet = std::set<TState, std::less<TState>, TMKQLAllocator<TState>>;
@@ -376,6 +502,7 @@ class TNfa {
376502
TStateSet& NewStates;
377503
TStateSet& DeletedStates;
378504
};
505+
379506
bool MakeEpsilonTransitionsImpl() {
380507
TStateSet newStates;
381508
TStateSet deletedStates;

0 commit comments

Comments
 (0)