Skip to content

Commit d942001

Browse files
author
Ivan Matantsev
committed
merge with master
2 parents ba52135 + 73b0308 commit d942001

28 files changed

+2030
-29
lines changed

src/Microsoft.ML.Core/Data/IEstimator.cs

-6
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,6 @@ public Column(string name, VectorKind vecKind, DataKind itemKind, bool isKey, st
4545
IsKey = isKey;
4646
MetadataKinds = metadataKinds;
4747
}
48-
49-
public Column CloneWithNewName(string newName)
50-
{
51-
Contracts.Check(newName != Name, "Should be different name");
52-
return new Column(newName, Kind, ItemKind, IsKey, MetadataKinds);
53-
}
5448
}
5549

5650
public SchemaShape(Column[] columns)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Core.Data;
6+
using Microsoft.ML.Runtime.Model;
7+
using System.IO;
8+
9+
namespace Microsoft.ML.Runtime.Data
10+
{
11+
/// <summary>
12+
/// This class represents a data reader that applies a transformer chain after reading.
13+
/// It also has methods to save itself to a repository.
14+
/// </summary>
15+
public sealed class CompositeDataReader<TSource, TLastTransformer> : IDataReader<TSource>
16+
where TLastTransformer : class, ITransformer
17+
{
18+
/// <summary>
19+
/// The underlying data reader.
20+
/// </summary>
21+
public readonly IDataReader<TSource> Reader;
22+
/// <summary>
23+
/// The chain of transformers (possibly empty) that are applied to data upon reading.
24+
/// </summary>
25+
public readonly TransformerChain<TLastTransformer> Transformer;
26+
27+
public CompositeDataReader(IDataReader<TSource> reader, TransformerChain<TLastTransformer> transformerChain = null)
28+
{
29+
Contracts.CheckValue(reader, nameof(reader));
30+
Contracts.CheckValueOrNull(transformerChain);
31+
32+
Reader = reader;
33+
Transformer = transformerChain ?? new TransformerChain<TLastTransformer>();
34+
}
35+
36+
public IDataView Read(TSource input)
37+
{
38+
var idv = Reader.Read(input);
39+
idv = Transformer.Transform(idv);
40+
return idv;
41+
}
42+
43+
public ISchema GetOutputSchema()
44+
{
45+
var s = Reader.GetOutputSchema();
46+
return Transformer.GetOutputSchema(s);
47+
}
48+
49+
/// <summary>
50+
/// Append a new transformer to the end.
51+
/// </summary>
52+
/// <returns>The new composite data reader</returns>
53+
public CompositeDataReader<TSource, TNewLast> AppendTransformer<TNewLast>(TNewLast transformer)
54+
where TNewLast : class, ITransformer
55+
{
56+
Contracts.CheckValue(transformer, nameof(transformer));
57+
58+
return new CompositeDataReader<TSource, TNewLast>(Reader, Transformer.Append(transformer));
59+
}
60+
61+
/// <summary>
62+
/// Save the contents to a stream, as a "model file".
63+
/// </summary>
64+
public void SaveTo(IHostEnvironment env, Stream outputStream)
65+
{
66+
Contracts.CheckValue(env, nameof(env));
67+
env.CheckValue(outputStream, nameof(outputStream));
68+
69+
env.Check(outputStream.CanWrite && outputStream.CanSeek, "Need a writable and seekable stream to save");
70+
using (var ch = env.Start("Saving pipeline"))
71+
{
72+
using (var rep = RepositoryWriter.CreateNew(outputStream, ch))
73+
{
74+
ch.Trace("Saving data reader");
75+
ModelSaveContext.SaveModel(rep, Reader, "Reader");
76+
77+
ch.Trace("Saving transformer chain");
78+
ModelSaveContext.SaveModel(rep, Transformer, TransformerChain.LoaderSignature);
79+
rep.Commit();
80+
}
81+
}
82+
}
83+
}
84+
85+
/// <summary>
86+
/// Utility class to facilitate loading from a stream.
87+
/// </summary>
88+
public static class CompositeDataReader
89+
{
90+
/// <summary>
91+
/// Load the pipeline from stream.
92+
/// </summary>
93+
public static CompositeDataReader<IMultiStreamSource, ITransformer> LoadFrom(IHostEnvironment env, Stream stream)
94+
{
95+
Contracts.CheckValue(env, nameof(env));
96+
env.CheckValue(stream, nameof(stream));
97+
98+
env.Check(stream.CanRead && stream.CanSeek, "Need a readable and seekable stream to load");
99+
using (var rep = RepositoryReader.Open(stream, env))
100+
using (var ch = env.Start("Loading pipeline"))
101+
{
102+
ch.Trace("Loading data reader");
103+
ModelLoadContext.LoadModel<IDataReader<IMultiStreamSource>, SignatureLoadModel>(env, out var reader, rep, "Reader");
104+
105+
ch.Trace("Loader transformer chain");
106+
ModelLoadContext.LoadModel<TransformerChain<ITransformer>, SignatureLoadModel>(env, out var transformerChain, rep, TransformerChain.LoaderSignature);
107+
return new CompositeDataReader<IMultiStreamSource, ITransformer>(reader, transformerChain);
108+
}
109+
}
110+
}
111+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Core.Data;
6+
7+
namespace Microsoft.ML.Runtime.Data
8+
{
9+
/// <summary>
10+
/// An estimator class for composite data reader.
11+
/// It can be used to build a 'trainable smart data reader', although this pattern is not very common.
12+
/// </summary>
13+
public sealed class CompositeReaderEstimator<TSource, TLastTransformer> : IDataReaderEstimator<TSource, CompositeDataReader<TSource, TLastTransformer>>
14+
where TLastTransformer : class, ITransformer
15+
{
16+
private readonly IDataReaderEstimator<TSource, IDataReader<TSource>> _start;
17+
private readonly EstimatorChain<TLastTransformer> _estimatorChain;
18+
19+
public CompositeReaderEstimator(IDataReaderEstimator<TSource, IDataReader<TSource>> start, EstimatorChain<TLastTransformer> estimatorChain = null)
20+
{
21+
Contracts.CheckValue(start, nameof(start));
22+
Contracts.CheckValueOrNull(estimatorChain);
23+
24+
_start = start;
25+
_estimatorChain = estimatorChain ?? new EstimatorChain<TLastTransformer>();
26+
27+
// REVIEW: enforce that estimator chain can read the reader's schema.
28+
// Right now it throws.
29+
// GetOutputSchema();
30+
}
31+
32+
public CompositeDataReader<TSource, TLastTransformer> Fit(TSource input)
33+
{
34+
var start = _start.Fit(input);
35+
var idv = start.Read(input);
36+
37+
var xfChain = _estimatorChain.Fit(idv);
38+
return new CompositeDataReader<TSource, TLastTransformer>(start, xfChain);
39+
}
40+
41+
public SchemaShape GetOutputSchema()
42+
{
43+
var shape = _start.GetOutputSchema();
44+
return _estimatorChain.GetOutputSchema(shape);
45+
}
46+
47+
/// <summary>
48+
/// Append another estimator to the end.
49+
/// </summary>
50+
public CompositeReaderEstimator<TSource, TNewTrans> Append<TNewTrans>(IEstimator<TNewTrans> estimator)
51+
where TNewTrans : class, ITransformer
52+
{
53+
Contracts.CheckValue(estimator, nameof(estimator));
54+
55+
return new CompositeReaderEstimator<TSource, TNewTrans>(_start, _estimatorChain.Append(estimator));
56+
}
57+
}
58+
59+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Core.Data;
6+
using Microsoft.ML.Runtime.Internal.Utilities;
7+
using System.Linq;
8+
9+
namespace Microsoft.ML.Runtime.Data
10+
{
11+
/// <summary>
12+
/// Represents a chain (potentially empty) of estimators that end with a <typeparamref name="TLastTransformer"/>.
13+
/// If the chain is empty, <typeparamref name="TLastTransformer"/> is always <see cref="ITransformer"/>.
14+
/// </summary>
15+
public sealed class EstimatorChain<TLastTransformer> : IEstimator<TransformerChain<TLastTransformer>>
16+
where TLastTransformer : class, ITransformer
17+
{
18+
private readonly TransformerScope[] _scopes;
19+
private readonly IEstimator<ITransformer>[] _estimators;
20+
public readonly IEstimator<TLastTransformer> LastEstimator;
21+
22+
private EstimatorChain(IEstimator<ITransformer>[] estimators, TransformerScope[] scopes)
23+
{
24+
Contracts.AssertValueOrNull(estimators);
25+
Contracts.AssertValueOrNull(scopes);
26+
Contracts.Assert(Utils.Size(estimators) == Utils.Size(scopes));
27+
28+
_estimators = estimators ?? new IEstimator<ITransformer>[0];
29+
_scopes = scopes ?? new TransformerScope[0];
30+
LastEstimator = estimators.LastOrDefault() as IEstimator<TLastTransformer>;
31+
32+
Contracts.Assert((_estimators.Length > 0) == (LastEstimator != null));
33+
}
34+
35+
/// <summary>
36+
/// Create an empty estimator chain.
37+
/// </summary>
38+
public EstimatorChain()
39+
{
40+
_estimators = new IEstimator<ITransformer>[0];
41+
_scopes = new TransformerScope[0];
42+
LastEstimator = null;
43+
}
44+
45+
public TransformerChain<TLastTransformer> Fit(IDataView input)
46+
{
47+
// REVIEW: before fitting, run schema propagation.
48+
// Currently, it throws.
49+
// GetOutputSchema(SchemaShape.Create(input.Schema);
50+
51+
IDataView current = input;
52+
var xfs = new ITransformer[_estimators.Length];
53+
for (int i = 0; i < _estimators.Length; i++)
54+
{
55+
var est = _estimators[i];
56+
xfs[i] = est.Fit(current);
57+
current = xfs[i].Transform(current);
58+
}
59+
60+
return new TransformerChain<TLastTransformer>(xfs, _scopes);
61+
}
62+
63+
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
64+
{
65+
var s = inputSchema;
66+
foreach (var est in _estimators)
67+
s = est.GetOutputSchema(s);
68+
return s;
69+
}
70+
71+
public EstimatorChain<TNewTrans> Append<TNewTrans>(IEstimator<TNewTrans> estimator, TransformerScope scope = TransformerScope.Everything)
72+
where TNewTrans : class, ITransformer
73+
{
74+
Contracts.CheckValue(estimator, nameof(estimator));
75+
return new EstimatorChain<TNewTrans>(_estimators.Append(estimator).ToArray(), _scopes.Append(scope).ToArray());
76+
}
77+
}
78+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Core.Data;
6+
7+
namespace Microsoft.ML.Runtime.Data
8+
{
9+
/// <summary>
10+
/// Extension methods that allow chaining estimator and transformer pipes together.
11+
/// </summary>
12+
public static class LearningPipelineExtensions
13+
{
14+
/// <summary>
15+
/// Create a composite reader estimator by appending an estimator to a reader estimator.
16+
/// </summary>
17+
public static CompositeReaderEstimator<TSource, TTrans> Append<TSource, TTrans>(
18+
this IDataReaderEstimator<TSource, IDataReader<TSource>> start, IEstimator<TTrans> estimator)
19+
where TTrans : class, ITransformer
20+
{
21+
Contracts.CheckValue(start, nameof(start));
22+
Contracts.CheckValue(estimator, nameof(estimator));
23+
24+
return new CompositeReaderEstimator<TSource, ITransformer>(start).Append(estimator);
25+
}
26+
27+
/// <summary>
28+
/// Create an estimator chain by appending an estimator to an estimator.
29+
/// </summary>
30+
public static EstimatorChain<TTrans> Append<TTrans>(
31+
this IEstimator<ITransformer> start, IEstimator<TTrans> estimator,
32+
TransformerScope scope = TransformerScope.Everything)
33+
where TTrans : class, ITransformer
34+
{
35+
Contracts.CheckValue(start, nameof(start));
36+
Contracts.CheckValue(estimator, nameof(estimator));
37+
38+
return new EstimatorChain<ITransformer>().Append(start).Append(estimator, scope);
39+
}
40+
41+
/// <summary>
42+
/// Create a composite reader by appending a transformer to a data reader.
43+
/// </summary>
44+
public static CompositeDataReader<TSource, TTrans> Append<TSource, TTrans>(this IDataReader<TSource> reader, TTrans transformer)
45+
where TTrans : class, ITransformer
46+
{
47+
Contracts.CheckValue(reader, nameof(reader));
48+
Contracts.CheckValue(transformer, nameof(transformer));
49+
50+
return new CompositeDataReader<TSource, ITransformer>(reader).AppendTransformer(transformer);
51+
}
52+
53+
/// <summary>
54+
/// Create a transformer chain by appending a transformer to a transformer.
55+
/// </summary>
56+
public static TransformerChain<TTrans> Append<TTrans>(this ITransformer start, TTrans transformer)
57+
where TTrans : class, ITransformer
58+
{
59+
Contracts.CheckValue(start, nameof(start));
60+
Contracts.CheckValue(transformer, nameof(transformer));
61+
62+
return new TransformerChain<TTrans>(start, transformer);
63+
}
64+
}
65+
}

0 commit comments

Comments
 (0)