@@ -16,19 +16,9 @@ namespace Microsoft.ML.Core.Data
16
16
/// </summary>
17
17
public sealed class SchemaShape
18
18
{
19
- public readonly ColumnBase [ ] Columns ;
19
+ public readonly Column [ ] Columns ;
20
20
21
- public abstract class ColumnBase
22
- {
23
- public readonly string Name ;
24
- public ColumnBase ( string name )
25
- {
26
- Contracts . CheckNonEmpty ( name , nameof ( name ) ) ;
27
- Name = name ;
28
- }
29
- }
30
-
31
- public sealed class RelaxedColumn : ColumnBase
21
+ public sealed class Column
32
22
{
33
23
public enum VectorKind
34
24
{
@@ -37,33 +27,22 @@ public enum VectorKind
37
27
VariableVector
38
28
}
39
29
30
+ public readonly string Name ;
40
31
public readonly VectorKind Kind ;
41
32
public readonly DataKind ItemKind ;
42
33
public readonly bool IsKey ;
43
34
44
- public RelaxedColumn ( string name , VectorKind kind , DataKind itemKind , bool isKey )
45
- : base ( name )
35
+ public Column ( string name , VectorKind vecKind , DataKind itemKind , bool isKey )
46
36
{
47
- Kind = kind ;
37
+ Contracts . CheckNonEmpty ( name , nameof ( name ) ) ;
38
+ Name = name ;
39
+ Kind = vecKind ;
48
40
ItemKind = itemKind ;
49
41
IsKey = isKey ;
50
42
}
51
43
}
52
44
53
- public sealed class StrictColumn : ColumnBase
54
- {
55
- // REVIEW: do we ever need strict columns? Maybe we should only have relaxed?
56
- public readonly ColumnType ColumnType ;
57
-
58
- public StrictColumn ( string name , ColumnType columnType )
59
- : base ( name )
60
- {
61
- Contracts . CheckValue ( columnType , nameof ( columnType ) ) ;
62
- ColumnType = columnType ;
63
- }
64
- }
65
-
66
- public SchemaShape ( ColumnBase [ ] columns )
45
+ public SchemaShape ( Column [ ] columns )
67
46
{
68
47
Contracts . CheckValue ( columns , nameof ( columns ) ) ;
69
48
Columns = columns ;
@@ -75,20 +54,32 @@ public SchemaShape(ColumnBase[] columns)
75
54
public static SchemaShape Create ( ISchema schema )
76
55
{
77
56
Contracts . CheckValue ( schema , nameof ( schema ) ) ;
78
- var cols = new List < ColumnBase > ( ) ;
57
+ var cols = new List < Column > ( ) ;
79
58
80
59
for ( int iCol = 0 ; iCol < schema . ColumnCount ; iCol ++ )
81
60
{
82
61
if ( ! schema . IsHidden ( iCol ) )
83
- cols . Append ( new StrictColumn ( schema . GetColumnName ( iCol ) , schema . GetColumnType ( iCol ) ) ) ;
62
+ {
63
+ Column . VectorKind vecKind ;
64
+ var type = schema . GetColumnType ( iCol ) ;
65
+ if ( type . IsKnownSizeVector )
66
+ vecKind = Column . VectorKind . Vector ;
67
+ else if ( type . IsVector )
68
+ vecKind = Column . VectorKind . VariableVector ;
69
+ else
70
+ vecKind = Column . VectorKind . Scalar ;
71
+ var kind = type . ItemType . RawKind ;
72
+ var isKey = type . ItemType . IsKey ;
73
+ cols . Add ( new Column ( schema . GetColumnName ( iCol ) , vecKind , kind , isKey ) ) ;
74
+ }
84
75
}
85
76
return new SchemaShape ( cols . ToArray ( ) ) ;
86
77
}
87
78
88
79
/// <summary>
89
80
/// Returns the column with a specified <paramref name="name"/>, and <c>null</c> if there is no such column.
90
81
/// </summary>
91
- public ColumnBase FindColumn ( string name )
82
+ public Column FindColumn ( string name )
92
83
{
93
84
Contracts . CheckValue ( name , nameof ( name ) ) ;
94
85
return Columns . FirstOrDefault ( x => x . Name == name ) ;
@@ -140,6 +131,15 @@ public interface IEstimator<TIn>
140
131
SchemaShape GetOutputSchema ( ) ;
141
132
}
142
133
134
+ /// <summary>
135
+ /// An estimator that provides more details about the produced transformer, in the form of <typeparamref name="TTransformer"/>.
136
+ /// </summary>
137
+ public interface IEstimator < TIn , TTransformer > : IEstimator < TIn >
138
+ where TTransformer : ITransformer < TIn >
139
+ {
140
+ new TTransformer Fit ( TIn input ) ;
141
+ }
142
+
143
143
/// <summary>
144
144
/// The data transformer, in addition to being a transformer, also exposes the input schema shape. It is handy for
145
145
/// evaluating what kind of columns the transformer expects.
@@ -174,4 +174,13 @@ public interface IDataEstimator
174
174
/// </summary>
175
175
SchemaShape GetOutputSchema ( SchemaShape inputSchema ) ;
176
176
}
177
+
178
+ /// <summary>
179
+ /// A data estimator that provides more details about the produced transformer, in the form of <typeparamref name="TTransformer"/>.
180
+ /// </summary>
181
+ public interface IDataEstimator < TTransformer > : IDataEstimator
182
+ where TTransformer : IDataTransformer
183
+ {
184
+ new TTransformer Fit ( IDataView input ) ;
185
+ }
177
186
}
0 commit comments