@@ -343,8 +343,8 @@ class TSpillingSupportState : public TComputationValue<TSpillingSupportState> {
343
343
};
344
344
345
345
public:
346
- enum class ETasteResult : ui8 {
347
- Init,
346
+ enum class ETasteResult : i8 {
347
+ Init = - 1 ,
348
348
Update,
349
349
Skip
350
350
};
@@ -372,15 +372,9 @@ class TSpillingSupportState : public TComputationValue<TSpillingSupportState> {
372
372
Tongue = InMemoryProcessingState.Tongue ;
373
373
Throat = InMemoryProcessingState.Throat ;
374
374
}
375
- ~TSpillingSupportState () {
376
- }
377
-
378
- bool IsFetchRequired () const {
379
- return InputStatus != EFetchResult::Finish;
380
- }
381
375
382
376
bool HasAnyData () const {
383
- return SpilledBuckets.size ();
377
+ return ! SpilledBuckets.empty ();
384
378
}
385
379
386
380
bool IsProcessingRequired () const {
@@ -456,6 +450,20 @@ class TSpillingSupportState : public TComputationValue<TSpillingSupportState> {
456
450
return ETasteResult::Skip;
457
451
}
458
452
453
+ NUdf::TUnboxedValuePod* Extract () {
454
+ if (GetMode () == EOperatingMode::InMemory) return static_cast <NUdf::TUnboxedValue*>(InMemoryProcessingState.Extract ());
455
+
456
+ MKQL_ENSURE (SpilledBuckets.front ().BucketState == TSpilledBucket::EBucketState::InMemory, " Internal logic error" );
457
+ MKQL_ENSURE (SpilledBuckets.size () > 0 , " Internal logic error" );
458
+
459
+ auto value = static_cast <NUdf::TUnboxedValue*>(SpilledBuckets.front ().InMemoryProcessingState ->Extract ());
460
+ if (!value) {
461
+ SpilledBuckets.pop_front ();
462
+ }
463
+
464
+ return value;
465
+ }
466
+ private:
459
467
void MoveKeyToBucket (TSpilledBucket& bucket) {
460
468
for (size_t i = 0 ; i < KeyWidth; ++i) {
461
469
// jumping into unsafe world, refusing ownership
@@ -483,20 +491,6 @@ class TSpillingSupportState : public TComputationValue<TSpillingSupportState> {
483
491
BufferForUsedInputItems.resize (0 );
484
492
}
485
493
486
- NUdf::TUnboxedValuePod* Extract () {
487
- if (GetMode () == EOperatingMode::InMemory) return static_cast <NUdf::TUnboxedValue*>(InMemoryProcessingState.Extract ());
488
-
489
- MKQL_ENSURE (SpilledBuckets.front ().BucketState == TSpilledBucket::EBucketState::InMemory, " Internal logic error" );
490
- MKQL_ENSURE (SpilledBuckets.size () > 0 , " Internal logic error" );
491
-
492
- auto value = static_cast <NUdf::TUnboxedValue*>(SpilledBuckets.front ().InMemoryProcessingState ->Extract ());
493
- if (!value) {
494
- SpilledBuckets.pop_front ();
495
- }
496
-
497
- return value;
498
- }
499
-
500
494
bool FlushSpillingBuffersAndWait () {
501
495
UpdateSpillingBuckets ();
502
496
@@ -521,7 +515,6 @@ class TSpillingSupportState : public TComputationValue<TSpillingSupportState> {
521
515
return ProcessSpilledDataAndWait ();
522
516
}
523
517
524
- private:
525
518
void SplitStateIntoBuckets () {
526
519
while (const auto keyAndState = static_cast <NUdf::TUnboxedValue *>(InMemoryProcessingState.Extract ())) {
527
520
auto hash = Hasher (keyAndState); // Hasher uses only key for hashing
@@ -1246,7 +1239,7 @@ using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideLastCombinerWra
1246
1239
1247
1240
EFetchResult DoCalculate (NUdf::TUnboxedValue& state, TComputationContext& ctx, NUdf::TUnboxedValue*const * output) const {
1248
1241
if (!state.HasValue ()) {
1249
- MakeSpillingSupportState (ctx, state);
1242
+ MakeState (ctx, state);
1250
1243
}
1251
1244
1252
1245
if (const auto ptr = static_cast <TSpillingSupportState*>(state.AsBoxed ().Get ())) {
@@ -1306,6 +1299,7 @@ using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideLastCombinerWra
1306
1299
const auto valueType = Type::getInt128Ty (context);
1307
1300
const auto ptrValueType = PointerType::getUnqual (valueType);
1308
1301
const auto statusType = Type::getInt32Ty (context);
1302
+ const auto wayType = Type::getInt8Ty (context);
1309
1303
1310
1304
TLLVMFieldsStructureState stateFields (context);
1311
1305
@@ -1332,26 +1326,39 @@ using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideLastCombinerWra
1332
1326
const auto state = new LoadInst (valueType, statePtr, " state" , block);
1333
1327
const auto half = CastInst::Create (Instruction::Trunc, state, Type::getInt64Ty (context), " half" , block);
1334
1328
const auto stateArg = CastInst::Create (Instruction::IntToPtr, half, statePtrType, " state_arg" , block);
1329
+ const auto boolFuncType = FunctionType::get (Type::getInt1Ty (context), {stateArg->getType ()}, false );
1335
1330
BranchInst::Create (more, block);
1336
1331
1337
- block = more;
1338
-
1339
- const auto loop = BasicBlock::Create (context, " loop" , ctx.Func );
1340
1332
const auto full = BasicBlock::Create (context, " full" , ctx.Func );
1341
1333
const auto over = BasicBlock::Create (context, " over" , ctx.Func );
1342
- const auto result = PHINode::Create (statusType, 3U , " result" , over);
1343
-
1344
- const auto statusPtr = GetElementPtrInst::CreateInBounds (stateType, stateArg, { stateFields.This (), stateFields.GetStatus () }, " last" , block);
1345
- const auto last = new LoadInst (statusType, statusPtr, " last" , block);
1346
- const auto finish = CmpInst::Create (Instruction::ICmp, ICmpInst::ICMP_EQ, last, ConstantInt::get (last->getType (), static_cast <i32>(EFetchResult::Finish)), " finish" , block);
1347
-
1348
- BranchInst::Create (full, loop, finish, block);
1334
+ const auto result = PHINode::Create (statusType, 4U , " result" , over);
1349
1335
1350
1336
{
1337
+ const auto test = BasicBlock::Create (context, " test" , ctx.Func );
1338
+ const auto pull = BasicBlock::Create (context, " pull" , ctx.Func );
1351
1339
const auto rest = BasicBlock::Create (context, " rest" , ctx.Func );
1340
+ const auto proc = BasicBlock::Create (context, " proc" , ctx.Func );
1352
1341
const auto good = BasicBlock::Create (context, " good" , ctx.Func );
1353
1342
1354
- block = loop;
1343
+ block = more;
1344
+
1345
+ const auto waitMoreFunc = ConstantInt::get (Type::getInt64Ty (context), GetMethodPtr (&TSpillingSupportState::UpdateAndWait));
1346
+ const auto waitMoreFuncPtr = CastInst::Create (Instruction::IntToPtr, waitMoreFunc, PointerType::getUnqual (boolFuncType), " wait_more_func" , block);
1347
+ const auto waitMore = CallInst::Create (boolFuncType, waitMoreFuncPtr, { stateArg }, " wait_more" , block);
1348
+
1349
+ result->addIncoming (ConstantInt::get (statusType, static_cast <i32>(EFetchResult::Yield)), block);
1350
+
1351
+ BranchInst::Create (over, test, waitMore, block);
1352
+
1353
+ block = test;
1354
+
1355
+ const auto statusPtr = GetElementPtrInst::CreateInBounds (stateType, stateArg, { stateFields.This (), stateFields.GetStatus () }, " last" , block);
1356
+ const auto last = new LoadInst (statusType, statusPtr, " last" , block);
1357
+ const auto finish = CmpInst::Create (Instruction::ICmp, ICmpInst::ICMP_EQ, last, ConstantInt::get (last->getType (), static_cast <i32>(EFetchResult::Finish)), " finish" , block);
1358
+
1359
+ BranchInst::Create (good, pull, finish, block);
1360
+
1361
+ block = pull;
1355
1362
1356
1363
const auto getres = GetNodeValues (Flow, ctx, block);
1357
1364
@@ -1362,12 +1369,19 @@ using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideLastCombinerWra
1362
1369
choise->addCase (ConstantInt::get (statusType, static_cast <i32>(EFetchResult::Finish)), rest);
1363
1370
1364
1371
block = rest;
1365
- new StoreInst (ConstantInt::get (last->getType (), static_cast <i32>(EFetchResult::Finish)), statusPtr, block);
1366
-
1367
- BranchInst::Create (full, block);
1372
+ new StoreInst (ConstantInt::get (statusType, static_cast <i32>(EFetchResult::Finish)), statusPtr, block);
1373
+ BranchInst::Create (more, block);
1368
1374
1369
1375
block = good;
1370
1376
1377
+ const auto processingFunc = ConstantInt::get (Type::getInt64Ty (context), GetMethodPtr (&TSpillingSupportState::IsProcessingRequired));
1378
+ const auto processingFuncPtr = CastInst::Create (Instruction::IntToPtr, processingFunc, PointerType::getUnqual (boolFuncType), " processing_func" , block);
1379
+ const auto processing = CallInst::Create (boolFuncType, processingFuncPtr, { stateArg }, " processing" , block);
1380
+
1381
+ BranchInst::Create (proc, full, processing, block);
1382
+
1383
+ block = proc;
1384
+
1371
1385
std::vector<Value*> items (Nodes.ItemNodes .size (), nullptr );
1372
1386
for (ui32 i = 0U ; i < items.size (); ++i) {
1373
1387
if (Nodes.ItemNodes [i]->GetDependencesCount () > 0U )
@@ -1398,10 +1412,10 @@ using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideLastCombinerWra
1398
1412
new StoreInst (key, keyPtr, block);
1399
1413
}
1400
1414
1401
- const auto atFunc = ConstantInt::get (Type::getInt64Ty (context), GetMethodPtr (&TState ::TasteIt));
1402
- const auto atType = FunctionType::get (Type::getInt1Ty (context) , {stateArg->getType ()}, false );
1415
+ const auto atFunc = ConstantInt::get (Type::getInt64Ty (context), GetMethodPtr (&TSpillingSupportState ::TasteIt));
1416
+ const auto atType = FunctionType::get (wayType , {stateArg->getType ()}, false );
1403
1417
const auto atPtr = CastInst::Create (Instruction::IntToPtr, atFunc, PointerType::getUnqual (atType), " function" , block);
1404
- const auto newKey = CallInst::Create (atType, atPtr, {stateArg}, " new_key " , block);
1418
+ const auto taste = CallInst::Create (atType, atPtr, {stateArg}, " taste " , block);
1405
1419
1406
1420
const auto init = BasicBlock::Create (context, " init" , ctx.Func );
1407
1421
const auto next = BasicBlock::Create (context, " next" , ctx.Func );
@@ -1415,7 +1429,9 @@ using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideLastCombinerWra
1415
1429
pointers.emplace_back (GetElementPtrInst::CreateInBounds (valueType, throat, {ConstantInt::get (Type::getInt32Ty (context), i)}, (TString (" state_" ) += ToString (i)).c_str (), block));
1416
1430
}
1417
1431
1418
- BranchInst::Create (init, next, newKey, block);
1432
+ const auto way = SwitchInst::Create (taste, more, 2U , block);
1433
+ way->addCase (ConstantInt::get (wayType, static_cast <i8>(TSpillingSupportState::ETasteResult::Init)), init);
1434
+ way->addCase (ConstantInt::get (wayType, static_cast <i8>(TSpillingSupportState::ETasteResult::Update)), next);
1419
1435
1420
1436
block = init;
1421
1437
@@ -1439,7 +1455,7 @@ using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideLastCombinerWra
1439
1455
}
1440
1456
}
1441
1457
1442
- BranchInst::Create (loop , block);
1458
+ BranchInst::Create (more , block);
1443
1459
1444
1460
block = next;
1445
1461
@@ -1484,23 +1500,22 @@ using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideLastCombinerWra
1484
1500
}
1485
1501
}
1486
1502
1487
- BranchInst::Create (loop , block);
1503
+ BranchInst::Create (more , block);
1488
1504
}
1489
1505
1490
1506
{
1491
1507
block = full;
1492
1508
1493
1509
const auto good = BasicBlock::Create (context, " good" , ctx.Func );
1510
+ const auto last = BasicBlock::Create (context, " last" , ctx.Func );
1494
1511
1495
- const auto extractFunc = ConstantInt::get (Type::getInt64Ty (context), GetMethodPtr (&TState ::Extract));
1512
+ const auto extractFunc = ConstantInt::get (Type::getInt64Ty (context), GetMethodPtr (&TSpillingSupportState ::Extract));
1496
1513
const auto extractType = FunctionType::get (ptrValueType, {stateArg->getType ()}, false );
1497
1514
const auto extractPtr = CastInst::Create (Instruction::IntToPtr, extractFunc, PointerType::getUnqual (extractType), " extract" , block);
1498
1515
const auto out = CallInst::Create (extractType, extractPtr, {stateArg}, " out" , block);
1499
1516
const auto has = CmpInst::Create (Instruction::ICmp, ICmpInst::ICMP_NE, out, ConstantPointerNull::get (ptrValueType), " has" , block);
1500
1517
1501
- result->addIncoming (ConstantInt::get (statusType, static_cast <i32>(EFetchResult::Finish)), block);
1502
-
1503
- BranchInst::Create (good, over, has, block);
1518
+ BranchInst::Create (good, last, has, block);
1504
1519
1505
1520
block = good;
1506
1521
@@ -1514,6 +1529,16 @@ using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideLastCombinerWra
1514
1529
1515
1530
result->addIncoming (ConstantInt::get (statusType, static_cast <i32>(EFetchResult::One)), block);
1516
1531
BranchInst::Create (over, block);
1532
+
1533
+ block = last;
1534
+
1535
+ const auto hasDataFunc = ConstantInt::get (Type::getInt64Ty (context), GetMethodPtr (&TSpillingSupportState::HasAnyData));
1536
+ const auto hasDataFuncPtr = CastInst::Create (Instruction::IntToPtr, hasDataFunc, PointerType::getUnqual (boolFuncType), " has_data_func" , block);
1537
+ const auto hasData = CallInst::Create (boolFuncType, hasDataFuncPtr, { stateArg }, " has_data" , block);
1538
+
1539
+ result->addIncoming (ConstantInt::get (statusType, static_cast <i32>(EFetchResult::Finish)), block);
1540
+
1541
+ BranchInst::Create (more, over, hasData, block);
1517
1542
}
1518
1543
1519
1544
block = over;
@@ -1528,23 +1553,17 @@ using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideLastCombinerWra
1528
1553
#endif
1529
1554
private:
1530
1555
void MakeState (TComputationContext& ctx, NUdf::TUnboxedValue& state) const {
1531
- #ifdef MKQL_DISABLE_CODEGEN
1532
- state = ctx.HolderFactory .Create <TState>(Nodes.KeyNodes .size (), Nodes.StateNodes .size (), TMyValueHasher (KeyTypes), TMyValueEqual (KeyTypes));
1533
- #else
1534
- state = ctx.HolderFactory .Create <TState>(Nodes.KeyNodes .size (), Nodes.StateNodes .size (),
1535
- ctx.ExecuteLLVM && Hash ? THashFunc (std::ptr_fun (Hash)) : THashFunc (TMyValueHasher (KeyTypes)),
1536
- ctx.ExecuteLLVM && Equals ? TEqualsFunc (std::ptr_fun (Equals)) : TEqualsFunc (TMyValueEqual (KeyTypes))
1537
- );
1538
- #endif
1539
- }
1540
-
1541
- void MakeSpillingSupportState (TComputationContext& ctx, NUdf::TUnboxedValue& state) const {
1542
1556
state = ctx.HolderFactory .Create <TSpillingSupportState>(WideFieldsIndex,
1543
1557
UsedInputItemType, KeyAndStateType,
1544
1558
Nodes.KeyNodes .size (),
1545
1559
Nodes.ItemNodes .size (),
1560
+ #ifdef MKQL_DISABLE_CODEGEN
1546
1561
TMyValueHasher (KeyTypes),
1547
1562
TMyValueEqual (KeyTypes),
1563
+ #else
1564
+ ctx.ExecuteLLVM && Hash ? THashFunc (std::ptr_fun (Hash)) : THashFunc (TMyValueHasher (KeyTypes)),
1565
+ ctx.ExecuteLLVM && Equals ? TEqualsFunc (std::ptr_fun (Equals)) : TEqualsFunc (TMyValueEqual (KeyTypes)),
1566
+ #endif
1548
1567
AllowSpilling,
1549
1568
ctx
1550
1569
);
@@ -1569,7 +1588,6 @@ using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideLastCombinerWra
1569
1588
const ui32 WideFieldsIndex;
1570
1589
1571
1590
const bool AllowSpilling;
1572
-
1573
1591
#ifndef MKQL_DISABLE_CODEGEN
1574
1592
TEqualsPtr Equals = nullptr ;
1575
1593
THashPtr Hash = nullptr ;
@@ -1626,7 +1644,7 @@ IComputationNode* WrapWideCombinerT(TCallable& callable, const TComputationNodeF
1626
1644
keyTypes.reserve (keysSize);
1627
1645
for (ui32 i = index ; i < index + keysSize; ++i) {
1628
1646
TType *type = callable.GetInput (i).GetStaticType ();
1629
- keyAndStateItemTypes.push_back (type);
1647
+ keyAndStateItemTypes.push_back (type);
1630
1648
bool optional;
1631
1649
keyTypes.emplace_back (*UnpackOptionalData (callable.GetInput (i).GetStaticType (), optional)->GetDataSlot (), optional);
1632
1650
}
0 commit comments