Skip to content

Commit 9157cea

Browse files
authored
ML Context to create them all (#1252)
* ML Context and a couple extensions
1 parent c6d4e62 commit 9157cea

File tree

56 files changed

+1647
-756
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

56 files changed

+1647
-756
lines changed

docs/code/MlNetCookBook.md

Lines changed: 93 additions & 119 deletions
Large diffs are not rendered by default.

docs/samples/Microsoft.ML.Samples/Trainers.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
// the alignment of the usings with the methods is intentional so they can display on the same level in the docs site.
66
using Microsoft.ML.Runtime.Data;
77
using Microsoft.ML.Runtime.Learners;
8-
using Microsoft.ML.Trainers;
8+
using Microsoft.ML.StaticPipe;
99
using System;
1010

1111
// NOTE: WHEN ADDING TO THE FILE, ALWAYS APPEND TO THE END OF IT.
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
namespace Microsoft.ML.Runtime
6+
{
7+
/// <summary>
8+
/// A catalog of operations to load and save data.
9+
/// </summary>
10+
public sealed class DataLoadSaveOperations
11+
{
12+
internal IHostEnvironment Environment { get; }
13+
14+
internal DataLoadSaveOperations(IHostEnvironment env)
15+
{
16+
Contracts.AssertValue(env);
17+
Environment = env;
18+
}
19+
}
20+
}

src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,11 @@ public Column() { }
4646
public Column(string name, DataKind? type, int index)
4747
: this(name, type, new[] { new Range(index) }) { }
4848

49+
public Column(string name, DataKind? type, int minIndex, int maxIndex)
50+
: this(name, type, new[] { new Range(minIndex, maxIndex) })
51+
{
52+
}
53+
4954
public Column(string name, DataKind? type, Range[] source, KeyRange keyRange = null)
5055
{
5156
Contracts.CheckValue(name, nameof(name));
@@ -1003,6 +1008,18 @@ private bool HasHeader
10031008
private readonly IHost _host;
10041009
private const string RegistrationName = "TextLoader";
10051010

1011+
public TextLoader(IHostEnvironment env, Column[] columns, Action<Arguments> advancedSettings, IMultiStreamSource dataSample = null)
1012+
: this(env, MakeArgs(columns, advancedSettings), dataSample)
1013+
{
1014+
}
1015+
1016+
private static Arguments MakeArgs(Column[] columns, Action<Arguments> advancedSettings)
1017+
{
1018+
var result = new Arguments { Column = columns };
1019+
advancedSettings?.Invoke(result);
1020+
return result;
1021+
}
1022+
10061023
public TextLoader(IHostEnvironment env, Arguments args, IMultiStreamSource dataSample = null)
10071024
{
10081025
Contracts.CheckValue(env, nameof(env));
@@ -1320,6 +1337,8 @@ public void Save(ModelSaveContext ctx)
13201337

13211338
public IDataView Read(IMultiStreamSource source) => new BoundLoader(this, source);
13221339

1340+
public IDataView Read(string path) => Read(new MultiFileSource(path));
1341+
13231342
private sealed class BoundLoader : IDataLoader
13241343
{
13251344
private readonly TextLoader _reader;
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Runtime;
6+
using Microsoft.ML.Runtime.Data;
7+
using Microsoft.ML.Runtime.Data.IO;
8+
using Microsoft.ML.Runtime.Internal.Utilities;
9+
using System;
10+
using System.Collections.Generic;
11+
using System.IO;
12+
using System.Linq;
13+
using System.Text;
14+
15+
namespace Microsoft.ML
16+
{
17+
public static class TextLoaderSaverCatalog
18+
{
19+
/// <summary>
20+
/// Create a text reader.
21+
/// </summary>
22+
/// <param name="catalog">The catalog.</param>
23+
/// <param name="args">The arguments to text reader, describing the data schema.</param>
24+
/// <param name="dataSample">The optional location of a data sample.</param>
25+
public static TextLoader TextReader(this DataLoadSaveOperations catalog,
26+
TextLoader.Arguments args, IMultiStreamSource dataSample = null)
27+
=> new TextLoader(CatalogUtils.GetEnvironment(catalog), args, dataSample);
28+
29+
/// <summary>
30+
/// Create a text reader.
31+
/// </summary>
32+
/// <param name="catalog">The catalog.</param>
33+
/// <param name="columns">The columns of the schema.</param>
34+
/// <param name="advancedSettings">The delegate to set additional settings.</param>
35+
/// <param name="dataSample">The optional location of a data sample.</param>
36+
public static TextLoader TextReader(this DataLoadSaveOperations catalog,
37+
TextLoader.Column[] columns, Action<TextLoader.Arguments> advancedSettings = null, IMultiStreamSource dataSample = null)
38+
=> new TextLoader(CatalogUtils.GetEnvironment(catalog), columns, advancedSettings, dataSample);
39+
40+
/// <summary>
41+
/// Read a data view from a text file using <see cref="TextLoader"/>.
42+
/// </summary>
43+
/// <param name="catalog">The catalog.</param>
44+
/// <param name="columns">The columns of the schema.</param>
45+
/// <param name="advancedSettings">The delegate to set additional settings</param>
46+
/// <param name="path">The path to the file</param>
47+
/// <returns>The data view.</returns>
48+
public static IDataView ReadFromTextFile(this DataLoadSaveOperations catalog,
49+
TextLoader.Column[] columns, string path, Action<TextLoader.Arguments> advancedSettings = null)
50+
{
51+
Contracts.CheckNonEmpty(path, nameof(path));
52+
53+
var env = catalog.GetEnvironment();
54+
55+
// REVIEW: it is almost always a mistake to have a 'trainable' text loader here.
56+
// Therefore, we are going to disallow data sample.
57+
var reader = new TextLoader(env, columns, advancedSettings, dataSample: null);
58+
return reader.Read(new MultiFileSource(path));
59+
}
60+
61+
/// <summary>
62+
/// Save the data view as text.
63+
/// </summary>
64+
/// <param name="catalog">The catalog.</param>
65+
/// <param name="data">The data view to save.</param>
66+
/// <param name="stream">The stream to write to.</param>
67+
/// <param name="separator">The column separator.</param>
68+
/// <param name="headerRow">Whether to write the header row.</param>
69+
/// <param name="schema">Whether to write the header comment with the schema.</param>
70+
/// <param name="keepHidden">Whether to keep hidden columns in the dataset.</param>
71+
public static void SaveAsText(this DataLoadSaveOperations catalog, IDataView data, Stream stream,
72+
char separator = '\t', bool headerRow = true, bool schema = true, bool keepHidden = false)
73+
{
74+
Contracts.CheckValue(catalog, nameof(catalog));
75+
Contracts.CheckValue(data, nameof(data));
76+
Contracts.CheckValue(stream, nameof(stream));
77+
78+
var env = catalog.GetEnvironment();
79+
var saver = new TextSaver(env, new TextSaver.Arguments { Separator = separator.ToString(), OutputHeader = headerRow, OutputSchema = schema });
80+
81+
using (var ch = env.Start("Saving data"))
82+
DataSaverUtils.SaveDataView(ch, saver, data, stream, keepHidden);
83+
}
84+
}
85+
}

src/Microsoft.ML.Data/Evaluators/EvaluatorStaticExtensions.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ public static RegressionEvaluator.Result Evaluate<T>(
213213
/// <param name="score">The index delegate for predicted score column.</param>
214214
/// <returns>The evaluation metrics.</returns>
215215
public static RankerEvaluator.Result Evaluate<T, TVal>(
216-
this RankerContext ctx,
216+
this RankingContext ctx,
217217
DataView<T> data,
218218
Func<T, Scalar<float>> label,
219219
Func<T, Key<uint, TVal>> groupId,

src/Microsoft.ML.Data/MLContext.cs

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Runtime;
6+
using Microsoft.ML.Runtime.Data;
7+
using System;
8+
9+
namespace Microsoft.ML
10+
{
11+
/// <summary>
12+
/// The <see cref="MLContext"/> is a starting point for all ML.NET operations. It is instantiated by user,
13+
/// provides mechanisms for logging and entry points for training, prediction, model operations etc.
14+
/// </summary>
15+
public sealed class MLContext : IHostEnvironment
16+
{
17+
// REVIEW: consider making LocalEnvironment and MLContext the same class instead of encapsulation.
18+
private readonly LocalEnvironment _env;
19+
20+
/// <summary>
21+
/// Trainers and tasks specific to binary classification problems.
22+
/// </summary>
23+
public BinaryClassificationContext BinaryClassification { get; }
24+
/// <summary>
25+
/// Trainers and tasks specific to multiclass classification problems.
26+
/// </summary>
27+
public MulticlassClassificationContext MulticlassClassification { get; }
28+
/// <summary>
29+
/// Trainers and tasks specific to regression problems.
30+
/// </summary>
31+
public RegressionContext Regression { get; }
32+
/// <summary>
33+
/// Trainers and tasks specific to clustering problems.
34+
/// </summary>
35+
public ClusteringContext Clustering { get; }
36+
/// <summary>
37+
/// Trainers and tasks specific to ranking problems.
38+
/// </summary>
39+
public RankingContext Ranking { get; }
40+
41+
/// <summary>
42+
/// Data processing operations.
43+
/// </summary>
44+
public TransformsCatalog Transforms { get; }
45+
46+
/// <summary>
47+
/// Operations with trained models.
48+
/// </summary>
49+
public ModelOperationsCatalog Model { get; }
50+
51+
/// <summary>
52+
/// Data loading and saving.
53+
/// </summary>
54+
public DataLoadSaveOperations Data { get; }
55+
56+
// REVIEW: I think it's valuable to have the simplest possible interface for logging interception here,
57+
// and expand if and when necessary. Exposing classes like ChannelMessage, MessageSensitivity and so on
58+
// looks premature at this point.
59+
/// <summary>
60+
/// The handler for the log messages.
61+
/// </summary>
62+
public Action<string> Log { get; set; }
63+
64+
/// <summary>
65+
/// Create the ML context.
66+
/// </summary>
67+
/// <param name="seed">Random seed. Set to <c>null</c> for a non-deterministic environment.</param>
68+
/// <param name="conc">Concurrency level. Set to 1 to run single-threaded. Set to 0 to pick automatically.</param>
69+
public MLContext(int? seed = null, int conc = 0)
70+
{
71+
_env = new LocalEnvironment(seed, conc);
72+
_env.AddListener(ProcessMessage);
73+
74+
BinaryClassification = new BinaryClassificationContext(_env);
75+
MulticlassClassification = new MulticlassClassificationContext(_env);
76+
Regression = new RegressionContext(_env);
77+
Clustering = new ClusteringContext(_env);
78+
Ranking = new RankingContext(_env);
79+
Transforms = new TransformsCatalog(_env);
80+
Model = new ModelOperationsCatalog(_env);
81+
Data = new DataLoadSaveOperations(_env);
82+
}
83+
84+
private void ProcessMessage(IMessageSource source, ChannelMessage message)
85+
{
86+
if (Log == null)
87+
return;
88+
89+
var msg = $"[Source={source.FullName}, Kind={message.Kind}] {message.Message}";
90+
// Log may have been reset from another thread.
91+
// We don't care which logger we send the message to, just making sure we don't crash.
92+
Log?.Invoke(msg);
93+
}
94+
95+
int IHostEnvironment.ConcurrencyFactor => _env.ConcurrencyFactor;
96+
bool IHostEnvironment.IsCancelled => _env.IsCancelled;
97+
ComponentCatalog IHostEnvironment.ComponentCatalog => _env.ComponentCatalog;
98+
string IExceptionContext.ContextDescription => _env.ContextDescription;
99+
IFileHandle IHostEnvironment.CreateOutputFile(string path) => _env.CreateOutputFile(path);
100+
IFileHandle IHostEnvironment.CreateTempFile(string suffix, string prefix) => _env.CreateTempFile(suffix, prefix);
101+
IFileHandle IHostEnvironment.OpenInputFile(string path) => _env.OpenInputFile(path);
102+
TException IExceptionContext.Process<TException>(TException ex) => _env.Process(ex);
103+
IHost IHostEnvironment.Register(string name, int? seed, bool? verbose, int? conc) => _env.Register(name, seed, verbose, conc);
104+
IChannel IChannelProvider.Start(string name) => _env.Start(name);
105+
IPipe<TMessage> IChannelProvider.StartPipe<TMessage>(string name) => _env.StartPipe<TMessage>(name);
106+
IProgressChannel IProgressChannelProvider.StartProgressChannel(string name) => _env.StartProgressChannel(name);
107+
}
108+
}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Core.Data;
6+
using Microsoft.ML.Runtime.Data;
7+
using System.IO;
8+
9+
namespace Microsoft.ML.Runtime
10+
{
11+
/// <summary>
12+
/// An object serving as a 'catalog' of available model operations.
13+
/// </summary>
14+
public sealed class ModelOperationsCatalog
15+
{
16+
internal IHostEnvironment Environment { get; }
17+
18+
internal ModelOperationsCatalog(IHostEnvironment env)
19+
{
20+
Contracts.AssertValue(env);
21+
Environment = env;
22+
}
23+
24+
/// <summary>
25+
/// Save the model to the stream.
26+
/// </summary>
27+
/// <param name="model">The trained model to be saved.</param>
28+
/// <param name="stream">A writeable, seekable stream to save to.</param>
29+
public void Save(ITransformer model, Stream stream) => model.SaveTo(Environment, stream);
30+
31+
/// <summary>
32+
/// Load the model from the stream.
33+
/// </summary>
34+
/// <param name="stream">A readable, seekable stream to load from.</param>
35+
/// <returns>The loaded model.</returns>
36+
public ITransformer Load(Stream stream) => TransformerChain.LoadFrom(Environment, stream);
37+
}
38+
}

0 commit comments

Comments
 (0)