Skip to content

Commit 60ae981

Browse files
authored
SDCA trainers become Estimators (#716)
Converted three SDCA trainers to estimators
1 parent 4fd8a9c commit 60ae981

26 files changed

+1362
-583
lines changed

src/Microsoft.ML.Core/Data/IEstimator.cs

+41-5
Original file line numberDiff line numberDiff line change
@@ -34,23 +34,59 @@ public enum VectorKind
3434
public readonly bool IsKey;
3535
public readonly string[] MetadataKinds;
3636

37-
public Column(string name, VectorKind vecKind, DataKind itemKind, bool isKey, string[] metadataKinds)
37+
public Column(string name, VectorKind vecKind, DataKind itemKind, bool isKey, string[] metadataKinds = null)
3838
{
3939
Contracts.CheckNonEmpty(name, nameof(name));
40-
Contracts.CheckValue(metadataKinds, nameof(metadataKinds));
40+
Contracts.CheckValueOrNull(metadataKinds);
4141

4242
Name = name;
4343
Kind = vecKind;
4444
ItemKind = itemKind;
4545
IsKey = isKey;
46-
MetadataKinds = metadataKinds;
46+
MetadataKinds = metadataKinds ?? new string[0];
47+
}
48+
49+
/// <summary>
50+
/// Returns whether <paramref name="inputColumn"/> is a valid input, if this object represents a
51+
/// requirement.
52+
///
53+
/// Namely, it returns true iff:
54+
/// - The <see cref="Name"/>, <see cref="Kind"/>, <see cref="ItemKind"/>, <see cref="IsKey"/> fields match.
55+
/// - The <see cref="MetadataKinds"/> of <paramref name="inputColumn"/> is a superset of our <see cref="MetadataKinds"/>.
56+
/// </summary>
57+
public bool IsCompatibleWith(Column inputColumn)
58+
{
59+
Contracts.CheckValue(inputColumn, nameof(inputColumn));
60+
if (Name != inputColumn.Name)
61+
return false;
62+
if (Kind != inputColumn.Kind)
63+
return false;
64+
if (ItemKind != inputColumn.ItemKind)
65+
return false;
66+
if (IsKey != inputColumn.IsKey)
67+
return false;
68+
if (inputColumn.MetadataKinds.Except(MetadataKinds).Any())
69+
return false;
70+
return true;
71+
}
72+
73+
public string GetTypeString()
74+
{
75+
string result = ItemKind.ToString();
76+
if (IsKey)
77+
result = $"Key<{result}>";
78+
if (Kind == VectorKind.Vector)
79+
result = $"Vector<{result}>";
80+
else if (Kind == VectorKind.VariableVector)
81+
result = $"VarVector<{result}>";
82+
return result;
4783
}
4884
}
4985

50-
public SchemaShape(Column[] columns)
86+
public SchemaShape(IEnumerable<Column> columns)
5187
{
5288
Contracts.CheckValue(columns, nameof(columns));
53-
Columns = columns;
89+
Columns = columns.ToArray();
5490
}
5591

5692
/// <summary>

0 commit comments

Comments
 (0)