diff --git a/src/Microsoft.Data.Analysis/IDataView.Extension.cs b/src/Microsoft.Data.Analysis/IDataView.Extension.cs index 303eaf3b76..545765f010 100644 --- a/src/Microsoft.Data.Analysis/IDataView.Extension.cs +++ b/src/Microsoft.Data.Analysis/IDataView.Extension.cs @@ -204,6 +204,10 @@ private static DataFrameColumn GetVectorDataFrame(VectorDataViewType vectorType, { return new VBufferDataFrameColumn(name); } + else if (itemType.RawType == typeof(ReadOnlyMemory)) + { + return new VBufferDataFrameColumn>(name); + } throw new NotSupportedException(String.Format(Microsoft.Data.Strings.VectorSubTypeNotSupported, itemType.ToString())); } diff --git a/src/Microsoft.Data.Analysis/VBufferDataFrameColumn.cs b/src/Microsoft.Data.Analysis/VBufferDataFrameColumn.cs index caa786afdb..35456f34c4 100644 --- a/src/Microsoft.Data.Analysis/VBufferDataFrameColumn.cs +++ b/src/Microsoft.Data.Analysis/VBufferDataFrameColumn.cs @@ -390,6 +390,10 @@ private static VectorDataViewType GetDataViewType() { return new VectorDataViewType(NumberDataViewType.Double); } + else if (typeof(T) == typeof(ReadOnlyMemory)) + { + return new VectorDataViewType(TextDataViewType.Instance); + } throw new NotSupportedException(); } diff --git a/test/Microsoft.Data.Analysis.Tests/DataFrameIDataViewTests.cs b/test/Microsoft.Data.Analysis.Tests/DataFrameIDataViewTests.cs index 4576870707..83cfb67fec 100644 --- a/test/Microsoft.Data.Analysis.Tests/DataFrameIDataViewTests.cs +++ b/test/Microsoft.Data.Analysis.Tests/DataFrameIDataViewTests.cs @@ -461,6 +461,7 @@ public void TestDataFrameFromIDataView_VBufferType() ushortFeatures = new ushort[] {0, 0}, uintFeatures = new uint[] {0, 0}, ulongFeatures = new ulong[] {0, 0}, + stringFeatures = new string[]{ "A", "B"}, }, new { boolFeature = new bool[] {false, false}, @@ -474,13 +475,14 @@ public void TestDataFrameFromIDataView_VBufferType() ushortFeatures = new ushort[] {0, 0}, uintFeatures = new uint[] {0, 0}, ulongFeatures = new ulong[] {0, 0}, + stringFeatures = new string[]{ "A", "B"}, } }; var data = mlContext.Data.LoadFromEnumerable(inputData); var df = data.ToDataFrame(); - Assert.Equal(11, df.Columns.Count); + Assert.Equal(12, df.Columns.Count); Assert.Equal(2, df.Rows.Count); } }