Skip to content

Commit 6d5a41d

Browse files
authored
CollectionDataSource (train on top of memory collection instead of loading data from file) (#106)
* in memory loader * add test file for memory collection * even in afterlife EntryPointCatalog will chase me down. * Address some comments. * update tests * address more comments. * remove empty param description * hide collectionloader * refactor classes a little. * pesky new lines! * slightly better comments. but only slighty * rename it * make class static * not a loader * remove alias in entrypoint * address comments
1 parent 160e9e4 commit 6d5a41d

File tree

7 files changed

+400
-2
lines changed

7 files changed

+400
-2
lines changed

ZBaselines/Common/EntryPoints/core_ep-list.tsv

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
Data.DataViewReference Pass dataview from memory to experiment Microsoft.ML.Runtime.EntryPoints.DataViewReference ImportData Microsoft.ML.Runtime.EntryPoints.DataViewReference+Input Microsoft.ML.Runtime.EntryPoints.DataViewReference+Output
12
Data.IDataViewArrayConverter Create and array variable Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro MakeArray Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+ArrayIDataViewInput Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+ArrayIDataViewOutput
23
Data.PredictorModelArrayConverter Create and array variable Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro MakeArray Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+ArrayIPredictorModelInput Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+ArrayIPredictorModelOutput
34
Data.TextLoader Import a dataset from a text file Microsoft.ML.Runtime.EntryPoints.ImportTextData ImportText Microsoft.ML.Runtime.EntryPoints.ImportTextData+Input Microsoft.ML.Runtime.EntryPoints.ImportTextData+Output

ZBaselines/Common/EntryPoints/core_manifest.json

+23
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,28 @@
11
{
22
"EntryPoints": [
3+
{
4+
"Name": "Data.DataViewReference",
5+
"Desc": "Pass dataview from memory to experiment",
6+
"FriendlyName": null,
7+
"ShortName": null,
8+
"Inputs": [
9+
{
10+
"Name": "Data",
11+
"Type": "DataView",
12+
"Desc": "Pointer to IDataView in memory",
13+
"Required": true,
14+
"SortOrder": 1.0,
15+
"IsNullable": false
16+
}
17+
],
18+
"Outputs": [
19+
{
20+
"Name": "Data",
21+
"Type": "DataView",
22+
"Desc": "The resulting data view"
23+
}
24+
]
25+
},
326
{
427
"Name": "Data.IDataViewArrayConverter",
528
"Desc": "Create and array variable",

src/Microsoft.ML/CSharpApi.cs

+28
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,22 @@ public Microsoft.ML.Data.TextLoader.Output Add(Microsoft.ML.Data.TextLoader inpu
5353
return output;
5454
}
5555

56+
public Microsoft.ML.Data.DataViewReference.Output Add(Microsoft.ML.Data.DataViewReference input)
57+
{
58+
var output = new Microsoft.ML.Data.DataViewReference.Output();
59+
Add(input, output);
60+
return output;
61+
}
62+
5663
public void Add(Microsoft.ML.Data.TextLoader input, Microsoft.ML.Data.TextLoader.Output output)
5764
{
5865
_jsonNodes.Add(Serialize("Data.TextLoader", input, output));
5966
}
6067

68+
public void Add(Microsoft.ML.Data.DataViewReference input, Microsoft.ML.Data.DataViewReference.Output output)
69+
{
70+
_jsonNodes.Add(Serialize("Data.DataViewReference", input, output));
71+
}
6172
public Microsoft.ML.Models.AnomalyDetectionEvaluator.Output Add(Microsoft.ML.Models.AnomalyDetectionEvaluator input)
6273
{
6374
var output = new Microsoft.ML.Models.AnomalyDetectionEvaluator.Output();
@@ -1335,6 +1346,23 @@ public sealed partial class TextLoader
13351346
public string CustomSchema { get; set; }
13361347

13371348

1349+
public sealed class Output
1350+
{
1351+
/// <summary>
1352+
/// The resulting data view
1353+
/// </summary>
1354+
public Var<Microsoft.ML.Runtime.Data.IDataView> Data { get; set; } = new Var<Microsoft.ML.Runtime.Data.IDataView>();
1355+
1356+
}
1357+
}
1358+
1359+
public sealed partial class DataViewReference
1360+
{
1361+
/// <summary>
1362+
/// Location of the input file
1363+
/// </summary>
1364+
public Var<Microsoft.ML.Runtime.Data.IDataView> Data { get; set; } = new Var<Microsoft.ML.Runtime.Data.IDataView>();
1365+
13381366
public sealed class Output
13391367
{
13401368
/// <summary>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System.Collections.Generic;
6+
using Microsoft.ML.Runtime;
7+
using Microsoft.ML.Runtime.Api;
8+
using Microsoft.ML.Runtime.Data;
9+
using Microsoft.ML.Runtime.EntryPoints;
10+
using Microsoft.ML.Runtime.Internal.Utilities;
11+
12+
namespace Microsoft.ML.Data
13+
{
14+
/// <summary>
15+
/// Creates data source for pipeline based on provided collection of data.
16+
/// </summary>
17+
public static class CollectionDataSource
18+
{
19+
/// <summary>
20+
/// Creates pipeline data source. Support shuffle.
21+
/// </summary>
22+
public static ILearningPipelineLoader Create<T>(IList<T> data) where T : class
23+
{
24+
return new ListDataSource<T>(data);
25+
}
26+
27+
/// <summary>
28+
/// Creates pipeline data source which can't be shuffled.
29+
/// </summary>
30+
public static ILearningPipelineLoader Create<T>(IEnumerable<T> data) where T : class
31+
{
32+
return new EnumerableDataSource<T>(data);
33+
}
34+
35+
private abstract class BaseDataSource<TInput> : ILearningPipelineLoader where TInput : class
36+
{
37+
private Data.DataViewReference _dataViewEntryPoint;
38+
private IDataView _dataView;
39+
40+
public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment)
41+
{
42+
Contracts.Assert(previousStep == null);
43+
_dataViewEntryPoint = new Data.DataViewReference();
44+
var importOutput = experiment.Add(_dataViewEntryPoint);
45+
return new CollectionDataSourcePipelineStep(importOutput.Data);
46+
}
47+
48+
public void SetInput(IHostEnvironment environment, Experiment experiment)
49+
{
50+
_dataView = GetDataView(environment);
51+
environment.CheckValue(_dataView, nameof(_dataView));
52+
experiment.SetInput(_dataViewEntryPoint.Data, _dataView);
53+
}
54+
55+
public abstract IDataView GetDataView(IHostEnvironment environment);
56+
}
57+
58+
private class EnumerableDataSource<TInput> : BaseDataSource<TInput> where TInput : class
59+
{
60+
private readonly IEnumerable<TInput> _enumerableCollection;
61+
62+
public EnumerableDataSource(IEnumerable<TInput> collection)
63+
{
64+
Contracts.CheckValue(collection, nameof(collection));
65+
_enumerableCollection = collection;
66+
}
67+
68+
public override IDataView GetDataView(IHostEnvironment environment)
69+
{
70+
return ComponentCreation.CreateStreamingDataView(environment, _enumerableCollection);
71+
}
72+
}
73+
74+
private class ListDataSource<TInput> : BaseDataSource<TInput> where TInput : class
75+
{
76+
private readonly IList<TInput> _listCollection;
77+
78+
public ListDataSource(IList<TInput> collection)
79+
{
80+
Contracts.CheckParamValue(Utils.Size(collection) > 0, collection, nameof(collection), "Must be non-empty");
81+
_listCollection = collection;
82+
}
83+
84+
public override IDataView GetDataView(IHostEnvironment environment)
85+
{
86+
return ComponentCreation.CreateDataView(environment, _listCollection);
87+
}
88+
}
89+
90+
private class CollectionDataSourcePipelineStep : ILearningPipelineDataStep
91+
{
92+
public CollectionDataSourcePipelineStep(Var<IDataView> data)
93+
{
94+
Data = data;
95+
}
96+
97+
public Var<IDataView> Data { get; }
98+
public Var<ITransformModel> Model => null;
99+
}
100+
}
101+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Runtime;
6+
using Microsoft.ML.Runtime.CommandLine;
7+
using Microsoft.ML.Runtime.Data;
8+
using Microsoft.ML.Runtime.EntryPoints;
9+
10+
[assembly: LoadableClass(typeof(void), typeof(DataViewReference), null, typeof(SignatureEntryPointModule), "DataViewReference")]
11+
namespace Microsoft.ML.Runtime.EntryPoints
12+
{
13+
public class DataViewReference
14+
{
15+
public sealed class Input
16+
{
17+
[Argument(ArgumentType.Required, HelpText = "Pointer to IDataView in memory", SortOrder = 1)]
18+
public IDataView Data;
19+
}
20+
21+
public sealed class Output
22+
{
23+
[TlcModule.Output(Desc = "The resulting data view", SortOrder = 1)]
24+
public IDataView Data;
25+
}
26+
27+
[TlcModule.EntryPoint(Name = "Data.DataViewReference", Desc = "Pass dataview from memory to experiment")]
28+
public static Output ImportData(IHostEnvironment env, Input input)
29+
{
30+
Contracts.CheckValue(env, nameof(env));
31+
var host = env.Register("DataViewReference");
32+
env.CheckValue(input, nameof(input));
33+
EntryPointUtils.CheckInputArgs(host, input);
34+
return new Output { Data = input.Data };
35+
}
36+
}
37+
}

src/Microsoft.ML/TextLoader.cs

+1-2
Original file line numberDiff line numberDiff line change
@@ -115,11 +115,10 @@ private class TextLoaderPipelineStep : ILearningPipelineDataStep
115115
public TextLoaderPipelineStep(Var<IDataView> data)
116116
{
117117
Data = data;
118-
Model = null;
119118
}
120119

121120
public Var<IDataView> Data { get; }
122-
public Var<ITransformModel> Model { get; }
121+
public Var<ITransformModel> Model => null;
123122
}
124123
}
125124
}

0 commit comments

Comments
 (0)