Skip to content

Commit 18a2915

Browse files
Merge 8f37671 into aa68b7c
2 parents aa68b7c + 8f37671 commit 18a2915

File tree

2 files changed

+84
-53
lines changed

2 files changed

+84
-53
lines changed

ydb/core/formats/arrow/program.cpp

Lines changed: 47 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ class TConstFunction : public IStepFunction<TAssign> {
8888
using TBase = IStepFunction<TAssign>;
8989
public:
9090
using TBase::TBase;
91-
arrow::Result<arrow::Datum> Call(const TAssign& assign, const TDatumBatch& batch) const override {
91+
arrow::Result<arrow::Datum> Call(const TAssign& assign, const TDatumBatch& batch) const override {
9292
Y_UNUSED(batch);
9393
return assign.GetConstant();
9494
}
@@ -531,7 +531,7 @@ class TFilterVisitor : public arrow::ArrayVisitor {
531531

532532

533533
arrow::Status TDatumBatch::AddColumn(const std::string& name, arrow::Datum&& column) {
534-
if (Schema->GetFieldIndex(name) != -1) {
534+
if (HasColumn(name)) {
535535
return arrow::Status::Invalid("Trying to add duplicate column '" + name + "'");
536536
}
537537

@@ -543,20 +543,27 @@ arrow::Status TDatumBatch::AddColumn(const std::string& name, arrow::Datum&& col
543543
return arrow::Status::Invalid("Wrong column length.");
544544
}
545545

546-
Schema = *Schema->AddField(Schema->num_fields(), field);
546+
NewColumnIds.emplace(name, NewColumnsPtr.size());
547+
NewColumnsPtr.emplace_back(field);
548+
547549
Datums.emplace_back(column);
548550
return arrow::Status::OK();
549551
}
550552

551553
arrow::Result<arrow::Datum> TDatumBatch::GetColumnByName(const std::string& name) const {
552-
auto i = Schema->GetFieldIndex(name);
554+
auto it = NewColumnIds.find(name);
555+
if (it != NewColumnIds.end()) {
556+
AFL_VERIFY(SchemaBase->num_fields() + it->second < Datums.size());
557+
return Datums[SchemaBase->num_fields() + it->second];
558+
}
559+
auto i = SchemaBase->GetFieldIndex(name);
553560
if (i < 0) {
554561
return arrow::Status::Invalid("Not found column '" + name + "' or duplicate");
555562
}
556563
return Datums[i];
557564
}
558565

559-
std::shared_ptr<arrow::Table> TDatumBatch::ToTable() const {
566+
std::shared_ptr<arrow::Table> TDatumBatch::ToTable() {
560567
std::vector<std::shared_ptr<arrow::ChunkedArray>> columns;
561568
columns.reserve(Datums.size());
562569
for (auto col : Datums) {
@@ -576,10 +583,10 @@ std::shared_ptr<arrow::Table> TDatumBatch::ToTable() const {
576583
AFL_VERIFY(false);
577584
}
578585
}
579-
return arrow::Table::Make(Schema, columns, Rows);
586+
return arrow::Table::Make(GetSchema(), columns, Rows);
580587
}
581588

582-
std::shared_ptr<arrow::RecordBatch> TDatumBatch::ToRecordBatch() const {
589+
std::shared_ptr<arrow::RecordBatch> TDatumBatch::ToRecordBatch() {
583590
std::vector<std::shared_ptr<arrow::Array>> columns;
584591
columns.reserve(Datums.size());
585592
for (auto col : Datums) {
@@ -594,7 +601,7 @@ std::shared_ptr<arrow::RecordBatch> TDatumBatch::ToRecordBatch() const {
594601
AFL_VERIFY(false);
595602
}
596603
}
597-
return arrow::RecordBatch::Make(Schema, Rows, columns);
604+
return arrow::RecordBatch::Make(GetSchema(), Rows, columns);
598605
}
599606

600607
std::shared_ptr<TDatumBatch> TDatumBatch::FromRecordBatch(const std::shared_ptr<arrow::RecordBatch>& batch) {
@@ -603,12 +610,7 @@ std::shared_ptr<TDatumBatch> TDatumBatch::FromRecordBatch(const std::shared_ptr<
603610
for (int64_t i = 0; i < batch->num_columns(); ++i) {
604611
datums.push_back(arrow::Datum(batch->column(i)));
605612
}
606-
return std::make_shared<TProgramStep::TDatumBatch>(
607-
TProgramStep::TDatumBatch{
608-
.Schema = std::make_shared<arrow::Schema>(*batch->schema()),
609-
.Datums = std::move(datums),
610-
.Rows = batch->num_rows()
611-
});
613+
return std::make_shared<TDatumBatch>(std::make_shared<arrow::Schema>(*batch->schema()), std::move(datums), batch->num_rows());
612614
}
613615

614616
std::shared_ptr<TDatumBatch> TDatumBatch::FromTable(const std::shared_ptr<arrow::Table>& batch) {
@@ -617,12 +619,15 @@ std::shared_ptr<TDatumBatch> TDatumBatch::FromTable(const std::shared_ptr<arrow:
617619
for (int64_t i = 0; i < batch->num_columns(); ++i) {
618620
datums.push_back(arrow::Datum(batch->column(i)));
619621
}
620-
return std::make_shared<TProgramStep::TDatumBatch>(
621-
TProgramStep::TDatumBatch{
622-
.Schema = std::make_shared<arrow::Schema>(*batch->schema()),
623-
.Datums = std::move(datums),
624-
.Rows = batch->num_rows()
625-
});
622+
return std::make_shared<TDatumBatch>(std::make_shared<arrow::Schema>(*batch->schema()), std::move(datums), batch->num_rows());
623+
}
624+
625+
TDatumBatch::TDatumBatch(const std::shared_ptr<arrow::Schema>& schema, std::vector<arrow::Datum>&& datums, const i64 rows)
626+
: SchemaBase(schema)
627+
, Rows(rows)
628+
, Datums(std::move(datums)) {
629+
AFL_VERIFY(SchemaBase);
630+
AFL_VERIFY(Datums.size() == (ui32)SchemaBase->num_fields());
626631
}
627632

628633
TAssign TAssign::MakeTimestamp(const TColumnInfo& column, ui64 value) {
@@ -680,7 +685,7 @@ arrow::Status TProgramStep::ApplyAssignes(TDatumBatch& batch, arrow::compute::Ex
680685
}
681686
batch.Datums.reserve(batch.Datums.size() + Assignes.size());
682687
for (auto& assign : Assignes) {
683-
if (batch.GetColumnByName(assign.GetName()).ok()) {
688+
if (batch.HasColumn(assign.GetName())) {
684689
return arrow::Status::Invalid("Assign to existing column '" + assign.GetName() + "'.");
685690
}
686691

@@ -703,8 +708,9 @@ arrow::Status TProgramStep::ApplyAggregates(TDatumBatch& batch, arrow::compute::
703708
}
704709

705710
ui32 numResultColumns = GroupBy.size() + GroupByKeys.size();
706-
TDatumBatch res;
707-
res.Datums.reserve(numResultColumns);
711+
std::vector<arrow::Datum> datums;
712+
datums.reserve(numResultColumns);
713+
std::optional<ui32> resultRecordsCount;
708714

709715
arrow::FieldVector fields;
710716
fields.reserve(numResultColumns);
@@ -715,13 +721,13 @@ arrow::Status TProgramStep::ApplyAggregates(TDatumBatch& batch, arrow::compute::
715721
if (!funcResult.ok()) {
716722
return funcResult.status();
717723
}
718-
res.Datums.push_back(*funcResult);
719-
fields.emplace_back(std::make_shared<arrow::Field>(assign.GetName(), res.Datums.back().type()));
724+
datums.push_back(*funcResult);
725+
fields.emplace_back(std::make_shared<arrow::Field>(assign.GetName(), datums.back().type()));
720726
}
721-
res.Rows = 1;
727+
resultRecordsCount = 1;
722728
} else {
723729
CH::GroupByOptions funcOpts;
724-
funcOpts.schema = batch.Schema;
730+
funcOpts.schema = batch.GetSchema();
725731
funcOpts.assigns.reserve(numResultColumns);
726732
funcOpts.has_nullable_key = false;
727733

@@ -759,19 +765,18 @@ arrow::Status TProgramStep::ApplyAggregates(TDatumBatch& batch, arrow::compute::
759765
return arrow::Status::Invalid("No expected column in GROUP BY result.");
760766
}
761767
fields.emplace_back(std::make_shared<arrow::Field>(assign.result_column, column->type()));
762-
res.Datums.push_back(column);
768+
datums.push_back(column);
763769
}
764770

765-
res.Rows = gbBatch->num_rows();
771+
resultRecordsCount = gbBatch->num_rows();
766772
}
767-
768-
res.Schema = std::make_shared<arrow::Schema>(std::move(fields));
769-
batch = std::move(res);
773+
AFL_VERIFY(resultRecordsCount);
774+
batch = TDatumBatch(std::make_shared<arrow::Schema>(std::move(fields)), std::move(datums), *resultRecordsCount);
770775
return arrow::Status::OK();
771776
}
772777

773778
arrow::Status TProgramStep::MakeCombinedFilter(TDatumBatch& batch, NArrow::TColumnFilter& result) const {
774-
TFilterVisitor filterVisitor(batch.Rows);
779+
TFilterVisitor filterVisitor(batch.GetRecordsCount());
775780
for (auto& colName : Filters) {
776781
auto column = batch.GetColumnByName(colName.GetColumnName());
777782
if (!column.ok()) {
@@ -821,13 +826,13 @@ arrow::Status TProgramStep::ApplyFilters(TDatumBatch& batch) const {
821826
}
822827
}
823828
std::vector<arrow::Datum*> filterDatums;
824-
for (int64_t i = 0; i < batch.Schema->num_fields(); ++i) {
825-
if (batch.Datums[i].is_arraylike() && (allColumns || neededColumns.contains(batch.Schema->field(i)->name()))) {
829+
for (int64_t i = 0; i < batch.GetSchema()->num_fields(); ++i) {
830+
if (batch.Datums[i].is_arraylike() && (allColumns || neededColumns.contains(batch.GetSchema()->field(i)->name()))) {
826831
filterDatums.emplace_back(&batch.Datums[i]);
827832
}
828833
}
829-
bits.Apply(batch.Rows, filterDatums);
830-
batch.Rows = bits.GetFilteredCount().value_or(batch.Rows);
834+
bits.Apply(batch.GetRecordsCount(), filterDatums);
835+
batch.SetRecordsCount(bits.GetFilteredCount().value_or(batch.GetRecordsCount()));
831836
return arrow::Status::OK();
832837
}
833838

@@ -838,15 +843,14 @@ arrow::Status TProgramStep::ApplyProjection(TDatumBatch& batch) const {
838843
std::vector<std::shared_ptr<arrow::Field>> newFields;
839844
std::vector<arrow::Datum> newDatums;
840845
for (size_t i = 0; i < Projection.size(); ++i) {
841-
int schemaFieldIndex = batch.Schema->GetFieldIndex(Projection[i].GetColumnName());
846+
int schemaFieldIndex = batch.GetSchema()->GetFieldIndex(Projection[i].GetColumnName());
842847
if (schemaFieldIndex == -1) {
843848
return arrow::Status::Invalid("Could not find column " + Projection[i].GetColumnName() + " in record batch schema.");
844849
}
845-
newFields.push_back(batch.Schema->field(schemaFieldIndex));
850+
newFields.push_back(batch.GetSchema()->field(schemaFieldIndex));
846851
newDatums.push_back(batch.Datums[schemaFieldIndex]);
847852
}
848-
batch.Schema = std::make_shared<arrow::Schema>(std::move(newFields));
849-
batch.Datums = std::move(newDatums);
853+
batch = TDatumBatch(std::make_shared<arrow::Schema>(std::move(newFields)), std::move(newDatums), batch.GetRecordsCount());
850854
return arrow::Status::OK();
851855
}
852856

@@ -919,14 +923,10 @@ std::set<std::string> TProgramStep::GetColumnsInUsage(const bool originalOnly/*
919923
}
920924

921925
arrow::Result<std::shared_ptr<NArrow::TColumnFilter>> TProgramStep::BuildFilter(const std::shared_ptr<NArrow::TGeneralContainer>& t) const {
922-
return BuildFilter(t->BuildTableVerified(GetColumnsInUsage(true)));
923-
}
924-
925-
arrow::Result<std::shared_ptr<NArrow::TColumnFilter>> TProgramStep::BuildFilter(const std::shared_ptr<arrow::Table>& t) const {
926926
if (Filters.empty()) {
927927
return nullptr;
928928
}
929-
std::vector<std::shared_ptr<arrow::RecordBatch>> batches = NArrow::SliceToRecordBatches(t);
929+
std::vector<std::shared_ptr<arrow::RecordBatch>> batches = NArrow::SliceToRecordBatches(t->BuildTableVerified(GetColumnsInUsage(true)));
930930
NArrow::TColumnFilter fullLocal = NArrow::TColumnFilter::BuildAllowFilter();
931931
for (auto&& rb : batches) {
932932
auto datumBatch = TDatumBatch::FromRecordBatch(rb);
@@ -938,7 +938,7 @@ arrow::Result<std::shared_ptr<NArrow::TColumnFilter>> TProgramStep::BuildFilter(
938938
}
939939
NArrow::TColumnFilter local = NArrow::TColumnFilter::BuildAllowFilter();
940940
NArrow::TStatusValidator::Validate(MakeCombinedFilter(*datumBatch, local));
941-
AFL_VERIFY(local.Size() == datumBatch->Rows)("local", local.Size())("datum", datumBatch->Rows);
941+
AFL_VERIFY(local.Size() == datumBatch->GetRecordsCount())("local", local.Size())("datum", datumBatch->GetRecordsCount());
942942
fullLocal.Append(local);
943943
}
944944
AFL_VERIFY(fullLocal.Size() == t->num_rows())("filter", fullLocal.Size())("t", t->num_rows());

ydb/core/formats/arrow/program.h

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,47 @@ const char * GetHouseFunctionName(EAggregate op);
3737
inline const char * GetHouseGroupByName() { return "ch.group_by"; }
3838
EOperation ValidateOperation(EOperation op, ui32 argsSize);
3939

40-
struct TDatumBatch {
41-
std::shared_ptr<arrow::Schema> Schema;
42-
std::vector<arrow::Datum> Datums;
40+
class TDatumBatch {
41+
private:
42+
std::shared_ptr<arrow::Schema> SchemaBase;
43+
THashMap<std::string, ui32> NewColumnIds;
44+
std::vector<std::shared_ptr<arrow::Field>> NewColumnsPtr;
4345
int64_t Rows = 0;
4446

47+
public:
48+
std::vector<arrow::Datum> Datums;
49+
50+
ui64 GetRecordsCount() const {
51+
return Rows;
52+
}
53+
54+
void SetRecordsCount(const ui64 value) {
55+
Rows = value;
56+
}
57+
58+
TDatumBatch(const std::shared_ptr<arrow::Schema>& schema, std::vector<arrow::Datum>&& datums, const i64 rows);
59+
60+
const std::shared_ptr<arrow::Schema>& GetSchema() {
61+
if (NewColumnIds.size()) {
62+
std::vector<std::shared_ptr<arrow::Field>> fields = SchemaBase->fields();
63+
fields.insert(fields.end(), NewColumnsPtr.begin(), NewColumnsPtr.end());
64+
SchemaBase = std::make_shared<arrow::Schema>(fields);
65+
NewColumnIds.clear();
66+
NewColumnsPtr.clear();
67+
}
68+
return SchemaBase;
69+
}
70+
4571
arrow::Status AddColumn(const std::string& name, arrow::Datum&& column);
4672
arrow::Result<arrow::Datum> GetColumnByName(const std::string& name) const;
47-
std::shared_ptr<arrow::Table> ToTable() const;
48-
std::shared_ptr<arrow::RecordBatch> ToRecordBatch() const;
73+
bool HasColumn(const std::string& name) const {
74+
if (NewColumnIds.contains(name)) {
75+
return true;
76+
}
77+
return SchemaBase->GetFieldIndex(name) > -1;
78+
}
79+
std::shared_ptr<arrow::Table> ToTable();
80+
std::shared_ptr<arrow::RecordBatch> ToRecordBatch();
4981
static std::shared_ptr<TDatumBatch> FromRecordBatch(const std::shared_ptr<arrow::RecordBatch>& batch);
5082
static std::shared_ptr<TDatumBatch> FromTable(const std::shared_ptr<arrow::Table>& batch);
5183
};
@@ -405,7 +437,6 @@ class TProgramStep {
405437
return Filters.size() && (!GroupBy.size() && !GroupByKeys.size());
406438
}
407439

408-
[[nodiscard]] arrow::Result<std::shared_ptr<NArrow::TColumnFilter>> BuildFilter(const std::shared_ptr<arrow::Table>& t) const;
409440
[[nodiscard]] arrow::Result<std::shared_ptr<NArrow::TColumnFilter>> BuildFilter(const std::shared_ptr<NArrow::TGeneralContainer>& t) const;
410441
};
411442

0 commit comments

Comments
 (0)