Skip to content

API scenarios implementation with Estimators #688

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Aug 21, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 182 additions & 0 deletions src/Microsoft.ML.Core/Data/IEstimator.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Data;
using System.Collections.Generic;
using System.Linq;

namespace Microsoft.ML.Core.Data
{
/// <summary>
/// A set of 'requirements' to the incoming schema, as well as a set of 'promises' of the outgoing schema.
/// This is more relaxed than the proper <see cref="ISchema"/>, since it's only a subset of the columns,
/// and also since it doesn't specify exact <see cref="ColumnType"/>'s for vectors and keys.
/// </summary>
public sealed class SchemaShape
{
public readonly Column[] Columns;

public sealed class Column
{
public enum VectorKind
{
Scalar,
Vector,
VariableVector
}

public readonly string Name;
public readonly VectorKind Kind;
public readonly DataKind ItemKind;
public readonly bool IsKey;
public readonly string[] MetadataKinds;

public Column(string name, VectorKind vecKind, DataKind itemKind, bool isKey, string[] metadataKinds)
{
Contracts.CheckNonEmpty(name, nameof(name));
Contracts.CheckValue(metadataKinds, nameof(metadataKinds));

Name = name;
Kind = vecKind;
ItemKind = itemKind;
IsKey = isKey;
MetadataKinds = metadataKinds;
}
}

public SchemaShape(Column[] columns)
{
Contracts.CheckValue(columns, nameof(columns));
Columns = columns;
}

/// <summary>
/// Create a schema shape out of the fully defined schema.
/// </summary>
public static SchemaShape Create(ISchema schema)
{
Contracts.CheckValue(schema, nameof(schema));
var cols = new List<Column>();

for (int iCol = 0; iCol < schema.ColumnCount; iCol++)
{
if (!schema.IsHidden(iCol))
{
Column.VectorKind vecKind;
var type = schema.GetColumnType(iCol);
if (type.IsKnownSizeVector)
vecKind = Column.VectorKind.Vector;
else if (type.IsVector)
vecKind = Column.VectorKind.VariableVector;
else
vecKind = Column.VectorKind.Scalar;

var kind = type.ItemType.RawKind;
var isKey = type.ItemType.IsKey;

var metadataNames = schema.GetMetadataTypes(iCol)
.Select(kvp => kvp.Key)
.ToArray();
cols.Add(new Column(schema.GetColumnName(iCol), vecKind, kind, isKey, metadataNames));
}
}
return new SchemaShape(cols.ToArray());
}

/// <summary>
/// Returns the column with a specified <paramref name="name"/>, and <c>null</c> if there is no such column.
/// </summary>
public Column FindColumn(string name)
{
Contracts.CheckValue(name, nameof(name));
return Columns.FirstOrDefault(x => x.Name == name);
}

// REVIEW: I think we should have an IsCompatible method to check if it's OK to use one schema shape
// as an input to another schema shape. I started writing, but realized that there's more than one way to check for
// the 'compatibility': as in, 'CAN be compatible' vs. 'WILL be compatible'.
}

/// <summary>
/// The 'data reader' takes a certain kind of input and turns it into an <see cref="IDataView"/>.
/// </summary>
/// <typeparam name="TSource">The type of input the reader takes.</typeparam>
public interface IDataReader<in TSource>
{
/// <summary>
/// Produce the data view from the specified input.
/// Note that <see cref="IDataView"/>'s are lazy, so no actual reading happens here, just schema validation.
/// </summary>
IDataView Read(TSource input);

/// <summary>
/// The output schema of the reader.
/// </summary>
ISchema GetOutputSchema();
}

/// <summary>
/// Sometimes we need to 'fit' an <see cref="IDataReader{TIn}"/>.
/// A DataReader estimator is the object that does it.
/// </summary>
public interface IDataReaderEstimator<in TSource, out TReader>
where TReader : IDataReader<TSource>
{
/// <summary>
/// Train and return a data reader.
///
/// REVIEW: you could consider the transformer to take a different <typeparamref name="TSource"/>, but we don't have such components
/// yet, so why complicate matters?
/// </summary>
TReader Fit(TSource input);

/// <summary>
/// The 'promise' of the output schema.
/// It will be used for schema propagation.
/// </summary>
SchemaShape GetOutputSchema();
}

/// <summary>
/// The transformer is a component that transforms data.
/// It also supports 'schema propagation' to answer the question of 'how the data with this schema look after you transform it?'.
/// </summary>
public interface ITransformer
{
/// <summary>
/// Schema propagation for transformers.
/// Returns the output schema of the data, if the input schema is like the one provided.
/// Returns <c>null</c> iff the schema is invalid (then a call to Transform with this data will fail).
/// </summary>
ISchema GetOutputSchema(ISchema inputSchema);

/// <summary>
/// Take the data in, make transformations, output the data.
/// Note that <see cref="IDataView"/>'s are lazy, so no actual transformations happen here, just schema validation.
/// </summary>
IDataView Transform(IDataView input);
}

/// <summary>
/// The estimator (in Spark terminology) is an 'untrained transformer'. It needs to 'fit' on the data to manufacture
/// a transformer.
/// It also provides the 'schema propagation' like transformers do, but over <see cref="SchemaShape"/> instead of <see cref="ISchema"/>.
/// </summary>
public interface IEstimator<out TTransformer>
where TTransformer : ITransformer
{
/// <summary>
/// Train and return a transformer.
/// </summary>
TTransformer Fit(IDataView input);

/// <summary>
/// Schema propagation for estimators.
/// Returns the output schema shape of the estimator, if the input schema shape is like the one provided.
/// Returns <c>null</c> iff the schema shape is invalid (then a call to <see cref="Fit"/> with this data will fail).
/// </summary>
SchemaShape GetOutputSchema(SchemaShape inputSchema);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
<ProjectReference Include="..\..\src\Microsoft.ML.Transforms\Microsoft.ML.Transforms.csproj" />
<ProjectReference Include="..\..\src\Microsoft.ML\Microsoft.ML.csproj" />
<ProjectReference Include="..\Microsoft.ML.TestFramework\Microsoft.ML.TestFramework.csproj" />
<ProjectReference Include="..\Microsoft.ML.Tests\Microsoft.ML.Tests.csproj" />
</ItemGroup>

<ItemGroup>
Expand Down
7 changes: 7 additions & 0 deletions test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
<PropertyGroup>
<TargetFramework>netcoreapp2.0</TargetFramework>
</PropertyGroup>
<ItemGroup>
<Compile Remove="Scenarios\Api\AspirationalExamples.cs" />
Copy link
Contributor

@Ivanidzo4ka Ivanidzo4ka Aug 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AspirationalExamples [](start = 35, length = 20)

why? #Pending

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because right now it doesn't compile, and may never compile, since it uses my own imaginary version of Pigsty, which may differ from the real one, when it appears.


In reply to: 210986324 [](ancestors = 210986324)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we even add this file? What's the point?


In reply to: 211030708 [](ancestors = 211030708,210986324)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Umm, as an aspirational example? It's a way to document what we want to reach at the end. Another way would be to put it into Markdown somewhere, but I think I like this way somewhat better.


In reply to: 211688654 [](ancestors = 211688654,211030708,210986324)

</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\..\src\Microsoft.ML.Ensemble\Microsoft.ML.Ensemble.csproj" />
Expand All @@ -26,4 +29,8 @@
<NativeAssemblyReference Include="SymSgdNative" />
Copy link
Contributor

@Ivanidzo4ka Ivanidzo4ka Aug 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[](start = 0, length = 1)

since you editing this file...
TAB! #Resolved

<NativeAssemblyReference Include="MklImports" />
</ItemGroup>

<ItemGroup>
<None Include="Scenarios\Api\AspirationalExamples.cs" />
</ItemGroup>
</Project>
60 changes: 60 additions & 0 deletions test/Microsoft.ML.Tests/Scenarios/Api/AspirationalExamples.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Microsoft.ML.Tests.Scenarios.Api
{
public class AspirationalExamples
{
public class IrisPrediction
{
public string PredictedLabel;
}

public class IrisExample
{
public float SepalWidth { get; set; }
public float SepalLength { get; set; }
public float PetalWidth { get; set; }
public float PetalLength { get; set; }
}

public void FirstExperienceWithML()
{
// This is the 'getting started with ML' example, how we see it in our new API.
// It currently doesn't compile, let alone work, but we still can discuss and improve the syntax.

// Load the data into the system.
string dataPath = "iris-data.txt";
var data = TextReader.FitAndRead(env, dataPath, row => (
Label: row.ReadString(0),
SepalWidth: row.ReadFloat(1),
SepalLength: row.ReadFloat(2),
PetalWidth: row.ReadFloat(3),
PetalLength: row.ReadFloat(4)));


var preprocess = data.Schema.MakeEstimator(row => (
// Convert string label to key.
Label: row.Label.DictionarizeLabel(),
// Concatenate all features into a vector.
Features: row.SepalWidth.ConcatWith(row.SepalLength, row.PetalWidth, row.PetalLength)));

var pipeline = preprocess
// Append the trainer to the training pipeline.
.AppendEstimator(row => row.Label.PredictWithSdca(row.Features))
.AppendEstimator(row => row.PredictedLabel.KeyToValue());

// Train the model and make some predictions.
var model = pipeline.Fit<IrisExample, IrisPrediction>(data);

IrisPrediction prediction = model.Predict(new IrisExample
{
SepalWidth = 3.3f,
SepalLength = 1.6f,
PetalWidth = 0.2f,
PetalLength = 5.1f
});
}
}
}
Loading