Skip to content

Commit 7b1c7d7

Browse files
artidoroZruty0
authored andcommitted
Conversion of prior and random trainers to estimators (#876)
1 parent 731381c commit 7b1c7d7

30 files changed

+6272
-83
lines changed

src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs

+20-10
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,18 @@ public abstract class PredictionTransformerBase<TModel> : IPredictionTransformer
4242
public PredictionTransformerBase(IHost host, TModel model, ISchema trainSchema, string featureColumn)
4343
{
4444
Contracts.CheckValue(host, nameof(host));
45+
Contracts.CheckValueOrNull(featureColumn);
4546
Host = host;
4647
Host.CheckValue(trainSchema, nameof(trainSchema));
4748

4849
Model = model;
4950
FeatureColumn = featureColumn;
50-
if (!trainSchema.TryGetColumnIndex(featureColumn, out int col))
51+
if (featureColumn == null)
52+
FeatureColumnType = null;
53+
else if (!trainSchema.TryGetColumnIndex(featureColumn, out int col))
5154
throw Host.ExceptSchemaMismatch(nameof(featureColumn), RoleMappedSchema.ColumnRole.Feature.Value, featureColumn);
52-
FeatureColumnType = trainSchema.GetColumnType(col);
55+
else
56+
FeatureColumnType = trainSchema.GetColumnType(col);
5357

5458
TrainSchema = trainSchema;
5559
BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model);
@@ -78,10 +82,13 @@ internal PredictionTransformerBase(IHost host, ModelLoadContext ctx)
7882
var loader = new BinaryLoader(host, new BinaryLoader.Arguments(), ms);
7983
TrainSchema = loader.Schema;
8084

81-
FeatureColumn = ctx.LoadString();
82-
if (!TrainSchema.TryGetColumnIndex(FeatureColumn, out int col))
85+
FeatureColumn = ctx.LoadStringOrNull();
86+
if (FeatureColumn == null)
87+
FeatureColumnType = null;
88+
else if (!TrainSchema.TryGetColumnIndex(FeatureColumn, out int col))
8389
throw Host.ExceptSchemaMismatch(nameof(FeatureColumn), RoleMappedSchema.ColumnRole.Feature.Value, FeatureColumn);
84-
FeatureColumnType = TrainSchema.GetColumnType(col);
90+
else
91+
FeatureColumnType = TrainSchema.GetColumnType(col);
8592

8693
BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model);
8794
}
@@ -90,10 +97,13 @@ public ISchema GetOutputSchema(ISchema inputSchema)
9097
{
9198
Host.CheckValue(inputSchema, nameof(inputSchema));
9299

93-
if (!inputSchema.TryGetColumnIndex(FeatureColumn, out int col))
94-
throw Host.ExceptSchemaMismatch(nameof(inputSchema), RoleMappedSchema.ColumnRole.Feature.Value, FeatureColumn, FeatureColumnType.ToString(), null);
95-
if (!inputSchema.GetColumnType(col).Equals(FeatureColumnType))
96-
throw Host.ExceptSchemaMismatch(nameof(inputSchema), RoleMappedSchema.ColumnRole.Feature.Value, FeatureColumn, FeatureColumnType.ToString(), inputSchema.GetColumnType(col).ToString());
100+
if(FeatureColumn != null)
101+
{
102+
if (!inputSchema.TryGetColumnIndex(FeatureColumn, out int col))
103+
throw Host.ExceptSchemaMismatch(nameof(inputSchema), RoleMappedSchema.ColumnRole.Feature.Value, FeatureColumn, FeatureColumnType.ToString(), null);
104+
if (!inputSchema.GetColumnType(col).Equals(FeatureColumnType))
105+
throw Host.ExceptSchemaMismatch(nameof(inputSchema), RoleMappedSchema.ColumnRole.Feature.Value, FeatureColumn, FeatureColumnType.ToString(), inputSchema.GetColumnType(col).ToString());
106+
}
97107

98108
return Transform(new EmptyDataView(Host, inputSchema)).Schema;
99109
}
@@ -124,7 +134,7 @@ protected virtual void SaveCore(ModelSaveContext ctx)
124134
}
125135
});
126136

127-
ctx.SaveString(FeatureColumn);
137+
ctx.SaveStringOrNull(FeatureColumn);
128138
}
129139
}
130140

src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs

+27-21
Original file line numberDiff line numberDiff line change
@@ -108,21 +108,22 @@ public ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema)
108108
using (var ch = env.Register("SchemaBindableWrapper").Start("Bind"))
109109
{
110110
ch.CheckValue(schema, nameof(schema));
111-
ch.CheckParam(schema.Feature != null, nameof(schema), "Need a features column");
112-
// Ensure that the feature column type is compatible with the needed input type.
113-
var type = schema.Feature.Type;
114-
var typeIn = ValueMapper != null ? ValueMapper.InputType : new VectorType(NumberType.Float);
115-
if (type != typeIn)
111+
if (schema.Feature != null)
116112
{
117-
if (!type.ItemType.Equals(typeIn.ItemType))
118-
throw ch.Except("Incompatible features column type item type: '{0}' vs '{1}'", type.ItemType, typeIn.ItemType);
119-
if (type.IsVector != typeIn.IsVector)
120-
throw ch.Except("Incompatible features column type: '{0}' vs '{1}'", type, typeIn);
121-
// typeIn can legally have unknown size.
122-
if (type.VectorSize != typeIn.VectorSize && typeIn.VectorSize > 0)
123-
throw ch.Except("Incompatible features column type: '{0}' vs '{1}'", type, typeIn);
113+
// Ensure that the feature column type is compatible with the needed input type.
114+
var type = schema.Feature.Type;
115+
var typeIn = ValueMapper != null ? ValueMapper.InputType : new VectorType(NumberType.Float);
116+
if (type != typeIn)
117+
{
118+
if (!type.ItemType.Equals(typeIn.ItemType))
119+
throw ch.Except("Incompatible features column type item type: '{0}' vs '{1}'", type.ItemType, typeIn.ItemType);
120+
if (type.IsVector != typeIn.IsVector)
121+
throw ch.Except("Incompatible features column type: '{0}' vs '{1}'", type, typeIn);
122+
// typeIn can legally have unknown size.
123+
if (type.VectorSize != typeIn.VectorSize && typeIn.VectorSize > 0)
124+
throw ch.Except("Incompatible features column type: '{0}' vs '{1}'", type, typeIn);
125+
}
124126
}
125-
126127
var mapper = BindCore(ch, schema);
127128
ch.Done();
128129
return mapper;
@@ -463,15 +464,18 @@ public CalibratedRowMapper(RoleMappedSchema schema, SchemaBindableBinaryPredicto
463464
Contracts.AssertValue(parent);
464465
Contracts.Assert(parent._distMapper != null);
465466
Contracts.AssertValue(schema);
466-
Contracts.AssertValue(schema.Feature);
467+
Contracts.AssertValueOrNull(schema.Feature);
467468

468469
_parent = parent;
469470
_inputSchema = schema;
470471
_outputSchema = new BinaryClassifierSchema();
471472

472-
var typeSrc = _inputSchema.Feature.Type;
473-
Contracts.Check(typeSrc.IsKnownSizeVector && typeSrc.ItemType == NumberType.Float,
474-
"Invalid feature column type");
473+
if (schema.Feature != null)
474+
{
475+
var typeSrc = _inputSchema.Feature.Type;
476+
Contracts.Check(typeSrc.IsKnownSizeVector && typeSrc.ItemType == NumberType.Float,
477+
"Invalid feature column type");
478+
}
475479
}
476480

477481
public RoleMappedSchema InputSchema { get { return _inputSchema; } }
@@ -484,15 +488,15 @@ public Func<int, bool> GetDependencies(Func<int, bool> predicate)
484488
{
485489
for (int i = 0; i < OutputSchema.ColumnCount; i++)
486490
{
487-
if (predicate(i))
491+
if (predicate(i) && _inputSchema.Feature != null)
488492
return col => col == _inputSchema.Feature.Index;
489493
}
490494
return col => false;
491495
}
492496

493497
public IEnumerable<KeyValuePair<RoleMappedSchema.ColumnRole, string>> GetInputColumnRoles()
494498
{
495-
yield return RoleMappedSchema.ColumnRole.Feature.Bind(_inputSchema.Feature.Name);
499+
yield return RoleMappedSchema.ColumnRole.Feature.Bind(_inputSchema.Feature != null ? _inputSchema.Feature.Name : null);
496500
}
497501

498502
private Delegate[] CreateGetters(IRow input, bool[] active)
@@ -504,7 +508,7 @@ private Delegate[] CreateGetters(IRow input, bool[] active)
504508
if (active[0] || active[1])
505509
{
506510
// Put all captured locals at this scope.
507-
var featureGetter = input.GetGetter<VBuffer<Float>>(_inputSchema.Feature.Index);
511+
var featureGetter = _inputSchema.Feature!= null ? input.GetGetter<VBuffer<Float>>(_inputSchema.Feature.Index) : null;
508512
Float prob = 0;
509513
Float score = 0;
510514
long cachedPosition = -1;
@@ -543,7 +547,9 @@ private static void EnsureCachedResultValueMapper(ValueMapper<VBuffer<Float>, Fl
543547
Contracts.AssertValue(mapper);
544548
if (cachedPosition != input.Position)
545549
{
546-
featureGetter(ref features);
550+
if (featureGetter != null)
551+
featureGetter(ref features);
552+
547553
mapper(ref features, ref score, ref prob);
548554
cachedPosition = input.Position;
549555
}

0 commit comments

Comments
 (0)