@@ -451,8 +451,8 @@ void BuildStreamLookupChannels(TKqpTasksGraph& graph, const TStageInfo& stageInf
451
451
}
452
452
}
453
453
454
- void BuildKqpStageChannels (TKqpTasksGraph& tasksGraph, const TStageInfo& stageInfo,
455
- ui64 txId, bool enableSpilling)
454
+ void BuildKqpStageChannels (TKqpTasksGraph& tasksGraph, TStageInfo& stageInfo,
455
+ ui64 txId, bool enableSpilling, bool enableShuffleElimination )
456
456
{
457
457
auto & stage = stageInfo.Meta .GetStage (stageInfo.Id );
458
458
@@ -473,19 +473,91 @@ void BuildKqpStageChannels(TKqpTasksGraph& tasksGraph, const TStageInfo& stageIn
473
473
<< (spilling ? " with spilling" : " without spilling" ));
474
474
};
475
475
476
- for (const auto & input : stage.GetInputs ()) {
476
+
477
+ bool hasMap = false ;
478
+ bool isFusedStage = (stageInfo.Meta .TaskIdByHash != nullptr );
479
+ if (enableShuffleElimination && !isFusedStage) { // taskIdHash can be already set if it is a fused stage, so hashpartition will derive columnv1 parameters from there
480
+ for (ui32 inputIndex = 0 ; inputIndex < stage.InputsSize (); ++inputIndex) {
481
+ const auto & input = stage.GetInputs (inputIndex);
482
+ auto & originStageInfo = tasksGraph.GetStageInfo (NYql::NDq::TStageId (stageInfo.Id .TxId , input.GetStageIndex ()));
483
+ stageInfo.Meta .TaskIdByHash = originStageInfo.Meta .TaskIdByHash ;
484
+ stageInfo.Meta .SourceShardCount = originStageInfo.Meta .SourceShardCount ;
485
+ stageInfo.Meta .SourceTableKeyColumnTypes = originStageInfo.Meta .SourceTableKeyColumnTypes ;
486
+ if (input.GetTypeCase () == NKqpProto::TKqpPhyConnection::kMap ) {
487
+ // We want to enforce sourceShardCount from map connection, cause it can be at most one map connection
488
+ // and ColumnShardHash in Shuffle will use this parameter to shuffle on this map (same with taskIdByHash mapping)
489
+ hasMap = true ;
490
+ break ;
491
+ }
492
+ }
493
+ }
494
+
495
+ // if it is stage, where we don't inherit parallelism.
496
+ if (enableShuffleElimination && !hasMap && !isFusedStage && stageInfo.Tasks .size () > 0 && stage.InputsSize () > 0 ) {
497
+ stageInfo.Meta .SourceShardCount = stageInfo.Tasks .size ();
498
+ stageInfo.Meta .TaskIdByHash = std::make_shared<TVector<ui64>>(stageInfo.Meta .SourceShardCount );
499
+ for (std::size_t i = 0 ; i < stageInfo.Meta .SourceShardCount ; ++i) {
500
+ (*stageInfo.Meta .TaskIdByHash )[i] = i;
501
+ }
502
+
503
+ for (auto & input : stage.GetInputs ()) {
504
+ if (input.GetTypeCase () != NKqpProto::TKqpPhyConnection::kHashShuffle ) {
505
+ continue ;
506
+ }
507
+
508
+ const auto & hashShuffle = input.GetHashShuffle ();
509
+ if (hashShuffle.GetHashKindCase () != NKqpProto::TKqpPhyCnHashShuffle::kColumnShardHashV1 ) {
510
+ continue ;
511
+ }
512
+
513
+ Y_ENSURE (enableShuffleElimination, " OptShuffleElimination wasn't turned on, but ColumnShardHashV1 detected!" );
514
+ // ^ if the flag if false, and kColumnShardHashV1 detected - then the data which would be returned - would be incorrect,
515
+ // because we didn't save partitioning in the BuildScanTasksFromShards.
516
+
517
+ auto columnShardHashV1 = hashShuffle.GetColumnShardHashV1 ();
518
+ stageInfo.Meta .SourceTableKeyColumnTypes = std::make_shared<TVector<NScheme::TTypeInfo>>();
519
+ stageInfo.Meta .SourceTableKeyColumnTypes ->reserve (columnShardHashV1.KeyColumnTypesSize ());
520
+ for (const auto & keyColumnType: columnShardHashV1.GetKeyColumnTypes ()) {
521
+ auto typeId = static_cast <NScheme::TTypeId>(keyColumnType);
522
+ auto typeInfo = NScheme::TTypeInfo{typeId};
523
+ stageInfo.Meta .SourceTableKeyColumnTypes ->push_back (typeInfo);
524
+ }
525
+ break ;
526
+ }
527
+ }
528
+
529
+ for (auto & input : stage.GetInputs ()) {
477
530
ui32 inputIdx = input.GetInputIndex ();
478
- const auto & inputStageInfo = tasksGraph.GetStageInfo (TStageId (stageInfo.Id .TxId , input.GetStageIndex ()));
531
+ auto & inputStageInfo = tasksGraph.GetStageInfo (TStageId (stageInfo.Id .TxId , input.GetStageIndex ()));
479
532
const auto & outputIdx = input.GetOutputIndex ();
480
533
481
534
switch (input.GetTypeCase ()) {
482
535
case NKqpProto::TKqpPhyConnection::kUnionAll :
483
536
BuildUnionAllChannels (tasksGraph, stageInfo, inputIdx, inputStageInfo, outputIdx, enableSpilling, log );
484
537
break ;
485
- case NKqpProto::TKqpPhyConnection::kHashShuffle :
538
+ case NKqpProto::TKqpPhyConnection::kHashShuffle : {
539
+ ui32 hashKind = NHashKind::EUndefined;
540
+ switch (input.GetHashShuffle ().GetHashKindCase ()) {
541
+ case NKqpProto::TKqpPhyCnHashShuffle::kHashV1 : {
542
+ hashKind = NHashKind::EHashV1;
543
+ break ;
544
+ }
545
+ case NKqpProto::TKqpPhyCnHashShuffle::kColumnShardHashV1 : {
546
+ Y_ENSURE (enableShuffleElimination, " OptShuffleElimination wasn't turned on, but ColumnShardHashV1 detected!" );
547
+ inputStageInfo.Meta .TaskIdByHash = stageInfo.Meta .TaskIdByHash ;
548
+ inputStageInfo.Meta .SourceShardCount = stageInfo.Meta .SourceShardCount ;
549
+ inputStageInfo.Meta .SourceTableKeyColumnTypes = stageInfo.Meta .SourceTableKeyColumnTypes ;
550
+ hashKind = NHashKind::EColumnShardHashV1;
551
+ break ;
552
+ }
553
+ default : {
554
+ Y_ENSURE (false , " undefined type of hash for shuffle" );
555
+ }
556
+ }
486
557
BuildHashShuffleChannels (tasksGraph, stageInfo, inputIdx, inputStageInfo, outputIdx,
487
- input.GetHashShuffle ().GetKeyColumns (), enableSpilling, log );
558
+ input.GetHashShuffle ().GetKeyColumns (), enableSpilling, log , hashKind );
488
559
break ;
560
+ }
489
561
case NKqpProto::TKqpPhyConnection::kBroadcast :
490
562
BuildBroadcastChannels (tasksGraph, stageInfo, inputIdx, inputStageInfo, outputIdx, enableSpilling, log );
491
563
break ;
@@ -1045,7 +1117,13 @@ void FillTaskMeta(const TStageInfo& stageInfo, const TTask& task, NYql::NDqProto
1045
1117
}
1046
1118
}
1047
1119
1048
- void FillOutputDesc (const TKqpTasksGraph& tasksGraph, NYql::NDqProto::TTaskOutput& outputDesc, const TTaskOutput& output, bool enableSpilling) {
1120
+ void FillOutputDesc (
1121
+ const TKqpTasksGraph& tasksGraph,
1122
+ NYql::NDqProto::TTaskOutput& outputDesc,
1123
+ const TTaskOutput& output,
1124
+ bool enableSpilling,
1125
+ const TStageInfo& stageInfo
1126
+ ) {
1049
1127
switch (output.Type ) {
1050
1128
case TTaskOutputType::Map:
1051
1129
YQL_ENSURE (output.Channels .size () == 1 );
@@ -1058,6 +1136,31 @@ void FillOutputDesc(const TKqpTasksGraph& tasksGraph, NYql::NDqProto::TTaskOutpu
1058
1136
hashPartitionDesc.AddKeyColumns (column);
1059
1137
}
1060
1138
hashPartitionDesc.SetPartitionsCount (output.PartitionsCount );
1139
+
1140
+ switch (output.HashKind ) {
1141
+ case NHashKind::EHashV1: {
1142
+ hashPartitionDesc.MutableHashV1 ();
1143
+ break ;
1144
+ }
1145
+ case NHashKind::EColumnShardHashV1: {
1146
+ Y_ENSURE (stageInfo.Meta .SourceShardCount != 0 , " ShardCount for ColumnShardHashV1 Shuffle can't be equal to 0" );
1147
+ Y_ENSURE (stageInfo.Meta .TaskIdByHash != nullptr , " TaskIdByHash for ColumnShardHashV1 wasn't propogated to this stage" );
1148
+ Y_ENSURE (stageInfo.Meta .SourceTableKeyColumnTypes != nullptr , " SourceTableKeyColumnTypes for ColumnShardHashV1 wasn't propogated to this stage" );
1149
+ auto & columnShardHashV1 = *hashPartitionDesc.MutableColumnShardHashV1 ();
1150
+ columnShardHashV1.SetShardCount (stageInfo.Meta .SourceShardCount );
1151
+
1152
+ auto * columnTypes = columnShardHashV1.MutableKeyColumnTypes ();
1153
+ for (const auto & type: *stageInfo.Meta .SourceTableKeyColumnTypes ) {
1154
+ columnTypes->Add (type.GetTypeId ());
1155
+ }
1156
+
1157
+ auto * taskIdByHash = columnShardHashV1.MutableTaskIdByHash ();
1158
+ for (std::size_t taskID: *stageInfo.Meta .TaskIdByHash ) {
1159
+ taskIdByHash->Add (taskID);
1160
+ }
1161
+ break ;
1162
+ }
1163
+ }
1061
1164
break ;
1062
1165
}
1063
1166
@@ -1228,7 +1331,7 @@ void SerializeTaskToProto(
1228
1331
enableSpilling = tasksGraph.GetMeta ().AllowWithSpilling ;
1229
1332
}
1230
1333
for (const auto & output : task.Outputs ) {
1231
- FillOutputDesc (tasksGraph, *result->AddOutputs (), output, enableSpilling);
1334
+ FillOutputDesc (tasksGraph, *result->AddOutputs (), output, enableSpilling, stageInfo );
1232
1335
}
1233
1336
1234
1337
const NKqpProto::TKqpPhyStage& stage = stageInfo.Meta .GetStage (stageInfo.Id );
0 commit comments