@@ -88,7 +88,7 @@ class TConstFunction : public IStepFunction<TAssign> {
88
88
using TBase = IStepFunction<TAssign>;
89
89
public:
90
90
using TBase::TBase;
91
- arrow::Result<arrow::Datum> Call (const TAssign& assign, const TDatumBatch& batch) const override {
91
+ arrow::Result<arrow::Datum> Call (const TAssign& assign, const TDatumBatch& batch) const override {
92
92
Y_UNUSED (batch);
93
93
return assign.GetConstant ();
94
94
}
@@ -531,7 +531,7 @@ class TFilterVisitor : public arrow::ArrayVisitor {
531
531
532
532
533
533
arrow::Status TDatumBatch::AddColumn (const std::string& name, arrow::Datum&& column) {
534
- if (Schema-> GetFieldIndex (name) != - 1 ) {
534
+ if (HasColumn (name)) {
535
535
return arrow::Status::Invalid (" Trying to add duplicate column '" + name + " '" );
536
536
}
537
537
@@ -543,20 +543,27 @@ arrow::Status TDatumBatch::AddColumn(const std::string& name, arrow::Datum&& col
543
543
return arrow::Status::Invalid (" Wrong column length." );
544
544
}
545
545
546
- Schema = *Schema->AddField (Schema->num_fields (), field);
546
+ NewColumnIds.emplace (name, NewColumnsPtr.size ());
547
+ NewColumnsPtr.emplace_back (field);
548
+
547
549
Datums.emplace_back (column);
548
550
return arrow::Status::OK ();
549
551
}
550
552
551
553
arrow::Result<arrow::Datum> TDatumBatch::GetColumnByName (const std::string& name) const {
552
- auto i = Schema->GetFieldIndex (name);
554
+ auto it = NewColumnIds.find (name);
555
+ if (it != NewColumnIds.end ()) {
556
+ AFL_VERIFY (SchemaBase->num_fields () + it->second < Datums.size ());
557
+ return Datums[SchemaBase->num_fields () + it->second ];
558
+ }
559
+ auto i = SchemaBase->GetFieldIndex (name);
553
560
if (i < 0 ) {
554
561
return arrow::Status::Invalid (" Not found column '" + name + " ' or duplicate" );
555
562
}
556
563
return Datums[i];
557
564
}
558
565
559
- std::shared_ptr<arrow::Table> TDatumBatch::ToTable () const {
566
+ std::shared_ptr<arrow::Table> TDatumBatch::ToTable () {
560
567
std::vector<std::shared_ptr<arrow::ChunkedArray>> columns;
561
568
columns.reserve (Datums.size ());
562
569
for (auto col : Datums) {
@@ -576,10 +583,10 @@ std::shared_ptr<arrow::Table> TDatumBatch::ToTable() const {
576
583
AFL_VERIFY (false );
577
584
}
578
585
}
579
- return arrow::Table::Make (Schema , columns, Rows);
586
+ return arrow::Table::Make (GetSchema () , columns, Rows);
580
587
}
581
588
582
- std::shared_ptr<arrow::RecordBatch> TDatumBatch::ToRecordBatch () const {
589
+ std::shared_ptr<arrow::RecordBatch> TDatumBatch::ToRecordBatch () {
583
590
std::vector<std::shared_ptr<arrow::Array>> columns;
584
591
columns.reserve (Datums.size ());
585
592
for (auto col : Datums) {
@@ -594,7 +601,7 @@ std::shared_ptr<arrow::RecordBatch> TDatumBatch::ToRecordBatch() const {
594
601
AFL_VERIFY (false );
595
602
}
596
603
}
597
- return arrow::RecordBatch::Make (Schema , Rows, columns);
604
+ return arrow::RecordBatch::Make (GetSchema () , Rows, columns);
598
605
}
599
606
600
607
std::shared_ptr<TDatumBatch> TDatumBatch::FromRecordBatch (const std::shared_ptr<arrow::RecordBatch>& batch) {
@@ -603,12 +610,7 @@ std::shared_ptr<TDatumBatch> TDatumBatch::FromRecordBatch(const std::shared_ptr<
603
610
for (int64_t i = 0 ; i < batch->num_columns (); ++i) {
604
611
datums.push_back (arrow::Datum (batch->column (i)));
605
612
}
606
- return std::make_shared<TProgramStep::TDatumBatch>(
607
- TProgramStep::TDatumBatch{
608
- .Schema = std::make_shared<arrow::Schema>(*batch->schema ()),
609
- .Datums = std::move (datums),
610
- .Rows = batch->num_rows ()
611
- });
613
+ return std::make_shared<TDatumBatch>(std::make_shared<arrow::Schema>(*batch->schema ()), std::move (datums), batch->num_rows ());
612
614
}
613
615
614
616
std::shared_ptr<TDatumBatch> TDatumBatch::FromTable (const std::shared_ptr<arrow::Table>& batch) {
@@ -617,12 +619,15 @@ std::shared_ptr<TDatumBatch> TDatumBatch::FromTable(const std::shared_ptr<arrow:
617
619
for (int64_t i = 0 ; i < batch->num_columns (); ++i) {
618
620
datums.push_back (arrow::Datum (batch->column (i)));
619
621
}
620
- return std::make_shared<TProgramStep::TDatumBatch>(
621
- TProgramStep::TDatumBatch{
622
- .Schema = std::make_shared<arrow::Schema>(*batch->schema ()),
623
- .Datums = std::move (datums),
624
- .Rows = batch->num_rows ()
625
- });
622
+ return std::make_shared<TDatumBatch>(std::make_shared<arrow::Schema>(*batch->schema ()), std::move (datums), batch->num_rows ());
623
+ }
624
+
625
+ TDatumBatch::TDatumBatch (const std::shared_ptr<arrow::Schema>& schema, std::vector<arrow::Datum>&& datums, const i64 rows)
626
+ : SchemaBase(schema)
627
+ , Rows(rows)
628
+ , Datums(std::move(datums)) {
629
+ AFL_VERIFY (SchemaBase);
630
+ AFL_VERIFY (Datums.size () == (ui32)SchemaBase->num_fields ());
626
631
}
627
632
628
633
TAssign TAssign::MakeTimestamp (const TColumnInfo& column, ui64 value) {
@@ -680,7 +685,7 @@ arrow::Status TProgramStep::ApplyAssignes(TDatumBatch& batch, arrow::compute::Ex
680
685
}
681
686
batch.Datums .reserve (batch.Datums .size () + Assignes.size ());
682
687
for (auto & assign : Assignes) {
683
- if (batch.GetColumnByName (assign.GetName ()). ok ( )) {
688
+ if (batch.HasColumn (assign.GetName ())) {
684
689
return arrow::Status::Invalid (" Assign to existing column '" + assign.GetName () + " '." );
685
690
}
686
691
@@ -703,8 +708,9 @@ arrow::Status TProgramStep::ApplyAggregates(TDatumBatch& batch, arrow::compute::
703
708
}
704
709
705
710
ui32 numResultColumns = GroupBy.size () + GroupByKeys.size ();
706
- TDatumBatch res;
707
- res.Datums .reserve (numResultColumns);
711
+ std::vector<arrow::Datum> datums;
712
+ datums.reserve (numResultColumns);
713
+ std::optional<ui32> resultRecordsCount;
708
714
709
715
arrow::FieldVector fields;
710
716
fields.reserve (numResultColumns);
@@ -715,13 +721,13 @@ arrow::Status TProgramStep::ApplyAggregates(TDatumBatch& batch, arrow::compute::
715
721
if (!funcResult.ok ()) {
716
722
return funcResult.status ();
717
723
}
718
- res. Datums .push_back (*funcResult);
719
- fields.emplace_back (std::make_shared<arrow::Field>(assign.GetName (), res. Datums .back ().type ()));
724
+ datums .push_back (*funcResult);
725
+ fields.emplace_back (std::make_shared<arrow::Field>(assign.GetName (), datums .back ().type ()));
720
726
}
721
- res. Rows = 1 ;
727
+ resultRecordsCount = 1 ;
722
728
} else {
723
729
CH::GroupByOptions funcOpts;
724
- funcOpts.schema = batch.Schema ;
730
+ funcOpts.schema = batch.GetSchema () ;
725
731
funcOpts.assigns .reserve (numResultColumns);
726
732
funcOpts.has_nullable_key = false ;
727
733
@@ -759,19 +765,18 @@ arrow::Status TProgramStep::ApplyAggregates(TDatumBatch& batch, arrow::compute::
759
765
return arrow::Status::Invalid (" No expected column in GROUP BY result." );
760
766
}
761
767
fields.emplace_back (std::make_shared<arrow::Field>(assign.result_column , column->type ()));
762
- res. Datums .push_back (column);
768
+ datums .push_back (column);
763
769
}
764
770
765
- res. Rows = gbBatch->num_rows ();
771
+ resultRecordsCount = gbBatch->num_rows ();
766
772
}
767
-
768
- res.Schema = std::make_shared<arrow::Schema>(std::move (fields));
769
- batch = std::move (res);
773
+ AFL_VERIFY (resultRecordsCount);
774
+ batch = TDatumBatch (std::make_shared<arrow::Schema>(std::move (fields)), std::move (datums), *resultRecordsCount);
770
775
return arrow::Status::OK ();
771
776
}
772
777
773
778
arrow::Status TProgramStep::MakeCombinedFilter (TDatumBatch& batch, NArrow::TColumnFilter& result) const {
774
- TFilterVisitor filterVisitor (batch.Rows );
779
+ TFilterVisitor filterVisitor (batch.GetRecordsCount () );
775
780
for (auto & colName : Filters) {
776
781
auto column = batch.GetColumnByName (colName.GetColumnName ());
777
782
if (!column.ok ()) {
@@ -821,13 +826,13 @@ arrow::Status TProgramStep::ApplyFilters(TDatumBatch& batch) const {
821
826
}
822
827
}
823
828
std::vector<arrow::Datum*> filterDatums;
824
- for (int64_t i = 0 ; i < batch.Schema ->num_fields (); ++i) {
825
- if (batch.Datums [i].is_arraylike () && (allColumns || neededColumns.contains (batch.Schema ->field (i)->name ()))) {
829
+ for (int64_t i = 0 ; i < batch.GetSchema () ->num_fields (); ++i) {
830
+ if (batch.Datums [i].is_arraylike () && (allColumns || neededColumns.contains (batch.GetSchema () ->field (i)->name ()))) {
826
831
filterDatums.emplace_back (&batch.Datums [i]);
827
832
}
828
833
}
829
- bits.Apply (batch.Rows , filterDatums);
830
- batch.Rows = bits.GetFilteredCount ().value_or (batch.Rows );
834
+ bits.Apply (batch.GetRecordsCount () , filterDatums);
835
+ batch.SetRecordsCount ( bits.GetFilteredCount ().value_or (batch.GetRecordsCount ()) );
831
836
return arrow::Status::OK ();
832
837
}
833
838
@@ -838,15 +843,14 @@ arrow::Status TProgramStep::ApplyProjection(TDatumBatch& batch) const {
838
843
std::vector<std::shared_ptr<arrow::Field>> newFields;
839
844
std::vector<arrow::Datum> newDatums;
840
845
for (size_t i = 0 ; i < Projection.size (); ++i) {
841
- int schemaFieldIndex = batch.Schema ->GetFieldIndex (Projection[i].GetColumnName ());
846
+ int schemaFieldIndex = batch.GetSchema () ->GetFieldIndex (Projection[i].GetColumnName ());
842
847
if (schemaFieldIndex == -1 ) {
843
848
return arrow::Status::Invalid (" Could not find column " + Projection[i].GetColumnName () + " in record batch schema." );
844
849
}
845
- newFields.push_back (batch.Schema ->field (schemaFieldIndex));
850
+ newFields.push_back (batch.GetSchema () ->field (schemaFieldIndex));
846
851
newDatums.push_back (batch.Datums [schemaFieldIndex]);
847
852
}
848
- batch.Schema = std::make_shared<arrow::Schema>(std::move (newFields));
849
- batch.Datums = std::move (newDatums);
853
+ batch = TDatumBatch (std::make_shared<arrow::Schema>(std::move (newFields)), std::move (newDatums), batch.GetRecordsCount ());
850
854
return arrow::Status::OK ();
851
855
}
852
856
@@ -919,14 +923,10 @@ std::set<std::string> TProgramStep::GetColumnsInUsage(const bool originalOnly/*
919
923
}
920
924
921
925
arrow::Result<std::shared_ptr<NArrow::TColumnFilter>> TProgramStep::BuildFilter (const std::shared_ptr<NArrow::TGeneralContainer>& t) const {
922
- return BuildFilter (t->BuildTableVerified (GetColumnsInUsage (true )));
923
- }
924
-
925
- arrow::Result<std::shared_ptr<NArrow::TColumnFilter>> TProgramStep::BuildFilter (const std::shared_ptr<arrow::Table>& t) const {
926
926
if (Filters.empty ()) {
927
927
return nullptr ;
928
928
}
929
- std::vector<std::shared_ptr<arrow::RecordBatch>> batches = NArrow::SliceToRecordBatches (t);
929
+ std::vector<std::shared_ptr<arrow::RecordBatch>> batches = NArrow::SliceToRecordBatches (t-> BuildTableVerified ( GetColumnsInUsage ( true )) );
930
930
NArrow::TColumnFilter fullLocal = NArrow::TColumnFilter::BuildAllowFilter ();
931
931
for (auto && rb : batches) {
932
932
auto datumBatch = TDatumBatch::FromRecordBatch (rb);
@@ -938,7 +938,7 @@ arrow::Result<std::shared_ptr<NArrow::TColumnFilter>> TProgramStep::BuildFilter(
938
938
}
939
939
NArrow::TColumnFilter local = NArrow::TColumnFilter::BuildAllowFilter ();
940
940
NArrow::TStatusValidator::Validate (MakeCombinedFilter (*datumBatch, local));
941
- AFL_VERIFY (local.Size () == datumBatch->Rows ) (" local" , local.Size ())(" datum" , datumBatch->Rows );
941
+ AFL_VERIFY (local.Size () == datumBatch->GetRecordsCount ()) (" local" , local.Size ())(" datum" , datumBatch->GetRecordsCount () );
942
942
fullLocal.Append (local);
943
943
}
944
944
AFL_VERIFY (fullLocal.Size () == t->num_rows ())(" filter" , fullLocal.Size ())(" t" , t->num_rows ());
0 commit comments