Skip to content

Commit 7c2636d

Browse files
committed
YQL-15891 Wide WithContext.
1 parent 6cd92ce commit 7c2636d

File tree

1 file changed

+33
-24
lines changed

1 file changed

+33
-24
lines changed

ydb/library/yql/minikql/comp_nodes/mkql_withcontext.cpp

+33-24
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,8 @@ using TBaseComputation = TStatefulFlowCodegeneratorNode<TWithContextFlowWrapper>
162162
const std::string_view ContextType;
163163
};
164164

165-
class TWithContextWideFlowWrapper : public TStatefulWideFlowComputationNode<TWithContextWideFlowWrapper> {
166-
using TBaseComputation = TStatefulWideFlowComputationNode<TWithContextWideFlowWrapper>;
165+
class TWithContextWideFlowWrapper : public TStatefulWideFlowCodegeneratorNode<TWithContextWideFlowWrapper> {
166+
using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWithContextWideFlowWrapper>;
167167
public:
168168
TWithContextWideFlowWrapper(TComputationMutables& mutables, IComputationWideFlowNode* flow,
169169
const std::string_view& contextType)
@@ -184,11 +184,8 @@ using TBaseComputation = TStatefulWideFlowComputationNode<TWithContextWideFlowWr
184184
state.Detach(status == EFetchResult::Finish);
185185
return status;
186186
}
187-
/*
188187
#ifndef MKQL_DISABLE_CODEGEN
189188
ICodegeneratorInlineWideNode::TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
190-
Cerr << Flow->DebugString() << Endl;
191-
Y_ABORT("bad");
192189
auto& context = ctx.Codegen.GetContext();
193190

194191
const auto valueType = Type::getInt128Ty(context);
@@ -198,6 +195,8 @@ using TBaseComputation = TStatefulWideFlowComputationNode<TWithContextWideFlowWr
198195

199196
const auto make = BasicBlock::Create(context, "make", ctx.Func);
200197
const auto main = BasicBlock::Create(context, "main", ctx.Func);
198+
const auto good = BasicBlock::Create(context, "good", ctx.Func);
199+
const auto exit = BasicBlock::Create(context, "exit", ctx.Func);
201200

202201
BranchInst::Create(main, make, HasValue(statePtr, block), block);
203202
block = make;
@@ -207,7 +206,7 @@ using TBaseComputation = TStatefulWideFlowComputationNode<TWithContextWideFlowWr
207206
const auto makeFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TWithContextWideFlowWrapper::MakeState));
208207
const auto makeType = FunctionType::get(Type::getVoidTy(context), {self->getType(), ctx.Ctx->getType(), statePtr->getType()}, false);
209208
const auto makeFuncPtr = CastInst::Create(Instruction::IntToPtr, makeFunc, PointerType::getUnqual(makeType), "function", block);
210-
CallInst::Create(makeFuncPtr, {self, ctx.Ctx, statePtr}, "", block);
209+
CallInst::Create(makeType, makeFuncPtr, {self, ctx.Ctx, statePtr}, "", block);
211210
BranchInst::Create(main, block);
212211

213212
block = main;
@@ -219,35 +218,46 @@ using TBaseComputation = TStatefulWideFlowComputationNode<TWithContextWideFlowWr
219218
const auto attachFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::Attach));
220219
const auto attachFuncType = FunctionType::get(Type::getVoidTy(context), { statePtrType }, false);
221220
const auto attachFuncPtr = CastInst::Create(Instruction::IntToPtr, attachFunc, PointerType::getUnqual(attachFuncType), "attach", block);
222-
CallInst::Create(attachFuncPtr, { stateArg }, "", block);
221+
CallInst::Create(attachFuncType, attachFuncPtr, { stateArg }, "", block);
223222

224223
auto getres = GetNodeValues(Flow, ctx, block);
225-
const auto array = new AllocaInst(ArrayType::get(valueType, getres.second.size()), 0U, "array", &ctx.Func->getEntryBlock().back());
226-
auto i = 0;
227-
for (auto& getter : getres.second) {
228-
const auto itemPtr = GetElementPtrInst::CreateInBounds(array, {ConstantInt::get(indexType, 0), ConstantInt::get(indexType, i)}, "item_ptr", &ctx.Func->getEntryBlock().back());
229-
const auto item = getter(ctx, block);
230-
ValueAddRef(EValueRepresentation::Any, item, ctx, block);
231-
new StoreInst(item, itemPtr, block);
232-
getter = [itemPtr] (const TCodegenContext& ctx, BasicBlock*& block) {
233-
const auto item = new LoadInst(itemPtr, "item", block);
234-
ValueRelease(EValueRepresentation::Any, item, ctx, block);
235-
return item;
236-
};
237-
++i;
224+
225+
const auto special = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_SLE, getres.first, ConstantInt::get(getres.first->getType(), static_cast<i32>(EFetchResult::Yield)), "special", block);
226+
227+
BranchInst::Create(exit, good, special, block);
228+
229+
block = good;
230+
231+
const auto arrayType = ArrayType::get(valueType, getres.second.size());
232+
const auto arrayPtr = new AllocaInst(arrayType, 0U, "array_ptr", &ctx.Func->getEntryBlock().back());
233+
Value* array = UndefValue::get(arrayType);
234+
for (auto idx = 0U; idx < getres.second.size(); ++idx) {
235+
const auto item = getres.second[idx](ctx, block);
236+
array = InsertValueInst::Create(array, item, {idx}, (TString("value_") += ToString(idx)).c_str(), block);
238237
}
238+
new StoreInst(array, arrayPtr, block);
239+
240+
BranchInst::Create(exit, block);
241+
242+
block = exit;
239243

240244
const auto finish = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, getres.first, ConstantInt::get(getres.first->getType(), static_cast<i32>(EFetchResult::Finish)), "finish", block);
241245

242246
const auto detachFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::Detach));
243247
const auto detachFuncType = FunctionType::get(Type::getVoidTy(context), { statePtrType, finish->getType() }, false);
244248
const auto detachFuncPtr = CastInst::Create(Instruction::IntToPtr, detachFunc, PointerType::getUnqual(detachFuncType), "detach", block);
245-
CallInst::Create(detachFuncPtr, { stateArg, finish }, "", block);
249+
CallInst::Create(detachFuncType, detachFuncPtr, { stateArg, finish }, "", block);
250+
251+
for (auto idx = 0U; idx < getres.second.size(); ++idx) {
252+
getres.second[idx] = [idx, arrayPtr, arrayType, indexType, valueType] (const TCodegenContext& ctx, BasicBlock*& block) {
253+
const auto itemPtr = GetElementPtrInst::CreateInBounds(arrayType, arrayPtr, {ConstantInt::get(indexType, 0), ConstantInt::get(indexType, idx)}, (TString("ptr_") += ToString(idx)).c_str(), block);
254+
return new LoadInst(valueType, itemPtr, (TString("item_") += ToString(idx)).c_str(), block);
255+
};
256+
}
246257

247258
return getres;
248259
}
249260
#endif
250-
*/
251261
private:
252262
void MakeState(TComputationContext& ctx, NUdf::TUnboxedValue& state) const {
253263
state = ctx.HolderFactory.Create<TState>(ContextType);
@@ -267,11 +277,10 @@ IComputationNode* WrapWithContext(TCallable& callable, const TComputationNodeFac
267277
const auto contextTypeData = AS_VALUE(TDataLiteral, callable.GetInput(0));
268278
const auto contextType = contextTypeData->AsValue().AsStringRef();
269279
const auto arg = LocateNode(ctx.NodeLocator, callable, 1);
270-
if (callable.GetInput(1).GetStaticType()->IsFlow()) {
280+
if (const auto type = callable.GetType()->GetReturnType(); type->IsFlow()) {
271281
if (const auto wide = dynamic_cast<IComputationWideFlowNode*>(arg)) {
272282
return new TWithContextWideFlowWrapper(ctx.Mutables, wide, contextType);
273283
} else {
274-
const auto type = callable.GetType()->GetReturnType();
275284
return new TWithContextFlowWrapper(ctx.Mutables, contextType, GetValueRepresentation(type), arg);
276285
}
277286
} else {

0 commit comments

Comments
 (0)