Skip to content

Commit 29e1e45

Browse files
authored
Update LLVM part of combiner with spilling. (#6277)
1 parent f256aea commit 29e1e45

File tree

1 file changed

+80
-62
lines changed

1 file changed

+80
-62
lines changed

ydb/library/yql/minikql/comp_nodes/mkql_wide_combine.cpp

+80-62
Original file line numberDiff line numberDiff line change
@@ -343,8 +343,8 @@ class TSpillingSupportState : public TComputationValue<TSpillingSupportState> {
343343
};
344344

345345
public:
346-
enum class ETasteResult: ui8 {
347-
Init,
346+
enum class ETasteResult: i8 {
347+
Init = -1,
348348
Update,
349349
Skip
350350
};
@@ -372,15 +372,9 @@ class TSpillingSupportState : public TComputationValue<TSpillingSupportState> {
372372
Tongue = InMemoryProcessingState.Tongue;
373373
Throat = InMemoryProcessingState.Throat;
374374
}
375-
~TSpillingSupportState() {
376-
}
377-
378-
bool IsFetchRequired() const {
379-
return InputStatus != EFetchResult::Finish;
380-
}
381375

382376
bool HasAnyData() const {
383-
return SpilledBuckets.size();
377+
return !SpilledBuckets.empty();
384378
}
385379

386380
bool IsProcessingRequired() const {
@@ -456,6 +450,20 @@ class TSpillingSupportState : public TComputationValue<TSpillingSupportState> {
456450
return ETasteResult::Skip;
457451
}
458452

453+
NUdf::TUnboxedValuePod* Extract() {
454+
if (GetMode() == EOperatingMode::InMemory) return static_cast<NUdf::TUnboxedValue*>(InMemoryProcessingState.Extract());
455+
456+
MKQL_ENSURE(SpilledBuckets.front().BucketState == TSpilledBucket::EBucketState::InMemory, "Internal logic error");
457+
MKQL_ENSURE(SpilledBuckets.size() > 0, "Internal logic error");
458+
459+
auto value = static_cast<NUdf::TUnboxedValue*>(SpilledBuckets.front().InMemoryProcessingState->Extract());
460+
if (!value) {
461+
SpilledBuckets.pop_front();
462+
}
463+
464+
return value;
465+
}
466+
private:
459467
void MoveKeyToBucket(TSpilledBucket& bucket) {
460468
for (size_t i = 0; i < KeyWidth; ++i) {
461469
//jumping into unsafe world, refusing ownership
@@ -483,20 +491,6 @@ class TSpillingSupportState : public TComputationValue<TSpillingSupportState> {
483491
BufferForUsedInputItems.resize(0);
484492
}
485493

486-
NUdf::TUnboxedValuePod* Extract() {
487-
if (GetMode() == EOperatingMode::InMemory) return static_cast<NUdf::TUnboxedValue*>(InMemoryProcessingState.Extract());
488-
489-
MKQL_ENSURE(SpilledBuckets.front().BucketState == TSpilledBucket::EBucketState::InMemory, "Internal logic error");
490-
MKQL_ENSURE(SpilledBuckets.size() > 0, "Internal logic error");
491-
492-
auto value = static_cast<NUdf::TUnboxedValue*>(SpilledBuckets.front().InMemoryProcessingState->Extract());
493-
if (!value) {
494-
SpilledBuckets.pop_front();
495-
}
496-
497-
return value;
498-
}
499-
500494
bool FlushSpillingBuffersAndWait() {
501495
UpdateSpillingBuckets();
502496

@@ -521,7 +515,6 @@ class TSpillingSupportState : public TComputationValue<TSpillingSupportState> {
521515
return ProcessSpilledDataAndWait();
522516
}
523517

524-
private:
525518
void SplitStateIntoBuckets() {
526519
while (const auto keyAndState = static_cast<NUdf::TUnboxedValue *>(InMemoryProcessingState.Extract())) {
527520
auto hash = Hasher(keyAndState); //Hasher uses only key for hashing
@@ -1246,7 +1239,7 @@ using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideLastCombinerWra
12461239

12471240
EFetchResult DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const {
12481241
if (!state.HasValue()) {
1249-
MakeSpillingSupportState(ctx, state);
1242+
MakeState(ctx, state);
12501243
}
12511244

12521245
if (const auto ptr = static_cast<TSpillingSupportState*>(state.AsBoxed().Get())) {
@@ -1306,6 +1299,7 @@ using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideLastCombinerWra
13061299
const auto valueType = Type::getInt128Ty(context);
13071300
const auto ptrValueType = PointerType::getUnqual(valueType);
13081301
const auto statusType = Type::getInt32Ty(context);
1302+
const auto wayType = Type::getInt8Ty(context);
13091303

13101304
TLLVMFieldsStructureState stateFields(context);
13111305

@@ -1332,26 +1326,39 @@ using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideLastCombinerWra
13321326
const auto state = new LoadInst(valueType, statePtr, "state", block);
13331327
const auto half = CastInst::Create(Instruction::Trunc, state, Type::getInt64Ty(context), "half", block);
13341328
const auto stateArg = CastInst::Create(Instruction::IntToPtr, half, statePtrType, "state_arg", block);
1329+
const auto boolFuncType = FunctionType::get(Type::getInt1Ty(context), {stateArg->getType()}, false);
13351330
BranchInst::Create(more, block);
13361331

1337-
block = more;
1338-
1339-
const auto loop = BasicBlock::Create(context, "loop", ctx.Func);
13401332
const auto full = BasicBlock::Create(context, "full", ctx.Func);
13411333
const auto over = BasicBlock::Create(context, "over", ctx.Func);
1342-
const auto result = PHINode::Create(statusType, 3U, "result", over);
1343-
1344-
const auto statusPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetStatus() }, "last", block);
1345-
const auto last = new LoadInst(statusType, statusPtr, "last", block);
1346-
const auto finish = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, last, ConstantInt::get(last->getType(), static_cast<i32>(EFetchResult::Finish)), "finish", block);
1347-
1348-
BranchInst::Create(full, loop, finish, block);
1334+
const auto result = PHINode::Create(statusType, 4U, "result", over);
13491335

13501336
{
1337+
const auto test = BasicBlock::Create(context, "test", ctx.Func);
1338+
const auto pull = BasicBlock::Create(context, "pull", ctx.Func);
13511339
const auto rest = BasicBlock::Create(context, "rest", ctx.Func);
1340+
const auto proc = BasicBlock::Create(context, "proc", ctx.Func);
13521341
const auto good = BasicBlock::Create(context, "good", ctx.Func);
13531342

1354-
block = loop;
1343+
block = more;
1344+
1345+
const auto waitMoreFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TSpillingSupportState::UpdateAndWait));
1346+
const auto waitMoreFuncPtr = CastInst::Create(Instruction::IntToPtr, waitMoreFunc, PointerType::getUnqual(boolFuncType), "wait_more_func", block);
1347+
const auto waitMore = CallInst::Create(boolFuncType, waitMoreFuncPtr, { stateArg }, "wait_more", block);
1348+
1349+
result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Yield)), block);
1350+
1351+
BranchInst::Create(over, test, waitMore, block);
1352+
1353+
block = test;
1354+
1355+
const auto statusPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetStatus() }, "last", block);
1356+
const auto last = new LoadInst(statusType, statusPtr, "last", block);
1357+
const auto finish = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, last, ConstantInt::get(last->getType(), static_cast<i32>(EFetchResult::Finish)), "finish", block);
1358+
1359+
BranchInst::Create(good, pull, finish, block);
1360+
1361+
block = pull;
13551362

13561363
const auto getres = GetNodeValues(Flow, ctx, block);
13571364

@@ -1362,12 +1369,19 @@ using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideLastCombinerWra
13621369
choise->addCase(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Finish)), rest);
13631370

13641371
block = rest;
1365-
new StoreInst(ConstantInt::get(last->getType(), static_cast<i32>(EFetchResult::Finish)), statusPtr, block);
1366-
1367-
BranchInst::Create(full, block);
1372+
new StoreInst(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Finish)), statusPtr, block);
1373+
BranchInst::Create(more, block);
13681374

13691375
block = good;
13701376

1377+
const auto processingFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TSpillingSupportState::IsProcessingRequired));
1378+
const auto processingFuncPtr = CastInst::Create(Instruction::IntToPtr, processingFunc, PointerType::getUnqual(boolFuncType), "processing_func", block);
1379+
const auto processing = CallInst::Create(boolFuncType, processingFuncPtr, { stateArg }, "processing", block);
1380+
1381+
BranchInst::Create(proc, full, processing, block);
1382+
1383+
block = proc;
1384+
13711385
std::vector<Value*> items(Nodes.ItemNodes.size(), nullptr);
13721386
for (ui32 i = 0U; i < items.size(); ++i) {
13731387
if (Nodes.ItemNodes[i]->GetDependencesCount() > 0U)
@@ -1398,10 +1412,10 @@ using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideLastCombinerWra
13981412
new StoreInst(key, keyPtr, block);
13991413
}
14001414

1401-
const auto atFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::TasteIt));
1402-
const auto atType = FunctionType::get(Type::getInt1Ty(context), {stateArg->getType()}, false);
1415+
const auto atFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TSpillingSupportState::TasteIt));
1416+
const auto atType = FunctionType::get(wayType, {stateArg->getType()}, false);
14031417
const auto atPtr = CastInst::Create(Instruction::IntToPtr, atFunc, PointerType::getUnqual(atType), "function", block);
1404-
const auto newKey = CallInst::Create(atType, atPtr, {stateArg}, "new_key", block);
1418+
const auto taste= CallInst::Create(atType, atPtr, {stateArg}, "taste", block);
14051419

14061420
const auto init = BasicBlock::Create(context, "init", ctx.Func);
14071421
const auto next = BasicBlock::Create(context, "next", ctx.Func);
@@ -1415,7 +1429,9 @@ using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideLastCombinerWra
14151429
pointers.emplace_back(GetElementPtrInst::CreateInBounds(valueType, throat, {ConstantInt::get(Type::getInt32Ty(context), i)}, (TString("state_") += ToString(i)).c_str(), block));
14161430
}
14171431

1418-
BranchInst::Create(init, next, newKey, block);
1432+
const auto way = SwitchInst::Create(taste, more, 2U, block);
1433+
way->addCase(ConstantInt::get(wayType, static_cast<i8>(TSpillingSupportState::ETasteResult::Init)), init);
1434+
way->addCase(ConstantInt::get(wayType, static_cast<i8>(TSpillingSupportState::ETasteResult::Update)), next);
14191435

14201436
block = init;
14211437

@@ -1439,7 +1455,7 @@ using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideLastCombinerWra
14391455
}
14401456
}
14411457

1442-
BranchInst::Create(loop, block);
1458+
BranchInst::Create(more, block);
14431459

14441460
block = next;
14451461

@@ -1484,23 +1500,22 @@ using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideLastCombinerWra
14841500
}
14851501
}
14861502

1487-
BranchInst::Create(loop, block);
1503+
BranchInst::Create(more, block);
14881504
}
14891505

14901506
{
14911507
block = full;
14921508

14931509
const auto good = BasicBlock::Create(context, "good", ctx.Func);
1510+
const auto last = BasicBlock::Create(context, "last", ctx.Func);
14941511

1495-
const auto extractFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::Extract));
1512+
const auto extractFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TSpillingSupportState::Extract));
14961513
const auto extractType = FunctionType::get(ptrValueType, {stateArg->getType()}, false);
14971514
const auto extractPtr = CastInst::Create(Instruction::IntToPtr, extractFunc, PointerType::getUnqual(extractType), "extract", block);
14981515
const auto out = CallInst::Create(extractType, extractPtr, {stateArg}, "out", block);
14991516
const auto has = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_NE, out, ConstantPointerNull::get(ptrValueType), "has", block);
15001517

1501-
result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Finish)), block);
1502-
1503-
BranchInst::Create(good, over, has, block);
1518+
BranchInst::Create(good, last, has, block);
15041519

15051520
block = good;
15061521

@@ -1514,6 +1529,16 @@ using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideLastCombinerWra
15141529

15151530
result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::One)), block);
15161531
BranchInst::Create(over, block);
1532+
1533+
block = last;
1534+
1535+
const auto hasDataFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TSpillingSupportState::HasAnyData));
1536+
const auto hasDataFuncPtr = CastInst::Create(Instruction::IntToPtr, hasDataFunc, PointerType::getUnqual(boolFuncType), "has_data_func", block);
1537+
const auto hasData = CallInst::Create(boolFuncType, hasDataFuncPtr, { stateArg }, "has_data", block);
1538+
1539+
result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Finish)), block);
1540+
1541+
BranchInst::Create(more, over, hasData, block);
15171542
}
15181543

15191544
block = over;
@@ -1528,23 +1553,17 @@ using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideLastCombinerWra
15281553
#endif
15291554
private:
15301555
void MakeState(TComputationContext& ctx, NUdf::TUnboxedValue& state) const {
1531-
#ifdef MKQL_DISABLE_CODEGEN
1532-
state = ctx.HolderFactory.Create<TState>(Nodes.KeyNodes.size(), Nodes.StateNodes.size(), TMyValueHasher(KeyTypes), TMyValueEqual(KeyTypes));
1533-
#else
1534-
state = ctx.HolderFactory.Create<TState>(Nodes.KeyNodes.size(), Nodes.StateNodes.size(),
1535-
ctx.ExecuteLLVM && Hash ? THashFunc(std::ptr_fun(Hash)) : THashFunc(TMyValueHasher(KeyTypes)),
1536-
ctx.ExecuteLLVM && Equals ? TEqualsFunc(std::ptr_fun(Equals)) : TEqualsFunc(TMyValueEqual(KeyTypes))
1537-
);
1538-
#endif
1539-
}
1540-
1541-
void MakeSpillingSupportState(TComputationContext& ctx, NUdf::TUnboxedValue& state) const {
15421556
state = ctx.HolderFactory.Create<TSpillingSupportState>(WideFieldsIndex,
15431557
UsedInputItemType, KeyAndStateType,
15441558
Nodes.KeyNodes.size(),
15451559
Nodes.ItemNodes.size(),
1560+
#ifdef MKQL_DISABLE_CODEGEN
15461561
TMyValueHasher(KeyTypes),
15471562
TMyValueEqual(KeyTypes),
1563+
#else
1564+
ctx.ExecuteLLVM && Hash ? THashFunc(std::ptr_fun(Hash)) : THashFunc(TMyValueHasher(KeyTypes)),
1565+
ctx.ExecuteLLVM && Equals ? TEqualsFunc(std::ptr_fun(Equals)) : TEqualsFunc(TMyValueEqual(KeyTypes)),
1566+
#endif
15481567
AllowSpilling,
15491568
ctx
15501569
);
@@ -1569,7 +1588,6 @@ using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideLastCombinerWra
15691588
const ui32 WideFieldsIndex;
15701589

15711590
const bool AllowSpilling;
1572-
15731591
#ifndef MKQL_DISABLE_CODEGEN
15741592
TEqualsPtr Equals = nullptr;
15751593
THashPtr Hash = nullptr;
@@ -1626,7 +1644,7 @@ IComputationNode* WrapWideCombinerT(TCallable& callable, const TComputationNodeF
16261644
keyTypes.reserve(keysSize);
16271645
for (ui32 i = index; i < index + keysSize; ++i) {
16281646
TType *type = callable.GetInput(i).GetStaticType();
1629-
keyAndStateItemTypes.push_back(type);
1647+
keyAndStateItemTypes.push_back(type);
16301648
bool optional;
16311649
keyTypes.emplace_back(*UnpackOptionalData(callable.GetInput(i).GetStaticType(), optional)->GetDataSlot(), optional);
16321650
}

0 commit comments

Comments
 (0)