|
3 | 3 | // See the LICENSE file in the project root for more information.
|
4 | 4 |
|
5 | 5 | using Microsoft.ML.Core.Data;
|
6 |
| -using Microsoft.ML.Runtime; |
7 | 6 | using Microsoft.ML.Runtime.Data;
|
8 | 7 | using Microsoft.ML.Runtime.Internal.Utilities;
|
9 |
| -using System.Collections.Generic; |
10 |
| -using System.Linq; |
| 8 | +using System; |
11 | 9 |
|
12 | 10 | namespace Microsoft.ML.Data.DataLoadSave
|
13 | 11 | {
|
14 |
| - |
15 | 12 | /// <summary>
|
16 | 13 | /// A fake schema that is manufactured out of a SchemaShape.
|
17 | 14 | /// It will pretend that all vector sizes are equal to 10, all key value counts are equal to 10,
|
18 | 15 | /// and all values are defaults (for metadata).
|
19 | 16 | /// </summary>
|
20 |
| - internal sealed class FakeSchema : ISchema |
| 17 | + internal static class FakeSchemaFactory |
21 | 18 | {
|
22 | 19 | private const int AllVectorSizes = 10;
|
23 | 20 | private const int AllKeySizes = 10;
|
24 | 21 |
|
25 |
| - private readonly IHostEnvironment _env; |
26 |
| - private readonly SchemaShape _shape; |
27 |
| - private readonly Dictionary<string, int> _colMap; |
28 |
| - |
29 |
| - public FakeSchema(IHostEnvironment env, SchemaShape inputShape) |
| 22 | + public static Schema Create(SchemaShape shape) |
30 | 23 | {
|
31 |
| - _env = env; |
32 |
| - _shape = inputShape; |
33 |
| - _colMap = Enumerable.Range(0, _shape.Count) |
34 |
| - .ToDictionary(idx => _shape[idx].Name, idx => idx); |
35 |
| - } |
| 24 | + var builder = new SchemaBuilder(); |
36 | 25 |
|
37 |
| - public int ColumnCount => _shape.Count; |
38 |
| - |
39 |
| - public string GetColumnName(int col) |
40 |
| - { |
41 |
| - _env.Check(0 <= col && col < ColumnCount); |
42 |
| - return _shape[col].Name; |
| 26 | + for (int i = 0; i < shape.Count; ++i) |
| 27 | + { |
| 28 | + var metaBuilder = new MetadataBuilder(); |
| 29 | + var partialMetadata = shape[i].Metadata; |
| 30 | + for (int j = 0; j < partialMetadata.Count; ++j) |
| 31 | + { |
| 32 | + var metaColumnType = MakeColumnType(partialMetadata[i]); |
| 33 | + Delegate del; |
| 34 | + if (metaColumnType.IsVector) |
| 35 | + del = Utils.MarshalInvoke(GetDefaultVectorGetter<int>, metaColumnType.ItemType.RawType); |
| 36 | + else |
| 37 | + del = Utils.MarshalInvoke(GetDefaultGetter<int>, metaColumnType.RawType); |
| 38 | + metaBuilder.Add(partialMetadata[j].Name, metaColumnType, del); |
| 39 | + } |
| 40 | + builder.AddColumn(shape[i].Name, MakeColumnType(shape[i])); |
| 41 | + } |
| 42 | + return builder.GetSchema(); |
43 | 43 | }
|
44 | 44 |
|
45 |
| - public ColumnType GetColumnType(int col) |
| 45 | + private static ColumnType MakeColumnType(SchemaShape.Column column) |
46 | 46 | {
|
47 |
| - _env.Check(0 <= col && col < ColumnCount); |
48 |
| - var inputCol = _shape[col]; |
49 |
| - return MakeColumnType(inputCol); |
50 |
| - } |
51 |
| - |
52 |
| - public bool TryGetColumnIndex(string name, out int col) => _colMap.TryGetValue(name, out col); |
53 |
| - |
54 |
| - private static ColumnType MakeColumnType(SchemaShape.Column inputCol) |
55 |
| - { |
56 |
| - ColumnType curType = inputCol.ItemType; |
57 |
| - if (inputCol.IsKey) |
| 47 | + ColumnType curType = column.ItemType; |
| 48 | + if (column.IsKey) |
58 | 49 | curType = new KeyType(((PrimitiveType)curType).RawKind, 0, AllKeySizes);
|
59 |
| - if (inputCol.Kind == SchemaShape.Column.VectorKind.VariableVector) |
| 50 | + if (column.Kind == SchemaShape.Column.VectorKind.VariableVector) |
60 | 51 | curType = new VectorType((PrimitiveType)curType, 0);
|
61 |
| - else if (inputCol.Kind == SchemaShape.Column.VectorKind.Vector) |
| 52 | + else if (column.Kind == SchemaShape.Column.VectorKind.Vector) |
62 | 53 | curType = new VectorType((PrimitiveType)curType, AllVectorSizes);
|
63 | 54 | return curType;
|
64 | 55 | }
|
65 | 56 |
|
66 |
| - public void GetMetadata<TValue>(string kind, int col, ref TValue value) |
| 57 | + private static Delegate GetDefaultVectorGetter<TValue>() |
67 | 58 | {
|
68 |
| - _env.Check(0 <= col && col < ColumnCount); |
69 |
| - var inputCol = _shape[col]; |
70 |
| - var metaShape = inputCol.Metadata; |
71 |
| - if (metaShape == null || !metaShape.TryFindColumn(kind, out var metaColumn)) |
72 |
| - throw _env.ExceptGetMetadata(); |
73 |
| - |
74 |
| - var colType = MakeColumnType(metaColumn); |
75 |
| - _env.Check(colType.RawType.Equals(typeof(TValue))); |
76 |
| - |
77 |
| - if (colType.IsVector) |
78 |
| - { |
79 |
| - // This as an atypical use of VBuffer: we create it in GetMetadataVec, and then pass through |
80 |
| - // via boxing to be returned out of this method. This is intentional. |
81 |
| - value = (TValue)Utils.MarshalInvoke(GetMetadataVec<int>, colType.ItemType.RawType); |
82 |
| - } |
83 |
| - else |
84 |
| - value = default; |
| 59 | + ValueGetter<VBuffer<TValue>> getter = (ref VBuffer<TValue> value) => value = new VBuffer<TValue>(AllVectorSizes, 0, null, null); |
| 60 | + return getter; |
85 | 61 | }
|
86 | 62 |
|
87 |
| - private object GetMetadataVec<TItem>() => new VBuffer<TItem>(AllVectorSizes, 0, null, null); |
88 |
| - |
89 |
| - public ColumnType GetMetadataTypeOrNull(string kind, int col) |
| 63 | + private static Delegate GetDefaultGetter<TValue>() |
90 | 64 | {
|
91 |
| - _env.Check(0 <= col && col < ColumnCount); |
92 |
| - var inputCol = _shape[col]; |
93 |
| - var metaShape = inputCol.Metadata; |
94 |
| - if (metaShape == null || !metaShape.TryFindColumn(kind, out var metaColumn)) |
95 |
| - return null; |
96 |
| - return MakeColumnType(metaColumn); |
| 65 | + ValueGetter<TValue> getter = (ref TValue value) => value = default; |
| 66 | + return getter; |
97 | 67 | }
|
98 | 68 |
|
99 |
| - public IEnumerable<KeyValuePair<string, ColumnType>> GetMetadataTypes(int col) |
100 |
| - { |
101 |
| - _env.Check(0 <= col && col < ColumnCount); |
102 |
| - var inputCol = _shape[col]; |
103 |
| - var metaShape = inputCol.Metadata; |
104 |
| - if (metaShape == null) |
105 |
| - return Enumerable.Empty<KeyValuePair<string, ColumnType>>(); |
106 |
| - |
107 |
| - return metaShape.Select(c => new KeyValuePair<string, ColumnType>(c.Name, MakeColumnType(c))); |
108 |
| - } |
109 | 69 | }
|
110 | 70 | }
|
0 commit comments