Skip to content

Commit 63660ad

Browse files
authored
YQL-16896: Common type inferring for SELECT combinators (#843)
1 parent 3df4be0 commit 63660ad

File tree

26 files changed

+656
-127
lines changed

26 files changed

+656
-127
lines changed

ydb/library/yql/core/common_opt/yql_co_pgselect.cpp

+62-7
Original file line numberDiff line numberDiff line change
@@ -1683,7 +1683,18 @@ TExprNode::TPtr BuildCrossJoinsBetweenGroups(TPositionHandle pos, const TExprNod
16831683
return ctx.NewCallable(pos, "EquiJoin", std::move(args));
16841684
}
16851685

1686-
TExprNode::TPtr BuildProjectionLambda(TPositionHandle pos, const TExprNode::TPtr& result, bool subLink, bool emitPgStar, TExprContext& ctx) {
1686+
TExprNode::TPtr BuildProjectionLambda(TPositionHandle pos, const TExprNode::TPtr& result, const TStructExprType* finalType,
1687+
const TColumnOrder& nodeColumnOrder, const TColumnOrder& setItemColumnOrder,
1688+
bool subLink, bool emitPgStar, TExprContext& ctx) {
1689+
1690+
YQL_ENSURE(nodeColumnOrder.size() == setItemColumnOrder.size());
1691+
TMap<TStringBuf, TStringBuf> columnNamesMap;
1692+
if (!emitPgStar) {
1693+
for (size_t i = 0; i < nodeColumnOrder.size(); ++i) {
1694+
columnNamesMap[setItemColumnOrder[i]] = nodeColumnOrder[i];
1695+
}
1696+
}
1697+
16871698
return ctx.Builder(pos)
16881699
.Lambda()
16891700
.Param("row")
@@ -1705,26 +1716,68 @@ TExprNode::TPtr BuildProjectionLambda(TPositionHandle pos, const TExprNode::TPtr
17051716
.Seal();
17061717
listBuilder.Seal();
17071718
};
1719+
1720+
auto addAtomToListWithCast = [&addAtomToList] (TExprNodeBuilder& listBuilder, TExprNode* x,
1721+
const TTypeAnnotationNode* expectedTypeNode) -> void {
1722+
auto actualType = x->GetTypeAnn()->Cast<TPgExprType>();
1723+
Y_ENSURE(expectedTypeNode);
1724+
const auto expectedType = expectedTypeNode->Cast<TPgExprType>();
1725+
1726+
if (actualType == expectedType) {
1727+
addAtomToList(listBuilder, x);
1728+
return;
1729+
}
1730+
listBuilder.Add(0, x->HeadPtr());
1731+
listBuilder.Callable(1, "PgCast")
1732+
.Apply(0, x->TailPtr())
1733+
.With(0, "row")
1734+
.Seal()
1735+
.Callable(1, "PgType")
1736+
.Atom(0, NPg::LookupType(expectedType->GetId()).Name)
1737+
.Seal();
1738+
listBuilder.Seal();
1739+
};
1740+
17081741
for (const auto& x : result->Tail().Children()) {
17091742
if (x->HeadPtr()->IsAtom()) {
17101743
if (!emitPgStar) {
1744+
const auto& columnName = x->Child(0)->Content();
17111745
auto listBuilder = parent.List(index++);
1712-
addAtomToList(listBuilder, x.Get());
1746+
addAtomToListWithCast(listBuilder, x.Get(), finalType->FindItemType(columnNamesMap[columnName]));
17131747
}
17141748
} else {
17151749
auto type = x->Child(1)->GetTypeAnn()->Cast<TTypeExprType>()->GetType()->Cast<TStructExprType>();
1750+
Y_ENSURE(type);
1751+
17161752
for (const auto& item : type->GetItems()) {
17171753
TStringBuf column = item->GetName();
17181754
auto columnName = subLink ? column : NTypeAnnImpl::RemoveAlias(column);
1755+
17191756
auto listBuilder = parent.List(index++);
17201757
if (overrideColumns.contains(columnName)) {
1758+
// we never get here while processing SELECTs,
1759+
// so no need to add PgCasts due to query combining with UNION ALL et al
17211760
addAtomToList(listBuilder, overrideColumns[columnName]);
17221761
} else {
17231762
listBuilder.Atom(0, columnName);
1724-
listBuilder.Callable(1, "Member")
1725-
.Arg(0, "row")
1726-
.Atom(1, column);
1727-
listBuilder.Seal();
1763+
1764+
const auto expectedType = finalType->FindItemType(columnNamesMap[columnName]);
1765+
if (item->GetItemType() == expectedType) {
1766+
listBuilder.Callable(1, "Member")
1767+
.Arg(0, "row")
1768+
.Atom(1, column)
1769+
.Seal();
1770+
} else {
1771+
listBuilder.Callable(1, "PgCast")
1772+
.Callable(0, "Member")
1773+
.Arg(0, "row")
1774+
.Atom(1, column)
1775+
.Seal()
1776+
.Callable(1, "PgType")
1777+
.Atom(0, NPg::LookupType(expectedType->Cast<TPgExprType>()->GetId()).Name)
1778+
.Seal()
1779+
.Seal();
1780+
}
17281781
}
17291782
}
17301783
}
@@ -3159,7 +3212,9 @@ TExprNode::TPtr ExpandPgSelectImpl(const TExprNode::TPtr& node, TExprContext& ct
31593212
}
31603213
} else {
31613214
YQL_ENSURE(result);
3162-
TExprNode::TPtr projectionLambda = BuildProjectionLambda(node->Pos(), result, subLinkId.Defined(), emitPgStar, ctx);
3215+
auto finalType = node->GetTypeAnn()->Cast<TListExprType>()->GetItemType()->Cast<TStructExprType>();
3216+
Y_ENSURE(finalType);
3217+
TExprNode::TPtr projectionLambda = BuildProjectionLambda(node->Pos(), result, finalType, *order, *childOrder, subLinkId.Defined(), emitPgStar, ctx);
31633218
TExprNode::TPtr projectionArg = projectionLambda->Head().HeadPtr();
31643219
TExprNode::TPtr projectionRoot = projectionLambda->TailPtr();
31653220
TVector<TString> inputAliases;

ydb/library/yql/core/type_ann/type_ann_pg.cpp

+125-13
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <ydb/library/yql/parser/pg_wrapper/interface/utils.h>
1313

1414
#include <util/generic/set.h>
15+
#include <util/generic/hash.h>
1516

1617
namespace NYql {
1718

@@ -64,8 +65,8 @@ bool ValidateInputTypes(TExprNode& node, TExprContext& ctx) {
6465
return true;
6566
}
6667

67-
TExprNodePtr WrapWithPgCast(TExprNodePtr&& node, ui32 targetTypeId, TContext& ctx) {
68-
return ctx.Expr.Builder(node->Pos())
68+
TExprNodePtr WrapWithPgCast(TExprNodePtr&& node, ui32 targetTypeId, TExprContext& ctx) {
69+
return ctx.Builder(node->Pos())
6970
.Callable("PgCast")
7071
.Add(0, std::move(node))
7172
.Callable(1, "PgType")
@@ -75,6 +76,113 @@ TExprNodePtr WrapWithPgCast(TExprNodePtr&& node, ui32 targetTypeId, TContext& ct
7576
.Build();
7677
};
7778

79+
TExprNodePtr FindLeftCombinatorOfNthSetItem(const TExprNode* setItems, const TExprNode* setOps, ui32 n) {
80+
TVector<ui32> setItemsStack(setItems->ChildrenSize());
81+
i32 sp = -1;
82+
ui32 itemIdx = 0;
83+
for (const auto& op : setOps->Children()) {
84+
if (op->Content() == "push") {
85+
setItemsStack[++sp] = itemIdx++;
86+
} else {
87+
if (setItemsStack[sp] == n) {
88+
return op;
89+
}
90+
--sp;
91+
Y_ENSURE(0 <= sp);
92+
}
93+
}
94+
Y_UNREACHABLE();
95+
}
96+
97+
IGraphTransformer::TStatus InferPgCommonType(TPositionHandle pos, const TExprNode* setItems, const TExprNode* setOps,
98+
TColumnOrder& resultColumnOrder, const TStructExprType*& resultStructType, TExtContext& ctx)
99+
{
100+
TVector<TVector<ui32>> pgTypes;
101+
size_t fieldsCnt = 0;
102+
103+
for (size_t i = 0; i < setItems->ChildrenSize(); ++i) {
104+
const auto* child = setItems->Child(i);
105+
106+
if (!EnsureListType(*child, ctx.Expr)) {
107+
return IGraphTransformer::TStatus::Error;
108+
}
109+
auto itemType = child->GetTypeAnn()->Cast<TListExprType>()->GetItemType();
110+
YQL_ENSURE(itemType);
111+
112+
if (!EnsureStructType(child->Pos(), *itemType, ctx.Expr)) {
113+
return IGraphTransformer::TStatus::Error;
114+
}
115+
116+
auto childColumnOrder = ctx.Types.LookupColumnOrder(*child);
117+
if (!childColumnOrder) {
118+
ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(child->Pos()), TStringBuilder()
119+
<< "Input #" << i << " does not have ordered columns. "
120+
<< "Consider making column order explicit by using SELECT with column names"));
121+
return IGraphTransformer::TStatus::Error;
122+
}
123+
124+
if (0 == i) {
125+
resultColumnOrder = *childColumnOrder;
126+
fieldsCnt = resultColumnOrder.size();
127+
128+
pgTypes.resize(fieldsCnt);
129+
for (size_t j = 0; j < fieldsCnt; ++j) {
130+
pgTypes[j].reserve(setItems->ChildrenSize());
131+
}
132+
} else {
133+
if ((*childColumnOrder).size() != fieldsCnt) {
134+
TExprNodePtr combinator = FindLeftCombinatorOfNthSetItem(setItems, setOps, i);
135+
Y_ENSURE(combinator);
136+
137+
TString op(combinator->Content());
138+
if (op.EndsWith("_all")) {
139+
op.erase(op.length() - 4);
140+
}
141+
op.to_upper();
142+
143+
ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(child->Pos()), TStringBuilder()
144+
<< "each " << op << " query must have the same number of columns"));
145+
146+
return IGraphTransformer::TStatus::Error;
147+
}
148+
}
149+
150+
const auto structType = itemType->Cast<TStructExprType>();
151+
{
152+
size_t j = 0;
153+
for (const auto& col : *childColumnOrder) {
154+
auto itemIdx = structType->FindItem(col);
155+
YQL_ENSURE(itemIdx);
156+
pgTypes[j].push_back(structType->GetItems()[*itemIdx]->GetItemType()->Cast<TPgExprType>()->GetId());
157+
158+
++j;
159+
}
160+
}
161+
}
162+
163+
TVector<const TItemExprType*> structItems;
164+
for (size_t j = 0; j < fieldsCnt; ++j) {
165+
const NPg::TTypeDesc* commonType;
166+
if (const auto issue = NPg::LookupCommonType(pgTypes[j],
167+
[j, &setItems, &ctx](size_t i) {
168+
return ctx.Expr.GetPosition(setItems->Child(i)->Child(j)->Pos());
169+
}, commonType))
170+
{
171+
ctx.Expr.AddError(*issue);
172+
return IGraphTransformer::TStatus::Error;
173+
}
174+
structItems.push_back(ctx.Expr.MakeType<TItemExprType>(resultColumnOrder[j],
175+
ctx.Expr.MakeType<TPgExprType>(commonType->TypeId)));
176+
}
177+
178+
resultStructType = ctx.Expr.MakeType<TStructExprType>(structItems);
179+
if (!resultStructType->Validate(pos, ctx.Expr)) {
180+
return IGraphTransformer::TStatus::Error;
181+
}
182+
183+
return IGraphTransformer::TStatus::Ok;
184+
}
185+
78186
IGraphTransformer::TStatus PgStarWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) {
79187
Y_UNUSED(output);
80188
if (!EnsureArgsCount(*input, 0, ctx.Expr)) {
@@ -212,12 +320,12 @@ IGraphTransformer::TStatus PgCallWrapper(const TExprNode::TPtr& input, TExprNode
212320
const auto& fargTypes = (*procPtr)->ArgTypes;
213321
for (size_t i = 0; i < argTypes.size(); ++i) {
214322
if (IsCastRequired(argTypes[i], fargTypes[i])) {
215-
children[i+3] = WrapWithPgCast(std::move(children[i+3]), fargTypes[i], ctx);
323+
children[i+3] = WrapWithPgCast(std::move(children[i+3]), fargTypes[i], ctx.Expr);
216324
}
217325
}
218326
output = ctx.Expr.NewCallable(input->Pos(), "PgResolvedCall", std::move(children));
219327
} else if (const auto* typePtr = std::get_if<const NPg::TTypeDesc*>(&procOrType)) {
220-
output = WrapWithPgCast(std::move(children[2]), (*typePtr)->TypeId, ctx);
328+
output = WrapWithPgCast(std::move(children[2]), (*typePtr)->TypeId, ctx.Expr);
221329
} else {
222330
Y_UNREACHABLE();
223331
}
@@ -454,16 +562,16 @@ IGraphTransformer::TStatus PgOpWrapper(const TExprNode::TPtr& input, TExprNode::
454562
switch(oper.Kind) {
455563
case NPg::EOperKind::LeftUnary:
456564
if (IsCastRequired(argTypes[0], oper.RightType)) {
457-
children[1] = WrapWithPgCast(std::move(children[1]), oper.RightType, ctx);
565+
children[1] = WrapWithPgCast(std::move(children[1]), oper.RightType, ctx.Expr);
458566
}
459567
break;
460568

461569
case NYql::NPg::EOperKind::Binary:
462570
if (IsCastRequired(argTypes[0], oper.LeftType)) {
463-
children[1] = WrapWithPgCast(std::move(children[1]), oper.LeftType, ctx);
571+
children[1] = WrapWithPgCast(std::move(children[1]), oper.LeftType, ctx.Expr);
464572
}
465573
if (IsCastRequired(argTypes[1], oper.RightType)) {
466-
children[2] = WrapWithPgCast(std::move(children[2]), oper.RightType, ctx);
574+
children[2] = WrapWithPgCast(std::move(children[2]), oper.RightType, ctx.Expr);
467575
}
468576
break;
469577

@@ -648,7 +756,7 @@ IGraphTransformer::TStatus PgAggWrapper(const TExprNode::TPtr& input, TExprNode:
648756
for (ui32 i = 0; i < argTypes.size(); ++i, ++argIdx) {
649757
if (IsCastRequired(argTypes[i], aggDesc.ArgTypes[i])) {
650758
auto& argNode = input->ChildRef(argIdx);
651-
argNode = WrapWithPgCast(std::move(argNode), aggDesc.ArgTypes[i], ctx);
759+
argNode = WrapWithPgCast(std::move(argNode), aggDesc.ArgTypes[i], ctx.Expr);
652760
needRetype = true;
653761
}
654762
}
@@ -4155,7 +4263,7 @@ IGraphTransformer::TStatus PgValuesListWrapper(const TExprNode::TPtr& input, TEx
41554263
if (item->GetTypeAnn()->Cast<TPgExprType>()->GetId() == commonTypes[j]) {
41564264
rowValues.push_back(item);
41574265
} else {
4158-
rowValues.push_back(WrapWithPgCast(std::move(item), commonTypes[j], ctx));
4266+
rowValues.push_back(WrapWithPgCast(std::move(item), commonTypes[j], ctx.Expr));
41594267
}
41604268
}
41614269
resultValues.push_back(ctx.Expr.NewList(value->Pos(), std::move(rowValues)));
@@ -4338,7 +4446,11 @@ IGraphTransformer::TStatus PgSelectWrapper(const TExprNode::TPtr& input, TExprNo
43384446

43394447
TColumnOrder resultColumnOrder;
43404448
const TStructExprType* resultStructType = nullptr;
4341-
auto status = InferPositionalUnionType(input->Pos(), setItems->ChildrenList(), resultColumnOrder, resultStructType, ctx);
4449+
4450+
auto status = (1 == setItems->ChildrenSize() && HasSetting(*setItems->Child(0)->Child(0), "unknowns_allowed"))
4451+
? InferPositionalUnionType(input->Pos(), setItems->ChildrenList(), resultColumnOrder, resultStructType, ctx)
4452+
: InferPgCommonType(input->Pos(), setItems, setOps, resultColumnOrder, resultStructType, ctx);
4453+
43424454
if (status != IGraphTransformer::TStatus::Ok) {
43434455
return status;
43444456
}
@@ -4471,7 +4583,7 @@ IGraphTransformer::TStatus PgArrayWrapper(const TExprNode::TPtr& input, TExprNod
44714583
if (argTypes[i] == elemType) {
44724584
castArrayElems.push_back(child);
44734585
} else {
4474-
castArrayElems.push_back(WrapWithPgCast(std::move(child), elemType, ctx));
4586+
castArrayElems.push_back(WrapWithPgCast(std::move(child), elemType, ctx.Expr));
44754587
}
44764588
}
44774589
output = ctx.Expr.NewCallable(input->Pos(), "PgArray", std::move(castArrayElems));
@@ -4587,7 +4699,7 @@ IGraphTransformer::TStatus PgLikeWrapper(const TExprNode::TPtr& input, TExprNode
45874699
if (argTypes[i] != textTypeId) {
45884700
if (argTypes[i] == NPg::UnknownOid) {
45894701
auto& argNode = input->ChildRef(i);
4590-
argNode = WrapWithPgCast(std::move(argNode), textTypeId, ctx);
4702+
argNode = WrapWithPgCast(std::move(argNode), textTypeId, ctx.Expr);
45914703
return IGraphTransformer::TStatus::Repeat;
45924704
}
45934705
ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()),
@@ -4656,7 +4768,7 @@ IGraphTransformer::TStatus PgInWrapper(const TExprNode::TPtr& input, TExprNode::
46564768
if (itemTypePg && inputTypePg && itemTypePg != inputTypePg) {
46574769
if (inputTypePg == NPg::UnknownOid) {
46584770

4659-
input->ChildRef(0) = WrapWithPgCast(std::move(input->Child(0)), itemTypePg, ctx);
4771+
input->ChildRef(0) = WrapWithPgCast(std::move(input->Child(0)), itemTypePg, ctx.Expr);
46604772
return IGraphTransformer::TStatus::Repeat;
46614773
}
46624774
if (itemTypePg == NPg::UnknownOid) {

ydb/library/yql/core/type_ann/type_ann_pg.h

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
namespace NYql {
99
namespace NTypeAnnImpl {
1010

11+
TExprNodePtr WrapWithPgCast(TExprNodePtr&& node, ui32 targetTypeId, TExprContext& ctx);
1112
TString MakeAliasedColumn(TStringBuf alias, TStringBuf column);
1213
const TItemExprType* AddAlias(const TString& alias, const TItemExprType* item, TExprContext& ctx);
1314
TStringBuf RemoveAlias(TStringBuf column);

ydb/library/yql/sql/pg/pg_sql.cpp

+6-5
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,7 @@ class TConverter : public IPGParseEvents {
705705
}
706706
}
707707

708+
bool hasCombiningQueries = (1 < setItems.size());
708709

709710
TAstNode* sort = nullptr;
710711
if (ListLength(value->sortClause) > 0) {
@@ -716,7 +717,7 @@ class TConverter : public IPGParseEvents {
716717
return nullptr;
717718
}
718719

719-
auto sort = ParseSortBy(CAST_NODE_EXT(PG_SortBy, T_SortBy, node), setItems.size() == 1, true);
720+
auto sort = ParseSortBy(CAST_NODE_EXT(PG_SortBy, T_SortBy, node), !hasCombiningQueries, true);
720721
if (!sort) {
721722
return nullptr;
722723
}
@@ -728,7 +729,7 @@ class TConverter : public IPGParseEvents {
728729
}
729730

730731
TVector<TAstNode*> setItemNodes;
731-
for (size_t id = 0; id < setItems.size(); id++) {
732+
for (size_t id = 0; id < setItems.size(); ++id) {
732733
const auto& x = setItems[id];
733734
bool hasDistinctAll = false;
734735
TVector<TAstNode*> distinctOnItems;
@@ -1051,11 +1052,11 @@ class TConverter : public IPGParseEvents {
10511052
setItemOptions.push_back(QL(QA("distinct_on"), distinctOn));
10521053
}
10531054

1054-
if (setItems.size() == 1 && sort) {
1055+
if (!hasCombiningQueries && sort) {
10551056
setItemOptions.push_back(QL(QA("sort"), sort));
10561057
}
10571058

1058-
if (unknownsAllowed) {
1059+
if (unknownsAllowed || hasCombiningQueries) {
10591060
setItemOptions.push_back(QL(QA("unknowns_allowed")));
10601061
}
10611062

@@ -1106,7 +1107,7 @@ class TConverter : public IPGParseEvents {
11061107
selectOptions.push_back(QL(QA("set_items"), QVL(setItemNodes.data(), setItemNodes.size())));
11071108
selectOptions.push_back(QL(QA("set_ops"), QVL(setOpsNodes.data(), setOpsNodes.size())));
11081109

1109-
if (setItems.size() > 1 && sort) {
1110+
if (hasCombiningQueries && sort) {
11101111
selectOptions.push_back(QL(QA("sort"), sort));
11111112
}
11121113

0 commit comments

Comments
 (0)