Skip to content

Add save/load APIs for IDataLoader #2858

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Mar 18, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,13 @@ public static void Example()
j.Features = features;
};

var engine = mlContext.Transforms.Text.TokenizeIntoWords("TokenizedWords", "Sentiment_Text")
var model = mlContext.Transforms.Text.TokenizeIntoWords("TokenizedWords", "Sentiment_Text")
.Append(mlContext.Transforms.Conversion.MapValue(lookupMap, "Words", "Ids", new ColumnOptions[] { ("VariableLenghtFeatures", "TokenizedWords") }))
.Append(mlContext.Transforms.CustomMapping(ResizeFeaturesAction, "Resize"))
.Append(tensorFlowModel.ScoreTensorFlowModel(new[] { "Prediction/Softmax" }, new[] { "Features" }))
.Append(mlContext.Transforms.CopyColumns(("Prediction", "Prediction/Softmax")))
.Fit(dataView)
.CreatePredictionEngine<IMDBSentiment, OutputScores>(mlContext);
.Fit(dataView);
var engine = mlContext.Model.CreatePredictionEngine<IMDBSentiment, OutputScores>(model);

// Predict with TensorFlow pipeline.
var prediction = engine.Predict(data[0]);
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Core/Data/IEstimator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ internal bool TryFindColumn(string name, out Column column)
/// The 'data loader' takes a certain kind of input and turns it into an <see cref="IDataView"/>.
/// </summary>
/// <typeparam name="TSource">The type of input the loader takes.</typeparam>
public interface IDataLoader<in TSource>
public interface IDataLoader<in TSource> : ICanSaveModel
{
/// <summary>
/// Produce the data view from the specified input.
Expand Down
95 changes: 44 additions & 51 deletions src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.IO;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Runtime;

[assembly: LoadableClass(CompositeDataLoader<IMultiStreamSource, ITransformer>.Summary, typeof(CompositeDataLoader<IMultiStreamSource, ITransformer>), null, typeof(SignatureLoadModel),
"Composite Loader", CompositeDataLoader<IMultiStreamSource, ITransformer>.LoaderSignature)]

namespace Microsoft.ML.Data
{
/// <summary>
Expand All @@ -14,6 +18,10 @@ namespace Microsoft.ML.Data
public sealed class CompositeDataLoader<TSource, TLastTransformer> : IDataLoader<TSource>
where TLastTransformer : class, ITransformer
{
internal const string TransformerDirectory = TransformerChain.LoaderSignature;
private const string LoaderDirectory = "Loader";
private const string LegacyLoaderDirectory = "Reader";

/// <summary>
/// The underlying data loader.
/// </summary>
Expand All @@ -32,6 +40,24 @@ public CompositeDataLoader(IDataLoader<TSource> loader, TransformerChain<TLastTr
Transformer = transformerChain ?? new TransformerChain<TLastTransformer>();
}

private CompositeDataLoader(IHost host, ModelLoadContext ctx)
{
if (!ctx.LoadModelOrNull<IDataLoader<TSource>, SignatureLoadModel>(host, out Loader, LegacyLoaderDirectory))
ctx.LoadModel<IDataLoader<TSource>, SignatureLoadModel>(host, out Loader, LoaderDirectory);
ctx.LoadModel<TransformerChain<TLastTransformer>, SignatureLoadModel>(host, out Transformer, TransformerDirectory);
}

private static CompositeDataLoader<TSource, TLastTransformer> Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
IHost h = env.Register(LoaderSignature);

h.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel(GetVersionInfo());

return h.Apply("Loading Model", ch => new CompositeDataLoader<TSource, TLastTransformer>(h, ctx));
}

/// <summary>
/// Produce the data view from the specified input.
/// Note that <see cref="IDataView"/>'s are lazy, so no actual loading happens here, just schema validation.
Expand Down Expand Up @@ -61,61 +87,28 @@ public CompositeDataLoader<TSource, TNewLast> AppendTransformer<TNewLast>(TNewLa
return new CompositeDataLoader<TSource, TNewLast>(Loader, Transformer.Append(transformer));
}

/// <summary>
/// Save the contents to a stream, as a "model file".
/// </summary>
public void SaveTo(IHostEnvironment env, Stream outputStream)
void ICanSaveModel.Save(ModelSaveContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(outputStream, nameof(outputStream));

env.Check(outputStream.CanWrite && outputStream.CanSeek, "Need a writable and seekable stream to save");
using (var ch = env.Start("Saving pipeline"))
{
using (var rep = RepositoryWriter.CreateNew(outputStream, ch))
{
ch.Trace("Saving data loader");
ModelSaveContext.SaveModel(rep, Loader, "Reader");

ch.Trace("Saving transformer chain");
ModelSaveContext.SaveModel(rep, Transformer, TransformerChain.LoaderSignature);
rep.Commit();
}
}
Contracts.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
ctx.SetVersionInfo(GetVersionInfo());

ctx.SaveModel(Loader, LoaderDirectory);
ctx.SaveModel(Transformer, TransformerDirectory);
}
}

/// <summary>
/// Utility class to facilitate loading from a stream.
/// </summary>
[BestFriend]
internal static class CompositeDataLoader
{
/// <summary>
/// Save the contents to a stream, as a "model file".
/// </summary>
public static void SaveTo<TSource>(this IDataLoader<TSource> loader, IHostEnvironment env, Stream outputStream)
=> new CompositeDataLoader<TSource, ITransformer>(loader).SaveTo(env, outputStream);
internal const string Summary = "A model loader that encapsulates a data loader and a transformer chain.";

/// <summary>
/// Load the pipeline from stream.
/// </summary>
public static CompositeDataLoader<IMultiStreamSource, ITransformer> LoadFrom(IHostEnvironment env, Stream stream)
internal const string LoaderSignature = "CompositeLoader";
private static VersionInfo GetVersionInfo()
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(stream, nameof(stream));

env.Check(stream.CanRead && stream.CanSeek, "Need a readable and seekable stream to load");
using (var rep = RepositoryReader.Open(stream, env))
using (var ch = env.Start("Loading pipeline"))
{
ch.Trace("Loading data loader");
ModelLoadContext.LoadModel<IDataLoader<IMultiStreamSource>, SignatureLoadModel>(env, out var loader, rep, "Reader");

ch.Trace("Loader transformer chain");
ModelLoadContext.LoadModel<TransformerChain<ITransformer>, SignatureLoadModel>(env, out var transformerChain, rep, TransformerChain.LoaderSignature);
return new CompositeDataLoader<IMultiStreamSource, ITransformer>(loader, transformerChain);
}
return new VersionInfo(
modelSignature: "CMPSTLDR",
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature,
loaderAssemblyName: typeof(CompositeDataLoader<,>).Assembly.FullName);
}
}
}
8 changes: 8 additions & 0 deletions src/Microsoft.ML.Data/DataLoadSave/DataOperationsCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,14 @@ public IDataView LoadFromEnumerable<TRow>(IEnumerable<TRow> data, SchemaDefiniti
return DataViewConstructionUtils.CreateFromEnumerable(_env, data, schemaDefinition);
}

public IDataView LoadFromEnumerable<TRow>(IEnumerable<TRow> data, DataViewSchema schema)
where TRow : class
{
_env.CheckValue(data, nameof(data));
_env.CheckValue(schema, nameof(schema));
return DataViewConstructionUtils.CreateFromEnumerable(_env, data, schema);
}

/// <summary>
/// Convert an <see cref="IDataView"/> into a strongly-typed <see cref="IEnumerable{TRow}"/>.
/// </summary>
Expand Down
79 changes: 41 additions & 38 deletions src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@
[assembly: LoadableClass(TextLoader.Summary, typeof(ILegacyDataLoader), typeof(TextLoader), null, typeof(SignatureLoadDataLoader),
"Text Loader", TextLoader.LoaderSignature)]

[assembly: LoadableClass(TextLoader.Summary, typeof(TextLoader), null, typeof(SignatureLoadModel),
"Text Loader", TextLoader.LoaderSignature)]

namespace Microsoft.ML.Data
{
/// <summary>
/// Loads a text file into an IDataView. Supports basic mapping from input columns to <see cref="IDataView"/> columns.
/// </summary>
public sealed partial class TextLoader : IDataLoader<IMultiStreamSource>, ICanSaveModel
public sealed partial class TextLoader : IDataLoader<IMultiStreamSource>
{
/// <summary>
/// Describes how an input column should be mapped to an <see cref="IDataView"/> column.
Expand Down Expand Up @@ -1189,31 +1192,31 @@ private char NormalizeSeparator(string sep)
{
switch (sep)
{
case "space":
case " ":
return ' ';
case "tab":
case "\t":
return '\t';
case "comma":
case ",":
return ',';
case "colon":
case ":":
_host.CheckUserArg((_flags & OptionFlags.AllowSparse) == 0, nameof(Options.Separator),
"When the separator is colon, turn off allowSparse");
return ':';
case "semicolon":
case ";":
return ';';
case "bar":
case "|":
return '|';
default:
char ch = sep[0];
if (sep.Length != 1 || ch < ' ' || '0' <= ch && ch <= '9' || ch == '"')
throw _host.ExceptUserArg(nameof(Options.Separator), "Illegal separator: '{0}'", sep);
return sep[0];
case "space":
case " ":
return ' ';
case "tab":
case "\t":
return '\t';
case "comma":
case ",":
return ',';
case "colon":
case ":":
_host.CheckUserArg((_flags & OptionFlags.AllowSparse) == 0, nameof(Options.Separator),
"When the separator is colon, turn off allowSparse");
return ':';
case "semicolon":
case ";":
return ';';
case "bar":
case "|":
return '|';
default:
char ch = sep[0];
if (sep.Length != 1 || ch < ' ' || '0' <= ch && ch <= '9' || ch == '"')
throw _host.ExceptUserArg(nameof(Options.Separator), "Illegal separator: '{0}'", sep);
return sep[0];
}
}

Expand Down Expand Up @@ -1310,7 +1313,7 @@ private static bool TryParseSchema(IHost host, IMultiStreamSource files,
error = false;
options = optionsNew;

LDone:
LDone:
return !error;
}
}
Expand Down Expand Up @@ -1470,20 +1473,20 @@ internal static TextLoader CreateTextLoader<TInput>(IHostEnvironment host,
InternalDataKind dk;
switch (memberInfo)
{
case FieldInfo field:
if (!InternalDataKindExtensions.TryGetDataKind(field.FieldType.IsArray ? field.FieldType.GetElementType() : field.FieldType, out dk))
throw Contracts.Except($"Field {memberInfo.Name} is of unsupported type.");
case FieldInfo field:
if (!InternalDataKindExtensions.TryGetDataKind(field.FieldType.IsArray ? field.FieldType.GetElementType() : field.FieldType, out dk))
throw Contracts.Except($"Field {memberInfo.Name} is of unsupported type.");

break;
break;

case PropertyInfo property:
if (!InternalDataKindExtensions.TryGetDataKind(property.PropertyType.IsArray ? property.PropertyType.GetElementType() : property.PropertyType, out dk))
throw Contracts.Except($"Property {memberInfo.Name} is of unsupported type.");
break;
case PropertyInfo property:
if (!InternalDataKindExtensions.TryGetDataKind(property.PropertyType.IsArray ? property.PropertyType.GetElementType() : property.PropertyType, out dk))
throw Contracts.Except($"Property {memberInfo.Name} is of unsupported type.");
break;

default:
Contracts.Assert(false);
throw Contracts.ExceptNotSupp("Expected a FieldInfo or a PropertyInfo");
default:
Contracts.Assert(false);
throw Contracts.ExceptNotSupp("Expected a FieldInfo or a PropertyInfo");
}

column.Type = dk;
Expand Down
Loading