1
1
#pragma once
2
2
3
3
#include " mkql_match_recognize_matched_vars.h"
4
+ #include " mkql_match_recognize_save_load.h"
4
5
#include " ../computation/mkql_computation_node_holders.h"
5
6
#include " ../computation/mkql_computation_node_impl.h"
6
7
#include < ydb/library/yql/core/sql_types/match_recognize.h>
@@ -12,20 +13,38 @@ namespace NKikimr::NMiniKQL::NMatchRecognize {
12
13
using namespace NYql ::NMatchRecognize;
13
14
14
15
struct TVoidTransition {
16
+ friend bool operator ==(const TVoidTransition&, const TVoidTransition&) {
17
+ return true ;
18
+ }
15
19
};
16
20
using TEpsilonTransition = size_t ; // to
17
21
using TEpsilonTransitions = std::vector<TEpsilonTransition, TMKQLAllocator<TEpsilonTransition>>;
18
22
using TMatchedVarTransition = std::pair<std::pair<ui32, bool >, size_t >; // {{varIndex, saveState}, to}
19
23
using TQuantityEnterTransition = size_t ; // to
20
24
using TQuantityExitTransition = std::pair<std::pair<ui64, ui64>, std::pair<size_t , size_t >>; // {{min, max}, {foFindMore, toMatched}}
21
- using TNfaTransition = std::variant<
25
+
26
+ template <typename ... Ts>
27
+ struct TVariantHelper {
28
+ using TVariant = std::variant<Ts...>;
29
+ using TTuple = std::tuple<Ts...>;
30
+
31
+ static std::variant<Ts...> getVariantByIndex (size_t i) {
32
+ MKQL_ENSURE (i < sizeof ...(Ts), " Wrong variant index" );
33
+ static std::variant<Ts...> table[] = { Ts{ }... };
34
+ return table[i];
35
+ }
36
+ };
37
+
38
+ using TNfaTransitionHelper = TVariantHelper<
22
39
TVoidTransition,
23
40
TMatchedVarTransition,
24
41
TEpsilonTransitions,
25
42
TQuantityEnterTransition,
26
43
TQuantityExitTransition
27
44
>;
28
45
46
+ using TNfaTransition = TNfaTransitionHelper::TVariant;
47
+
29
48
struct TNfaTransitionDestinationVisitor {
30
49
std::function<size_t (size_t )> callback;
31
50
@@ -61,11 +80,42 @@ struct TNfaTransitionDestinationVisitor {
61
80
};
62
81
63
82
struct TNfaTransitionGraph {
64
- std::vector<TNfaTransition, TMKQLAllocator<TNfaTransition>> Transitions;
83
+ using TTransitions = std::vector<TNfaTransition, TMKQLAllocator<TNfaTransition>>;
84
+
85
+ TTransitions Transitions;
65
86
size_t Input;
66
87
size_t Output;
67
88
68
89
using TPtr = std::shared_ptr<TNfaTransitionGraph>;
90
+
91
+ template <class >
92
+ inline constexpr static bool always_false_v = false ;
93
+
94
+ void Save (TOutputSerializer& serializer) const {
95
+ serializer (Transitions.size ());
96
+ for (ui64 i = 0 ; i < Transitions.size (); ++i) {
97
+ serializer.Write (Transitions[i].index ());
98
+ std::visit (serializer, Transitions[i]);
99
+ }
100
+ serializer (Input, Output);
101
+ }
102
+
103
+ void Load (TInputSerializer& serializer) {
104
+ ui64 transitionSize = serializer.Read <TTransitions::size_type>();
105
+ Transitions.resize (transitionSize);
106
+ for (ui64 i = 0 ; i < transitionSize; ++i) {
107
+ size_t index = serializer.Read <std::size_t >();
108
+ Transitions[i] = TNfaTransitionHelper::getVariantByIndex (index );
109
+ std::visit (serializer, Transitions[i]);
110
+ }
111
+ serializer (Input, Output);
112
+ }
113
+
114
+ bool operator ==(const TNfaTransitionGraph& other) {
115
+ return Transitions == other.Transitions
116
+ && Input == other.Input
117
+ && Output == other.Output ;
118
+ }
69
119
};
70
120
71
121
class TNfaTransitionGraphOptimizer {
@@ -78,6 +128,7 @@ class TNfaTransitionGraphOptimizer {
78
128
EliminateSingleEpsilons ();
79
129
CollectGarbage ();
80
130
}
131
+
81
132
private:
82
133
void EliminateEpsilonChains () {
83
134
for (size_t node = 0 ; node != Graph->Transitions .size (); node++) {
@@ -250,14 +301,69 @@ class TNfaTransitionGraphBuilder {
250
301
class TNfa {
251
302
using TRange = TSparseList::TRange;
252
303
using TMatchedVars = TMatchedVars<TRange>;
304
+
305
+
253
306
struct TState {
307
+
308
+ TState () {}
309
+
254
310
TState (size_t index, const TMatchedVars& vars, std::stack<ui64, std::deque<ui64, TMKQLAllocator<ui64>>>&& quantifiers)
255
311
: Index(index)
256
312
, Vars(vars)
257
313
, Quantifiers(quantifiers) {}
258
- const size_t Index;
314
+ size_t Index;
259
315
TMatchedVars Vars;
260
- std::stack<ui64, std::deque<ui64, TMKQLAllocator<ui64>>> Quantifiers; // get rid of this
316
+
317
+ using TQuantifiersStdStack = std::stack<
318
+ ui64,
319
+ std::deque<ui64, TMKQLAllocator<ui64>>>; // get rid of this
320
+
321
+ struct TQuantifiersStack : public TQuantifiersStdStack {
322
+ template <typename ...TArgs>
323
+ TQuantifiersStack (TArgs... args) : TQuantifiersStdStack(args...) {}
324
+
325
+ auto begin () const { return c.begin (); }
326
+ auto end () const { return c.end (); }
327
+ auto clear () { return c.clear (); }
328
+ };
329
+
330
+ TQuantifiersStack Quantifiers;
331
+
332
+ void Save (TOutputSerializer& serializer) const {
333
+ serializer.Write (Index);
334
+ serializer.Write (Vars.size ());
335
+ for (const auto & vector : Vars) {
336
+ serializer.Write (vector.size ());
337
+ for (const auto & range : vector) {
338
+ range.Save (serializer);
339
+ }
340
+ }
341
+ serializer.Write (Quantifiers.size ());
342
+ for (ui64 qnt : Quantifiers) {
343
+ serializer.Write (qnt);
344
+ }
345
+ }
346
+
347
+ void Load (TInputSerializer& serializer) {
348
+ serializer.Read (Index);
349
+
350
+ auto varsSize = serializer.Read <TMatchedVars::size_type>();
351
+ Vars.clear ();
352
+ Vars.resize (varsSize);
353
+ for (auto & subvec: Vars) {
354
+ ui64 vectorSize = serializer.Read <ui64>();
355
+ subvec.resize (vectorSize);
356
+ for (auto & item : subvec) {
357
+ item.Load (serializer);
358
+ }
359
+ }
360
+ Quantifiers.clear ();
361
+ auto quantifiersSize = serializer.Read <ui64>();
362
+ for (size_t i = 0 ; i < quantifiersSize; ++i) {
363
+ ui64 qnt = serializer.Read <ui64>();
364
+ Quantifiers.push (qnt);
365
+ }
366
+ }
261
367
262
368
friend inline bool operator <(const TState& lhs, const TState& rhs) {
263
369
return std::tie (lhs.Index , lhs.Quantifiers , lhs.Vars ) < std::tie (rhs.Index , rhs.Quantifiers , rhs.Vars );
@@ -267,13 +373,14 @@ class TNfa {
267
373
}
268
374
};
269
375
public:
376
+
270
377
TNfa (TNfaTransitionGraph::TPtr transitionGraph, IComputationExternalNode* matchedRangesArg, const TComputationNodePtrVector& defines)
271
378
: TransitionGraph(transitionGraph)
272
379
, MatchedRangesArg(matchedRangesArg)
273
380
, Defines(defines) {
274
381
}
275
382
276
- void ProcessRow (TSparseList::TRange&& currentRowLock, TComputationContext& ctx) {
383
+ void ProcessRow (TSparseList::TRange&& currentRowLock, TComputationContext& ctx) {
277
384
ActiveStates.emplace (TransitionGraph->Input , TMatchedVars (Defines.size ()), std::stack<ui64, std::deque<ui64, TMKQLAllocator<ui64>>>{});
278
385
MakeEpsilonTransitions ();
279
386
std::set<TState, std::less<TState>, TMKQLAllocator<TState>> newStates;
@@ -329,6 +436,25 @@ class TNfa {
329
436
return ActiveStates.size ();
330
437
}
331
438
439
+ void Save (TOutputSerializer& serializer) const {
440
+ // TransitionGraph is not saved/loaded, passed in constructor.
441
+ serializer.Write (ActiveStates.size ());
442
+ for (const auto & state : ActiveStates) {
443
+ state.Save (serializer);
444
+ }
445
+ serializer.Write (EpsilonTransitionsLastRow);
446
+ }
447
+
448
+ void Load (TInputSerializer& serializer) {
449
+ auto stateSize = serializer.Read <ui64>();
450
+ for (size_t i = 0 ; i < stateSize; ++i) {
451
+ TState state;
452
+ state.Load (serializer);
453
+ ActiveStates.emplace (state);
454
+ }
455
+ serializer.Read (EpsilonTransitionsLastRow);
456
+ }
457
+
332
458
private:
333
459
// TODO (zverevgeny): Consider to change to std::vector for the sake of perf
334
460
using TStateSet = std::set<TState, std::less<TState>, TMKQLAllocator<TState>>;
@@ -376,6 +502,7 @@ class TNfa {
376
502
TStateSet& NewStates;
377
503
TStateSet& DeletedStates;
378
504
};
505
+
379
506
bool MakeEpsilonTransitionsImpl () {
380
507
TStateSet newStates;
381
508
TStateSet deletedStates;
0 commit comments