Skip to content

Conversion of prior and random trainers to estimators #876

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Sep 19, 2018
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Core/Data/RoleMappedSchema.cs
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ public RoleMappedData(IDataView data, IEnumerable<KeyValuePair<RoleMappedSchema.
/// <param name="opt">Whether to consider the column names specified "optional" or not. If <c>false</c> then any non-empty
/// values for the column names that does not appear in <paramref name="data"/>'s schema will result in an exception being thrown,
/// but if <c>true</c> such values will be ignored</param>
public RoleMappedData(IDataView data, string label, string feature,
public RoleMappedData(IDataView data, string label, string feature = null,
Copy link
Contributor

@Zruty0 Zruty0 Sep 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

= null [](start = 74, length = 7)

undo this #Resolved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will handle this in my code then! Sorry about that


In reply to: 218244002 [](ancestors = 218244002)

string group = null, string weight = null, string name = null,
IEnumerable<KeyValuePair<RoleMappedSchema.ColumnRole, string>> custom = null, bool opt = false)
: this(Contracts.CheckRef(data, nameof(data)),
Expand Down
30 changes: 20 additions & 10 deletions src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,18 @@ public abstract class PredictionTransformerBase<TModel> : IPredictionTransformer
public PredictionTransformerBase(IHost host, TModel model, ISchema trainSchema, string featureColumn)
{
Contracts.CheckValue(host, nameof(host));
Contracts.CheckValueOrNull(featureColumn);
Host = host;
Host.CheckValue(trainSchema, nameof(trainSchema));

Model = model;
FeatureColumn = featureColumn;
if (!trainSchema.TryGetColumnIndex(featureColumn, out int col))
if (featureColumn == null)
FeatureColumnType = null;
else if (!trainSchema.TryGetColumnIndex(featureColumn, out int col))
throw Host.ExceptSchemaMismatch(nameof(featureColumn), RoleMappedSchema.ColumnRole.Feature.Value, featureColumn);
FeatureColumnType = trainSchema.GetColumnType(col);
else
FeatureColumnType = trainSchema.GetColumnType(col);

TrainSchema = trainSchema;
BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model);
Expand Down Expand Up @@ -75,10 +79,13 @@ internal PredictionTransformerBase(IHost host, ModelLoadContext ctx)
var loader = new BinaryLoader(host, new BinaryLoader.Arguments(), ms);
TrainSchema = loader.Schema;

FeatureColumn = ctx.LoadString();
if (!TrainSchema.TryGetColumnIndex(FeatureColumn, out int col))
FeatureColumn = ctx.LoadStringOrNull();
if (FeatureColumn == null)
FeatureColumnType = null;
else if (!TrainSchema.TryGetColumnIndex(FeatureColumn, out int col))
throw Host.ExceptSchemaMismatch(nameof(FeatureColumn), RoleMappedSchema.ColumnRole.Feature.Value, FeatureColumn);
FeatureColumnType = TrainSchema.GetColumnType(col);
else
FeatureColumnType = TrainSchema.GetColumnType(col);

BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model);
}
Expand All @@ -87,10 +94,13 @@ public ISchema GetOutputSchema(ISchema inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));

if (!inputSchema.TryGetColumnIndex(FeatureColumn, out int col))
throw Host.ExceptSchemaMismatch(nameof(inputSchema), RoleMappedSchema.ColumnRole.Feature.Value, FeatureColumn, FeatureColumnType.ToString(), null);
if (!inputSchema.GetColumnType(col).Equals(FeatureColumnType))
throw Host.ExceptSchemaMismatch(nameof(inputSchema), RoleMappedSchema.ColumnRole.Feature.Value, FeatureColumn, FeatureColumnType.ToString(), inputSchema.GetColumnType(col).ToString());
if(FeatureColumn != null)
{
if (!inputSchema.TryGetColumnIndex(FeatureColumn, out int col))
throw Host.ExceptSchemaMismatch(nameof(inputSchema), RoleMappedSchema.ColumnRole.Feature.Value, FeatureColumn, FeatureColumnType.ToString(), null);
if (!inputSchema.GetColumnType(col).Equals(FeatureColumnType))
throw Host.ExceptSchemaMismatch(nameof(inputSchema), RoleMappedSchema.ColumnRole.Feature.Value, FeatureColumn, FeatureColumnType.ToString(), inputSchema.GetColumnType(col).ToString());
}

return Transform(new EmptyDataView(Host, inputSchema)).Schema;
}
Expand Down Expand Up @@ -121,7 +131,7 @@ protected virtual void SaveCore(ModelSaveContext ctx)
}
});

ctx.SaveString(FeatureColumn);
ctx.SaveStringOrNull(FeatureColumn);
}
}

Expand Down
48 changes: 27 additions & 21 deletions src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -108,21 +108,22 @@ public ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema)
using (var ch = env.Register("SchemaBindableWrapper").Start("Bind"))
{
ch.CheckValue(schema, nameof(schema));
ch.CheckParam(schema.Feature != null, nameof(schema), "Need a features column");
// Ensure that the feature column type is compatible with the needed input type.
var type = schema.Feature.Type;
var typeIn = ValueMapper != null ? ValueMapper.InputType : new VectorType(NumberType.Float);
if (type != typeIn)
if (schema.Feature != null)
{
if (!type.ItemType.Equals(typeIn.ItemType))
throw ch.Except("Incompatible features column type item type: '{0}' vs '{1}'", type.ItemType, typeIn.ItemType);
if (type.IsVector != typeIn.IsVector)
throw ch.Except("Incompatible features column type: '{0}' vs '{1}'", type, typeIn);
// typeIn can legally have unknown size.
if (type.VectorSize != typeIn.VectorSize && typeIn.VectorSize > 0)
throw ch.Except("Incompatible features column type: '{0}' vs '{1}'", type, typeIn);
// Ensure that the feature column type is compatible with the needed input type.
var type = schema.Feature.Type;
var typeIn = ValueMapper != null ? ValueMapper.InputType : new VectorType(NumberType.Float);
if (type != typeIn)
{
if (!type.ItemType.Equals(typeIn.ItemType))
throw ch.Except("Incompatible features column type item type: '{0}' vs '{1}'", type.ItemType, typeIn.ItemType);
if (type.IsVector != typeIn.IsVector)
throw ch.Except("Incompatible features column type: '{0}' vs '{1}'", type, typeIn);
// typeIn can legally have unknown size.
if (type.VectorSize != typeIn.VectorSize && typeIn.VectorSize > 0)
throw ch.Except("Incompatible features column type: '{0}' vs '{1}'", type, typeIn);
}
}

var mapper = BindCore(ch, schema);
ch.Done();
return mapper;
Expand Down Expand Up @@ -463,15 +464,18 @@ public CalibratedRowMapper(RoleMappedSchema schema, SchemaBindableBinaryPredicto
Contracts.AssertValue(parent);
Contracts.Assert(parent._distMapper != null);
Contracts.AssertValue(schema);
Contracts.AssertValue(schema.Feature);
Contracts.AssertValueOrNull(schema.Feature);

_parent = parent;
_inputSchema = schema;
_outputSchema = new BinaryClassifierSchema();

var typeSrc = _inputSchema.Feature.Type;
Contracts.Check(typeSrc.IsKnownSizeVector && typeSrc.ItemType == NumberType.Float,
"Invalid feature column type");
if (schema.Feature != null)
{
var typeSrc = _inputSchema.Feature.Type;
Contracts.Check(typeSrc.IsKnownSizeVector && typeSrc.ItemType == NumberType.Float,
"Invalid feature column type");
}
}

public RoleMappedSchema InputSchema { get { return _inputSchema; } }
Expand All @@ -484,15 +488,15 @@ public Func<int, bool> GetDependencies(Func<int, bool> predicate)
{
for (int i = 0; i < OutputSchema.ColumnCount; i++)
{
if (predicate(i))
if (predicate(i) && _inputSchema.Feature != null)
return col => col == _inputSchema.Feature.Index;
}
return col => false;
}

public IEnumerable<KeyValuePair<RoleMappedSchema.ColumnRole, string>> GetInputColumnRoles()
{
yield return RoleMappedSchema.ColumnRole.Feature.Bind(_inputSchema.Feature.Name);
yield return RoleMappedSchema.ColumnRole.Feature.Bind(_inputSchema.Feature != null ? _inputSchema.Feature.Name : null);
}

private Delegate[] CreateGetters(IRow input, bool[] active)
Expand All @@ -504,7 +508,7 @@ private Delegate[] CreateGetters(IRow input, bool[] active)
if (active[0] || active[1])
{
// Put all captured locals at this scope.
var featureGetter = input.GetGetter<VBuffer<Float>>(_inputSchema.Feature.Index);
var featureGetter = _inputSchema.Feature!= null ? input.GetGetter<VBuffer<Float>>(_inputSchema.Feature.Index) : null;
Float prob = 0;
Float score = 0;
long cachedPosition = -1;
Expand Down Expand Up @@ -543,7 +547,9 @@ private static void EnsureCachedResultValueMapper(ValueMapper<VBuffer<Float>, Fl
Contracts.AssertValue(mapper);
if (cachedPosition != input.Position)
{
featureGetter(ref features);
if (featureGetter != null)
featureGetter(ref features);

mapper(ref features, ref score, ref prob);
cachedPosition = input.Position;
}
Expand Down
Loading