Skip to content

[KQP] Fix recursion problem when computing SimplifiedPlan (#9519) #9631

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
298 changes: 184 additions & 114 deletions ydb/core/kqp/opt/kqp_query_plan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1945,6 +1945,9 @@ TVector<NJson::TJsonValue> RemoveRedundantNodes(NJson::TJsonValue& plan, const T
}
}

if (!planMap.contains("Node Type")) {
return {};
}
const auto typeName = planMap.at("Node Type").GetStringSafe();
if (redundantNodes.contains(typeName) || typeName.find("Precompute") != TString::npos) {
return children;
Expand All @@ -1953,167 +1956,235 @@ TVector<NJson::TJsonValue> RemoveRedundantNodes(NJson::TJsonValue& plan, const T
return {plan};
}

NJson::TJsonValue ReconstructQueryPlanRec(const NJson::TJsonValue& plan,
int operatorIndex,
const THashMap<int, NJson::TJsonValue>& planIndex,
const THashMap<TString, NJson::TJsonValue>& precomputes,
int& nodeCounter) {

int currentNodeId = nodeCounter++;

NJson::TJsonValue result;
result["PlanNodeId"] = currentNodeId;

if (plan.GetMapSafe().contains("PlanNodeType")) {
result["PlanNodeType"] = plan.GetMapSafe().at("PlanNodeType").GetStringSafe();
}
struct TQueryPlanReconstructor {
TQueryPlanReconstructor(
const THashMap<int, NJson::TJsonValue>& planIndex,
const THashMap<TString, NJson::TJsonValue>& precomputes
)
: PlanIndex(planIndex)
, Precomputes(precomputes)
, NodeIDCounter(0)
, Budget(10'000)
{}

if (plan.GetMapSafe().contains("Stats") && operatorIndex==0) {
result["Stats"] = plan.GetMapSafe().at("Stats");
}
NJson::TJsonValue Reconstruct(
const NJson::TJsonValue& plan,
int operatorIndex
) {
int currentNodeId = NodeIDCounter++;

if (!plan.GetMapSafe().contains("Operators")) {
NJson::TJsonValue planInputs;
NJson::TJsonValue result;
result["PlanNodeId"] = currentNodeId;

result["Node Type"] = plan.GetMapSafe().at("Node Type").GetStringSafe();
if (--Budget <= 0) {
YQL_CLOG(DEBUG, ProviderKqp) << "Can't build the plan - recursion depth has been exceeded!";
return result;
}

if (plan.GetMapSafe().contains("CTE Name")) {
auto precompute = plan.GetMapSafe().at("CTE Name").GetStringSafe();
if (precomputes.contains(precompute)) {
planInputs.AppendValue(ReconstructQueryPlanRec(precomputes.at(precompute), 0, planIndex, precomputes, nodeCounter));
}
if (plan.GetMapSafe().contains("PlanNodeType")) {
result["PlanNodeType"] = plan.GetMapSafe().at("PlanNodeType").GetStringSafe();
}

if (!plan.GetMapSafe().contains("Plans")) {
result["Plans"] = planInputs;
return result;
if (plan.GetMapSafe().contains("Stats") && operatorIndex==0) {
result["Stats"] = plan.GetMapSafe().at("Stats");
}

if (plan.GetMapSafe().at("Node Type").GetStringSafe() == "TableLookup") {
if (plan.GetMapSafe().at("Node Type") == "TableLookupJoin" && plan.GetMapSafe().contains("Table")) {
result["Node Type"] = "LookupJoin";
NJson::TJsonValue newOps;
NJson::TJsonValue op;

op["Name"] = "TableLookup";
op["Columns"] = plan.GetMapSafe().at("Columns");
op["Name"] = "LookupJoin";
op["LookupKeyColumns"] = plan.GetMapSafe().at("LookupKeyColumns");
op["Table"] = plan.GetMapSafe().at("Table");

newOps.AppendValue(std::move(op));
result["Operators"] = std::move(newOps);

NJson::TJsonValue newPlans;

NJson::TJsonValue lookupPlan;
lookupPlan["Node Type"] = "TableLookup";
lookupPlan["PlanNodeType"] = "TableLookup";

NJson::TJsonValue lookupOps;
NJson::TJsonValue lookupOp;

lookupOp["Name"] = "TableLookup";
lookupOp["Columns"] = plan.GetMapSafe().at("Columns");
lookupOp["LookupKeyColumns"] = plan.GetMapSafe().at("LookupKeyColumns");
lookupOp["Table"] = plan.GetMapSafe().at("Table");

if (plan.GetMapSafe().contains("E-Cost")) {
op["E-Cost"] = plan.GetMapSafe().at("E-Cost");
}
lookupOp["E-Cost"] = plan.GetMapSafe().at("E-Cost");
}
if (plan.GetMapSafe().contains("E-Rows")) {
op["E-Rows"] = plan.GetMapSafe().at("E-Rows");
lookupOp["E-Rows"] = plan.GetMapSafe().at("E-Rows");
}
if (plan.GetMapSafe().contains("E-Size")) {
op["E-Size"] = plan.GetMapSafe().at("E-Size");
lookupOp["E-Size"] = plan.GetMapSafe().at("E-Size");
}

newOps.AppendValue(op);
lookupOps.AppendValue(std::move(lookupOp));
lookupPlan["Operators"] = std::move(lookupOps);

newPlans.AppendValue(Reconstruct(plan.GetMapSafe().at("Plans").GetArraySafe()[0], 0));

newPlans.AppendValue(std::move(lookupPlan));

result["Plans"] = std::move(newPlans);

result["Operators"] = newOps;
return result;
}

for (auto p : plan.GetMapSafe().at("Plans").GetArraySafe()) {
if (!p.GetMapSafe().contains("Operators") && p.GetMapSafe().contains("CTE Name")) {
auto precompute = p.GetMapSafe().at("CTE Name").GetStringSafe();
if (precomputes.contains(precompute)) {
planInputs.AppendValue(ReconstructQueryPlanRec(precomputes.at(precompute), 0, planIndex, precomputes, nodeCounter));
if (!plan.GetMapSafe().contains("Operators")) {
NJson::TJsonValue planInputs;

result["Node Type"] = plan.GetMapSafe().at("Node Type").GetStringSafe();

if (plan.GetMapSafe().contains("CTE Name")) {
auto precompute = plan.GetMapSafe().at("CTE Name").GetStringSafe();
if (Precomputes.contains(precompute)) {
planInputs.AppendValue(Reconstruct(Precomputes.at(precompute), 0));
}
} else if (p.GetMapSafe().at("Node Type").GetStringSafe().find("Precompute") == TString::npos) {
planInputs.AppendValue(ReconstructQueryPlanRec(p, 0, planIndex, precomputes, nodeCounter));
}
}
result["Plans"] = planInputs;
return result;
}

if (plan.GetMapSafe().contains("CTE Name") && plan.GetMapSafe().at("Node Type").GetStringSafe() == "ConstantExpr") {
auto precompute = plan.GetMapSafe().at("CTE Name").GetStringSafe();
if (!precomputes.contains(precompute)) {
result["Node Type"] = plan.GetMapSafe().at("Node Type");
if (!plan.GetMapSafe().contains("Plans")) {
result["Plans"] = std::move(planInputs);
return result;
}

if (plan.GetMapSafe().at("Node Type").GetStringSafe() == "TableLookup") {
NJson::TJsonValue newOps;
NJson::TJsonValue op;

op["Name"] = "TableLookup";
op["Columns"] = plan.GetMapSafe().at("Columns");
op["LookupKeyColumns"] = plan.GetMapSafe().at("LookupKeyColumns");
op["Table"] = plan.GetMapSafe().at("Table");

if (plan.GetMapSafe().contains("E-Cost")) {
op["E-Cost"] = plan.GetMapSafe().at("E-Cost");
}
if (plan.GetMapSafe().contains("E-Rows")) {
op["E-Rows"] = plan.GetMapSafe().at("E-Rows");
}
if (plan.GetMapSafe().contains("E-Size")) {
op["E-Size"] = plan.GetMapSafe().at("E-Size");
}

newOps.AppendValue(std::move(op));

result["Operators"] = std::move(newOps);
return result;
}

for (auto p : plan.GetMapSafe().at("Plans").GetArraySafe()) {
if (!p.GetMapSafe().contains("Operators") && p.GetMapSafe().contains("CTE Name")) {
auto precompute = p.GetMapSafe().at("CTE Name").GetStringSafe();
if (Precomputes.contains(precompute)) {
planInputs.AppendValue(Reconstruct(Precomputes.at(precompute), 0));
}
} else if (p.GetMapSafe().at("Node Type").GetStringSafe().find("Precompute") == TString::npos) {
planInputs.AppendValue(Reconstruct(p, 0));
}
}
result["Plans"] = planInputs;
return result;
}

return ReconstructQueryPlanRec(precomputes.at(precompute), 0, planIndex, precomputes, nodeCounter);
}
if (plan.GetMapSafe().contains("CTE Name") && plan.GetMapSafe().at("Node Type").GetStringSafe() == "ConstantExpr") {
auto precompute = plan.GetMapSafe().at("CTE Name").GetStringSafe();
if (!Precomputes.contains(precompute)) {
result["Node Type"] = plan.GetMapSafe().at("Node Type");
return result;
}

auto ops = plan.GetMapSafe().at("Operators").GetArraySafe();
auto op = ops[operatorIndex];
return Reconstruct(Precomputes.at(precompute), 0);
}

TVector<NJson::TJsonValue> planInputs;
auto ops = plan.GetMapSafe().at("Operators").GetArraySafe();
auto op = ops[operatorIndex];

auto opName = op.GetMapSafe().at("Name").GetStringSafe();
TVector<NJson::TJsonValue> planInputs;

THashSet<ui32> processedExternalOperators;
THashSet<ui32> processedInternalOperators;
for (auto opInput : op.GetMapSafe().at("Inputs").GetArraySafe()) {
auto opName = op.GetMapSafe().at("Name").GetStringSafe();

if (opInput.GetMapSafe().contains("ExternalPlanNodeId")) {
auto inputPlanKey = opInput.GetMapSafe().at("ExternalPlanNodeId").GetIntegerSafe();
THashSet<ui32> processedExternalOperators;
THashSet<ui32> processedInternalOperators;
for (auto opInput : op.GetMapSafe().at("Inputs").GetArraySafe()) {

if (processedExternalOperators.contains(inputPlanKey)) {
continue;
}
processedExternalOperators.insert(inputPlanKey);
if (opInput.GetMapSafe().contains("ExternalPlanNodeId")) {
auto inputPlanKey = opInput.GetMapSafe().at("ExternalPlanNodeId").GetIntegerSafe();

auto inputPlan = planIndex.at(inputPlanKey);
planInputs.push_back( ReconstructQueryPlanRec(inputPlan, 0, planIndex, precomputes, nodeCounter));
} else if (opInput.GetMapSafe().contains("InternalOperatorId")) {
auto inputPlanId = opInput.GetMapSafe().at("InternalOperatorId").GetIntegerSafe();
if (processedExternalOperators.contains(inputPlanKey)) {
continue;
}
processedExternalOperators.insert(inputPlanKey);

if (processedInternalOperators.contains(inputPlanId)) {
continue;
}
processedInternalOperators.insert(inputPlanId);
auto inputPlan = PlanIndex.at(inputPlanKey);
planInputs.push_back( Reconstruct(inputPlan, 0) );
} else if (opInput.GetMapSafe().contains("InternalOperatorId")) {
auto inputPlanId = opInput.GetMapSafe().at("InternalOperatorId").GetIntegerSafe();

planInputs.push_back( ReconstructQueryPlanRec(plan, inputPlanId, planIndex, precomputes, nodeCounter));
if (processedInternalOperators.contains(inputPlanId)) {
continue;
}
processedInternalOperators.insert(inputPlanId);

planInputs.push_back( Reconstruct(plan, inputPlanId) );
}
}
}

if (op.GetMapSafe().contains("Inputs")) {
op.GetMapSafe().erase("Inputs");
}
if (op.GetMapSafe().contains("Inputs")) {
op.GetMapSafe().erase("Inputs");
}

if (op.GetMapSafe().contains("Input")
|| op.GetMapSafe().contains("ToFlow")
|| op.GetMapSafe().contains("Member")
|| op.GetMapSafe().contains("AssumeSorted")
|| op.GetMapSafe().contains("Iterator")) {
if (op.GetMapSafe().contains("Input")
|| op.GetMapSafe().contains("ToFlow")
|| op.GetMapSafe().contains("Member")
|| op.GetMapSafe().contains("AssumeSorted")
|| op.GetMapSafe().contains("Iterator")) {

TString maybePrecompute = "";
if (op.GetMapSafe().contains("Input")) {
maybePrecompute = op.GetMapSafe().at("Input").GetStringSafe();
} else if (op.GetMapSafe().contains("ToFlow")) {
maybePrecompute = op.GetMapSafe().at("ToFlow").GetStringSafe();
} else if (op.GetMapSafe().contains("Member")) {
maybePrecompute = op.GetMapSafe().at("Member").GetStringSafe();
} else if (op.GetMapSafe().contains("AssumeSorted")) {
maybePrecompute = op.GetMapSafe().at("AssumeSorted").GetStringSafe();
} else if (op.GetMapSafe().contains("Iterator")) {
maybePrecompute = op.GetMapSafe().at("Iterator").GetStringSafe();
}
TString maybePrecompute = "";
if (op.GetMapSafe().contains("Input")) {
maybePrecompute = op.GetMapSafe().at("Input").GetStringSafe();
} else if (op.GetMapSafe().contains("ToFlow")) {
maybePrecompute = op.GetMapSafe().at("ToFlow").GetStringSafe();
} else if (op.GetMapSafe().contains("Member")) {
maybePrecompute = op.GetMapSafe().at("Member").GetStringSafe();
} else if (op.GetMapSafe().contains("AssumeSorted")) {
maybePrecompute = op.GetMapSafe().at("AssumeSorted").GetStringSafe();
} else if (op.GetMapSafe().contains("Iterator")) {
maybePrecompute = op.GetMapSafe().at("Iterator").GetStringSafe();
}

if (precomputes.contains(maybePrecompute) && planInputs.empty()) {
planInputs.push_back(ReconstructQueryPlanRec(precomputes.at(maybePrecompute), 0, planIndex, precomputes, nodeCounter));
if (Precomputes.contains(maybePrecompute) && planInputs.empty()) {
planInputs.push_back(Reconstruct(Precomputes.at(maybePrecompute), 0));
}
}
}

result["Node Type"] = opName;
NJson::TJsonValue newOps;
newOps.AppendValue(op);
result["Operators"] = newOps;
result["Node Type"] = std::move(opName);
NJson::TJsonValue newOps;
newOps.AppendValue(std::move(op));
result["Operators"] = std::move(newOps);

if (planInputs.size()){
NJson::TJsonValue plans;
for( auto i : planInputs) {
plans.AppendValue(i);
if (!planInputs.empty()){
NJson::TJsonValue plans;
for(auto&& i : planInputs) {
plans.AppendValue(std::move(i));
}
result["Plans"] = std::move(plans);
}
result["Plans"] = plans;

return result;
}

return result;
}
private:
const THashMap<int, NJson::TJsonValue>& PlanIndex;
const THashMap<TString, NJson::TJsonValue>& Precomputes;
ui32 NodeIDCounter;
i32 Budget; // Prevent bugs with inf recursion
};

double ComputeCpuTimes(NJson::TJsonValue& plan) {
double currCpuTime = 0;
Expand Down Expand Up @@ -2209,8 +2280,7 @@ NJson::TJsonValue SimplifyQueryPlan(NJson::TJsonValue& plan) {

BuildPlanIndex(plan, planIndex, precomputes);

int nodeCounter = 0;
plan = ReconstructQueryPlanRec(plan, 0, planIndex, precomputes, nodeCounter);
plan = TQueryPlanReconstructor(planIndex, precomputes).Reconstruct(plan, 0);

RemoveRedundantNodes(plan, redundantNodes);
ComputeCpuTimes(plan);
Expand Down
Loading
Loading