Skip to content

Commit 0f07d90

Browse files
authored
Merge e96445f into fbc9d17
2 parents fbc9d17 + e96445f commit 0f07d90

File tree

5 files changed

+114
-26
lines changed

5 files changed

+114
-26
lines changed

ydb/library/yql/core/type_ann/type_ann_blocks.cpp

+16-5
Original file line numberDiff line numberDiff line change
@@ -115,13 +115,24 @@ IGraphTransformer::TStatus BlockExpandChunkedWrapper(const TExprNode::TPtr& inpu
115115
return IGraphTransformer::TStatus::Error;
116116
}
117117

118+
TTypeAnnotationNode::TListType itemTypes;
118119
TTypeAnnotationNode::TListType blockItemTypes;
119-
if (!EnsureWideFlowBlockType(input->Head(), blockItemTypes, ctx.Expr)) {
120-
return IGraphTransformer::TStatus::Error;
121-
}
122120

123-
auto flowItemTypes = input->Head().GetTypeAnn()->Cast<TFlowExprType>()->GetItemType()->Cast<TMultiExprType>()->GetItems();
124-
bool allScalars = AllOf(flowItemTypes, [](const TTypeAnnotationNode* item) { return item->GetKind() == ETypeAnnotationKind::Scalar; });
121+
if (input->Head().GetTypeAnn()->GetKind() == ETypeAnnotationKind::Stream) {
122+
if (!EnsureWideStreamBlockType(input->Head(), blockItemTypes, ctx.Expr)) {
123+
return IGraphTransformer::TStatus::Error;
124+
}
125+
126+
itemTypes = input->Head().GetTypeAnn()->Cast<TStreamExprType>()->GetItemType()->Cast<TMultiExprType>()->GetItems();
127+
} else {
128+
if (!EnsureWideFlowBlockType(input->Head(), blockItemTypes, ctx.Expr)) {
129+
return IGraphTransformer::TStatus::Error;
130+
}
131+
132+
itemTypes = input->Head().GetTypeAnn()->Cast<TFlowExprType>()->GetItemType()->Cast<TMultiExprType>()->GetItems();
133+
}
134+
135+
bool allScalars = AllOf(itemTypes, [](const TTypeAnnotationNode* item) { return item->GetKind() == ETypeAnnotationKind::Scalar; });
125136
if (allScalars) {
126137
output = input->HeadPtr();
127138
return IGraphTransformer::TStatus::Repeat;

ydb/library/yql/minikql/comp_nodes/mkql_blocks.cpp

+61-7
Original file line numberDiff line numberDiff line change
@@ -1118,6 +1118,52 @@ using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TBlockExpandChunkedW
11181118
const size_t WideFieldsIndex_;
11191119
};
11201120

1121+
class TBlockExpandChunkedStreamWrapper : public TMutableComputationNode<TBlockExpandChunkedStreamWrapper> {
1122+
using TBaseComputation = TMutableComputationNode<TBlockExpandChunkedStreamWrapper>;
1123+
class TExpanderState : public TComputationValue<TExpanderState> {
1124+
using TBase = TComputationValue<TExpanderState>;
1125+
public:
1126+
TExpanderState(TMemoryUsageInfo* memInfo, TComputationContext& ctx, NUdf::TUnboxedValue&& stream, size_t width)
1127+
: TBase(memInfo), HolderFactory_(ctx.HolderFactory), State_(ctx.HolderFactory.Create<TBlockState>(width)), Stream_(stream) {}
1128+
1129+
NUdf::EFetchStatus WideFetch(NUdf::TUnboxedValue* output, ui32 width) {
1130+
auto& s = *static_cast<TBlockState*>(State_.AsBoxed().Get());
1131+
if (!s.Count) {
1132+
s.ClearValues();
1133+
auto result = Stream_.WideFetch(s.Values.data(), width);
1134+
if (NUdf::EFetchStatus::Ok != result) {
1135+
return result;
1136+
}
1137+
s.FillArrays();
1138+
}
1139+
1140+
const auto sliceSize = s.Slice();
1141+
for (size_t i = 0; i < width; ++i) {
1142+
output[i] = s.Get(sliceSize, HolderFactory_, i);
1143+
}
1144+
return NUdf::EFetchStatus::Ok;
1145+
}
1146+
1147+
private:
1148+
const THolderFactory& HolderFactory_;
1149+
NUdf::TUnboxedValue State_;
1150+
NUdf::TUnboxedValue Stream_;
1151+
};
1152+
public:
1153+
TBlockExpandChunkedStreamWrapper(TComputationMutables& mutables, IComputationNode* stream, size_t width)
1154+
: TBaseComputation(mutables, EValueRepresentation::Boxed)
1155+
, Stream_(stream)
1156+
, Width_(width) {}
1157+
1158+
NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
1159+
return ctx.HolderFactory.Create<TExpanderState>(ctx, std::move(Stream_->GetValue(ctx)), Width_);
1160+
}
1161+
void RegisterDependencies() const override {}
1162+
private:
1163+
IComputationNode* const Stream_;
1164+
const size_t Width_;
1165+
};
1166+
11211167
} // namespace
11221168

11231169
IComputationNode* WrapToBlocks(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
@@ -1182,13 +1228,21 @@ IComputationNode* WrapReplicateScalar(TCallable& callable, const TComputationNod
11821228

11831229
IComputationNode* WrapBlockExpandChunked(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
11841230
MKQL_ENSURE(callable.GetInputsCount() == 1, "Expected 1 args, got " << callable.GetInputsCount());
1185-
1186-
const auto flowType = AS_TYPE(TFlowType, callable.GetInput(0).GetStaticType());
1187-
const auto wideComponents = GetWideComponents(flowType);
1188-
1189-
const auto wideFlow = dynamic_cast<IComputationWideFlowNode*>(LocateNode(ctx.NodeLocator, callable, 0));
1190-
MKQL_ENSURE(wideFlow != nullptr, "Expected wide flow node");
1191-
return new TBlockExpandChunkedWrapper(ctx.Mutables, wideFlow, wideComponents.size());
1231+
if (callable.GetInput(0).GetStaticType()->IsStream()) {
1232+
const auto streamType = AS_TYPE(TStreamType, callable.GetInput(0).GetStaticType());
1233+
const auto wideComponents = GetWideComponents(streamType);
1234+
const auto computation = dynamic_cast<IComputationNode*>(LocateNode(ctx.NodeLocator, callable, 0));
1235+
1236+
MKQL_ENSURE(computation != nullptr, "Expected computation node");
1237+
return new TBlockExpandChunkedStreamWrapper(ctx.Mutables, computation, wideComponents.size());
1238+
} else {
1239+
const auto flowType = AS_TYPE(TFlowType, callable.GetInput(0).GetStaticType());
1240+
const auto wideComponents = GetWideComponents(flowType);
1241+
1242+
const auto wideFlow = dynamic_cast<IComputationWideFlowNode*>(LocateNode(ctx.NodeLocator, callable, 0));
1243+
MKQL_ENSURE(wideFlow != nullptr, "Expected wide flow node");
1244+
return new TBlockExpandChunkedWrapper(ctx.Mutables, wideFlow, wideComponents.size());
1245+
}
11921246
}
11931247

11941248
}

ydb/library/yql/minikql/mkql_program_builder.cpp

+26-4
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,24 @@ bool ReduceOptionalElements(const TType* type, const TArrayRef<const ui32>& test
228228
return multiOptional;
229229
}
230230

231+
std::vector<TType*> ValidateBlockStreamType(const TType* streamType) {
232+
const auto wideComponents = GetWideComponents(AS_TYPE(TStreamType, streamType));
233+
MKQL_ENSURE(wideComponents.size() > 0, "Expected at least one column");
234+
std::vector<TType*> streamItems;
235+
streamItems.reserve(wideComponents.size());
236+
bool isScalar;
237+
for (size_t i = 0; i < wideComponents.size(); ++i) {
238+
auto blockType = AS_TYPE(TBlockType, wideComponents[i]);
239+
isScalar = blockType->GetShape() == TBlockType::EShape::Scalar;
240+
auto withoutBlock = blockType->GetItemType();
241+
streamItems.push_back(withoutBlock);
242+
}
243+
244+
MKQL_ENSURE(isScalar, "Last column should be scalar");
245+
MKQL_ENSURE(AS_TYPE(TDataType, streamItems.back())->GetSchemeType() == NUdf::TDataType<ui64>::Id, "Expected Uint64");
246+
return streamItems;
247+
}
248+
231249
std::vector<TType*> ValidateBlockFlowType(const TType* flowType) {
232250
const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, flowType));
233251
MKQL_ENSURE(wideComponents.size() > 0, "Expected at least one column");
@@ -1550,10 +1568,14 @@ TRuntimeNode TProgramBuilder::BlockCompress(TRuntimeNode flow, ui32 bitmapIndex)
15501568
return TRuntimeNode(callableBuilder.Build(), false);
15511569
}
15521570

1553-
TRuntimeNode TProgramBuilder::BlockExpandChunked(TRuntimeNode flow) {
1554-
ValidateBlockFlowType(flow.GetStaticType());
1555-
TCallableBuilder callableBuilder(Env, __func__, flow.GetStaticType());
1556-
callableBuilder.Add(flow);
1571+
TRuntimeNode TProgramBuilder::BlockExpandChunked(TRuntimeNode comp) {
1572+
if (comp.GetStaticType()->IsStream()) {
1573+
ValidateBlockStreamType(comp.GetStaticType());
1574+
} else {
1575+
ValidateBlockFlowType(comp.GetStaticType());
1576+
}
1577+
TCallableBuilder callableBuilder(Env, __func__, comp.GetStaticType());
1578+
callableBuilder.Add(comp);
15571579
return TRuntimeNode(callableBuilder.Build(), false);
15581580
}
15591581

ydb/library/yql/providers/dq/opt/dqs_opt.cpp

+6-5
Original file line numberDiff line numberDiff line change
@@ -94,12 +94,13 @@ namespace NYql::NDqs {
9494

9595
YQL_CLOG(INFO, ProviderDq) << "DqsRewritePhyBlockReadOnDqIntegration";
9696
return Build<TCoWideFromBlocks>(ctx, node->Pos())
97-
.Input(Build<TCoToFlow>(ctx, node->Pos())
97+
.Input(
98+
Build<TCoToFlow>(ctx, node->Pos())
9899
.Input(Build<TDqReadBlockWideWrap>(ctx, node->Pos())
99-
.Input(readWideWrap.Input())
100-
.Flags(readWideWrap.Flags())
101-
.Token(readWideWrap.Token())
102-
.Done())
100+
.Input(readWideWrap.Input())
101+
.Flags(readWideWrap.Flags())
102+
.Token(readWideWrap.Token())
103+
.Done().Ptr())
103104
.Done())
104105
.Done().Ptr();
105106
}, ctx, optSettings);

ydb/library/yql/providers/yt/provider/yql_yt_mkql_compiler.cpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -496,11 +496,11 @@ void RegisterDqYtMkqlCompilers(NCommon::TMkqlCallableCompilerBase& compiler, con
496496
for (const auto& flag : wrapper.Flags())
497497
if (solid = flag.Value() == "Solid")
498498
break;
499-
500-
if (solid)
501-
return BuildDqYtInputCall<false>(outputType, inputItemType, cluster, tokenName, ytRead.Input(), state, ctx, inflight, timeout, true && inflight);
502-
else
503-
return BuildDqYtInputCall<true>(outputType, inputItemType, cluster, tokenName, ytRead.Input(), state, ctx, inflight, timeout, true && inflight);
499+
return ctx.ProgramBuilder.BlockExpandChunked(
500+
solid
501+
? BuildDqYtInputCall<false>(outputType, inputItemType, cluster, tokenName, ytRead.Input(), state, ctx, inflight, timeout, true && inflight)
502+
: BuildDqYtInputCall<true>(outputType, inputItemType, cluster, tokenName, ytRead.Input(), state, ctx, inflight, timeout, true && inflight)
503+
);
504504
}
505505

506506
return TRuntimeNode();

0 commit comments

Comments
 (0)