Skip to content

Commit 5a819d3

Browse files
author
Pete Luferenko
committed
Typed estimators
1 parent 91dc0f2 commit 5a819d3

File tree

2 files changed

+287
-82
lines changed

2 files changed

+287
-82
lines changed

src/Microsoft.ML.Core/Data/IEstimator.cs

Lines changed: 41 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,9 @@ namespace Microsoft.ML.Core.Data
1616
/// </summary>
1717
public sealed class SchemaShape
1818
{
19-
public readonly ColumnBase[] Columns;
19+
public readonly Column[] Columns;
2020

21-
public abstract class ColumnBase
22-
{
23-
public readonly string Name;
24-
public ColumnBase(string name)
25-
{
26-
Contracts.CheckNonEmpty(name, nameof(name));
27-
Name = name;
28-
}
29-
}
30-
31-
public sealed class RelaxedColumn : ColumnBase
21+
public sealed class Column
3222
{
3323
public enum VectorKind
3424
{
@@ -37,33 +27,22 @@ public enum VectorKind
3727
VariableVector
3828
}
3929

30+
public readonly string Name;
4031
public readonly VectorKind Kind;
4132
public readonly DataKind ItemKind;
4233
public readonly bool IsKey;
4334

44-
public RelaxedColumn(string name, VectorKind kind, DataKind itemKind, bool isKey)
45-
: base(name)
35+
public Column(string name, VectorKind vecKind, DataKind itemKind, bool isKey)
4636
{
47-
Kind = kind;
37+
Contracts.CheckNonEmpty(name, nameof(name));
38+
Name = name;
39+
Kind = vecKind;
4840
ItemKind = itemKind;
4941
IsKey = isKey;
5042
}
5143
}
5244

53-
public sealed class StrictColumn : ColumnBase
54-
{
55-
// REVIEW: do we ever need strict columns? Maybe we should only have relaxed?
56-
public readonly ColumnType ColumnType;
57-
58-
public StrictColumn(string name, ColumnType columnType)
59-
: base(name)
60-
{
61-
Contracts.CheckValue(columnType, nameof(columnType));
62-
ColumnType = columnType;
63-
}
64-
}
65-
66-
public SchemaShape(ColumnBase[] columns)
45+
public SchemaShape(Column[] columns)
6746
{
6847
Contracts.CheckValue(columns, nameof(columns));
6948
Columns = columns;
@@ -75,20 +54,32 @@ public SchemaShape(ColumnBase[] columns)
7554
public static SchemaShape Create(ISchema schema)
7655
{
7756
Contracts.CheckValue(schema, nameof(schema));
78-
var cols = new List<ColumnBase>();
57+
var cols = new List<Column>();
7958

8059
for (int iCol = 0; iCol < schema.ColumnCount; iCol++)
8160
{
8261
if (!schema.IsHidden(iCol))
83-
cols.Append(new StrictColumn(schema.GetColumnName(iCol), schema.GetColumnType(iCol)));
62+
{
63+
Column.VectorKind vecKind;
64+
var type = schema.GetColumnType(iCol);
65+
if (type.IsKnownSizeVector)
66+
vecKind = Column.VectorKind.Vector;
67+
else if (type.IsVector)
68+
vecKind = Column.VectorKind.VariableVector;
69+
else
70+
vecKind = Column.VectorKind.Scalar;
71+
var kind = type.ItemType.RawKind;
72+
var isKey = type.ItemType.IsKey;
73+
cols.Add(new Column(schema.GetColumnName(iCol), vecKind, kind, isKey));
74+
}
8475
}
8576
return new SchemaShape(cols.ToArray());
8677
}
8778

8879
/// <summary>
8980
/// Returns the column with a specified <paramref name="name"/>, and <c>null</c> if there is no such column.
9081
/// </summary>
91-
public ColumnBase FindColumn(string name)
82+
public Column FindColumn(string name)
9283
{
9384
Contracts.CheckValue(name, nameof(name));
9485
return Columns.FirstOrDefault(x => x.Name == name);
@@ -140,6 +131,15 @@ public interface IEstimator<TIn>
140131
SchemaShape GetOutputSchema();
141132
}
142133

134+
/// <summary>
135+
/// An estimator that provides more details about the produced transformer, in the form of <typeparamref name="TTransformer"/>.
136+
/// </summary>
137+
public interface IEstimator<TIn, TTransformer>: IEstimator<TIn>
138+
where TTransformer: ITransformer<TIn>
139+
{
140+
new TTransformer Fit(TIn input);
141+
}
142+
143143
/// <summary>
144144
/// The data transformer, in addition to being a transformer, also exposes the input schema shape. It is handy for
145145
/// evaluating what kind of columns the transformer expects.
@@ -174,4 +174,13 @@ public interface IDataEstimator
174174
/// </summary>
175175
SchemaShape GetOutputSchema(SchemaShape inputSchema);
176176
}
177+
178+
/// <summary>
179+
/// A data estimator that provides more details about the produced transformer, in the form of <typeparamref name="TTransformer"/>.
180+
/// </summary>
181+
public interface IDataEstimator<TTransformer>: IDataEstimator
182+
where TTransformer: IDataTransformer
183+
{
184+
new TTransformer Fit(IDataView input);
185+
}
177186
}

0 commit comments

Comments
 (0)