Skip to content

Commit cb7ab00

Browse files
author
Prashanth Govindarajan
authored
Update FromArrowRecordBatches for dotnet-spark (dotnet#2978)
* Support for RecordBatches with StructArrays * Sq * Address comments * Nits * Nits
1 parent db5c49e commit cb7ab00

File tree

4 files changed

+208
-109
lines changed

4 files changed

+208
-109
lines changed

src/Microsoft.Data.Analysis/DataFrame.Arrow.cs

+114-99
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,127 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System;
6-
using System.Collections;
76
using System.Collections.Generic;
8-
using System.Diagnostics;
9-
using System.Text;
10-
using System.Threading.Tasks;
117
using Apache.Arrow;
128
using Apache.Arrow.Types;
139

1410
namespace Microsoft.Data.Analysis
1511
{
1612
public partial class DataFrame
1713
{
14+
private static void AppendDataFrameColumnFromArrowArray(Field field, IArrowArray arrowArray, DataFrame ret, string fieldNamePrefix = "")
15+
{
16+
IArrowType fieldType = field.DataType;
17+
DataFrameColumn dataFrameColumn = null;
18+
string fieldName = fieldNamePrefix + field.Name;
19+
switch (fieldType.TypeId)
20+
{
21+
case ArrowTypeId.Boolean:
22+
BooleanArray arrowBooleanArray = (BooleanArray)arrowArray;
23+
ReadOnlyMemory<byte> valueBuffer = arrowBooleanArray.ValueBuffer.Memory;
24+
ReadOnlyMemory<byte> nullBitMapBuffer = arrowBooleanArray.NullBitmapBuffer.Memory;
25+
dataFrameColumn = new BooleanDataFrameColumn(fieldName, valueBuffer, nullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
26+
break;
27+
case ArrowTypeId.Double:
28+
PrimitiveArray<double> arrowDoubleArray = (PrimitiveArray<double>)arrowArray;
29+
ReadOnlyMemory<byte> doubleValueBuffer = arrowDoubleArray.ValueBuffer.Memory;
30+
ReadOnlyMemory<byte> doubleNullBitMapBuffer = arrowDoubleArray.NullBitmapBuffer.Memory;
31+
dataFrameColumn = new DoubleDataFrameColumn(fieldName, doubleValueBuffer, doubleNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
32+
break;
33+
case ArrowTypeId.Float:
34+
PrimitiveArray<float> arrowFloatArray = (PrimitiveArray<float>)arrowArray;
35+
ReadOnlyMemory<byte> floatValueBuffer = arrowFloatArray.ValueBuffer.Memory;
36+
ReadOnlyMemory<byte> floatNullBitMapBuffer = arrowFloatArray.NullBitmapBuffer.Memory;
37+
dataFrameColumn = new SingleDataFrameColumn(fieldName, floatValueBuffer, floatNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
38+
break;
39+
case ArrowTypeId.Int8:
40+
PrimitiveArray<sbyte> arrowsbyteArray = (PrimitiveArray<sbyte>)arrowArray;
41+
ReadOnlyMemory<byte> sbyteValueBuffer = arrowsbyteArray.ValueBuffer.Memory;
42+
ReadOnlyMemory<byte> sbyteNullBitMapBuffer = arrowsbyteArray.NullBitmapBuffer.Memory;
43+
dataFrameColumn = new SByteDataFrameColumn(fieldName, sbyteValueBuffer, sbyteNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
44+
break;
45+
case ArrowTypeId.Int16:
46+
PrimitiveArray<short> arrowshortArray = (PrimitiveArray<short>)arrowArray;
47+
ReadOnlyMemory<byte> shortValueBuffer = arrowshortArray.ValueBuffer.Memory;
48+
ReadOnlyMemory<byte> shortNullBitMapBuffer = arrowshortArray.NullBitmapBuffer.Memory;
49+
dataFrameColumn = new Int16DataFrameColumn(fieldName, shortValueBuffer, shortNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
50+
break;
51+
case ArrowTypeId.Int32:
52+
PrimitiveArray<int> arrowIntArray = (PrimitiveArray<int>)arrowArray;
53+
ReadOnlyMemory<byte> intValueBuffer = arrowIntArray.ValueBuffer.Memory;
54+
ReadOnlyMemory<byte> intNullBitMapBuffer = arrowIntArray.NullBitmapBuffer.Memory;
55+
dataFrameColumn = new Int32DataFrameColumn(fieldName, intValueBuffer, intNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
56+
break;
57+
case ArrowTypeId.Int64:
58+
PrimitiveArray<long> arrowLongArray = (PrimitiveArray<long>)arrowArray;
59+
ReadOnlyMemory<byte> longValueBuffer = arrowLongArray.ValueBuffer.Memory;
60+
ReadOnlyMemory<byte> longNullBitMapBuffer = arrowLongArray.NullBitmapBuffer.Memory;
61+
dataFrameColumn = new Int64DataFrameColumn(fieldName, longValueBuffer, longNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
62+
break;
63+
case ArrowTypeId.String:
64+
StringArray stringArray = (StringArray)arrowArray;
65+
ReadOnlyMemory<byte> dataMemory = stringArray.ValueBuffer.Memory;
66+
ReadOnlyMemory<byte> offsetsMemory = stringArray.ValueOffsetsBuffer.Memory;
67+
ReadOnlyMemory<byte> nullMemory = stringArray.NullBitmapBuffer.Memory;
68+
dataFrameColumn = new ArrowStringDataFrameColumn(fieldName, dataMemory, offsetsMemory, nullMemory, stringArray.Length, stringArray.NullCount);
69+
break;
70+
case ArrowTypeId.UInt8:
71+
PrimitiveArray<byte> arrowbyteArray = (PrimitiveArray<byte>)arrowArray;
72+
ReadOnlyMemory<byte> byteValueBuffer = arrowbyteArray.ValueBuffer.Memory;
73+
ReadOnlyMemory<byte> byteNullBitMapBuffer = arrowbyteArray.NullBitmapBuffer.Memory;
74+
dataFrameColumn = new ByteDataFrameColumn(fieldName, byteValueBuffer, byteNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
75+
break;
76+
case ArrowTypeId.UInt16:
77+
PrimitiveArray<ushort> arrowUshortArray = (PrimitiveArray<ushort>)arrowArray;
78+
ReadOnlyMemory<byte> ushortValueBuffer = arrowUshortArray.ValueBuffer.Memory;
79+
ReadOnlyMemory<byte> ushortNullBitMapBuffer = arrowUshortArray.NullBitmapBuffer.Memory;
80+
dataFrameColumn = new UInt16DataFrameColumn(fieldName, ushortValueBuffer, ushortNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
81+
break;
82+
case ArrowTypeId.UInt32:
83+
PrimitiveArray<uint> arrowUintArray = (PrimitiveArray<uint>)arrowArray;
84+
ReadOnlyMemory<byte> uintValueBuffer = arrowUintArray.ValueBuffer.Memory;
85+
ReadOnlyMemory<byte> uintNullBitMapBuffer = arrowUintArray.NullBitmapBuffer.Memory;
86+
dataFrameColumn = new UInt32DataFrameColumn(fieldName, uintValueBuffer, uintNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
87+
break;
88+
case ArrowTypeId.UInt64:
89+
PrimitiveArray<ulong> arrowUlongArray = (PrimitiveArray<ulong>)arrowArray;
90+
ReadOnlyMemory<byte> ulongValueBuffer = arrowUlongArray.ValueBuffer.Memory;
91+
ReadOnlyMemory<byte> ulongNullBitMapBuffer = arrowUlongArray.NullBitmapBuffer.Memory;
92+
dataFrameColumn = new UInt64DataFrameColumn(fieldName, ulongValueBuffer, ulongNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
93+
break;
94+
case ArrowTypeId.Struct:
95+
StructArray structArray = (StructArray)arrowArray;
96+
StructType structType = (StructType)field.DataType;
97+
IEnumerator<Field> fieldsEnumerator = structType.Fields.GetEnumerator();
98+
IEnumerator<IArrowArray> structArrayEnumerator = structArray.Fields.GetEnumerator();
99+
while (fieldsEnumerator.MoveNext() && structArrayEnumerator.MoveNext())
100+
{
101+
AppendDataFrameColumnFromArrowArray(fieldsEnumerator.Current, structArrayEnumerator.Current, ret, field.Name + "_");
102+
}
103+
break;
104+
case ArrowTypeId.Decimal:
105+
case ArrowTypeId.Binary:
106+
case ArrowTypeId.Date32:
107+
case ArrowTypeId.Date64:
108+
case ArrowTypeId.Dictionary:
109+
case ArrowTypeId.FixedSizedBinary:
110+
case ArrowTypeId.HalfFloat:
111+
case ArrowTypeId.Interval:
112+
case ArrowTypeId.List:
113+
case ArrowTypeId.Map:
114+
case ArrowTypeId.Null:
115+
case ArrowTypeId.Time32:
116+
case ArrowTypeId.Time64:
117+
default:
118+
throw new NotImplementedException(nameof(fieldType.Name));
119+
}
120+
121+
if (dataFrameColumn != null)
122+
{
123+
ret.Columns.Insert(ret.Columns.Count, dataFrameColumn);
124+
}
125+
}
126+
18127
/// <summary>
19128
/// Wraps a <see cref="DataFrame"/> around an Arrow <see cref="RecordBatch"/> without copying data
20129
/// </summary>
@@ -29,101 +138,7 @@ public static DataFrame FromArrowRecordBatch(RecordBatch recordBatch)
29138
foreach (IArrowArray arrowArray in arrowArrays)
30139
{
31140
Field field = arrowSchema.GetFieldByIndex(fieldIndex);
32-
IArrowType fieldType = field.DataType;
33-
DataFrameColumn dataFrameColumn = null;
34-
switch (fieldType.TypeId)
35-
{
36-
case ArrowTypeId.Boolean:
37-
BooleanArray arrowBooleanArray = (BooleanArray)arrowArray;
38-
ReadOnlyMemory<byte> valueBuffer = arrowBooleanArray.ValueBuffer.Memory;
39-
ReadOnlyMemory<byte> nullBitMapBuffer = arrowBooleanArray.NullBitmapBuffer.Memory;
40-
dataFrameColumn = new BooleanDataFrameColumn(field.Name, valueBuffer, nullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
41-
break;
42-
case ArrowTypeId.Double:
43-
PrimitiveArray<double> arrowDoubleArray = (PrimitiveArray<double>)arrowArray;
44-
ReadOnlyMemory<byte> doubleValueBuffer = arrowDoubleArray.ValueBuffer.Memory;
45-
ReadOnlyMemory<byte> doubleNullBitMapBuffer = arrowDoubleArray.NullBitmapBuffer.Memory;
46-
dataFrameColumn = new DoubleDataFrameColumn(field.Name, doubleValueBuffer, doubleNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
47-
break;
48-
case ArrowTypeId.Float:
49-
PrimitiveArray<float> arrowFloatArray = (PrimitiveArray<float>)arrowArray;
50-
ReadOnlyMemory<byte> floatValueBuffer = arrowFloatArray.ValueBuffer.Memory;
51-
ReadOnlyMemory<byte> floatNullBitMapBuffer = arrowFloatArray.NullBitmapBuffer.Memory;
52-
dataFrameColumn = new SingleDataFrameColumn(field.Name, floatValueBuffer, floatNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
53-
break;
54-
case ArrowTypeId.Int8:
55-
PrimitiveArray<sbyte> arrowsbyteArray = (PrimitiveArray<sbyte>)arrowArray;
56-
ReadOnlyMemory<byte> sbyteValueBuffer = arrowsbyteArray.ValueBuffer.Memory;
57-
ReadOnlyMemory<byte> sbyteNullBitMapBuffer = arrowsbyteArray.NullBitmapBuffer.Memory;
58-
dataFrameColumn = new SByteDataFrameColumn(field.Name, sbyteValueBuffer, sbyteNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
59-
break;
60-
case ArrowTypeId.Int16:
61-
PrimitiveArray<short> arrowshortArray = (PrimitiveArray<short>)arrowArray;
62-
ReadOnlyMemory<byte> shortValueBuffer = arrowshortArray.ValueBuffer.Memory;
63-
ReadOnlyMemory<byte> shortNullBitMapBuffer = arrowshortArray.NullBitmapBuffer.Memory;
64-
dataFrameColumn = new Int16DataFrameColumn(field.Name, shortValueBuffer, shortNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
65-
break;
66-
case ArrowTypeId.Int32:
67-
PrimitiveArray<int> arrowIntArray = (PrimitiveArray<int>)arrowArray;
68-
ReadOnlyMemory<byte> intValueBuffer = arrowIntArray.ValueBuffer.Memory;
69-
ReadOnlyMemory<byte> intNullBitMapBuffer = arrowIntArray.NullBitmapBuffer.Memory;
70-
dataFrameColumn = new Int32DataFrameColumn(field.Name, intValueBuffer, intNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
71-
break;
72-
case ArrowTypeId.Int64:
73-
PrimitiveArray<long> arrowLongArray = (PrimitiveArray<long>)arrowArray;
74-
ReadOnlyMemory<byte> longValueBuffer = arrowLongArray.ValueBuffer.Memory;
75-
ReadOnlyMemory<byte> longNullBitMapBuffer = arrowLongArray.NullBitmapBuffer.Memory;
76-
dataFrameColumn = new Int64DataFrameColumn(field.Name, longValueBuffer, longNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
77-
break;
78-
case ArrowTypeId.String:
79-
StringArray stringArray = (StringArray)arrowArray;
80-
ReadOnlyMemory<byte> dataMemory = stringArray.ValueBuffer.Memory;
81-
ReadOnlyMemory<byte> offsetsMemory = stringArray.ValueOffsetsBuffer.Memory;
82-
ReadOnlyMemory<byte> nullMemory = stringArray.NullBitmapBuffer.Memory;
83-
dataFrameColumn = new ArrowStringDataFrameColumn(field.Name, dataMemory, offsetsMemory, nullMemory, stringArray.Length, stringArray.NullCount);
84-
break;
85-
case ArrowTypeId.UInt8:
86-
PrimitiveArray<byte> arrowbyteArray = (PrimitiveArray<byte>)arrowArray;
87-
ReadOnlyMemory<byte> byteValueBuffer = arrowbyteArray.ValueBuffer.Memory;
88-
ReadOnlyMemory<byte> byteNullBitMapBuffer = arrowbyteArray.NullBitmapBuffer.Memory;
89-
dataFrameColumn = new ByteDataFrameColumn(field.Name, byteValueBuffer, byteNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
90-
break;
91-
case ArrowTypeId.UInt16:
92-
PrimitiveArray<ushort> arrowUshortArray = (PrimitiveArray<ushort>)arrowArray;
93-
ReadOnlyMemory<byte> ushortValueBuffer = arrowUshortArray.ValueBuffer.Memory;
94-
ReadOnlyMemory<byte> ushortNullBitMapBuffer = arrowUshortArray.NullBitmapBuffer.Memory;
95-
dataFrameColumn = new UInt16DataFrameColumn(field.Name, ushortValueBuffer, ushortNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
96-
break;
97-
case ArrowTypeId.UInt32:
98-
PrimitiveArray<uint> arrowUintArray = (PrimitiveArray<uint>)arrowArray;
99-
ReadOnlyMemory<byte> uintValueBuffer = arrowUintArray.ValueBuffer.Memory;
100-
ReadOnlyMemory<byte> uintNullBitMapBuffer = arrowUintArray.NullBitmapBuffer.Memory;
101-
dataFrameColumn = new UInt32DataFrameColumn(field.Name, uintValueBuffer, uintNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
102-
break;
103-
case ArrowTypeId.UInt64:
104-
PrimitiveArray<ulong> arrowUlongArray = (PrimitiveArray<ulong>)arrowArray;
105-
ReadOnlyMemory<byte> ulongValueBuffer = arrowUlongArray.ValueBuffer.Memory;
106-
ReadOnlyMemory<byte> ulongNullBitMapBuffer = arrowUlongArray.NullBitmapBuffer.Memory;
107-
dataFrameColumn = new UInt64DataFrameColumn(field.Name, ulongValueBuffer, ulongNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
108-
break;
109-
case ArrowTypeId.Decimal:
110-
case ArrowTypeId.Binary:
111-
case ArrowTypeId.Date32:
112-
case ArrowTypeId.Date64:
113-
case ArrowTypeId.Dictionary:
114-
case ArrowTypeId.FixedSizedBinary:
115-
case ArrowTypeId.HalfFloat:
116-
case ArrowTypeId.Interval:
117-
case ArrowTypeId.List:
118-
case ArrowTypeId.Map:
119-
case ArrowTypeId.Null:
120-
case ArrowTypeId.Struct:
121-
case ArrowTypeId.Time32:
122-
case ArrowTypeId.Time64:
123-
default:
124-
throw new NotImplementedException(nameof(fieldType.Name));
125-
}
126-
ret.Columns.Insert(ret.Columns.Count, dataFrameColumn);
141+
AppendDataFrameColumnFromArrowArray(field, arrowArray, ret);
127142
fieldIndex++;
128143
}
129144
return ret;

src/Microsoft.Data.Analysis/Microsoft.Data.Analysis.csproj

+1-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
</ItemGroup>
4848

4949
<ItemGroup>
50-
<PackageReference Include="Apache.Arrow" Version="0.14.1" />
50+
<PackageReference Include="Apache.Arrow" Version="2.0.0" />
5151
<PackageReference Include="System.Memory" Version="4.5.3" />
5252
<PackageReference Include="System.Runtime.CompilerServices.Unsafe" Version="4.5.2" />
5353
<PackageReference Include="System.Buffers" Version="4.5.0" />

tests/Microsoft.Data.Analysis.Tests/ArrayComparer.cs

+27-9
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ public class ArrayComparer :
2626
IArrowArrayVisitor<Date64Array>,
2727
IArrowArrayVisitor<ListArray>,
2828
IArrowArrayVisitor<StringArray>,
29-
IArrowArrayVisitor<BinaryArray>
29+
IArrowArrayVisitor<BinaryArray>,
30+
IArrowArrayVisitor<StructArray>
3031
{
3132
private readonly IArrowArray _expectedArray;
3233

@@ -54,6 +55,23 @@ public ArrayComparer(IArrowArray expectedArray)
5455
public void Visit(BinaryArray array) => throw new NotImplementedException();
5556
public void Visit(IArrowArray array) => throw new NotImplementedException();
5657

58+
public void Visit(StructArray array)
59+
{
60+
Assert.IsAssignableFrom<StructArray>(_expectedArray);
61+
StructArray expectedArray = (StructArray)_expectedArray;
62+
63+
Assert.Equal(expectedArray.Length, array.Length);
64+
Assert.Equal(expectedArray.NullCount, array.NullCount);
65+
Assert.Equal(expectedArray.Offset, array.Offset);
66+
Assert.Equal(expectedArray.Data.Children.Length, array.Data.Children.Length);
67+
Assert.Equal(expectedArray.Fields.Count, array.Fields.Count);
68+
69+
for (int i = 0; i < array.Fields.Count; i++)
70+
{
71+
array.Fields[i].Accept(new ArrayComparer(expectedArray.Fields[i]));
72+
}
73+
}
74+
5775
private void CompareArrays<T>(PrimitiveArray<T> actualArray)
5876
where T : struct, IEquatable<T>
5977
{
@@ -68,15 +86,15 @@ private void CompareArrays<T>(PrimitiveArray<T> actualArray)
6886
{
6987
Assert.True(expectedArray.NullBitmapBuffer.Span.SequenceEqual(actualArray.NullBitmapBuffer.Span));
7088
}
71-
else
72-
{
89+
else
90+
{
7391
// expectedArray may have passed in a null bitmap. DataFrame might have populated it with Length set bits
74-
Assert.Equal(0, expectedArray.NullCount);
75-
Assert.Equal(0, actualArray.NullCount);
76-
for (int i = 0; i < actualArray.Length; i++)
77-
{
78-
Assert.True(actualArray.IsValid(i));
79-
}
92+
Assert.Equal(0, expectedArray.NullCount);
93+
Assert.Equal(0, actualArray.NullCount);
94+
for (int i = 0; i < actualArray.Length; i++)
95+
{
96+
Assert.True(actualArray.IsValid(i));
97+
}
8098
}
8199
Assert.True(expectedArray.Values.Slice(0, expectedArray.Length).SequenceEqual(actualArray.Values.Slice(0, actualArray.Length)));
82100
}

0 commit comments

Comments
 (0)