Skip to content

Commit 616dca2

Browse files
authored
Remove ISchema in FakeSchema (#1872)
* Remove ISchema in FakeSchema * Address comments * Cleaning
1 parent 9db16c8 commit 616dca2

File tree

2 files changed

+34
-74
lines changed

2 files changed

+34
-74
lines changed

src/Microsoft.ML.Data/DataLoadSave/FakeSchema.cs

Lines changed: 32 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -3,108 +3,68 @@
33
// See the LICENSE file in the project root for more information.
44

55
using Microsoft.ML.Core.Data;
6-
using Microsoft.ML.Runtime;
76
using Microsoft.ML.Runtime.Data;
87
using Microsoft.ML.Runtime.Internal.Utilities;
9-
using System.Collections.Generic;
10-
using System.Linq;
8+
using System;
119

1210
namespace Microsoft.ML.Data.DataLoadSave
1311
{
14-
1512
/// <summary>
1613
/// A fake schema that is manufactured out of a SchemaShape.
1714
/// It will pretend that all vector sizes are equal to 10, all key value counts are equal to 10,
1815
/// and all values are defaults (for metadata).
1916
/// </summary>
20-
internal sealed class FakeSchema : ISchema
17+
internal static class FakeSchemaFactory
2118
{
2219
private const int AllVectorSizes = 10;
2320
private const int AllKeySizes = 10;
2421

25-
private readonly IHostEnvironment _env;
26-
private readonly SchemaShape _shape;
27-
private readonly Dictionary<string, int> _colMap;
28-
29-
public FakeSchema(IHostEnvironment env, SchemaShape inputShape)
22+
public static Schema Create(SchemaShape shape)
3023
{
31-
_env = env;
32-
_shape = inputShape;
33-
_colMap = Enumerable.Range(0, _shape.Count)
34-
.ToDictionary(idx => _shape[idx].Name, idx => idx);
35-
}
24+
var builder = new SchemaBuilder();
3625

37-
public int ColumnCount => _shape.Count;
38-
39-
public string GetColumnName(int col)
40-
{
41-
_env.Check(0 <= col && col < ColumnCount);
42-
return _shape[col].Name;
26+
for (int i = 0; i < shape.Count; ++i)
27+
{
28+
var metaBuilder = new MetadataBuilder();
29+
var partialMetadata = shape[i].Metadata;
30+
for (int j = 0; j < partialMetadata.Count; ++j)
31+
{
32+
var metaColumnType = MakeColumnType(partialMetadata[i]);
33+
Delegate del;
34+
if (metaColumnType.IsVector)
35+
del = Utils.MarshalInvoke(GetDefaultVectorGetter<int>, metaColumnType.ItemType.RawType);
36+
else
37+
del = Utils.MarshalInvoke(GetDefaultGetter<int>, metaColumnType.RawType);
38+
metaBuilder.Add(partialMetadata[j].Name, metaColumnType, del);
39+
}
40+
builder.AddColumn(shape[i].Name, MakeColumnType(shape[i]));
41+
}
42+
return builder.GetSchema();
4343
}
4444

45-
public ColumnType GetColumnType(int col)
45+
private static ColumnType MakeColumnType(SchemaShape.Column column)
4646
{
47-
_env.Check(0 <= col && col < ColumnCount);
48-
var inputCol = _shape[col];
49-
return MakeColumnType(inputCol);
50-
}
51-
52-
public bool TryGetColumnIndex(string name, out int col) => _colMap.TryGetValue(name, out col);
53-
54-
private static ColumnType MakeColumnType(SchemaShape.Column inputCol)
55-
{
56-
ColumnType curType = inputCol.ItemType;
57-
if (inputCol.IsKey)
47+
ColumnType curType = column.ItemType;
48+
if (column.IsKey)
5849
curType = new KeyType(((PrimitiveType)curType).RawKind, 0, AllKeySizes);
59-
if (inputCol.Kind == SchemaShape.Column.VectorKind.VariableVector)
50+
if (column.Kind == SchemaShape.Column.VectorKind.VariableVector)
6051
curType = new VectorType((PrimitiveType)curType, 0);
61-
else if (inputCol.Kind == SchemaShape.Column.VectorKind.Vector)
52+
else if (column.Kind == SchemaShape.Column.VectorKind.Vector)
6253
curType = new VectorType((PrimitiveType)curType, AllVectorSizes);
6354
return curType;
6455
}
6556

66-
public void GetMetadata<TValue>(string kind, int col, ref TValue value)
57+
private static Delegate GetDefaultVectorGetter<TValue>()
6758
{
68-
_env.Check(0 <= col && col < ColumnCount);
69-
var inputCol = _shape[col];
70-
var metaShape = inputCol.Metadata;
71-
if (metaShape == null || !metaShape.TryFindColumn(kind, out var metaColumn))
72-
throw _env.ExceptGetMetadata();
73-
74-
var colType = MakeColumnType(metaColumn);
75-
_env.Check(colType.RawType.Equals(typeof(TValue)));
76-
77-
if (colType.IsVector)
78-
{
79-
// This as an atypical use of VBuffer: we create it in GetMetadataVec, and then pass through
80-
// via boxing to be returned out of this method. This is intentional.
81-
value = (TValue)Utils.MarshalInvoke(GetMetadataVec<int>, colType.ItemType.RawType);
82-
}
83-
else
84-
value = default;
59+
ValueGetter<VBuffer<TValue>> getter = (ref VBuffer<TValue> value) => value = new VBuffer<TValue>(AllVectorSizes, 0, null, null);
60+
return getter;
8561
}
8662

87-
private object GetMetadataVec<TItem>() => new VBuffer<TItem>(AllVectorSizes, 0, null, null);
88-
89-
public ColumnType GetMetadataTypeOrNull(string kind, int col)
63+
private static Delegate GetDefaultGetter<TValue>()
9064
{
91-
_env.Check(0 <= col && col < ColumnCount);
92-
var inputCol = _shape[col];
93-
var metaShape = inputCol.Metadata;
94-
if (metaShape == null || !metaShape.TryFindColumn(kind, out var metaColumn))
95-
return null;
96-
return MakeColumnType(metaColumn);
65+
ValueGetter<TValue> getter = (ref TValue value) => value = default;
66+
return getter;
9767
}
9868

99-
public IEnumerable<KeyValuePair<string, ColumnType>> GetMetadataTypes(int col)
100-
{
101-
_env.Check(0 <= col && col < ColumnCount);
102-
var inputCol = _shape[col];
103-
var metaShape = inputCol.Metadata;
104-
if (metaShape == null)
105-
return Enumerable.Empty<KeyValuePair<string, ColumnType>>();
106-
107-
return metaShape.Select(c => new KeyValuePair<string, ColumnType>(c.Name, MakeColumnType(c)));
108-
}
10969
}
11070
}

src/Microsoft.ML.Data/DataLoadSave/TransformWrapper.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
160160
{
161161
Host.CheckValue(inputSchema, nameof(inputSchema));
162162

163-
var fakeSchema = Schema.Create(new FakeSchema(Host, inputSchema));
163+
var fakeSchema = FakeSchemaFactory.Create(inputSchema);
164164
var transformer = Fit(new EmptyDataView(Host, fakeSchema));
165165
return SchemaShape.Create(transformer.GetOutputSchema(fakeSchema));
166166
}
@@ -179,7 +179,7 @@ protected TrivialWrapperEstimator(IHost host, TransformWrapper transformer)
179179
public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
180180
{
181181
Host.CheckValue(inputSchema, nameof(inputSchema));
182-
var fakeSchema = Schema.Create(new FakeSchema(Host, inputSchema));
182+
var fakeSchema = FakeSchemaFactory.Create(inputSchema);
183183
return SchemaShape.Create(Transformer.GetOutputSchema(fakeSchema));
184184
}
185185
}

0 commit comments

Comments
 (0)