Skip to content

Commit feed72b

Browse files
committed
Working static pipelines for TextLoader
1 parent 46806aa commit feed72b

File tree

13 files changed

+361
-294
lines changed

13 files changed

+361
-294
lines changed

src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs

+10-4
Original file line numberDiff line numberDiff line change
@@ -201,15 +201,21 @@ public sealed class Range
201201
public Range() { }
202202

203203
public Range(int index)
204-
: this(index, index) { }
204+
{
205+
Contracts.CheckParam(index >= 0, nameof(index), "Must be non-negative");
206+
Min = index;
207+
Max = index;
208+
}
205209

206-
public Range(int min, int max)
210+
public Range(int min, int? max)
207211
{
208-
Contracts.CheckParam(min >= 0, nameof(min), "min must be non-negative.");
209-
Contracts.CheckParam(max >= min, nameof(max), "max must be greater than or equal to min.");
212+
Contracts.CheckParam(min >= 0, nameof(min), "Must be non-negative");
213+
Contracts.CheckParam(!(max < min), nameof(max), "If specified, must be greater than or equal to " + nameof(min));
210214

211215
Min = min;
212216
Max = max;
217+
ForceVector = true;
218+
AutoEnd = max == null;
213219
}
214220

215221
[Argument(ArgumentType.Required, HelpText = "First index in the range")]
Original file line numberDiff line numberDiff line change
@@ -1,80 +1,176 @@
1-
using Float = System.Single;
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
24

35
using System;
4-
using System.Collections.Concurrent;
56
using System.Collections.Generic;
6-
using System.IO;
77
using System.Linq;
8-
using System.Text;
9-
using System.Threading;
10-
using Microsoft.ML.Runtime.Internal.Utilities;
11-
using Microsoft.ML.Data.StaticPipe.Runtime;
128
using Microsoft.ML.Core.Data;
9+
using Microsoft.ML.Data.StaticPipe.Runtime;
10+
using Microsoft.ML.Data.StaticPipe;
11+
using Microsoft.ML.Runtime.Internal.Utilities;
1312

1413
namespace Microsoft.ML.Runtime.Data
1514
{
1615
public sealed partial class TextLoader
1716
{
18-
private sealed class TextReconciler : ReaderReconciler<IMultiStreamSource>
17+
public static DataReader<IMultiStreamSource, TTupleShape> CreateReader<TTupleShape>(
18+
IHostEnvironment env, Func<Context, TTupleShape> func, IMultiStreamSource files = null,
19+
bool hasHeader = false, char separator = '\t', bool allowQuoting = true, bool allowSparse = true,
20+
bool trimWhitspace = false)
1921
{
20-
public static readonly TextReconciler Inst = new TextReconciler();
22+
Contracts.CheckValue(env, nameof(env));
23+
env.CheckValue(func, nameof(func));
24+
env.CheckValueOrNull(files);
2125

22-
public override IDataReaderEstimator<IMultiStreamSource, IDataReader<IMultiStreamSource>> Reconcile(
23-
PipelineColumn[] toOutput, Dictionary<PipelineColumn, string> outputNames)
26+
// Populate all args except the columns.
27+
var args = new Arguments();
28+
args.AllowQuoting = allowQuoting;
29+
args.AllowSparse = allowSparse;
30+
args.HasHeader = hasHeader;
31+
args.SeparatorChars = new[] { separator };
32+
args.TrimWhitespace = trimWhitspace;
33+
34+
var rec = new TextReconciler(args, files);
35+
var ctx = new Context(rec);
36+
37+
using (var ch = env.Start("Initializing " + nameof(TextLoader)))
2438
{
25-
//return new FakeReaderEstimator<IMultiStreamSource>();
26-
return null;
39+
var readerEst = StaticPipeUtils.ReaderEstimatorAnalyzerHelper(env, ch, ctx, rec, func);
40+
Contracts.AssertValue(readerEst);
41+
var reader = readerEst.Fit(files);
42+
ch.Done();
43+
return reader;
2744
}
2845
}
2946

30-
public sealed class Context
47+
private sealed class TextReconciler : ReaderReconciler<IMultiStreamSource>
3148
{
32-
private class MyScalar<T> : Scalar<T>
49+
private readonly Arguments _args;
50+
private readonly IMultiStreamSource _files;
51+
52+
public TextReconciler(Arguments args, IMultiStreamSource files)
3353
{
34-
public readonly int Ordinal;
54+
Contracts.AssertValue(args);
55+
Contracts.AssertValueOrNull(files);
3556

36-
public MyScalar(int ordinal)
37-
: base(TextReconciler.Inst, null)
38-
{
39-
Ordinal = ordinal;
40-
}
57+
_args = args;
58+
_files = files;
4159
}
4260

43-
private class MyVector<T> : Vector<T>
61+
public override IDataReaderEstimator<IMultiStreamSource, IDataReader<IMultiStreamSource>> Reconcile(
62+
IHostEnvironment env, PipelineColumn[] toOutput, IReadOnlyDictionary<PipelineColumn, string> outputNames)
4463
{
45-
public readonly int? Min;
46-
public readonly int? Max;
64+
Contracts.AssertValue(env);
65+
Contracts.AssertValue(toOutput);
66+
Contracts.AssertValue(outputNames);
67+
Contracts.Assert(_args.Column == null);
4768

48-
public MyVector(int? min, int? max)
49-
: base(TextReconciler.Inst, null)
69+
Column Create(PipelineColumn pipelineCol)
5070
{
51-
Min = min;
52-
Max = max;
71+
var pipelineArgCol = (IPipelineArgColumn)pipelineCol;
72+
var argCol = pipelineArgCol.Create();
73+
argCol.Name = outputNames[pipelineCol];
74+
return argCol;
5375
}
76+
77+
var cols = _args.Column = new Column[toOutput.Length];
78+
for (int i = 0; i < toOutput.Length; ++i)
79+
cols[i] = Create(toOutput[i]);
80+
81+
var orig = new TextLoader(env, _args, _files);
82+
return new TrivialReaderEstimator<IMultiStreamSource, TextLoader>(orig);
5483
}
84+
}
5585

56-
public Scalar<bool> LoadBool(int ordinal) => Load<bool>(ordinal);
57-
public Vector<bool> LoadBool(int minOrdinal, int? maxOrdinal) => Load<bool>(minOrdinal, maxOrdinal);
58-
public Scalar<float> LoadFloat(int ordinal) => Load<float>(ordinal);
59-
public Vector<float> LoadFloat(int minOrdinal, int? maxOrdinal) => Load<float>(minOrdinal, maxOrdinal);
60-
public Scalar<double> LoadDouble(int ordinal) => Load<double>(ordinal);
61-
public Vector<double> LoadDouble(int minOrdinal, int? maxOrdinal) => Load<double>(minOrdinal, maxOrdinal);
62-
public Scalar<string> LoadText(int ordinal) => Load<string>(ordinal);
63-
public Vector<string> LoadText(int minOrdinal, int? maxOrdinal) => Load<string>(minOrdinal, maxOrdinal);
86+
private interface IPipelineArgColumn
87+
{
88+
/// <summary>
89+
/// Creates a <see cref="Column"/> object corresponding to the <see cref="PipelineColumn"/>, with everything
90+
/// filled in except <see cref="ColInfo.Name"/>.
91+
/// </summary>
92+
Column Create();
93+
}
6494

65-
private Scalar<T> Load<T>(int ordinal)
95+
public sealed class Context
96+
{
97+
private readonly Reconciler _rec;
98+
99+
internal Context(Reconciler rec)
100+
{
101+
Contracts.AssertValue(rec);
102+
_rec = rec;
103+
}
104+
105+
public Scalar<bool> LoadBool(int ordinal) => Load<bool>(DataKind.BL, ordinal);
106+
public Vector<bool> LoadBool(int minOrdinal, int? maxOrdinal) => Load<bool>(DataKind.BL, minOrdinal, maxOrdinal);
107+
public Scalar<float> LoadFloat(int ordinal) => Load<float>(DataKind.R4, ordinal);
108+
public Vector<float> LoadFloat(int minOrdinal, int? maxOrdinal) => Load<float>(DataKind.R4, minOrdinal, maxOrdinal);
109+
public Scalar<double> LoadDouble(int ordinal) => Load<double>(DataKind.R8, ordinal);
110+
public Vector<double> LoadDouble(int minOrdinal, int? maxOrdinal) => Load<double>(DataKind.R8, minOrdinal, maxOrdinal);
111+
public Scalar<string> LoadText(int ordinal) => Load<string>(DataKind.TX, ordinal);
112+
public Vector<string> LoadText(int minOrdinal, int? maxOrdinal) => Load<string>(DataKind.TX, minOrdinal, maxOrdinal);
113+
114+
private Scalar<T> Load<T>(DataKind kind, int ordinal)
66115
{
67116
Contracts.CheckParam(ordinal >= 0, nameof(ordinal), "Should be non-negative");
68-
return new MyScalar<T>(ordinal);
117+
return new MyScalar<T>(_rec, kind, ordinal);
69118
}
70119

71-
private Vector<T> Load<T>(int minOrdinal, int? maxOrdinal)
120+
private Vector<T> Load<T>(DataKind kind, int minOrdinal, int? maxOrdinal)
72121
{
73122
Contracts.CheckParam(minOrdinal >= 0, nameof(minOrdinal), "Should be non-negative");
74123
var v = maxOrdinal >= minOrdinal;
75124
Contracts.CheckParam(!(maxOrdinal < minOrdinal), nameof(maxOrdinal), "If specified, cannot be less than " + nameof(minOrdinal));
76-
return new MyVector<T>(minOrdinal, maxOrdinal);
125+
return new MyVector<T>(_rec, kind, minOrdinal, maxOrdinal);
126+
}
127+
128+
private class MyScalar<T> : Scalar<T>, IPipelineArgColumn
129+
{
130+
private readonly DataKind _kind;
131+
private readonly int _ordinal;
132+
133+
public MyScalar(Reconciler rec, DataKind kind, int ordinal)
134+
: base(rec, null)
135+
{
136+
_kind = kind;
137+
_ordinal = ordinal;
138+
}
139+
140+
public Column Create()
141+
{
142+
return new Column()
143+
{
144+
Type = _kind,
145+
Source = new[] { new Range(_ordinal) },
146+
};
147+
}
148+
}
149+
150+
private class MyVector<T> : Vector<T>, IPipelineArgColumn
151+
{
152+
private readonly DataKind _kind;
153+
private readonly int _min;
154+
private readonly int? _max;
155+
156+
public MyVector(Reconciler rec, DataKind kind, int min, int? max)
157+
: base(rec, null)
158+
{
159+
_kind = kind;
160+
_min = min;
161+
_max = max;
162+
}
163+
164+
public Column Create()
165+
{
166+
return new Column()
167+
{
168+
Type = _kind,
169+
Source = new[] { new Range(_min, _max) },
170+
};
171+
}
77172
}
78173
}
79174
}
80175
}
176+

src/Microsoft.ML.Data/DataLoadSave/TrivialReaderEstimator.cs

+4-4
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,15 @@ namespace Microsoft.ML.Runtime.Data
1212
public sealed class TrivialReaderEstimator<TSource, TReader>: IDataReaderEstimator<TSource, TReader>
1313
where TReader: IDataReader<TSource>
1414
{
15-
private readonly TReader _reader;
15+
public TReader Reader { get; }
1616

1717
public TrivialReaderEstimator(TReader reader)
1818
{
19-
_reader = reader;
19+
Reader = reader;
2020
}
2121

22-
public TReader Fit(TSource input) => _reader;
22+
public TReader Fit(TSource input) => Reader;
2323

24-
public SchemaShape GetOutputSchema() => SchemaShape.Create(_reader.GetOutputSchema());
24+
public SchemaShape GetOutputSchema() => SchemaShape.Create(Reader.GetOutputSchema());
2525
}
2626
}

src/Microsoft.ML.Data/StaticPipe/PipelineColumnAnalyzer.cs

+4-2
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,12 @@ public sealed class Rec : ReaderReconciler<int>
5353
{
5454
public Rec() : base() { }
5555

56-
public override IDataReaderEstimator<int, IDataReader<int>> Reconcile(PipelineColumn[] toOutput, Dictionary<PipelineColumn, string> outputNames)
56+
public override IDataReaderEstimator<int, IDataReader<int>> Reconcile(
57+
IHostEnvironment env, PipelineColumn[] toOutput, IReadOnlyDictionary<PipelineColumn, string> outputNames)
5758
{
59+
Contracts.AssertValue(env);
5860
foreach (var col in toOutput)
59-
Contracts.Assert(col is IIsAnalysisColumn);
61+
env.Assert(col is IIsAnalysisColumn);
6062
return null;
6163
}
6264
}

src/Microsoft.ML.Data/StaticPipe/Reconciler.cs

+12-8
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using System;
66
using System.Collections.Generic;
77
using Microsoft.ML.Core.Data;
8+
using Microsoft.ML.Runtime;
89

910
namespace Microsoft.ML.Data.StaticPipe.Runtime
1011
{
@@ -34,11 +35,12 @@ public ReaderReconciler() : base() { }
3435
/// Returns a data-reader estimator. Note that there are no input names because the columns from a data-reader
3536
/// estimator should have no dependencies.
3637
/// </summary>
37-
/// <param name="toOutput">The columns that the reconciler should output</param>
38-
/// <param name="outputNames"></param>
38+
/// <param name="env">The host environment to use to create the data-reader estimator</param>
39+
/// <param name="toOutput">The columns that the object created by the reconciler should output</param>
40+
/// <param name="outputNames">A map containing</param>
3941
/// <returns></returns>
4042
public abstract IDataReaderEstimator<TIn, IDataReader<TIn>> Reconcile(
41-
PipelineColumn[] toOutput, Dictionary<PipelineColumn, string> outputNames);
43+
IHostEnvironment env, PipelineColumn[] toOutput, IReadOnlyDictionary<PipelineColumn, string> outputNames);
4244
}
4345

4446
/// <summary>
@@ -53,13 +55,15 @@ public DataInputReconciler() : base() { }
5355
/// <summary>
5456
/// Returns an estimator.
5557
/// </summary>
56-
/// <param name="toOutput"></param>
57-
/// <param name="inputNames"></param>
58-
/// <param name="outputNames"></param>
58+
/// <param name="env">The host environment to use to create the estimator</param>
59+
/// <param name="toOutput">The columns that the object created by the reconciler should output</param>
60+
/// <param name="inputNames">The columns that the object created by the reconciler should output</param>
61+
/// <param name="outputNames">The </param>
5962
/// <returns></returns>
6063
public abstract IEstimator<ITransformer> Reconcile(
64+
IHostEnvironment env,
6165
PipelineColumn[] toOutput,
62-
Dictionary<PipelineColumn, string> inputNames,
63-
Dictionary<PipelineColumn, string> outputNames);
66+
IReadOnlyDictionary<PipelineColumn, string> inputNames,
67+
IReadOnlyDictionary<PipelineColumn, string> outputNames);
6468
}
6569
}

src/Microsoft.ML.Data/StaticPipe/StaticPipeUtils.cs

+4-4
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ internal static IDataReaderEstimator<TIn, IDataReader<TIn>>
187187
}
188188

189189
// Call the reconciler to get the base reader estimator.
190-
var readerEstimator = baseReconciler.Reconcile(baseInputs, nameMap.AsOther(baseInputs));
190+
var readerEstimator = baseReconciler.Reconcile(env, baseInputs, nameMap.AsOther(baseInputs));
191191
ch.AssertValueOrNull(readerEstimator);
192192

193193
// Next we iteratively find those columns with zero dependencies, "create" them, and if anything depends on
@@ -223,7 +223,7 @@ internal static IDataReaderEstimator<TIn, IDataReader<TIn>>
223223

224224
var localInputNames = nameMap.AsOther(cols.SelectMany(c => c.Dependencies ?? Enumerable.Empty<PipelineColumn>()));
225225
var localOutputNames = nameMap.AsOther(cols);
226-
var localEstimator = rec.Reconcile(cols, localInputNames, localOutputNames);
226+
var localEstimator = rec.Reconcile(env, cols, localInputNames, localOutputNames);
227227
readerEstimator = readerEstimator?.Append(localEstimator);
228228
estimator = estimator?.Append(localEstimator) ?? localEstimator;
229229

@@ -346,15 +346,15 @@ public T2 this[T1 key]
346346
}
347347
}
348348

349-
public Dictionary<T1, T2> AsOther(IEnumerable<T1> keys)
349+
public IReadOnlyDictionary<T1, T2> AsOther(IEnumerable<T1> keys)
350350
{
351351
Dictionary<T1, T2> d = new Dictionary<T1, T2>();
352352
foreach (var v in keys)
353353
d[v] = _d12[v];
354354
return d;
355355
}
356356

357-
public Dictionary<T2, T1> AsOther(IEnumerable<T2> keys)
357+
public IReadOnlyDictionary<T2, T1> AsOther(IEnumerable<T2> keys)
358358
{
359359
Dictionary<T2, T1> d = new Dictionary<T2, T1>();
360360
foreach (var v in keys)

test/Microsoft.ML.Predictor.Tests/TestTransposer.cs

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
using Microsoft.ML.Runtime.Data;
99
using Microsoft.ML.Runtime.Data.IO;
1010
using Microsoft.ML.Runtime.Internal.Utilities;
11+
using Microsoft.ML.TestFramework;
1112
using Xunit;
1213
using Xunit.Abstractions;
1314

@@ -234,7 +235,7 @@ public void TransposerSaverLoaderTest()
234235
{
235236
TransposeSaver saver = new TransposeSaver(Env, new TransposeSaver.Arguments());
236237
saver.SaveData(mem, view, Utils.GetIdentityPermutation(view.Schema.ColumnCount));
237-
src = new BytesSource(mem.ToArray());
238+
src = new BytesStreamSource(mem.ToArray());
238239
}
239240
TransposeLoader loader = new TransposeLoader(Env, new TransposeLoader.Arguments(), src);
240241
// First check whether this as an IDataView yields the same values.

test/Microsoft.ML.StaticPipelineTesting/Microsoft.ML.StaticPipelineTesting.csproj

+1
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@
66
<ItemGroup>
77
<ProjectReference Include="..\..\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj" />
88
<ProjectReference Include="..\..\tools-local\Microsoft.ML.Analyzer\Microsoft.ML.Analyzer.csproj" />
9+
<ProjectReference Include="..\Microsoft.ML.TestFramework\Microsoft.ML.TestFramework.csproj" />
910
</ItemGroup>
1011
</Project>

0 commit comments

Comments
 (0)