Skip to content

Conversion of prior and random trainers to estimators #876

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Sep 19, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 20 additions & 10 deletions src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,18 @@ public abstract class PredictionTransformerBase<TModel> : IPredictionTransformer
public PredictionTransformerBase(IHost host, TModel model, ISchema trainSchema, string featureColumn)
{
Contracts.CheckValue(host, nameof(host));
Contracts.CheckValueOrNull(featureColumn);
Host = host;
Host.CheckValue(trainSchema, nameof(trainSchema));

Model = model;
FeatureColumn = featureColumn;
if (!trainSchema.TryGetColumnIndex(featureColumn, out int col))
if (featureColumn == null)
FeatureColumnType = null;
else if (!trainSchema.TryGetColumnIndex(featureColumn, out int col))
throw Host.ExceptSchemaMismatch(nameof(featureColumn), RoleMappedSchema.ColumnRole.Feature.Value, featureColumn);
FeatureColumnType = trainSchema.GetColumnType(col);
else
FeatureColumnType = trainSchema.GetColumnType(col);

TrainSchema = trainSchema;
BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model);
Expand Down Expand Up @@ -75,10 +79,13 @@ internal PredictionTransformerBase(IHost host, ModelLoadContext ctx)
var loader = new BinaryLoader(host, new BinaryLoader.Arguments(), ms);
TrainSchema = loader.Schema;

FeatureColumn = ctx.LoadString();
if (!TrainSchema.TryGetColumnIndex(FeatureColumn, out int col))
FeatureColumn = ctx.LoadStringOrNull();
if (FeatureColumn == null)
FeatureColumnType = null;
else if (!TrainSchema.TryGetColumnIndex(FeatureColumn, out int col))
throw Host.ExceptSchemaMismatch(nameof(FeatureColumn), RoleMappedSchema.ColumnRole.Feature.Value, FeatureColumn);
FeatureColumnType = TrainSchema.GetColumnType(col);
else
FeatureColumnType = TrainSchema.GetColumnType(col);

BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model);
}
Expand All @@ -87,10 +94,13 @@ public ISchema GetOutputSchema(ISchema inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));

if (!inputSchema.TryGetColumnIndex(FeatureColumn, out int col))
throw Host.ExceptSchemaMismatch(nameof(inputSchema), RoleMappedSchema.ColumnRole.Feature.Value, FeatureColumn, FeatureColumnType.ToString(), null);
if (!inputSchema.GetColumnType(col).Equals(FeatureColumnType))
throw Host.ExceptSchemaMismatch(nameof(inputSchema), RoleMappedSchema.ColumnRole.Feature.Value, FeatureColumn, FeatureColumnType.ToString(), inputSchema.GetColumnType(col).ToString());
if(FeatureColumn != null)
{
if (!inputSchema.TryGetColumnIndex(FeatureColumn, out int col))
throw Host.ExceptSchemaMismatch(nameof(inputSchema), RoleMappedSchema.ColumnRole.Feature.Value, FeatureColumn, FeatureColumnType.ToString(), null);
if (!inputSchema.GetColumnType(col).Equals(FeatureColumnType))
throw Host.ExceptSchemaMismatch(nameof(inputSchema), RoleMappedSchema.ColumnRole.Feature.Value, FeatureColumn, FeatureColumnType.ToString(), inputSchema.GetColumnType(col).ToString());
}

return Transform(new EmptyDataView(Host, inputSchema)).Schema;
}
Expand Down Expand Up @@ -121,7 +131,7 @@ protected virtual void SaveCore(ModelSaveContext ctx)
}
});

ctx.SaveString(FeatureColumn);
ctx.SaveStringOrNull(FeatureColumn);
}
}

Expand Down
48 changes: 27 additions & 21 deletions src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -108,21 +108,22 @@ public ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema)
using (var ch = env.Register("SchemaBindableWrapper").Start("Bind"))
{
ch.CheckValue(schema, nameof(schema));
ch.CheckParam(schema.Feature != null, nameof(schema), "Need a features column");
// Ensure that the feature column type is compatible with the needed input type.
var type = schema.Feature.Type;
var typeIn = ValueMapper != null ? ValueMapper.InputType : new VectorType(NumberType.Float);
if (type != typeIn)
if (schema.Feature != null)
{
if (!type.ItemType.Equals(typeIn.ItemType))
throw ch.Except("Incompatible features column type item type: '{0}' vs '{1}'", type.ItemType, typeIn.ItemType);
if (type.IsVector != typeIn.IsVector)
throw ch.Except("Incompatible features column type: '{0}' vs '{1}'", type, typeIn);
// typeIn can legally have unknown size.
if (type.VectorSize != typeIn.VectorSize && typeIn.VectorSize > 0)
throw ch.Except("Incompatible features column type: '{0}' vs '{1}'", type, typeIn);
// Ensure that the feature column type is compatible with the needed input type.
var type = schema.Feature.Type;
var typeIn = ValueMapper != null ? ValueMapper.InputType : new VectorType(NumberType.Float);
if (type != typeIn)
{
if (!type.ItemType.Equals(typeIn.ItemType))
throw ch.Except("Incompatible features column type item type: '{0}' vs '{1}'", type.ItemType, typeIn.ItemType);
if (type.IsVector != typeIn.IsVector)
throw ch.Except("Incompatible features column type: '{0}' vs '{1}'", type, typeIn);
// typeIn can legally have unknown size.
if (type.VectorSize != typeIn.VectorSize && typeIn.VectorSize > 0)
throw ch.Except("Incompatible features column type: '{0}' vs '{1}'", type, typeIn);
}
}

var mapper = BindCore(ch, schema);
ch.Done();
return mapper;
Expand Down Expand Up @@ -463,15 +464,18 @@ public CalibratedRowMapper(RoleMappedSchema schema, SchemaBindableBinaryPredicto
Contracts.AssertValue(parent);
Contracts.Assert(parent._distMapper != null);
Contracts.AssertValue(schema);
Contracts.AssertValue(schema.Feature);
Contracts.AssertValueOrNull(schema.Feature);

_parent = parent;
_inputSchema = schema;
_outputSchema = new BinaryClassifierSchema();

var typeSrc = _inputSchema.Feature.Type;
Contracts.Check(typeSrc.IsKnownSizeVector && typeSrc.ItemType == NumberType.Float,
"Invalid feature column type");
if (schema.Feature != null)
{
var typeSrc = _inputSchema.Feature.Type;
Contracts.Check(typeSrc.IsKnownSizeVector && typeSrc.ItemType == NumberType.Float,
"Invalid feature column type");
}
}

public RoleMappedSchema InputSchema { get { return _inputSchema; } }
Expand All @@ -484,15 +488,15 @@ public Func<int, bool> GetDependencies(Func<int, bool> predicate)
{
for (int i = 0; i < OutputSchema.ColumnCount; i++)
{
if (predicate(i))
if (predicate(i) && _inputSchema.Feature != null)
return col => col == _inputSchema.Feature.Index;
}
return col => false;
}

public IEnumerable<KeyValuePair<RoleMappedSchema.ColumnRole, string>> GetInputColumnRoles()
{
yield return RoleMappedSchema.ColumnRole.Feature.Bind(_inputSchema.Feature.Name);
yield return RoleMappedSchema.ColumnRole.Feature.Bind(_inputSchema.Feature != null ? _inputSchema.Feature.Name : null);
}

private Delegate[] CreateGetters(IRow input, bool[] active)
Expand All @@ -504,7 +508,7 @@ private Delegate[] CreateGetters(IRow input, bool[] active)
if (active[0] || active[1])
{
// Put all captured locals at this scope.
var featureGetter = input.GetGetter<VBuffer<Float>>(_inputSchema.Feature.Index);
var featureGetter = _inputSchema.Feature!= null ? input.GetGetter<VBuffer<Float>>(_inputSchema.Feature.Index) : null;
Float prob = 0;
Float score = 0;
long cachedPosition = -1;
Expand Down Expand Up @@ -543,7 +547,9 @@ private static void EnsureCachedResultValueMapper(ValueMapper<VBuffer<Float>, Fl
Contracts.AssertValue(mapper);
if (cachedPosition != input.Position)
{
featureGetter(ref features);
if (featureGetter != null)
featureGetter(ref features);

mapper(ref features, ref score, ref prob);
cachedPosition = input.Position;
}
Expand Down
Loading