diff --git a/ydb/library/yql/core/type_ann/type_ann_blocks.cpp b/ydb/library/yql/core/type_ann/type_ann_blocks.cpp index be0b2b7edc35..a6350053839d 100644 --- a/ydb/library/yql/core/type_ann/type_ann_blocks.cpp +++ b/ydb/library/yql/core/type_ann/type_ann_blocks.cpp @@ -115,13 +115,24 @@ IGraphTransformer::TStatus BlockExpandChunkedWrapper(const TExprNode::TPtr& inpu return IGraphTransformer::TStatus::Error; } + TTypeAnnotationNode::TListType itemTypes; TTypeAnnotationNode::TListType blockItemTypes; - if (!EnsureWideFlowBlockType(input->Head(), blockItemTypes, ctx.Expr)) { - return IGraphTransformer::TStatus::Error; - } - auto flowItemTypes = input->Head().GetTypeAnn()->Cast()->GetItemType()->Cast()->GetItems(); - bool allScalars = AllOf(flowItemTypes, [](const TTypeAnnotationNode* item) { return item->GetKind() == ETypeAnnotationKind::Scalar; }); + if (input->Head().GetTypeAnn()->GetKind() == ETypeAnnotationKind::Stream) { + if (!EnsureWideStreamBlockType(input->Head(), blockItemTypes, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + itemTypes = input->Head().GetTypeAnn()->Cast()->GetItemType()->Cast()->GetItems(); + } else { + if (!EnsureWideFlowBlockType(input->Head(), blockItemTypes, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + itemTypes = input->Head().GetTypeAnn()->Cast()->GetItemType()->Cast()->GetItems(); + } + + bool allScalars = AllOf(itemTypes, [](const TTypeAnnotationNode* item) { return item->GetKind() == ETypeAnnotationKind::Scalar; }); if (allScalars) { output = input->HeadPtr(); return IGraphTransformer::TStatus::Repeat; diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_blocks.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_blocks.cpp index a719c116f9ab..9e86008d66c3 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_blocks.cpp +++ b/ydb/library/yql/minikql/comp_nodes/mkql_blocks.cpp @@ -1120,6 +1120,52 @@ using TBaseComputation = TStatefulWideFlowCodegeneratorNode { +using TBaseComputation = TMutableComputationNode; +class TExpanderState : public TComputationValue { +using TBase = TComputationValue; +public: + TExpanderState(TMemoryUsageInfo* memInfo, TComputationContext& ctx, NUdf::TUnboxedValue&& stream, size_t width) + : TBase(memInfo), HolderFactory_(ctx.HolderFactory), State_(ctx.HolderFactory.Create(width)), Stream_(stream) {} + + NUdf::EFetchStatus WideFetch(NUdf::TUnboxedValue* output, ui32 width) { + auto& s = *static_cast(State_.AsBoxed().Get()); + if (!s.Count) { + s.ClearValues(); + auto result = Stream_.WideFetch(s.Values.data(), width); + if (NUdf::EFetchStatus::Ok != result) { + return result; + } + s.FillArrays(); + } + + const auto sliceSize = s.Slice(); + for (size_t i = 0; i < width; ++i) { + output[i] = s.Get(sliceSize, HolderFactory_, i); + } + return NUdf::EFetchStatus::Ok; + } + +private: + const THolderFactory& HolderFactory_; + NUdf::TUnboxedValue State_; + NUdf::TUnboxedValue Stream_; +}; +public: + TBlockExpandChunkedStreamWrapper(TComputationMutables& mutables, IComputationNode* stream, size_t width) + : TBaseComputation(mutables, EValueRepresentation::Boxed) + , Stream_(stream) + , Width_(width) {} + + NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const { + return ctx.HolderFactory.Create(ctx, std::move(Stream_->GetValue(ctx)), Width_); + } + void RegisterDependencies() const override {} +private: + IComputationNode* const Stream_; + const size_t Width_; +}; + } // namespace IComputationNode* WrapToBlocks(TCallable& callable, const TComputationNodeFactoryContext& ctx) { @@ -1184,13 +1230,21 @@ IComputationNode* WrapReplicateScalar(TCallable& callable, const TComputationNod IComputationNode* WrapBlockExpandChunked(TCallable& callable, const TComputationNodeFactoryContext& ctx) { MKQL_ENSURE(callable.GetInputsCount() == 1, "Expected 1 args, got " << callable.GetInputsCount()); - - const auto flowType = AS_TYPE(TFlowType, callable.GetInput(0).GetStaticType()); - const auto wideComponents = GetWideComponents(flowType); - - const auto wideFlow = dynamic_cast(LocateNode(ctx.NodeLocator, callable, 0)); - MKQL_ENSURE(wideFlow != nullptr, "Expected wide flow node"); - return new TBlockExpandChunkedWrapper(ctx.Mutables, wideFlow, wideComponents.size()); + if (callable.GetInput(0).GetStaticType()->IsStream()) { + const auto streamType = AS_TYPE(TStreamType, callable.GetInput(0).GetStaticType()); + const auto wideComponents = GetWideComponents(streamType); + const auto computation = dynamic_cast(LocateNode(ctx.NodeLocator, callable, 0)); + + MKQL_ENSURE(computation != nullptr, "Expected computation node"); + return new TBlockExpandChunkedStreamWrapper(ctx.Mutables, computation, wideComponents.size()); + } else { + const auto flowType = AS_TYPE(TFlowType, callable.GetInput(0).GetStaticType()); + const auto wideComponents = GetWideComponents(flowType); + + const auto wideFlow = dynamic_cast(LocateNode(ctx.NodeLocator, callable, 0)); + MKQL_ENSURE(wideFlow != nullptr, "Expected wide flow node"); + return new TBlockExpandChunkedWrapper(ctx.Mutables, wideFlow, wideComponents.size()); + } } } diff --git a/ydb/library/yql/minikql/mkql_program_builder.cpp b/ydb/library/yql/minikql/mkql_program_builder.cpp index c12742d9a606..c57e118ab126 100644 --- a/ydb/library/yql/minikql/mkql_program_builder.cpp +++ b/ydb/library/yql/minikql/mkql_program_builder.cpp @@ -228,6 +228,24 @@ bool ReduceOptionalElements(const TType* type, const TArrayRef& test return multiOptional; } +std::vector ValidateBlockStreamType(const TType* streamType) { + const auto wideComponents = GetWideComponents(AS_TYPE(TStreamType, streamType)); + MKQL_ENSURE(wideComponents.size() > 0, "Expected at least one column"); + std::vector streamItems; + streamItems.reserve(wideComponents.size()); + bool isScalar; + for (size_t i = 0; i < wideComponents.size(); ++i) { + auto blockType = AS_TYPE(TBlockType, wideComponents[i]); + isScalar = blockType->GetShape() == TBlockType::EShape::Scalar; + auto withoutBlock = blockType->GetItemType(); + streamItems.push_back(withoutBlock); + } + + MKQL_ENSURE(isScalar, "Last column should be scalar"); + MKQL_ENSURE(AS_TYPE(TDataType, streamItems.back())->GetSchemeType() == NUdf::TDataType::Id, "Expected Uint64"); + return streamItems; +} + std::vector ValidateBlockFlowType(const TType* flowType) { const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, flowType)); MKQL_ENSURE(wideComponents.size() > 0, "Expected at least one column"); @@ -1550,10 +1568,14 @@ TRuntimeNode TProgramBuilder::BlockCompress(TRuntimeNode flow, ui32 bitmapIndex) return TRuntimeNode(callableBuilder.Build(), false); } -TRuntimeNode TProgramBuilder::BlockExpandChunked(TRuntimeNode flow) { - ValidateBlockFlowType(flow.GetStaticType()); - TCallableBuilder callableBuilder(Env, __func__, flow.GetStaticType()); - callableBuilder.Add(flow); +TRuntimeNode TProgramBuilder::BlockExpandChunked(TRuntimeNode comp) { + if (comp.GetStaticType()->IsStream()) { + ValidateBlockStreamType(comp.GetStaticType()); + } else { + ValidateBlockFlowType(comp.GetStaticType()); + } + TCallableBuilder callableBuilder(Env, __func__, comp.GetStaticType()); + callableBuilder.Add(comp); return TRuntimeNode(callableBuilder.Build(), false); } diff --git a/ydb/library/yql/providers/dq/opt/dqs_opt.cpp b/ydb/library/yql/providers/dq/opt/dqs_opt.cpp index 994de78b22b5..f6b7d10acdc7 100644 --- a/ydb/library/yql/providers/dq/opt/dqs_opt.cpp +++ b/ydb/library/yql/providers/dq/opt/dqs_opt.cpp @@ -94,12 +94,13 @@ namespace NYql::NDqs { YQL_CLOG(INFO, ProviderDq) << "DqsRewritePhyBlockReadOnDqIntegration"; return Build(ctx, node->Pos()) - .Input(Build(ctx, node->Pos()) + .Input( + Build(ctx, node->Pos()) .Input(Build(ctx, node->Pos()) - .Input(readWideWrap.Input()) - .Flags(readWideWrap.Flags()) - .Token(readWideWrap.Token()) - .Done()) + .Input(readWideWrap.Input()) + .Flags(readWideWrap.Flags()) + .Token(readWideWrap.Token()) + .Done().Ptr()) .Done()) .Done().Ptr(); }, ctx, optSettings); diff --git a/ydb/library/yql/providers/yt/provider/yql_yt_mkql_compiler.cpp b/ydb/library/yql/providers/yt/provider/yql_yt_mkql_compiler.cpp index d4cf3f1d5a68..88babe788429 100644 --- a/ydb/library/yql/providers/yt/provider/yql_yt_mkql_compiler.cpp +++ b/ydb/library/yql/providers/yt/provider/yql_yt_mkql_compiler.cpp @@ -495,11 +495,11 @@ void RegisterDqYtMkqlCompilers(NCommon::TMkqlCallableCompilerBase& compiler, con for (const auto& flag : wrapper.Flags()) if (solid = flag.Value() == "Solid") break; - - if (solid) - return BuildDqYtInputCall(outputType, inputItemType, cluster, tokenName, ytRead.Input(), state, ctx, inflight, timeout, true && inflight); - else - return BuildDqYtInputCall(outputType, inputItemType, cluster, tokenName, ytRead.Input(), state, ctx, inflight, timeout, true && inflight); + return ctx.ProgramBuilder.BlockExpandChunked( + solid + ? BuildDqYtInputCall(outputType, inputItemType, cluster, tokenName, ytRead.Input(), state, ctx, inflight, timeout, true && inflight) + : BuildDqYtInputCall(outputType, inputItemType, cluster, tokenName, ytRead.Input(), state, ctx, inflight, timeout, true && inflight) + ); } return TRuntimeNode();