Skip to content

Commit fe71bb8

Browse files
authored
Turn TextLoader into a data reader (#723)
TextLoader is now an `IDataReader<IMultiStreamSource>`
1 parent fe5c0e2 commit fe71bb8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+276
-239
lines changed

src/Microsoft.ML.Data/Commands/DataCommand.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -371,11 +371,11 @@ protected IDataLoader CreateRawLoader(
371371
var isBinary = string.Equals(ext, ".idv", StringComparison.OrdinalIgnoreCase);
372372
var isTranspose = string.Equals(ext, ".tdv", StringComparison.OrdinalIgnoreCase);
373373

374-
return isText ? new TextLoader(Host, new TextLoader.Arguments(), fileSource) :
374+
return isText ? TextLoader.Create(Host, new TextLoader.Arguments(), fileSource) :
375375
isBinary ? new BinaryLoader(Host, new BinaryLoader.Arguments(), fileSource) :
376376
isTranspose ? new TransposeLoader(Host, new TransposeLoader.Arguments(), fileSource) :
377377
defaultLoaderFactory != null ? defaultLoaderFactory(Host, fileSource) :
378-
new TextLoader(Host, new TextLoader.Arguments(), fileSource);
378+
TextLoader.Create(Host, new TextLoader.Arguments(), fileSource);
379379
}
380380
else
381381
{

src/Microsoft.ML.Data/DataLoadSave/EstimatorExtensions.cs

+13
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,19 @@ public static CompositeReaderEstimator<TSource, TTrans> Append<TSource, TTrans>(
2424
return new CompositeReaderEstimator<TSource, ITransformer>(start).Append(estimator);
2525
}
2626

27+
/// <summary>
28+
/// Create a composite reader estimator by appending an estimator to a reader.
29+
/// </summary>
30+
public static CompositeReaderEstimator<TSource, TTrans> Append<TSource, TTrans>(
31+
this IDataReader<TSource> start, IEstimator<TTrans> estimator)
32+
where TTrans : class, ITransformer
33+
{
34+
Contracts.CheckValue(start, nameof(start));
35+
Contracts.CheckValue(estimator, nameof(estimator));
36+
37+
return new TrivialReaderEstimator<TSource, IDataReader<TSource>>(start).Append(estimator);
38+
}
39+
2740
/// <summary>
2841
/// Create an estimator chain by appending an estimator to an estimator.
2942
/// </summary>

src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs

+116-96
Large diffs are not rendered by default.

src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderCursor.cs

+7-9
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
namespace Microsoft.ML.Runtime.Data
1717
{
18-
public sealed partial class TextLoader : IDataLoader
18+
public sealed partial class TextLoader
1919
{
2020
private sealed class Cursor : RootCursorBase, IRowCursor
2121
{
@@ -49,7 +49,6 @@ private static void SetupCursor(TextLoader parent, bool[] active, int n,
4949
{
5050
// Note that files is allowed to be empty.
5151
Contracts.AssertValue(parent);
52-
Contracts.AssertValue(parent._files);
5352
Contracts.Assert(active == null || active.Length == parent._bindings.Infos.Length);
5453

5554
var bindings = parent._bindings;
@@ -88,7 +87,6 @@ private static void SetupCursor(TextLoader parent, bool[] active, int n,
8887
private Cursor(TextLoader parent, ParseStats stats, bool[] active, LineReader reader, int srcNeeded, int cthd)
8988
: base(parent._host)
9089
{
91-
Ch.AssertValue(parent._files);
9290
Ch.Assert(active == null || active.Length == parent._bindings.Infos.Length);
9391
Ch.AssertValue(reader);
9492
Ch.AssertValue(stats);
@@ -138,37 +136,37 @@ private Cursor(TextLoader parent, ParseStats stats, bool[] active, LineReader re
138136
}
139137
}
140138

141-
public static IRowCursor Create(TextLoader parent, bool[] active)
139+
public static IRowCursor Create(TextLoader parent, IMultiStreamSource files, bool[] active)
142140
{
143141
// Note that files is allowed to be empty.
144142
Contracts.AssertValue(parent);
145-
Contracts.AssertValue(parent._files);
143+
Contracts.AssertValue(files);
146144
Contracts.Assert(active == null || active.Length == parent._bindings.Infos.Length);
147145

148146
int srcNeeded;
149147
int cthd;
150148
SetupCursor(parent, active, 0, out srcNeeded, out cthd);
151149
Contracts.Assert(cthd > 0);
152150

153-
var reader = new LineReader(parent._files, BatchSize, 100, parent.HasHeader, parent._maxRows, 1);
151+
var reader = new LineReader(files, BatchSize, 100, parent.HasHeader, parent._maxRows, 1);
154152
var stats = new ParseStats(parent._host, 1);
155153
return new Cursor(parent, stats, active, reader, srcNeeded, cthd);
156154
}
157155

158156
public static IRowCursor[] CreateSet(out IRowCursorConsolidator consolidator,
159-
TextLoader parent, bool[] active, int n)
157+
TextLoader parent, IMultiStreamSource files, bool[] active, int n)
160158
{
161159
// Note that files is allowed to be empty.
162160
Contracts.AssertValue(parent);
163-
Contracts.AssertValue(parent._files);
161+
Contracts.AssertValue(files);
164162
Contracts.Assert(active == null || active.Length == parent._bindings.Infos.Length);
165163

166164
int srcNeeded;
167165
int cthd;
168166
SetupCursor(parent, active, n, out srcNeeded, out cthd);
169167
Contracts.Assert(cthd > 0);
170168

171-
var reader = new LineReader(parent._files, BatchSize, 100, parent.HasHeader, parent._maxRows, cthd);
169+
var reader = new LineReader(files, BatchSize, 100, parent.HasHeader, parent._maxRows, cthd);
172170
var stats = new ParseStats(parent._host, cthd);
173171
if (cthd <= 1)
174172
{

src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ namespace Microsoft.ML.Runtime.Data
2020
{
2121
using Conditional = System.Diagnostics.ConditionalAttribute;
2222

23-
public sealed partial class TextLoader : IDataLoader
23+
public sealed partial class TextLoader
2424
{
2525
/// <summary>
2626
/// This type exists to provide efficient delegates for creating a ColumnValue specific to a DataKind.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Core.Data;
6+
7+
namespace Microsoft.ML.Runtime.Data
8+
{
9+
/// <summary>
10+
/// The trivial wrapper for a <see cref="IDataReader{TSource}"/> that acts as an estimator and ignores the source.
11+
/// </summary>
12+
public sealed class TrivialReaderEstimator<TSource, TReader>: IDataReaderEstimator<TSource, TReader>
13+
where TReader: IDataReader<TSource>
14+
{
15+
private readonly TReader _reader;
16+
17+
public TrivialReaderEstimator(TReader reader)
18+
{
19+
_reader = reader;
20+
}
21+
22+
public TReader Fit(TSource input) => _reader;
23+
24+
public SchemaShape GetOutputSchema() => SchemaShape.Create(_reader.GetOutputSchema());
25+
}
26+
}

src/Microsoft.ML.Data/Transforms/TermTransform.cs

+9-9
Original file line numberDiff line numberDiff line change
@@ -316,10 +316,10 @@ private static TermMap CreateFileTermMap(IHostEnvironment env, IChannel ch, Argu
316316
// file, then we assume the user knows what they're doing and do not attempt to convert
317317
// to the desired type ourselves.
318318
bool autoConvert = false;
319-
IDataLoader loader;
319+
IDataView termData;
320320
if (loaderFactory != null)
321321
{
322-
loader = loaderFactory.CreateComponent(env, fileSource);
322+
termData = loaderFactory.CreateComponent(env, fileSource);
323323
}
324324
else
325325
{
@@ -333,11 +333,11 @@ private static TermMap CreateFileTermMap(IHostEnvironment env, IChannel ch, Argu
333333
ch.CheckUserArg(!string.IsNullOrWhiteSpace(src), nameof(args.TermsColumn),
334334
"Must be specified");
335335
if (isBinary)
336-
loader = new BinaryLoader(env, new BinaryLoader.Arguments(), fileSource);
336+
termData = new BinaryLoader(env, new BinaryLoader.Arguments(), fileSource);
337337
else
338338
{
339339
ch.Assert(isTranspose);
340-
loader = new TransposeLoader(env, new TransposeLoader.Arguments(), fileSource);
340+
termData = new TransposeLoader(env, new TransposeLoader.Arguments(), fileSource);
341341
}
342342
}
343343
else
@@ -348,7 +348,7 @@ private static TermMap CreateFileTermMap(IHostEnvironment env, IChannel ch, Argu
348348
"{0} should not be specified when default loader is TextLoader. Ignoring {0}={1}",
349349
nameof(Arguments.TermsColumn), src);
350350
}
351-
loader = new TextLoader(env,
351+
termData = TextLoader.ReadFile(env,
352352
new TextLoader.Arguments()
353353
{
354354
Separator = "tab",
@@ -362,18 +362,18 @@ private static TermMap CreateFileTermMap(IHostEnvironment env, IChannel ch, Argu
362362
ch.AssertNonEmpty(src);
363363

364364
int colSrc;
365-
if (!loader.Schema.TryGetColumnIndex(src, out colSrc))
365+
if (!termData.Schema.TryGetColumnIndex(src, out colSrc))
366366
throw ch.ExceptUserArg(nameof(args.TermsColumn), "Unknown column '{0}'", src);
367-
var typeSrc = loader.Schema.GetColumnType(colSrc);
367+
var typeSrc = termData.Schema.GetColumnType(colSrc);
368368
if (!autoConvert && !typeSrc.Equals(bldr.ItemType))
369369
throw ch.ExceptUserArg(nameof(args.TermsColumn), "Must be of type '{0}' but was '{1}'", bldr.ItemType, typeSrc);
370370

371-
using (var cursor = loader.GetRowCursor(col => col == colSrc))
371+
using (var cursor = termData.GetRowCursor(col => col == colSrc))
372372
using (var pch = env.StartProgressChannel("Building term dictionary from file"))
373373
{
374374
var header = new ProgressHeader(new[] { "Total Terms" }, new[] { "examples" });
375375
var trainer = Trainer.Create(cursor, colSrc, autoConvert, int.MaxValue, bldr);
376-
double rowCount = loader.GetRowCount(true) ?? double.NaN;
376+
double rowCount = termData.GetRowCount(true) ?? double.NaN;
377377
long rowCur = 0;
378378
pch.SetHeader(header,
379379
e =>

src/Microsoft.ML.Data/Utilities/ModelFileUtils.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ public static IEnumerable<KeyValuePair<ColumnRole, string>> LoadRoleMappingsOrNu
282282
{
283283
// REVIEW: Should really validate the schema here, and consider
284284
// ignoring this stream if it isn't as expected.
285-
var loader = new TextLoader(env, new TextLoader.Arguments(),
285+
var loader = TextLoader.ReadFile(env, new TextLoader.Arguments(),
286286
new RepositoryStreamWrapper(rep, DirTrainingInfo, RoleMappingFile));
287287

288288
using (var cursor = loader.GetRowCursor(c => true))

src/Microsoft.ML.PipelineInference/ColumnTypeInference.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -259,8 +259,8 @@ private static InferenceResult InferTextFileColumnTypesCore(IHostEnvironment env
259259
AllowSparse = args.AllowSparse,
260260
AllowQuoting = args.AllowQuote,
261261
};
262-
var textLoader = new TextLoader(env, textLoaderArgs, fileSource);
263-
var idv = textLoader.Take(args.MaxRowsToRead);
262+
var idv = TextLoader.ReadFile(env, textLoaderArgs, fileSource);
263+
idv = idv.Take(args.MaxRowsToRead);
264264

265265
// Read all the data into memory.
266266
// List items are rows of the dataset.

src/Microsoft.ML.PipelineInference/InferenceUtils.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,9 @@ public static ColumnGroupingInference.GroupingColumn[] InferColumnPurposes(IChan
121121
AllowQuoting = splitResult.AllowQuote,
122122
HasHeader = typeInferenceResult.HasHeader
123123
};
124-
var typedLoader = new TextLoader(env, typedLoaderArgs, sample);
124+
var typedData = TextLoader.ReadFile(env, typedLoaderArgs, sample);
125125

126-
var purposeInferenceResult = PurposeInference.InferPurposes(env, typedLoader,
126+
var purposeInferenceResult = PurposeInference.InferPurposes(env, typedData,
127127
Utils.GetIdentityPermutation(typedLoaderArgs.Column.Length), new PurposeInference.Arguments());
128128
ch.Info("Detecting column grouping and generating column names");
129129

src/Microsoft.ML.PipelineInference/RecipeInference.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -407,8 +407,8 @@ public static SuggestedRecipe[] InferRecipesFromData(IHostEnvironment env, strin
407407
ch.Info($"Loader options: {settingsString}");
408408

409409
ch.Info("Inferring recipes");
410-
var finalLoader = new TextLoader(h, finalLoaderArgs, sample);
411-
var cached = new CacheDataView(h, finalLoader,
410+
var finalData = TextLoader.ReadFile(h, finalLoaderArgs, sample);
411+
var cached = new CacheDataView(h, finalData,
412412
Enumerable.Range(0, finalLoaderArgs.Column.Length).ToArray());
413413

414414
var purposeColumns = columns.Select((x, i) => new PurposeInference.Column(i, x.Purpose, x.ItemKind)).ToArray();

src/Microsoft.ML.PipelineInference/TextFileContents.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ private static bool TryParseFile(IChannel ch, TextLoader.Arguments args, IMultiS
123123
{
124124
messages.Add(msg);
125125
});
126-
var idv = new TextLoader(loaderEnv, args, source).Take(1000);
126+
var idv = TextLoader.ReadFile(loaderEnv, args, source).Take(1000);
127127
var columnCounts = new List<int>();
128128
int columnIndex;
129129
bool found = idv.Schema.TryGetColumnIndex("C", out columnIndex);

src/Microsoft.ML.Transforms/TermLookupTransform.cs

+4-4
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ private static IComponentFactory<IMultiStreamSource, IDataLoader> GetLoaderFacto
352352
// If the user specified non-key values, we define the value column to be numeric.
353353
if (!keyValues)
354354
return ComponentFactoryUtils.CreateFromFunction<IMultiStreamSource, IDataLoader>(
355-
(env, files) => new TextLoader(
355+
(env, files) => TextLoader.Create(
356356
env,
357357
new TextLoader.Arguments()
358358
{
@@ -372,8 +372,8 @@ private static IComponentFactory<IMultiStreamSource, IDataLoader> GetLoaderFacto
372372
var txtArgs = new TextLoader.Arguments();
373373
bool parsed = CmdParser.ParseArguments(host, "col=Term:TX:0 col=Value:TX:1", txtArgs);
374374
host.Assert(parsed);
375-
var txtLoader = new TextLoader(host, txtArgs, new MultiFileSource(filename));
376-
using (var cursor = txtLoader.GetRowCursor(c => true))
375+
var data = TextLoader.ReadFile(host, txtArgs, new MultiFileSource(filename));
376+
using (var cursor = data.GetRowCursor(c => true))
377377
{
378378
var getTerm = cursor.GetGetter<DvText>(0);
379379
var getVal = cursor.GetGetter<DvText>(1);
@@ -448,7 +448,7 @@ private static IComponentFactory<IMultiStreamSource, IDataLoader> GetLoaderFacto
448448
}
449449

450450
return ComponentFactoryUtils.CreateFromFunction<IMultiStreamSource, IDataLoader>(
451-
(env, files) => new TextLoader(
451+
(env, files) => TextLoader.Create(
452452
env,
453453
new TextLoader.Arguments()
454454
{

src/Microsoft.ML.Transforms/Text/StopWordsRemoverTransform.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -658,7 +658,7 @@ private static IDataLoader LoadStopwords(IHostEnvironment env, IChannel ch, stri
658658
ch.Warning("{0} should not be specified when default loader is TextLoader. Ignoring stopwordsColumn={0}",
659659
stopwordsCol);
660660
}
661-
dataLoader = new TextLoader(
661+
dataLoader = TextLoader.Create(
662662
env,
663663
new TextLoader.Arguments()
664664
{

test/Microsoft.ML.Benchmarks/KMeansAndLogisticRegressionBench.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ private static IPredictor TrainKMeansAndLRCore()
3232
using (var env = new TlcEnvironment(seed: 1))
3333
{
3434
// Pipeline
35-
var loader = new TextLoader(env,
35+
var loader = TextLoader.ReadFile(env,
3636
new TextLoader.Arguments()
3737
{
3838
HasHeader = true,

test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ private static IPredictor TrainSentimentCore()
102102
using (var env = new TlcEnvironment(seed: 1))
103103
{
104104
// Pipeline
105-
var loader = new TextLoader(env,
105+
var loader = TextLoader.ReadFile(env,
106106
new TextLoader.Arguments()
107107
{
108108
AllowQuoting = false,

test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs

+4-4
Original file line numberDiff line numberDiff line change
@@ -295,13 +295,13 @@ protected bool SaveLoadText(IDataView view, IHostEnvironment env,
295295

296296
// Note that we don't pass in "args", but pass in a default args so we test
297297
// the auto-schema parsing.
298-
var loader = new TextLoader(env, new TextLoader.Arguments(), new MultiFileSource(pathData));
299-
if (!CheckMetadataTypes(loader.Schema))
298+
var loadedData = TextLoader.ReadFile(env, new TextLoader.Arguments(), new MultiFileSource(pathData));
299+
if (!CheckMetadataTypes(loadedData.Schema))
300300
Failed();
301301

302-
if (!CheckSameSchemas(view.Schema, loader.Schema, exactTypes: false, keyNames: false))
302+
if (!CheckSameSchemas(view.Schema, loadedData.Schema, exactTypes: false, keyNames: false))
303303
return Failed();
304-
if (!CheckSameValues(view, loader, exactTypes: false, exactDoubles: false, checkId: false))
304+
if (!CheckSameValues(view, loadedData, exactTypes: false, exactDoubles: false, checkId: false))
305305
return Failed();
306306
return true;
307307
}

test/Microsoft.ML.TestFramework/ModelHelper.cs

+1-2
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,7 @@ public static IDataView GetKcHouseDataView(string dataPath)
5151
var txtArgs = new Runtime.Data.TextLoader.Arguments();
5252
bool parsed = CmdParser.ParseArguments(s_environment, dataSchema, txtArgs);
5353
s_environment.Assert(parsed);
54-
var txtLoader = new Runtime.Data.TextLoader(s_environment, txtArgs, new MultiFileSource(dataPath));
55-
return txtLoader;
54+
return Runtime.Data.TextLoader.ReadFile(s_environment, txtArgs, new MultiFileSource(dataPath));
5655
}
5756

5857
private static ITransformModel CreateKcHousePricePredictorModel(string dataPath)

test/Microsoft.ML.Tests/Scenarios/Api/AutoNormalizationAndCaching.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ public void AutoNormalizationAndCaching()
2424
using (var env = new TlcEnvironment(seed: 1, conc: 1))
2525
{
2626
// Pipeline.
27-
var loader = new TextLoader(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(dataPath));
27+
var loader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(dataPath));
2828

2929
var trans = TextTransform.Create(env, MakeSentimentTextTransformArgs(false), loader);
3030

test/Microsoft.ML.Tests/Scenarios/Api/CrossValidation.cs

+2-3
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ void CrossValidation()
3232
using (var env = new TlcEnvironment(seed: 1, conc: 1))
3333
{
3434
// Pipeline.
35-
var loader = new TextLoader(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(dataPath));
35+
var loader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(dataPath));
3636

3737
var text = TextTransform.Create(env, MakeSentimentTextTransformArgs(false), loader);
3838
IDataView trans = new GenerateNumberTransform(env, text, "StratificationColumn");
@@ -43,7 +43,6 @@ void CrossValidation()
4343
ConvergenceTolerance = 1f
4444
});
4545

46-
4746
var metrics = new List<BinaryClassificationMetrics>();
4847
for (int fold = 0; fold < numFolds; fold++)
4948
{
@@ -80,7 +79,7 @@ void CrossValidation()
8079
var pipe = ApplyTransformUtils.ApplyAllTransformsToData(env, preCachedData.Data, testPipe, trainPipe);
8180

8281
var testRoles = new RoleMappedData(pipe, trainData.Schema.GetColumnRoleNames());
83-
82+
8483
IDataScorerTransform scorer = ScoreUtils.GetScorer(predictor, testRoles, env, testRoles.Schema);
8584

8685
BinaryClassifierMamlEvaluator eval = new BinaryClassifierMamlEvaluator(env, new BinaryClassifierMamlEvaluator.Arguments() { });

0 commit comments

Comments
 (0)