Skip to content

Commit 93dee2f

Browse files
committed
move logic to finalizing transformer
1 parent e9a91df commit 93dee2f

File tree

2 files changed

+112
-85
lines changed

2 files changed

+112
-85
lines changed

ydb/core/kqp/opt/kqp_opt_build_txs.cpp

Lines changed: 1 addition & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -771,7 +771,7 @@ class TKqpBuildTxsTransformer : public TSyncTransformerBase {
771771

772772
TNodeOnNodeOwnedMap phaseStagesMap;
773773
TVector<TKqlQueryResult> phaseResults;
774-
TVector<TExprBase> computedInputs;
774+
TVector<TDqPhyPrecompute> computedInputs;
775775
TNodeSet computedInputsSet;
776776

777777
// Gather all Precompute stages, that are independent of any other stage and form phase of execution
@@ -785,96 +785,12 @@ class TKqpBuildTxsTransformer : public TSyncTransformerBase {
785785
phaseStagesMap.emplace(raw, ptr);
786786
}
787787

788-
{
789-
TNodeOnNodeOwnedMap fullPhaseStagesMap;
790-
for (auto& [_, stagePtr] : phaseStagesMap) {
791-
VisitExpr(stagePtr,
792-
[&](const TExprNode::TPtr& node) {
793-
if (TExprBase(node).Maybe<TDqStage>()) {
794-
fullPhaseStagesMap[node.Get()] = node;
795-
}
796-
return true;
797-
});
798-
}
799-
phaseStagesMap.swap(fullPhaseStagesMap);
800-
801-
TNodeOnNodeOwnedMap fullDependantMap;
802-
VisitExpr(query.Ptr(),
803-
[&](const TExprNode::TPtr& node) {
804-
if (phaseStagesMap.contains(node.Get())) {
805-
return false;
806-
}
807-
if (TExprBase(node).Maybe<TDqStage>()) {
808-
fullDependantMap[node.Get()] = node;
809-
}
810-
return true;
811-
});
812-
dependantStagesMap.swap(fullDependantMap);
813-
}
814-
815788
if (phaseStagesMap.empty()) {
816789
output = query.Ptr();
817790
ctx.AddError(TIssue(ctx.GetPosition(query.Pos()), "Phase stages is empty"));
818791
return TStatus::Error;
819792
}
820793

821-
// so that all outputs to dependent stages are precomputes
822-
{
823-
TSet<TExprNode*> buildingTxStages;
824-
825-
TNodeOnNodeOwnedMap replaces;
826-
827-
for (auto& [_, stagePtr] : dependantStagesMap) {
828-
TDqStage stage(stagePtr);
829-
for (size_t i = 0; i < stage.Inputs().Size(); ++i) {
830-
auto input = stage.Inputs().Item(i);
831-
if (auto maybeConn = input.Maybe<TDqConnection>()) {
832-
auto conn = maybeConn.Cast();
833-
if (!conn.Maybe<TDqCnValue>() && !conn.Maybe<TDqCnUnionAll>()) {
834-
continue;
835-
}
836-
837-
if (phaseStagesMap.contains(conn.Output().Stage().Raw())) {
838-
auto oldArg = stage.Program().Args().Arg(i);
839-
auto newArg = Build<TCoArgument>(ctx, stage.Program().Args().Arg(i).Pos())
840-
.Name("_replaced_arg")
841-
.Done();
842-
843-
TVector<TCoArgument> newArgs;
844-
TNodeOnNodeOwnedMap programReplaces;
845-
for (size_t j = 0; j < stage.Program().Args().Size(); ++j) {
846-
auto oldArg = stage.Program().Args().Arg(j);
847-
newArgs.push_back(Build<TCoArgument>(ctx, stage.Program().Args().Arg(i).Pos())
848-
.Name("_replaced_arg_" + ToString(j))
849-
.Done());
850-
if (i == j) {
851-
programReplaces[oldArg.Raw()] = Build<TCoToFlow>(ctx, oldArg.Pos()).Input(newArgs.back()).Done().Ptr();
852-
} else {
853-
programReplaces[oldArg.Raw()] = newArgs.back().Ptr();
854-
}
855-
}
856-
857-
replaces[stage.Raw()] =
858-
Build<TDqStage>(ctx, stage.Pos())
859-
.Inputs(ctx.ReplaceNode(stage.Inputs().Ptr(), input.Ref(), Build<TDqPhyPrecompute>(ctx, input.Pos()).Connection(conn).Done().Ptr()))
860-
.Outputs(stage.Outputs())
861-
.Settings(stage.Settings())
862-
.Program()
863-
.Args(newArgs)
864-
.Body(TExprBase(ctx.ReplaceNodes(stage.Program().Body().Ptr(), programReplaces)))
865-
.Build()
866-
.Done().Ptr();
867-
}
868-
}
869-
}
870-
}
871-
872-
if (!replaces.empty()) {
873-
output = ctx.ReplaceNodes(query.Ptr(), replaces);
874-
return TStatus(TStatus::Repeat, true);
875-
}
876-
}
877-
878794
for (auto& [_, stagePtr] : dependantStagesMap) {
879795
TDqStage stage(stagePtr);
880796
auto precomputes = PrecomputeInputs(stage);

ydb/core/kqp/opt/kqp_opt_phy_finalize.cpp

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,115 @@ TStatus KqpDuplicateResults(const TExprNode::TPtr& input, TExprNode::TPtr& outpu
222222
return TStatus::Ok;
223223
}
224224

225+
template <typename TExpr>
226+
TVector<TExpr> CollectNodes(const TExprNode::TPtr& input) {
227+
TVector<TExpr> result;
228+
229+
VisitExpr(input, [&result](const TExprNode::TPtr& node) {
230+
if (TExpr::Match(node.Get())) {
231+
result.emplace_back(TExpr(node));
232+
}
233+
return true;
234+
});
235+
236+
return result;
237+
}
238+
239+
bool FindPrecomputedOutputs(TDqStageBase stage, const TParentsMap& parentsMap) {
240+
auto outIt = parentsMap.find(stage.Raw());
241+
if (outIt == parentsMap.end()) {
242+
return false;
243+
}
244+
245+
for (auto& output : outIt->second) {
246+
if (TDqOutput::Match(output)) {
247+
auto connIt = parentsMap.find(output);
248+
if (connIt != parentsMap.end()) {
249+
for (auto maybeConn : connIt->second) {
250+
auto parentIt = parentsMap.find(maybeConn);
251+
if (parentIt != parentsMap.end()) {
252+
for (auto& parent : parentIt->second) {
253+
if (TDqPrecompute::Match(parent) || TDqPhyPrecompute::Match(parent)) {
254+
return true;
255+
}
256+
}
257+
}
258+
}
259+
}
260+
}
261+
}
262+
263+
return false;
264+
}
265+
266+
267+
TExprBase ReplicatePrecompute(TDqStageBase stage, TExprContext& ctx, const TParentsMap& parentsMap) {
268+
for (size_t i = 0; i < stage.Inputs().Size(); ++i) {
269+
auto input = stage.Inputs().Item(i);
270+
if (auto maybeConn = stage.Inputs().Item(i).Maybe<TDqConnection>()) {
271+
auto conn = maybeConn.Cast();
272+
if (conn.Maybe<TDqCnValue>() || conn.Maybe<TDqCnUnionAll>()) {
273+
{
274+
auto sourceStage = conn.Output().Stage();
275+
if (!sourceStage.Program().Body().Maybe<TDqReplicate>()) {
276+
continue;
277+
}
278+
279+
if (!FindPrecomputedOutputs(sourceStage, parentsMap)) {
280+
continue;
281+
}
282+
}
283+
284+
auto arg = stage.Program().Args().Arg(i);
285+
auto newArg = Build<TCoArgument>(ctx, stage.Program().Args().Arg(i).Pos())
286+
.Name("_replaced_arg")
287+
.Done();
288+
289+
TVector<TCoArgument> newArgs;
290+
TNodeOnNodeOwnedMap programReplaces;
291+
for (size_t j = 0; j < stage.Program().Args().Size(); ++j) {
292+
auto oldArg = stage.Program().Args().Arg(j);
293+
newArgs.push_back(Build<TCoArgument>(ctx, stage.Program().Args().Arg(i).Pos())
294+
.Name("_replaced_arg_" + ToString(j))
295+
.Done());
296+
if (i == j) {
297+
programReplaces[oldArg.Raw()] = Build<TCoToFlow>(ctx, oldArg.Pos()).Input(newArgs.back()).Done().Ptr();
298+
} else {
299+
programReplaces[oldArg.Raw()] = newArgs.back().Ptr();
300+
}
301+
}
302+
303+
return
304+
Build<TDqStage>(ctx, stage.Pos())
305+
.Inputs(ctx.ReplaceNode(stage.Inputs().Ptr(), input.Ref(), Build<TDqPhyPrecompute>(ctx, input.Pos()).Connection(conn).Done().Ptr()))
306+
.Outputs(stage.Outputs())
307+
.Settings(stage.Settings())
308+
.Program()
309+
.Args(newArgs)
310+
.Body(TExprBase(ctx.ReplaceNodes(stage.Program().Body().Ptr(), programReplaces)))
311+
.Build()
312+
.Done();
313+
}
314+
}
315+
}
316+
return stage;
317+
}
318+
319+
NYql::IGraphTransformer::TStatus ReplicatePrecomputeRule(const TExprNode::TPtr& input, TExprNode::TPtr& output, TExprContext& ctx) {
320+
TParentsMap parents;
321+
GatherParents(*input, parents, true);
322+
auto stages = CollectNodes<TDqStageBase>(input);
323+
for (auto& stage : stages) {
324+
auto applied = ReplicatePrecompute(stage, ctx, parents);
325+
if (applied.Raw() != stage.Raw()) {
326+
output = ctx.ReplaceNode(input.Get(), stage.Ref(), applied.Ptr());
327+
return TStatus::Repeat;
328+
}
329+
}
330+
output = input;
331+
return TStatus::Ok;
332+
}
333+
225334
template <typename TFunctor>
226335
NYql::IGraphTransformer::TStatus PerformGlobalRule(const TString& ruleName, const NYql::TExprNode::TPtr& input,
227336
NYql::TExprNode::TPtr& output, NYql::TExprContext& ctx, TFunctor func)
@@ -251,6 +360,8 @@ TAutoPtr<IGraphTransformer> CreateKqpFinalizingOptTransformer(const TIntrusivePt
251360
[kqpCtx](const TExprNode::TPtr& input, TExprNode::TPtr& output, TExprContext& ctx) -> TStatus {
252361
output = input;
253362

363+
PERFORM_GLOBAL_RULE("ReplicatePrecompute", input, output, ctx, ReplicatePrecomputeRule);
364+
254365
PERFORM_GLOBAL_RULE("ReplicateMultiUsedConnection", input, output, ctx,
255366
[](const TExprNode::TPtr& input, TExprNode::TPtr& output, TExprContext& ctx) {
256367
YQL_ENSURE(TKqlQuery::Match(input.Get()));

0 commit comments

Comments
 (0)