Skip to content

Commit e88073e

Browse files
fix
1 parent 3c9ba45 commit e88073e

File tree

11 files changed

+80
-6
lines changed

11 files changed

+80
-6
lines changed

ydb/core/formats/arrow/program/abstract.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ TConclusionStatus IResourceProcessor::Execute(const std::shared_ptr<TAccessorsCo
2727
return DoExecute(resources);
2828
}
2929

30+
bool IResourceProcessor::DoHasExecutionData(const ui32 columnId, const std::shared_ptr<TAccessorsCollection>& resources) const {
31+
return resources->HasColumn(columnId);
32+
}
33+
3034
NJson::TJsonValue TResourceProcessorStep::DebugJson() const {
3135
NJson::TJsonValue result = NJson::JSON_MAP;
3236
if (ColumnsToFetch.size()) {

ydb/core/formats/arrow/program/abstract.h

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ class IColumnResolver {
5050
virtual ~IColumnResolver() = default;
5151
virtual TString GetColumnName(ui32 id, bool required = true) const = 0;
5252
virtual std::optional<ui32> GetColumnIdOptional(const TString& name) const = 0;
53+
bool HasColumn(const ui32 id) const {
54+
return !!GetColumnName(id, false);
55+
}
56+
5357
ui32 GetColumnIdVerified(const char* name) const {
5458
auto result = GetColumnIdOptional(name);
5559
AFL_VERIFY(!!result);
@@ -174,12 +178,23 @@ class IResourceProcessor {
174178
virtual NJson::TJsonValue DoDebugJson() const {
175179
return NJson::JSON_MAP;
176180
}
181+
virtual bool DoHasExecutionData(const ui32 columnId, const std::shared_ptr<TAccessorsCollection>& resources) const;
177182

178183
public:
184+
virtual bool IsAggregation() const = 0;
185+
179186
virtual ~IResourceProcessor() = default;
180187

188+
virtual TString GetKernelClassNameDef(const TString& defaultValue) const {
189+
return defaultValue;
190+
}
191+
181192
NJson::TJsonValue DebugJson() const;
182193

194+
bool HasExecutionData(const ui32 columnId, const std::shared_ptr<TAccessorsCollection>& resources) const {
195+
return DoHasExecutionData(columnId, resources);
196+
}
197+
183198
ui32 GetOutputColumnIdOnce() const {
184199
AFL_VERIFY(Output.size() == 1)("size", Output.size());
185200
return Output.front().GetColumnId();
@@ -202,15 +217,17 @@ class IResourceProcessor {
202217
class TResourceProcessorStep {
203218
private:
204219
YDB_READONLY_DEF(std::vector<TColumnChainInfo>, ColumnsToFetch);
220+
YDB_READONLY_DEF(std::vector<TColumnChainInfo>, OriginalColumnsToUse);
205221
YDB_READONLY_DEF(std::shared_ptr<IResourceProcessor>, Processor);
206222
YDB_READONLY_DEF(std::vector<TColumnChainInfo>, ColumnsToDrop);
207223

208224
public:
209225
NJson::TJsonValue DebugJson() const;
210226

211-
TResourceProcessorStep(
212-
std::vector<TColumnChainInfo>&& toFetch, std::shared_ptr<IResourceProcessor>&& processor, std::vector<TColumnChainInfo>&& toDrop)
227+
TResourceProcessorStep(std::vector<TColumnChainInfo>&& toFetch, std::vector<TColumnChainInfo>&& originalToUse,
228+
std::shared_ptr<IResourceProcessor>&& processor, std::vector<TColumnChainInfo>&& toDrop)
213229
: ColumnsToFetch(std::move(toFetch))
230+
, OriginalColumnsToUse(std::move(originalToUse))
214231
, Processor(std::move(processor))
215232
, ColumnsToDrop(std::move(toDrop)) {
216233
AFL_VERIFY(Processor);

ydb/core/formats/arrow/program/aggr_keys.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ class TAggregateFunction: public TInternalFunction {
4444
}
4545

4646
public:
47+
virtual bool IsAggregation() const override {
48+
return true;
49+
}
50+
4751
TAggregateFunction(const EAggregate aggregationType, const std::shared_ptr<arrow::compute::FunctionOptions>& functionOptions = nullptr)
4852
: TBase(functionOptions, true)
4953
, AggregationType(aggregationType) {
@@ -153,6 +157,9 @@ class TWithKeysAggregationProcessor: public IResourceProcessor {
153157
, AggregationKeys(std::move(aggregationKeys))
154158
, Aggregations(std::move(aggregations)) {
155159
}
160+
virtual bool IsAggregation() const override {
161+
return true;
162+
}
156163

157164
public:
158165
static const char* GetHouseGroupByName() {

ydb/core/formats/arrow/program/assign_const.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ class TConstProcessor: public IResourceProcessor {
1010

1111
virtual TConclusionStatus DoExecute(const std::shared_ptr<TAccessorsCollection>& resources) const override;
1212

13+
virtual bool IsAggregation() const override {
14+
return false;
15+
}
16+
1317
public:
1418
TConstProcessor(const std::shared_ptr<arrow::Scalar>& scalar, const ui32 columnId)
1519
: TBase(std::vector<TColumnChainInfo>(), std::vector<TColumnChainInfo>({ TColumnChainInfo(columnId) }), EProcessorType::Const)

ydb/core/formats/arrow/program/assign_internal.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,14 @@ class TCalculationProcessor: public IResourceProcessor {
2323
, Function(function) {
2424
}
2525

26+
virtual TString GetKernelClassNameDef(const TString& defaultValue) const override {
27+
return KernelLogic ? KernelLogic->GetClassName() : defaultValue;
28+
}
29+
30+
virtual bool IsAggregation() const override {
31+
return Function->IsAggregation();
32+
}
33+
2634
public:
2735
static TConclusion<std::shared_ptr<TCalculationProcessor>> Build(std::vector<TColumnChainInfo>&& input, const TColumnChainInfo& output,
2836
const std::shared_ptr<IStepFunction>& function, const std::shared_ptr<IKernelLogic>& kernelLogic = nullptr);

ydb/core/formats/arrow/program/chain.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,10 @@ TConclusion<TProgramChain> TProgramChain::Build(std::vector<std::shared_ptr<IRes
4545
THashSet<TColumnChainInfo> sourceColumns;
4646
std::optional<ui32> lastFilter;
4747
std::optional<ui32> firstAggregation;
48+
std::vector<std::vector<TColumnChainInfo>> originalsToUse;
49+
originalsToUse.resize(processors.size());
4850
for (auto&& i : processors) {
49-
if (i->GetProcessorType() == EProcessorType::Aggregation) {
51+
if (!firstAggregation && i->IsAggregation()) {
5052
firstAggregation = stepIdx;
5153
}
5254
if (!firstAggregation && i->GetProcessorType() == EProcessorType::Filter) {
@@ -62,9 +64,12 @@ TConclusion<TProgramChain> TProgramChain::Build(std::vector<std::shared_ptr<IRes
6264
}
6365
for (auto&& c : i->GetInput()) {
6466
auto it = contextUsage.find(c);
67+
const bool isOriginalColumn = resolver.HasColumn(c);
68+
if (isOriginalColumn) {
69+
originalsToUse[stepIdx].emplace_back(c);
70+
}
6571
if (it == contextUsage.end()) {
66-
if (!resolver.GetColumnName(c, false)) {
67-
resolver.GetColumnName(c, true);
72+
if (!isOriginalColumn) {
6873
return TConclusionStatus::Fail("incorrect input column: " + ::ToString(c));
6974
}
7075
it = contextUsage.emplace(c, TColumnUsage::Fetch(stepIdx, i)).first;
@@ -94,7 +99,8 @@ TConclusion<TProgramChain> TProgramChain::Build(std::vector<std::shared_ptr<IRes
9499
}
95100
TProgramChain result;
96101
for (ui32 i = 0; i < processors.size(); ++i) {
97-
result.Processors.emplace_back(std::move(columnsToFetch[i]), std::move(processors[i]), std::move(columnsToDrop[i]));
102+
result.Processors.emplace_back(
103+
std::move(columnsToFetch[i]), std::move(originalsToUse[i]), std::move(processors[i]), std::move(columnsToDrop[i]));
98104
}
99105
auto initStatus = result.Initialize();
100106
result.LastOriginalDataFilter = lastFilter;

ydb/core/formats/arrow/program/chain.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ class TProgramChain {
1919
public:
2020
TProgramChain() = default;
2121

22+
bool HasAggregations() const {
23+
return !!FirstAggregation;
24+
}
25+
2226
bool IsGenerated(const ui32 columnId) const {
2327
auto it = SourcesByColumnId.find(columnId);
2428
AFL_VERIFY(it != SourcesByColumnId.end());

ydb/core/formats/arrow/program/filter.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@ class TFilterProcessor: public IResourceProcessor {
99

1010
virtual TConclusionStatus DoExecute(const std::shared_ptr<TAccessorsCollection>& resources) const override;
1111

12+
virtual bool IsAggregation() const override {
13+
return false;
14+
}
15+
1216
public:
1317
TFilterProcessor(std::vector<TColumnChainInfo>&& input)
1418
: TBase(std::move(input), {}, EProcessorType::Filter) {

ydb/core/formats/arrow/program/functions.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ class IStepFunction {
2626
bool NeedConcatenation = false;
2727

2828
public:
29+
virtual bool IsAggregation() const = 0;
30+
2931
arrow::compute::ExecContext* GetContext() const {
3032
return GetCustomExecContext();
3133
}
@@ -70,6 +72,10 @@ class TSimpleFunction: public TInternalFunction {
7072
return { GetFunctionName(OperationId) };
7173
}
7274

75+
virtual bool IsAggregation() const override {
76+
return false;
77+
}
78+
7379
public:
7480
static const char* GetFunctionName(const EOperation op) {
7581
switch (op) {
@@ -321,6 +327,10 @@ class TKernelFunction: public IStepFunction {
321327
const std::shared_ptr<arrow::compute::ScalarFunction> Function;
322328
std::shared_ptr<arrow::compute::FunctionOptions> FunctionOptions;
323329

330+
virtual bool IsAggregation() const override {
331+
return false;
332+
}
333+
324334
public:
325335
TKernelFunction(const std::shared_ptr<arrow::compute::ScalarFunction> kernelsFunction,
326336
const std::shared_ptr<arrow::compute::FunctionOptions>& functionOptions = nullptr, const bool needConcatenation = false)

ydb/core/formats/arrow/program/kernel_logic.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ class IKernelLogic {
1616

1717
using TFactory = NObjectFactory::TObjectFactory<IKernelLogic, TString>;
1818

19+
virtual TString GetClassName() const = 0;
20+
1921
TConclusion<bool> Execute(const std::vector<TColumnChainInfo>& input, const std::vector<TColumnChainInfo>& output,
2022
const std::shared_ptr<TAccessorsCollection>& resources) const {
2123
if (!resources) {
@@ -31,6 +33,10 @@ class TGetJsonPath: public IKernelLogic {
3133
return "JsonValue";
3234
}
3335
private:
36+
virtual TString GetClassName() const override {
37+
return GetClassNameStatic();
38+
}
39+
3440
static const inline TFactory::TRegistrator<TGetJsonPath> Registrator = TFactory::TRegistrator<TGetJsonPath>(GetClassNameStatic());
3541

3642
virtual TConclusion<bool> DoExecute(const std::vector<TColumnChainInfo>& input, const std::vector<TColumnChainInfo>& output,

ydb/core/formats/arrow/program/projection.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@ class TProjectionProcessor: public IResourceProcessor {
99

1010
virtual TConclusionStatus DoExecute(const std::shared_ptr<TAccessorsCollection>& resources) const override;
1111

12+
virtual bool IsAggregation() const override {
13+
return false;
14+
}
15+
1216
public:
1317
TProjectionProcessor(std::vector<TColumnChainInfo>&& columns)
1418
: TBase(std::vector<TColumnChainInfo>(columns), {}, EProcessorType::Projection) {

0 commit comments

Comments
 (0)