Skip to content

Commit d24a594

Browse files
authored
Merge pull request #2 from asmirnov82/jakerad_generic_math
Integrate latest dataframe changes from asmirnov/machinelearning into local branch
2 parents 2ba8944 + b78c0a7 commit d24a594

11 files changed

+157
-48
lines changed

eng/Versions.props

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
<SystemTextJsonVersion>6.0.1</SystemTextJsonVersion>
3131
<SystemThreadingChannelsVersion>4.7.1</SystemThreadingChannelsVersion>
3232
<!-- Other product dependencies -->
33-
<ApacheArrowVersion>2.0.0</ApacheArrowVersion>
33+
<ApacheArrowVersion>11.0.0</ApacheArrowVersion>
3434
<GoogleProtobufVersion>3.19.6</GoogleProtobufVersion>
3535
<LightGBMVersion>2.3.1</LightGBMVersion>
3636
<MicrosoftCodeAnalysisAnalyzersVersion>3.3.0</MicrosoftCodeAnalysisAnalyzersVersion>

src/Microsoft.Data.Analysis/DataFrame.Arrow.cs

+12-3
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,18 @@ private static void AppendDataFrameColumnFromArrowArray(Field field, IArrowArray
101101
AppendDataFrameColumnFromArrowArray(fieldsEnumerator.Current, structArrayEnumerator.Current, ret, field.Name + "_");
102102
}
103103
break;
104-
case ArrowTypeId.Decimal:
104+
case ArrowTypeId.Date64:
105+
Date64Array arrowDate64Array = (Date64Array)arrowArray;
106+
dataFrameColumn = new PrimitiveDataFrameColumn<DateTime>(fieldName, arrowDate64Array.Data.Length);
107+
for (int i = 0; i < arrowDate64Array.Data.Length; i++)
108+
{
109+
dataFrameColumn[i] = arrowDate64Array.GetDateTime(i);
110+
}
111+
break;
112+
case ArrowTypeId.Decimal128:
113+
case ArrowTypeId.Decimal256:
105114
case ArrowTypeId.Binary:
106115
case ArrowTypeId.Date32:
107-
case ArrowTypeId.Date64:
108116
case ArrowTypeId.Dictionary:
109117
case ArrowTypeId.FixedSizedBinary:
110118
case ArrowTypeId.HalfFloat:
@@ -114,6 +122,7 @@ private static void AppendDataFrameColumnFromArrowArray(Field field, IArrowArray
114122
case ArrowTypeId.Null:
115123
case ArrowTypeId.Time32:
116124
case ArrowTypeId.Time64:
125+
case ArrowTypeId.Timestamp:
117126
default:
118127
throw new NotImplementedException($"{fieldType.Name}");
119128
}
@@ -145,7 +154,7 @@ public static DataFrame FromArrowRecordBatch(RecordBatch recordBatch)
145154
}
146155

147156
/// <summary>
148-
/// Returns an <see cref="IEnumerable{RecordBatch}"/> without copying data
157+
/// Returns an <see cref="IEnumerable{RecordBatch}"/> mostly without copying data
149158
/// </summary>
150159
public IEnumerable<RecordBatch> ToArrowRecordBatches()
151160
{

src/Microsoft.Data.Analysis/DataFrame.Join.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ private void SetSuffixForDuplicatedColumnNames(DataFrame dataFrame, DataFrameCol
3030
{
3131
// Pre-existing column. Change name
3232
DataFrameColumn existingColumn = dataFrame.Columns[index];
33-
dataFrame._columnCollection.SetColumnName(existingColumn, existingColumn.Name + leftSuffix);
33+
existingColumn.SetName(existingColumn.Name + leftSuffix);
3434
column.SetName(column.Name + rightSuffix);
3535
index = dataFrame._columnCollection.IndexOf(column.Name);
3636
}

src/Microsoft.Data.Analysis/DataFrame.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ public DataFrame AddPrefix(string prefix, bool inPlace = false)
301301
for (int i = 0; i < df.Columns.Count; i++)
302302
{
303303
DataFrameColumn column = df.Columns[i];
304-
df._columnCollection.SetColumnName(column, prefix + column.Name);
304+
column.SetName(prefix + column.Name);
305305
df.OnColumnsChanged();
306306
}
307307
return df;
@@ -316,7 +316,7 @@ public DataFrame AddSuffix(string suffix, bool inPlace = false)
316316
for (int i = 0; i < df.Columns.Count; i++)
317317
{
318318
DataFrameColumn column = df.Columns[i];
319-
df._columnCollection.SetColumnName(column, column.Name + suffix);
319+
column.SetName(column.Name + suffix);
320320
df.OnColumnsChanged();
321321
}
322322
return df;

src/Microsoft.Data.Analysis/DataFrameColumn.cs

+34-8
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,26 @@ protected set
8484
}
8585
}
8686

87+
// List of ColumnCollections that owns the column
88+
// Current API allows column to be added into multiple dataframes, that's why the list is needed
89+
private readonly List<DataFrameColumnCollection> _ownerColumnCollections = new();
90+
91+
internal void AddOwner(DataFrameColumnCollection columCollection)
92+
{
93+
if (!_ownerColumnCollections.Contains(columCollection))
94+
{
95+
_ownerColumnCollections.Add(columCollection);
96+
}
97+
}
98+
99+
internal void RemoveOwner(DataFrameColumnCollection columCollection)
100+
{
101+
if (_ownerColumnCollections.Contains(columCollection))
102+
{
103+
_ownerColumnCollections.Remove(columCollection);
104+
}
105+
}
106+
87107
/// <summary>
88108
/// The number of <see langword="null" /> values in this column.
89109
/// </summary>
@@ -95,24 +115,30 @@ public abstract long NullCount
95115
private string _name;
96116

97117
/// <summary>
98-
/// The name of this column.
118+
/// The column name.
99119
/// </summary>
100120
public string Name => _name;
101121

102122
/// <summary>
103-
/// Updates the name of this column.
123+
/// Updates the column name.
104124
/// </summary>
105125
/// <param name="newName">The new name.</param>
106-
/// <param name="dataFrame">If passed in, update the column name in <see cref="DataFrame.Columns"/></param>
107-
public void SetName(string newName, DataFrame dataFrame = null)
126+
public void SetName(string newName)
108127
{
109-
if (!(dataFrame is null))
110-
{
111-
dataFrame.Columns.SetColumnName(this, newName);
112-
}
128+
foreach (var owner in _ownerColumnCollections)
129+
owner.UpdateColumnNameMetadata(this, newName);
130+
113131
_name = newName;
114132
}
115133

134+
/// <summary>
135+
/// Updates the name of this column.
136+
/// </summary>
137+
/// <param name="newName">The new name.</param>
138+
/// <param name="dataFrame">Ignored (for backward compatibility)</param>
139+
[Obsolete]
140+
public void SetName(string newName, DataFrame dataFrame) => SetName(newName);
141+
116142
/// <summary>
117143
/// The type of data this column holds.
118144
/// </summary>

src/Microsoft.Data.Analysis/DataFrameColumnCollection.cs

+23-6
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Licensed to the .NET Foundation under one or more agreements.
1+
// Licensed to the .NET Foundation under one or more agreements.
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

@@ -38,11 +38,23 @@ internal IReadOnlyList<string> GetColumnNames()
3838
return ret;
3939
}
4040

41+
public void RenameColumn(string currentName, string newName)
42+
{
43+
var column = this[currentName];
44+
column.SetName(newName);
45+
}
46+
47+
[Obsolete]
4148
public void SetColumnName(DataFrameColumn column, string newName)
49+
{
50+
column.SetName(newName);
51+
}
52+
53+
//Updates column's metadata (is used as a callback from Column class)
54+
internal void UpdateColumnNameMetadata(DataFrameColumn column, string newName)
4255
{
4356
string currentName = column.Name;
4457
int currentIndex = _columnNameToIndexDictionary[currentName];
45-
column.SetName(newName);
4658
_columnNameToIndexDictionary.Remove(currentName);
4759
_columnNameToIndexDictionary.Add(newName, currentIndex);
4860
ColumnsChanged?.Invoke();
@@ -66,7 +78,7 @@ protected override void InsertItem(int columnIndex, DataFrameColumn column)
6678
}
6779
else if (column.Length != RowCount)
6880
{
69-
//check all columns in the dataframe have the same length (amount of rows)
81+
//check all columns in the dataframe have the same lenght (amount of rows)
7082
throw new ArgumentException(Strings.MismatchedColumnLengths, nameof(column));
7183
}
7284

@@ -75,7 +87,7 @@ protected override void InsertItem(int columnIndex, DataFrameColumn column)
7587
throw new ArgumentException(string.Format(Strings.DuplicateColumnName, column.Name), nameof(column));
7688
}
7789

78-
RowCount = column.Length;
90+
column.AddOwner(this);
7991

8092
_columnNameToIndexDictionary[column.Name] = columnIndex;
8193
for (int i = columnIndex + 1; i < Count; i++)
@@ -100,7 +112,10 @@ protected override void SetItem(int columnIndex, DataFrameColumn column)
100112
}
101113
_columnNameToIndexDictionary.Remove(this[columnIndex].Name);
102114
_columnNameToIndexDictionary[column.Name] = columnIndex;
115+
116+
this[columnIndex].RemoveOwner(this);
103117
base.SetItem(columnIndex, column);
118+
104119
ColumnsChanged?.Invoke();
105120
}
106121

@@ -111,6 +126,8 @@ protected override void RemoveItem(int columnIndex)
111126
{
112127
_columnNameToIndexDictionary[this[i].Name]--;
113128
}
129+
130+
this[columnIndex].RemoveOwner(this);
114131
base.RemoveItem(columnIndex);
115132

116133
//Reset RowCount if the last column was removed and dataframe is empty
@@ -204,10 +221,10 @@ public PrimitiveDataFrameColumn<T> GetPrimitiveColumn<T>(string name)
204221
}
205222

206223
/// <summary>
207-
/// Gets the <see cref="PrimitiveDataFrameColumn{DateTime}"/> with the specified <paramref name="name"/>.
224+
/// Gets the <see cref="PrimitiveDataFrameColumn{T}"/> with the specified <paramref name="name"/>.
208225
/// </summary>
209226
/// <param name="name">The name of the column</param>
210-
/// <returns><see cref="PrimitiveDataFrameColumn{DateTime}"/>.</returns>
227+
/// <returns><see cref="PrimitiveDataFrameColumn{T}"/>.</returns>
211228
/// <exception cref="ArgumentException">A column named <paramref name="name"/> cannot be found, or if the column's type doesn't match.</exception>
212229
public PrimitiveDataFrameColumn<DateTime> GetDateTimeColumn(string name)
213230
{

src/Microsoft.Data.Analysis/PrimitiveColumnContainer.cs

-12
Original file line numberDiff line numberDiff line change
@@ -374,18 +374,6 @@ internal int MaxRecordBatchLength(long startIndex)
374374
return Buffers[arrayIndex].Length - (int)startIndex;
375375
}
376376

377-
internal ReadOnlyMemory<byte> GetValueBuffer(long startIndex)
378-
{
379-
int arrayIndex = GetArrayContainingRowIndex(startIndex);
380-
return Buffers[arrayIndex].ReadOnlyBuffer;
381-
}
382-
383-
internal ReadOnlyMemory<byte> GetNullBuffer(long startIndex)
384-
{
385-
int arrayIndex = GetArrayContainingRowIndex(startIndex);
386-
return NullBitMapBuffers[arrayIndex].ReadOnlyBuffer;
387-
}
388-
389377
public IReadOnlyList<T?> this[long startIndex, int length]
390378
{
391379
get

src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.cs

+44-13
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
using System.Collections.Generic;
88
using System.Diagnostics;
99
using System.Runtime.CompilerServices;
10+
using System.Runtime.InteropServices;
1011
using Apache.Arrow;
1112
using Apache.Arrow.Types;
1213
using Microsoft.ML;
@@ -104,6 +105,8 @@ private IArrowType GetArrowType()
104105
return UInt64Type.Default;
105106
else if (typeof(T) == typeof(ushort))
106107
return UInt16Type.Default;
108+
else if (typeof(T) == typeof(DateTime))
109+
return Date64Type.Default;
107110
else
108111
throw new NotImplementedException(nameof(T));
109112
}
@@ -127,36 +130,64 @@ protected internal override Apache.Arrow.Array ToArrowArray(long startIndex, int
127130
{
128131
int arrayIndex = numberOfRows == 0 ? 0 : _columnContainer.GetArrayContainingRowIndex(startIndex);
129132
int offset = (int)(startIndex - arrayIndex * ReadOnlyDataFrameBuffer<T>.MaxCapacity);
133+
130134
if (numberOfRows != 0 && numberOfRows > _columnContainer.Buffers[arrayIndex].Length - offset)
131135
{
132136
throw new ArgumentException(Strings.SpansMultipleBuffers, nameof(numberOfRows));
133137
}
134-
ArrowBuffer valueBuffer = numberOfRows == 0 ? ArrowBuffer.Empty : new ArrowBuffer(_columnContainer.GetValueBuffer(startIndex));
135-
ArrowBuffer nullBuffer = numberOfRows == 0 ? ArrowBuffer.Empty : new ArrowBuffer(_columnContainer.GetNullBuffer(startIndex));
138+
136139
int nullCount = GetNullCount(startIndex, numberOfRows);
140+
141+
//DateTime requires convertion
142+
if (this.DataType == typeof(DateTime))
143+
{
144+
if (numberOfRows == 0)
145+
return new Date64Array(ArrowBuffer.Empty, ArrowBuffer.Empty, numberOfRows, nullCount, offset);
146+
147+
ReadOnlyDataFrameBuffer<T> valueBuffer = (numberOfRows == 0) ? null : _columnContainer.Buffers[arrayIndex];
148+
ReadOnlyDataFrameBuffer<byte> nullBuffer = (numberOfRows == 0) ? null : _columnContainer.NullBitMapBuffers[arrayIndex];
149+
150+
ReadOnlySpan<DateTime> valueSpan = MemoryMarshal.Cast<T, DateTime>(valueBuffer.ReadOnlySpan);
151+
Date64Array.Builder builder = new Date64Array.Builder().Reserve(valueBuffer.Length);
152+
153+
for (int i = 0; i < valueBuffer.Length; i++)
154+
{
155+
if (BitUtility.GetBit(nullBuffer.ReadOnlySpan, i))
156+
builder.Append(valueSpan[i]);
157+
else
158+
builder.AppendNull();
159+
}
160+
161+
return builder.Build();
162+
}
163+
164+
//No convertion
165+
ArrowBuffer arrowValueBuffer = numberOfRows == 0 ? ArrowBuffer.Empty : new ArrowBuffer(_columnContainer.Buffers[arrayIndex].ReadOnlyBuffer);
166+
ArrowBuffer arrowNullBuffer = numberOfRows == 0 ? ArrowBuffer.Empty : new ArrowBuffer(_columnContainer.NullBitMapBuffers[arrayIndex].ReadOnlyBuffer);
167+
137168
Type type = this.DataType;
138169
if (type == typeof(bool))
139-
return new BooleanArray(valueBuffer, nullBuffer, numberOfRows, nullCount, offset);
170+
return new BooleanArray(arrowValueBuffer, arrowNullBuffer, numberOfRows, nullCount, offset);
140171
else if (type == typeof(double))
141-
return new DoubleArray(valueBuffer, nullBuffer, numberOfRows, nullCount, offset);
172+
return new DoubleArray(arrowValueBuffer, arrowNullBuffer, numberOfRows, nullCount, offset);
142173
else if (type == typeof(float))
143-
return new FloatArray(valueBuffer, nullBuffer, numberOfRows, nullCount, offset);
174+
return new FloatArray(arrowValueBuffer, arrowNullBuffer, numberOfRows, nullCount, offset);
144175
else if (type == typeof(int))
145-
return new Int32Array(valueBuffer, nullBuffer, numberOfRows, nullCount, offset);
176+
return new Int32Array(arrowValueBuffer, arrowNullBuffer, numberOfRows, nullCount, offset);
146177
else if (type == typeof(long))
147-
return new Int64Array(valueBuffer, nullBuffer, numberOfRows, nullCount, offset);
178+
return new Int64Array(arrowValueBuffer, arrowNullBuffer, numberOfRows, nullCount, offset);
148179
else if (type == typeof(sbyte))
149-
return new Int8Array(valueBuffer, nullBuffer, numberOfRows, nullCount, offset);
180+
return new Int8Array(arrowValueBuffer, arrowNullBuffer, numberOfRows, nullCount, offset);
150181
else if (type == typeof(short))
151-
return new Int16Array(valueBuffer, nullBuffer, numberOfRows, nullCount, offset);
182+
return new Int16Array(arrowValueBuffer, arrowNullBuffer, numberOfRows, nullCount, offset);
152183
else if (type == typeof(uint))
153-
return new UInt32Array(valueBuffer, nullBuffer, numberOfRows, nullCount, offset);
184+
return new UInt32Array(arrowValueBuffer, arrowNullBuffer, numberOfRows, nullCount, offset);
154185
else if (type == typeof(ulong))
155-
return new UInt64Array(valueBuffer, nullBuffer, numberOfRows, nullCount, offset);
186+
return new UInt64Array(arrowValueBuffer, arrowNullBuffer, numberOfRows, nullCount, offset);
156187
else if (type == typeof(ushort))
157-
return new UInt16Array(valueBuffer, nullBuffer, numberOfRows, nullCount, offset);
188+
return new UInt16Array(arrowValueBuffer, arrowNullBuffer, numberOfRows, nullCount, offset);
158189
else if (type == typeof(byte))
159-
return new UInt8Array(valueBuffer, nullBuffer, numberOfRows, nullCount, offset);
190+
return new UInt8Array(arrowValueBuffer, arrowNullBuffer, numberOfRows, nullCount, offset);
160191
else
161192
throw new NotImplementedException(type.ToString());
162193
}

test/Microsoft.Data.Analysis.Tests/ArrowIntegrationTests.cs

+1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ public void TestArrowIntegration()
4848
.Append("ULongColumn", false, new UInt64Array.Builder().AppendRange(Enumerable.Repeat((ulong)1, 10)).Build())
4949
.Append("ByteColumn", false, new Int8Array.Builder().AppendRange(Enumerable.Repeat((sbyte)1, 10)).Build())
5050
.Append("UByteColumn", false, new UInt8Array.Builder().AppendRange(Enumerable.Repeat((byte)1, 10)).Build())
51+
.Append("Date64Column", false, new Date64Array.Builder().AppendRange(Enumerable.Repeat(DateTime.Now, 10)).Build())
5152
.Build();
5253

5354
DataFrame df = DataFrame.FromArrowRecordBatch(originalBatch);

test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs

+39-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Licensed to the .NET Foundation under one or more agreements.
1+
// Licensed to the .NET Foundation under one or more agreements.
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

@@ -388,6 +388,44 @@ public void ClearColumnsTests()
388388
Assert.Equal(0, dataFrame.Columns.LongCount());
389389
}
390390

391+
[Fact]
392+
public void RenameColumnWithSetNameTests()
393+
{
394+
StringDataFrameColumn city = new StringDataFrameColumn("City", new string[] { "London", "Berlin" });
395+
PrimitiveDataFrameColumn<int> temp = new PrimitiveDataFrameColumn<int>("Temperature", new int[] { 12, 13 });
396+
397+
DataFrame dataframe = new DataFrame(city, temp);
398+
399+
// Change the name of the column:
400+
dataframe["City"].SetName("Town");
401+
var renamedColumn = dataframe["Town"];
402+
403+
Assert.Throws<ArgumentException>(() => dataframe["City"]);
404+
405+
Assert.NotNull(renamedColumn);
406+
Assert.Equal("Town", renamedColumn.Name);
407+
Assert.True(ReferenceEquals(city, renamedColumn));
408+
}
409+
410+
[Fact]
411+
public void RenameColumnWithRenameColumnTests()
412+
{
413+
StringDataFrameColumn city = new StringDataFrameColumn("City", new string[] { "London", "Berlin" });
414+
PrimitiveDataFrameColumn<int> temp = new PrimitiveDataFrameColumn<int>("Temperature", new int[] { 12, 13 });
415+
416+
DataFrame dataframe = new DataFrame(city, temp);
417+
418+
// Change the name of the column:
419+
dataframe.Columns.RenameColumn("City", "Town");
420+
var renamedColumn = dataframe["Town"];
421+
422+
Assert.Throws<ArgumentException>(() => dataframe["City"]);
423+
424+
Assert.NotNull(renamedColumn);
425+
Assert.Equal("Town", renamedColumn.Name);
426+
Assert.True(ReferenceEquals(city, renamedColumn));
427+
}
428+
391429
[Fact]
392430
public void TestBinaryOperations()
393431
{

0 commit comments

Comments
 (0)