-
Notifications
You must be signed in to change notification settings - Fork 1.9k
CollectionDataSource (train on top of memory collection instead of loading data from file) #106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
f69d659
55b6e46
b166f05
a1761b1
12d1b9e
ebcf448
110e205
62ab575
1da42ca
0cac7dc
ca9c031
d78afa3
ab86b09
ebe6f33
04ff469
9698d19
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using System.Collections.Generic; | ||
using Microsoft.ML.Runtime; | ||
using Microsoft.ML.Runtime.Api; | ||
using Microsoft.ML.Runtime.Data; | ||
using Microsoft.ML.Runtime.EntryPoints; | ||
using Microsoft.ML.Runtime.Internal.Utilities; | ||
|
||
namespace Microsoft.ML | ||
{ | ||
public class MemoryCollection<TInput> : ILearningPipelineLoader | ||
where TInput : class | ||
{ | ||
public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) | ||
{ | ||
Contracts.Assert(previousStep == null); | ||
_dataViewEntryPoint = new Data.DataViewReference(); | ||
var importOutput = experiment.Add(_dataViewEntryPoint); | ||
return new MemoryCollectionPipelineStep(importOutput.Data); | ||
} | ||
|
||
private readonly IList<TInput> _listCollection; | ||
private readonly IEnumerable<TInput> _enumerableCollection; | ||
|
||
private Data.DataViewReference _dataViewEntryPoint; | ||
private IDataView _dataView; | ||
|
||
public MemoryCollection(IList<TInput> collection) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
public constructor required comments. #Resolved |
||
{ | ||
Contracts.CheckParamValue(Utils.Size(collection) > 0, collection, nameof(collection), "Must be non-empty"); | ||
_listCollection = collection; | ||
} | ||
|
||
public MemoryCollection(IEnumerable<TInput> collection) | ||
{ | ||
Contracts.CheckParamValue(collection != null, collection, nameof(collection), "Must be non-null"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I'm sorry, I'm silly... I didn't mean |
||
_enumerableCollection = collection; | ||
} | ||
|
||
public void SetInput(IHostEnvironment env, Experiment experiment) | ||
{ | ||
if (_listCollection != null) | ||
_dataView = ComponentCreation.CreateDataView(env, _listCollection); | ||
if (_enumerableCollection != null) | ||
_dataView = ComponentCreation.CreateStreamingDataView(env, _listCollection); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Note that you are not using _enumerableCollection here. Likely a bug There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thank you for pointing that out, I update test to call both implementation and I also split class into two classes, to get rid of if condition In reply to: 187225560 [](ancestors = 187225560) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I still want to have two different implementation for IList and IEnumerable, so I would prefer to keep it that way, but with two separate classes In reply to: 187395933 [](ancestors = 187395933,187225560) |
||
env.CheckValue(_dataView, nameof(_dataView)); | ||
experiment.SetInput(_dataViewEntryPoint.Data, _dataView); | ||
} | ||
|
||
private class MemoryCollectionPipelineStep : ILearningPipelineDataStep | ||
{ | ||
public MemoryCollectionPipelineStep(Var<IDataView> data) | ||
{ | ||
Data = data; | ||
Model = null; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unnecessary, the default value is That's fine, but I'd go one step further... What I'd do is change the |
||
} | ||
|
||
public Var<IDataView> Data { get; } | ||
public Var<ITransformModel> Model { get; } | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using Microsoft.ML.Runtime; | ||
using Microsoft.ML.Runtime.CommandLine; | ||
using Microsoft.ML.Runtime.Data; | ||
using Microsoft.ML.Runtime.EntryPoints; | ||
|
||
[assembly: LoadableClass(typeof(void), typeof(InMemoryDataView), null, typeof(SignatureEntryPointModule), "InMemoryDataView")] | ||
namespace Microsoft.ML.Runtime.EntryPoints | ||
{ | ||
public class InMemoryDataView | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Where is this being used? I don't see any references to this class in code or tests. #Closed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a entrypoint :) Data.DataViewReference it get used in MemoryCollection.cs (At least it entry point wrapper) In reply to: 187222065 [](ancestors = 187222065) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, silly me as looking for "InMemoryDataView" not "Data.DataViewReference" In reply to: 187222365 [](ancestors = 187222365,187222065) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
So to clarify, all this entrypoint does it turns input to output? Should we call it as such, something like a data passthrough or something? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No it doesn't. but I wouldn't call it DataPass entrypoint either, since it allow you pass only dataview from you code to experiment, and DataViewReference is already taken by entrypoint class. In reply to: 187240561 [](ancestors = 187240561) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not DataViewReference or DataViewReferenceEp? Seems like very unrelated class name to the entrypoint name. In reply to: 187395356 [](ancestors = 187395356,187240561) |
||
{ | ||
public sealed class Input | ||
{ | ||
[Argument(ArgumentType.Required, ShortName = "data", HelpText = "Pointer to IDataView in memory", SortOrder = 1)] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Since shortname is same as longname, you can safely omit. #Resolved |
||
public IDataView Data; | ||
} | ||
|
||
public sealed class Output | ||
{ | ||
[TlcModule.Output(Desc = "The resulting data view", SortOrder = 1)] | ||
public IDataView Data; | ||
} | ||
|
||
[TlcModule.EntryPoint(Name = "Data.DataViewReference", Desc = "Pass dataview from memory to experiment")] | ||
public static Output ImportData(IHostEnvironment env, Input input) | ||
{ | ||
Contracts.CheckValue(env, nameof(env)); | ||
var host = env.Register("DataViewReference"); | ||
env.CheckValue(input, nameof(input)); | ||
EntryPointUtils.CheckInputArgs(host, input); | ||
return new Output { Data = input.Data }; | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,201 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using Microsoft.ML.Runtime; | ||
using Microsoft.ML.Runtime.Api; | ||
using Microsoft.ML.Runtime.Data; | ||
using Microsoft.ML.TestFramework; | ||
using Microsoft.ML.Trainers; | ||
using Microsoft.ML.Transforms; | ||
using System.Collections.Generic; | ||
using Xunit; | ||
using Xunit.Abstractions; | ||
|
||
namespace Microsoft.ML.EntryPoints.Tests | ||
{ | ||
public class MemoryCollectionTests : BaseTestClass | ||
{ | ||
public MemoryCollectionTests(ITestOutputHelper output) | ||
: base(output) | ||
{ | ||
|
||
} | ||
|
||
[Fact] | ||
public void CheckConstructor() | ||
{ | ||
Assert.NotNull(new MemoryCollection<Input>(new List<Input>() { new Input { Number1 = 1, String1 = "1" } })); | ||
Assert.NotNull(new MemoryCollection<Input>(new Input[1] { new Input { Number1 = 1, String1 = "1" } })); | ||
bool thrown = false; | ||
try | ||
{ | ||
new MemoryCollection<Input>(null); | ||
} | ||
catch | ||
{ | ||
thrown = true; | ||
} | ||
Assert.True(thrown); | ||
thrown = false; | ||
try | ||
{ | ||
new MemoryCollection<Input>(new List<Input>()); | ||
} | ||
catch | ||
{ | ||
thrown = true; | ||
} | ||
Assert.True(thrown); | ||
|
||
thrown = false; | ||
try | ||
{ | ||
new MemoryCollection<Input>(new Input[0]); | ||
} | ||
catch | ||
{ | ||
thrown = true; | ||
} | ||
Assert.True(thrown); | ||
} | ||
|
||
[Fact] | ||
public void CanSuccessfullyApplyATransform() | ||
{ | ||
var collection = new MemoryCollection<Input>(new List<Input>() { new Input { Number1 = 1, String1 = "1" } }); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Just an observation... if we had instead structured this as a static utility method somewhere, then we could avoid having the double-specification of the So: if we had a input type |
||
using (var environment = new TlcEnvironment()) | ||
{ | ||
Experiment experiment = environment.CreateExperiment(); | ||
ILearningPipelineDataStep output = collection.ApplyStep(null, experiment) as ILearningPipelineDataStep; | ||
|
||
Assert.NotNull(output.Data); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
If you really meant the |
||
Assert.NotNull(output.Data.VarName); | ||
Assert.Null(output.Model); | ||
} | ||
} | ||
|
||
[Fact] | ||
public void CanSuccessfullyEnumerated() | ||
{ | ||
var collection = new MemoryCollection<Input>(new List<Input>() { | ||
new Input { Number1 = 1, String1 = "1" }, | ||
new Input { Number1 = 2, String1 = "2" }, | ||
new Input { Number1 = 3, String1 = "3" } | ||
}); | ||
|
||
using (var environment = new TlcEnvironment()) | ||
{ | ||
Experiment experiment = environment.CreateExperiment(); | ||
ILearningPipelineDataStep output = collection.ApplyStep(null, experiment) as ILearningPipelineDataStep; | ||
|
||
experiment.Compile(); | ||
collection.SetInput(environment, experiment); | ||
experiment.Run(); | ||
|
||
IDataView data = experiment.GetOutput(output.Data); | ||
Assert.NotNull(data); | ||
|
||
using (var cursor = data.GetRowCursor((a => true))) | ||
{ | ||
var IDGetter = cursor.GetGetter<float>(0); | ||
var TextGetter = cursor.GetGetter<DvText>(1); | ||
|
||
Assert.True(cursor.MoveNext()); | ||
|
||
float ID = 0; | ||
IDGetter(ref ID); | ||
Assert.Equal(1, ID); | ||
|
||
DvText Text = new DvText(); | ||
TextGetter(ref Text); | ||
Assert.Equal("1", Text.ToString()); | ||
|
||
Assert.True(cursor.MoveNext()); | ||
|
||
ID = 0; | ||
IDGetter(ref ID); | ||
Assert.Equal(2, ID); | ||
|
||
Text = new DvText(); | ||
TextGetter(ref Text); | ||
Assert.Equal("2", Text.ToString()); | ||
|
||
Assert.True(cursor.MoveNext()); | ||
|
||
ID = 0; | ||
IDGetter(ref ID); | ||
Assert.Equal(3, ID); | ||
|
||
Text = new DvText(); | ||
TextGetter(ref Text); | ||
Assert.Equal("3", Text.ToString()); | ||
|
||
Assert.False(cursor.MoveNext()); | ||
} | ||
} | ||
} | ||
|
||
[Fact] | ||
public void CanTrain() | ||
{ | ||
var pipeline = new LearningPipeline(); | ||
var collection = new MemoryCollection<IrisData>(new List<IrisData>() { | ||
new IrisData { SepalLength = 1f, SepalWidth = 1f ,PetalLength=0.3f, PetalWidth=5.1f, Label=1}, | ||
new IrisData { SepalLength = 1f, SepalWidth = 1f ,PetalLength=0.3f, PetalWidth=5.1f, Label=1}, | ||
new IrisData { SepalLength = 1.2f, SepalWidth = 0.5f ,PetalLength=0.3f, PetalWidth=5.1f, Label=0} | ||
}); | ||
|
||
pipeline.Add(collection); | ||
|
||
pipeline.Add(new ColumnConcatenator(outputColumn: "Features", | ||
"SepalLength", "SepalWidth", "PetalLength", "PetalWidth")); | ||
|
||
pipeline.Add(new StochasticDualCoordinateAscentClassifier()); | ||
PredictionModel<IrisData, IrisPrediction> model = pipeline.Train<IrisData, IrisPrediction>(); | ||
|
||
IrisPrediction prediction = model.Predict(new IrisData() | ||
{ | ||
SepalLength = 3.3f, | ||
SepalWidth = 1.6f, | ||
PetalLength = 0.2f, | ||
PetalWidth = 5.1f, | ||
}); | ||
|
||
} | ||
|
||
public class Input | ||
{ | ||
[Column("0")] | ||
public float Number1; | ||
|
||
[Column("1")] | ||
public string String1; | ||
} | ||
|
||
public class IrisData | ||
{ | ||
[Column("0")] | ||
public float Label; | ||
|
||
[Column("1")] | ||
public float SepalLength; | ||
|
||
[Column("2")] | ||
public float SepalWidth; | ||
|
||
[Column("3")] | ||
public float PetalLength; | ||
|
||
[Column("4")] | ||
public float PetalWidth; | ||
} | ||
|
||
public class IrisPrediction | ||
{ | ||
[ColumnName("Score")] | ||
public float[] PredictedLabels; | ||
} | ||
|
||
} | ||
} |
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ML.Data #Closed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TextLoader is part of just ML, should I change it as well?
In reply to: 187221578 [](ancestors = 187221578)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good question... seems they both should be in data. the argument went about like this - can a user has out of the box experience with just ML namespace.
I guess we can keep it in ML for now.
In reply to: 187221799 [](ancestors = 187221799,187221578)
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please move to Data. My PR will move TextLoader to ML.Data. #Resolved