Skip to content

Commit 7102e38

Browse files
authored
[CBO] Shuffle elimination (#14901)
1 parent 63aaaeb commit 7102e38

File tree

74 files changed

+2753
-480
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

74 files changed

+2753
-480
lines changed

ydb/core/kqp/executer_actor/kqp_data_executer.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2027,7 +2027,7 @@ class TKqpDataExecuter : public TKqpExecuterBase<TKqpDataExecuter, EExecType::Da
20272027
YQL_ENSURE(false, "unknown source type");
20282028
}
20292029
} else if ((AllowOlapDataQuery || StreamResult) && stageInfo.Meta.IsOlap() && stage.SinksSize() == 0) {
2030-
BuildScanTasksFromShards(stageInfo);
2030+
BuildScanTasksFromShards(stageInfo, tx.Body->EnableShuffleElimination());
20312031
} else if (stageInfo.Meta.IsSysView()) {
20322032
BuildSysViewScanTasks(stageInfo);
20332033
} else if (stageInfo.Meta.ShardOperations.empty() || stage.SinksSize() > 0) {
@@ -2041,7 +2041,7 @@ class TKqpDataExecuter : public TKqpExecuterBase<TKqpDataExecuter, EExecType::Da
20412041
}
20422042

20432043
TasksGraph.GetMeta().AllowWithSpilling |= stage.GetAllowWithSpilling();
2044-
BuildKqpStageChannels(TasksGraph, stageInfo, TxId, /* enableSpilling */ TasksGraph.GetMeta().AllowWithSpilling);
2044+
BuildKqpStageChannels(TasksGraph, stageInfo, TxId, /* enableSpilling */ TasksGraph.GetMeta().AllowWithSpilling, tx.Body->EnableShuffleElimination());
20452045
}
20462046

20472047
ResponseEv->InitTxResult(tx.Body);

ydb/core/kqp/executer_actor/kqp_executer_impl.h

Lines changed: 86 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,8 @@ class TKqpExecuterBase : public TActor<TDerived> {
136136
TKqpRequestCounters::TPtr counters,
137137
const NKikimrConfig::TTableServiceConfig& tableServiceConfig,
138138
const TIntrusivePtr<TUserRequestContext>& userRequestContext,
139-
ui32 statementResultIndex, ui64 spanVerbosity = 0, TString spanName = "KqpExecuterBase",
139+
ui32 statementResultIndex,
140+
ui64 spanVerbosity = 0, TString spanName = "KqpExecuterBase",
140141
bool streamResult = false, const TActorId bufferActorId = {}, const IKqpTransactionManagerPtr& txManager = nullptr)
141142
: NActors::TActor<TDerived>(&TDerived::ReadyState)
142143
, Request(std::move(request))
@@ -1421,6 +1422,10 @@ class TKqpExecuterBase : public TActor<TDerived> {
14211422
ui32 partitionsCount = 1;
14221423
ui32 inputTasks = 0;
14231424
bool isShuffle = false;
1425+
bool forceMapTasks = false;
1426+
ui32 mapCnt = 0;
1427+
1428+
14241429
for (ui32 inputIndex = 0; inputIndex < stage.InputsSize(); ++inputIndex) {
14251430
const auto& input = stage.GetInputs(inputIndex);
14261431

@@ -1434,6 +1439,7 @@ class TKqpExecuterBase : public TActor<TDerived> {
14341439
case NKqpProto::TKqpPhyConnection::kUnionAll:
14351440
case NKqpProto::TKqpPhyConnection::kMerge:
14361441
case NKqpProto::TKqpPhyConnection::kStreamLookup:
1442+
case NKqpProto::TKqpPhyConnection::kMap:
14371443
break;
14381444
default:
14391445
YQL_ENSURE(false, "Unexpected connection type: " << (ui32)input.GetTypeCase() << Endl
@@ -1451,18 +1457,23 @@ class TKqpExecuterBase : public TActor<TDerived> {
14511457
}
14521458

14531459
case NKqpProto::TKqpPhyConnection::kStreamLookup:
1460+
partitionsCount = originStageInfo.Tasks.size();
14541461
UnknownAffectedShardCount = true;
1455-
[[fallthrough]];
1456-
case NKqpProto::TKqpPhyConnection::kMap:
1462+
break;
1463+
case NKqpProto::TKqpPhyConnection::kMap:
14571464
partitionsCount = originStageInfo.Tasks.size();
1465+
forceMapTasks = true;
1466+
++mapCnt;
14581467
break;
1459-
14601468
default:
14611469
break;
14621470
}
1471+
14631472
}
14641473

1465-
if (isShuffle) {
1474+
Y_ENSURE(mapCnt < 2, "There can be only < 2 'Map' connections");
1475+
1476+
if (isShuffle && !forceMapTasks) {
14661477
if (stage.GetTaskCount()) {
14671478
partitionsCount = stage.GetTaskCount();
14681479
} else {
@@ -1671,17 +1682,26 @@ class TKqpExecuterBase : public TActor<TDerived> {
16711682
return true;
16721683
}
16731684

1674-
void BuildScanTasksFromShards(TStageInfo& stageInfo) {
1685+
void BuildScanTasksFromShards(TStageInfo& stageInfo, bool enableShuffleElimination) {
16751686
THashMap<ui64, std::vector<ui64>> nodeTasks;
16761687
THashMap<ui64, std::vector<TShardInfoWithId>> nodeShards;
16771688
THashMap<ui64, ui64> assignedShardsCount;
16781689
auto& stage = stageInfo.Meta.GetStage(stageInfo.Id);
16791690

1691+
if (enableShuffleElimination && stageInfo.Meta.ColumnTableInfoPtr) {
1692+
const auto& tableDesc = stageInfo.Meta.ColumnTableInfoPtr->Description;
1693+
stageInfo.Meta.SourceShardCount = tableDesc.GetColumnShardCount();
1694+
stageInfo.Meta.SourceTableKeyColumnTypes = std::make_shared<TVector<NScheme::TTypeInfo>>();
1695+
for (const auto& column: tableDesc.GetSharding().GetHashSharding().GetColumns()) {
1696+
auto columnType = stageInfo.Meta.TableConstInfo->Columns.at(column).Type;
1697+
stageInfo.Meta.SourceTableKeyColumnTypes->push_back(columnType);
1698+
}
1699+
}
1700+
16801701
YQL_ENSURE(Stats);
16811702

16821703
const auto& tableInfo = stageInfo.Meta.TableConstInfo;
16831704
const auto& keyTypes = tableInfo->KeyColumnTypes;
1684-
ui32 metaId = 0;
16851705
for (auto& op : stage.GetTableOps()) {
16861706
Y_DEBUG_ABORT_UNLESS(stageInfo.Meta.TablePath == op.GetTable().GetPath());
16871707

@@ -1732,7 +1752,66 @@ class TKqpExecuterBase : public TActor<TDerived> {
17321752
}
17331753
}
17341754

1755+
} else if (enableShuffleElimination /* save partitioning for shuffle elimination */) {
1756+
std::size_t stageInternalTaskId = 0;
1757+
stageInfo.Meta.TaskIdByHash = std::make_shared<TVector<ui64>>();
1758+
stageInfo.Meta.TaskIdByHash->resize(stageInfo.Meta.SourceShardCount);
1759+
1760+
for (auto&& pair : nodeShards) {
1761+
const auto nodeId = pair.first;
1762+
auto& shardsInfo = pair.second;
1763+
std::size_t maxTasksPerNode = std::min<std::size_t>(shardsInfo.size(), GetScanTasksPerNode(stageInfo, isOlapScan, nodeId));
1764+
std::vector<TTaskMeta> metas(maxTasksPerNode, TTaskMeta());
1765+
{
1766+
for (std::size_t i = 0; i < shardsInfo.size(); ++i) {
1767+
auto&& shardInfo = shardsInfo[i];
1768+
MergeReadInfoToTaskMeta(
1769+
metas[i % maxTasksPerNode],
1770+
shardInfo.ShardId,
1771+
shardInfo.KeyReadRanges,
1772+
readSettings,
1773+
columns, op,
1774+
/*isPersistentScan*/ true
1775+
);
1776+
}
1777+
1778+
for (auto& meta: metas) {
1779+
PrepareScanMetaForUsage(meta, keyTypes);
1780+
LOG_D("Stage " << stageInfo.Id << " create scan task meta for node: " << nodeId
1781+
<< ", meta: " << meta.ToString(keyTypes, *AppData()->TypeRegistry));
1782+
}
1783+
}
1784+
1785+
// in runtime we calc hash, which will be in [0; shardcount]
1786+
// so we merge to mappings : hash -> shardID and shardID -> channelID for runtime
1787+
THashMap<ui64, ui64> hashByShardId;
1788+
const auto& tableDesc = stageInfo.Meta.ColumnTableInfoPtr->Description;
1789+
const auto& sharding = tableDesc.GetSharding();
1790+
for (std::size_t i = 0; i < sharding.ColumnShardsSize(); ++i) {
1791+
hashByShardId.insert({sharding.GetColumnShards(i), i});
1792+
}
1793+
1794+
for (ui32 t = 0; t < maxTasksPerNode; ++t, ++stageInternalTaskId) {
1795+
auto& task = TasksGraph.AddTask(stageInfo);
1796+
task.Meta = metas[t];
1797+
task.Meta.SetEnableShardsSequentialScan(false);
1798+
task.Meta.ExecuterId = SelfId();
1799+
task.Meta.NodeId = nodeId;
1800+
task.Meta.ScanTask = true;
1801+
task.Meta.Type = TTaskMeta::TTaskType::Scan;
1802+
task.SetMetaId(t);
1803+
FillSecureParamsFromStage(task.Meta.SecureParams, stage);
1804+
BuildSinks(stage, task);
1805+
1806+
for (const auto& readInfo: *task.Meta.Reads) {
1807+
Y_ENSURE(hashByShardId.contains(readInfo.ShardId));
1808+
(*stageInfo.Meta.TaskIdByHash)[hashByShardId[readInfo.ShardId]] = stageInternalTaskId;
1809+
}
1810+
1811+
}
1812+
}
17351813
} else {
1814+
ui32 metaId = 0;
17361815
for (auto&& pair : nodeShards) {
17371816
const auto nodeId = pair.first;
17381817
auto& shardsInfo = pair.second;

ydb/core/kqp/executer_actor/kqp_scan_executer.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ class TKqpScanExecuter : public TKqpExecuterBase<TKqpScanExecuter, EExecType::Sc
196196
BuildSysViewScanTasks(stageInfo);
197197
} else if (stageInfo.Meta.IsOlap() || stageInfo.Meta.IsDatashard()) {
198198
HasOlapTable = true;
199-
BuildScanTasksFromShards(stageInfo);
199+
BuildScanTasksFromShards(stageInfo, tx.Body->EnableShuffleElimination());
200200
} else {
201201
YQL_ENSURE(false, "Unexpected stage type " << (int) stageInfo.Meta.TableKind);
202202
}
@@ -219,7 +219,7 @@ class TKqpScanExecuter : public TKqpExecuterBase<TKqpScanExecuter, EExecType::Sc
219219
}
220220

221221
TasksGraph.GetMeta().AllowWithSpilling |= stage.GetAllowWithSpilling();
222-
BuildKqpStageChannels(TasksGraph, stageInfo, TxId, /* enableSpilling */ TasksGraph.GetMeta().AllowWithSpilling);
222+
BuildKqpStageChannels(TasksGraph, stageInfo, TxId, /* enableSpilling */ TasksGraph.GetMeta().AllowWithSpilling, tx.Body->EnableShuffleElimination());
223223
}
224224

225225
ResponseEv->InitTxResult(tx.Body);

ydb/core/kqp/executer_actor/kqp_tasks_graph.cpp

Lines changed: 111 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -451,8 +451,8 @@ void BuildStreamLookupChannels(TKqpTasksGraph& graph, const TStageInfo& stageInf
451451
}
452452
}
453453

454-
void BuildKqpStageChannels(TKqpTasksGraph& tasksGraph, const TStageInfo& stageInfo,
455-
ui64 txId, bool enableSpilling)
454+
void BuildKqpStageChannels(TKqpTasksGraph& tasksGraph, TStageInfo& stageInfo,
455+
ui64 txId, bool enableSpilling, bool enableShuffleElimination)
456456
{
457457
auto& stage = stageInfo.Meta.GetStage(stageInfo.Id);
458458

@@ -473,19 +473,91 @@ void BuildKqpStageChannels(TKqpTasksGraph& tasksGraph, const TStageInfo& stageIn
473473
<< (spilling ? " with spilling" : " without spilling"));
474474
};
475475

476-
for (const auto& input : stage.GetInputs()) {
476+
477+
bool hasMap = false;
478+
bool isFusedStage = (stageInfo.Meta.TaskIdByHash != nullptr);
479+
if (enableShuffleElimination && !isFusedStage) { // taskIdHash can be already set if it is a fused stage, so hashpartition will derive columnv1 parameters from there
480+
for (ui32 inputIndex = 0; inputIndex < stage.InputsSize(); ++inputIndex) {
481+
const auto& input = stage.GetInputs(inputIndex);
482+
auto& originStageInfo = tasksGraph.GetStageInfo(NYql::NDq::TStageId(stageInfo.Id.TxId, input.GetStageIndex()));
483+
stageInfo.Meta.TaskIdByHash = originStageInfo.Meta.TaskIdByHash;
484+
stageInfo.Meta.SourceShardCount = originStageInfo.Meta.SourceShardCount;
485+
stageInfo.Meta.SourceTableKeyColumnTypes = originStageInfo.Meta.SourceTableKeyColumnTypes;
486+
if (input.GetTypeCase() == NKqpProto::TKqpPhyConnection::kMap) {
487+
// We want to enforce sourceShardCount from map connection, cause it can be at most one map connection
488+
// and ColumnShardHash in Shuffle will use this parameter to shuffle on this map (same with taskIdByHash mapping)
489+
hasMap = true;
490+
break;
491+
}
492+
}
493+
}
494+
495+
// if it is stage, where we don't inherit parallelism.
496+
if (enableShuffleElimination && !hasMap && !isFusedStage && stageInfo.Tasks.size() > 0 && stage.InputsSize() > 0) {
497+
stageInfo.Meta.SourceShardCount = stageInfo.Tasks.size();
498+
stageInfo.Meta.TaskIdByHash = std::make_shared<TVector<ui64>>(stageInfo.Meta.SourceShardCount);
499+
for (std::size_t i = 0; i < stageInfo.Meta.SourceShardCount; ++i) {
500+
(*stageInfo.Meta.TaskIdByHash)[i] = i;
501+
}
502+
503+
for (auto& input : stage.GetInputs()) {
504+
if (input.GetTypeCase() != NKqpProto::TKqpPhyConnection::kHashShuffle) {
505+
continue;
506+
}
507+
508+
const auto& hashShuffle = input.GetHashShuffle();
509+
if (hashShuffle.GetHashKindCase() != NKqpProto::TKqpPhyCnHashShuffle::kColumnShardHashV1) {
510+
continue;
511+
}
512+
513+
Y_ENSURE(enableShuffleElimination, "OptShuffleElimination wasn't turned on, but ColumnShardHashV1 detected!");
514+
// ^ if the flag if false, and kColumnShardHashV1 detected - then the data which would be returned - would be incorrect,
515+
// because we didn't save partitioning in the BuildScanTasksFromShards.
516+
517+
auto columnShardHashV1 = hashShuffle.GetColumnShardHashV1();
518+
stageInfo.Meta.SourceTableKeyColumnTypes = std::make_shared<TVector<NScheme::TTypeInfo>>();
519+
stageInfo.Meta.SourceTableKeyColumnTypes->reserve(columnShardHashV1.KeyColumnTypesSize());
520+
for (const auto& keyColumnType: columnShardHashV1.GetKeyColumnTypes()) {
521+
auto typeId = static_cast<NScheme::TTypeId>(keyColumnType);
522+
auto typeInfo = NScheme::TTypeInfo{typeId};
523+
stageInfo.Meta.SourceTableKeyColumnTypes->push_back(typeInfo);
524+
}
525+
break;
526+
}
527+
}
528+
529+
for (auto& input : stage.GetInputs()) {
477530
ui32 inputIdx = input.GetInputIndex();
478-
const auto& inputStageInfo = tasksGraph.GetStageInfo(TStageId(stageInfo.Id.TxId, input.GetStageIndex()));
531+
auto& inputStageInfo = tasksGraph.GetStageInfo(TStageId(stageInfo.Id.TxId, input.GetStageIndex()));
479532
const auto& outputIdx = input.GetOutputIndex();
480533

481534
switch (input.GetTypeCase()) {
482535
case NKqpProto::TKqpPhyConnection::kUnionAll:
483536
BuildUnionAllChannels(tasksGraph, stageInfo, inputIdx, inputStageInfo, outputIdx, enableSpilling, log);
484537
break;
485-
case NKqpProto::TKqpPhyConnection::kHashShuffle:
538+
case NKqpProto::TKqpPhyConnection::kHashShuffle: {
539+
ui32 hashKind = NHashKind::EUndefined;
540+
switch (input.GetHashShuffle().GetHashKindCase()) {
541+
case NKqpProto::TKqpPhyCnHashShuffle::kHashV1: {
542+
hashKind = NHashKind::EHashV1;
543+
break;
544+
}
545+
case NKqpProto::TKqpPhyCnHashShuffle::kColumnShardHashV1: {
546+
Y_ENSURE(enableShuffleElimination, "OptShuffleElimination wasn't turned on, but ColumnShardHashV1 detected!");
547+
inputStageInfo.Meta.TaskIdByHash = stageInfo.Meta.TaskIdByHash;
548+
inputStageInfo.Meta.SourceShardCount = stageInfo.Meta.SourceShardCount;
549+
inputStageInfo.Meta.SourceTableKeyColumnTypes = stageInfo.Meta.SourceTableKeyColumnTypes;
550+
hashKind = NHashKind::EColumnShardHashV1;
551+
break;
552+
}
553+
default: {
554+
Y_ENSURE(false, "undefined type of hash for shuffle");
555+
}
556+
}
486557
BuildHashShuffleChannels(tasksGraph, stageInfo, inputIdx, inputStageInfo, outputIdx,
487-
input.GetHashShuffle().GetKeyColumns(), enableSpilling, log);
558+
input.GetHashShuffle().GetKeyColumns(), enableSpilling, log, hashKind);
488559
break;
560+
}
489561
case NKqpProto::TKqpPhyConnection::kBroadcast:
490562
BuildBroadcastChannels(tasksGraph, stageInfo, inputIdx, inputStageInfo, outputIdx, enableSpilling, log);
491563
break;
@@ -1045,7 +1117,13 @@ void FillTaskMeta(const TStageInfo& stageInfo, const TTask& task, NYql::NDqProto
10451117
}
10461118
}
10471119

1048-
void FillOutputDesc(const TKqpTasksGraph& tasksGraph, NYql::NDqProto::TTaskOutput& outputDesc, const TTaskOutput& output, bool enableSpilling) {
1120+
void FillOutputDesc(
1121+
const TKqpTasksGraph& tasksGraph,
1122+
NYql::NDqProto::TTaskOutput& outputDesc,
1123+
const TTaskOutput& output,
1124+
bool enableSpilling,
1125+
const TStageInfo& stageInfo
1126+
) {
10491127
switch (output.Type) {
10501128
case TTaskOutputType::Map:
10511129
YQL_ENSURE(output.Channels.size() == 1);
@@ -1058,6 +1136,31 @@ void FillOutputDesc(const TKqpTasksGraph& tasksGraph, NYql::NDqProto::TTaskOutpu
10581136
hashPartitionDesc.AddKeyColumns(column);
10591137
}
10601138
hashPartitionDesc.SetPartitionsCount(output.PartitionsCount);
1139+
1140+
switch (output.HashKind) {
1141+
case NHashKind::EHashV1: {
1142+
hashPartitionDesc.MutableHashV1();
1143+
break;
1144+
}
1145+
case NHashKind::EColumnShardHashV1: {
1146+
Y_ENSURE(stageInfo.Meta.SourceShardCount != 0, "ShardCount for ColumnShardHashV1 Shuffle can't be equal to 0");
1147+
Y_ENSURE(stageInfo.Meta.TaskIdByHash != nullptr, "TaskIdByHash for ColumnShardHashV1 wasn't propogated to this stage");
1148+
Y_ENSURE(stageInfo.Meta.SourceTableKeyColumnTypes != nullptr, "SourceTableKeyColumnTypes for ColumnShardHashV1 wasn't propogated to this stage");
1149+
auto& columnShardHashV1 = *hashPartitionDesc.MutableColumnShardHashV1();
1150+
columnShardHashV1.SetShardCount(stageInfo.Meta.SourceShardCount);
1151+
1152+
auto* columnTypes = columnShardHashV1.MutableKeyColumnTypes();
1153+
for (const auto& type: *stageInfo.Meta.SourceTableKeyColumnTypes) {
1154+
columnTypes->Add(type.GetTypeId());
1155+
}
1156+
1157+
auto* taskIdByHash = columnShardHashV1.MutableTaskIdByHash();
1158+
for (std::size_t taskID: *stageInfo.Meta.TaskIdByHash) {
1159+
taskIdByHash->Add(taskID);
1160+
}
1161+
break;
1162+
}
1163+
}
10611164
break;
10621165
}
10631166

@@ -1228,7 +1331,7 @@ void SerializeTaskToProto(
12281331
enableSpilling = tasksGraph.GetMeta().AllowWithSpilling;
12291332
}
12301333
for (const auto& output : task.Outputs) {
1231-
FillOutputDesc(tasksGraph, *result->AddOutputs(), output, enableSpilling);
1334+
FillOutputDesc(tasksGraph, *result->AddOutputs(), output, enableSpilling, stageInfo);
12321335
}
12331336

12341337
const NKqpProto::TKqpPhyStage& stage = stageInfo.Meta.GetStage(stageInfo.Id);

ydb/core/kqp/executer_actor/kqp_tasks_graph.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,12 @@ struct TStageInfoMeta {
4444
THolder<TKeyDesc> ShardKey;
4545
NSchemeCache::TSchemeCacheRequest::EKind ShardKind = NSchemeCache::TSchemeCacheRequest::EKind::KindUnknown;
4646

47+
// used for ColumnV1Hashing
48+
ui64 SourceShardCount = 0;
49+
std::shared_ptr<TVector<NScheme::TTypeInfo>> SourceTableKeyColumnTypes = nullptr;
50+
std::shared_ptr<TVector<ui64>> TaskIdByHash = nullptr; // hash belongs [0; ShardCount]
51+
//
52+
4753
const NKqpProto::TKqpPhyStage& GetStage(const size_t idx) const {
4854
auto& txBody = Tx.Body;
4955
YQL_ENSURE(idx < txBody->StagesSize());
@@ -272,8 +278,7 @@ using TKqpTasksGraph = NYql::NDq::TDqTasksGraph<TGraphMeta, TStageInfoMeta, TTas
272278

273279
void FillKqpTasksGraphStages(TKqpTasksGraph& tasksGraph, const TVector<IKqpGateway::TPhysicalTxData>& txs);
274280
void BuildKqpTaskGraphResultChannels(TKqpTasksGraph& tasksGraph, const TKqpPhyTxHolder::TConstPtr& tx, ui64 txIdx);
275-
void BuildKqpStageChannels(TKqpTasksGraph& tasksGraph, const TStageInfo& stageInfo,
276-
ui64 txId, bool enableSpilling);
281+
void BuildKqpStageChannels(TKqpTasksGraph& tasksGraph, TStageInfo& stageInfo, ui64 txId, bool enableSpilling, bool enableShuffleElimination);
277282

278283
NYql::NDqProto::TDqTask* ArenaSerializeTaskToProto(TKqpTasksGraph& tasksGraph, const TTask& task, bool serializeAsyncIoSettings);
279284
void FillTableMeta(const TStageInfo& stageInfo, NKikimrTxDataShard::TKqpTransaction_TTableMeta* meta);

0 commit comments

Comments
 (0)