Skip to content

Commit 0c176aa

Browse files
committed
Reverting the signature change on Fit()
1 parent 58dbbac commit 0c176aa

File tree

11 files changed

+22
-19
lines changed

11 files changed

+22
-19
lines changed

src/Microsoft.ML.Core/Data/IEstimator.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ public interface IEstimator<out TTransformer>
237237
/// <summary>
238238
/// Train and return a transformer.
239239
/// </summary>
240-
TTransformer Fit(IDataView input, IDataView validationData = null, IPredictor initialPredictor = null);
240+
TTransformer Fit(IDataView input);
241241

242242
/// <summary>
243243
/// Schema propagation for estimators.

src/Microsoft.ML.Data/DataLoadSave/EstimatorChain.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ public EstimatorChain()
4242
LastEstimator = null;
4343
}
4444

45-
public TransformerChain<TLastTransformer> Fit(IDataView input, IDataView validationData = null, IPredictor initialPredictor = null)
45+
public TransformerChain<TLastTransformer> Fit(IDataView input)
4646
{
4747
// REVIEW: before fitting, run schema propagation.
4848
// Currently, it throws.

src/Microsoft.ML.Data/DataLoadSave/EstimatorExtensions.cs

+4-4
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ public DelegateEstimator(IEstimator<TTransformer> estimator, Action<TTransformer
9090
_onFit = onFit;
9191
}
9292

93-
public TTransformer Fit(IDataView input, IDataView validationData = null, IPredictor initialPredictor = null)
93+
public TTransformer Fit(IDataView input)
9494
{
9595
var trans = _est.Fit(input);
9696
_onFit(trans);
@@ -102,9 +102,9 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
102102
}
103103

104104
/// <summary>
105-
/// Given an estimator, return a wrapping object that will call a delegate once <see cref="IEstimator{TTransformer}.Fit(IDataView, IDataView, IPredictor)"/>
105+
/// Given an estimator, return a wrapping object that will call a delegate once <see cref="IEstimator{TTransformer}.Fit(IDataView)"/>
106106
/// is called. It is often important for an estimator to return information about what was fit, which is why the
107-
/// <see cref="IEstimator{TTransformer}.Fit(IDataView, IDataView, IPredictor)"/> method returns a specifically typed object, rather than just a general
107+
/// <see cref="IEstimator{TTransformer}.Fit(IDataView)"/> method returns a specifically typed object, rather than just a general
108108
/// <see cref="ITransformer"/>. However, at the same time, <see cref="IEstimator{TTransformer}"/> are often formed into pipelines
109109
/// with many objects, so we may need to build a chain of estimators via <see cref="EstimatorChain{TLastTransformer}"/> where the
110110
/// estimator for which we want to get the transformer is buried somewhere in this chain. For that scenario, we can through this
@@ -113,7 +113,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
113113
/// <typeparam name="TTransformer">The type of <see cref="ITransformer"/> returned by <paramref name="estimator"/></typeparam>
114114
/// <param name="estimator">The estimator to wrap</param>
115115
/// <param name="onFit">The delegate that is called with the resulting <typeparamref name="TTransformer"/> instances once
116-
/// <see cref="IEstimator{TTransformer}.Fit(IDataView, IDataView, IPredictor)"/> is called. Because <see cref="IEstimator{TTransformer}.Fit(IDataView, IDataView, IPredictor)"/>
116+
/// <see cref="IEstimator{TTransformer}.Fit(IDataView)"/> is called. Because <see cref="IEstimator{TTransformer}.Fit(IDataView)"/>
117117
/// may be called multiple times, this delegate may also be called multiple times.</param>
118118
/// <returns>A wrapping estimator that calls the indicated delegate whenever fit is called</returns>
119119
public static IEstimator<TTransformer> WithOnFitDelegate<TTransformer>(this IEstimator<TTransformer> estimator, Action<TTransformer> onFit)

src/Microsoft.ML.Data/DataLoadSave/TrivialEstimator.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ namespace Microsoft.ML.Runtime.Data
88
{
99
/// <summary>
1010
/// The trivial implementation of <see cref="IEstimator{TTransformer}"/> that already has
11-
/// the transformer and returns it on every call to <see cref="Fit(IDataView, IDataView, IPredictor)"/>.
11+
/// the transformer and returns it on every call to <see cref="Fit(IDataView)"/>.
1212
///
1313
/// Concrete implementations still have to provide the schema propagation mechanism, since
1414
/// there is no easy way to infer it from the transformer.
@@ -28,7 +28,7 @@ protected TrivialEstimator(IHost host, TTransformer transformer)
2828
Transformer = transformer;
2929
}
3030

31-
public TTransformer Fit(IDataView input, IDataView validationData = null, IPredictor initialPredictor = null) => Transformer;
31+
public TTransformer Fit(IDataView input) => Transformer;
3232

3333
public abstract SchemaShape GetOutputSchema(SchemaShape inputSchema);
3434
}

src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ public TrainerEstimatorBase(IHost host, SchemaShape.Column feature, SchemaShape.
6767
WeightColumn = weight;
6868
}
6969

70-
public TTransformer Fit(IDataView input, IDataView validationData = null, IPredictor initialPredictor = null) => TrainTransformer(input, validationData, initialPredictor);
70+
public TTransformer Fit(IDataView input) => TrainTransformer(input);
7171

7272
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
7373
{

src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ public CopyColumnsEstimator(IHostEnvironment env, params (string source, string
5454
_columns = columns;
5555
}
5656

57-
public CopyColumnsTransform Fit(IDataView input, IDataView validationData = null, IPredictor initialPredictor = null)
57+
public CopyColumnsTransform Fit(IDataView input)
5858
{
5959
// Invoke schema validation.
6060
GetOutputSchema(SchemaShape.Create(input.Schema));

src/Microsoft.ML.Data/Transforms/Normalizer.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ public Normalizer(IHostEnvironment env, params ColumnBase[] columns)
170170
_columns = columns.ToArray();
171171
}
172172

173-
public NormalizerTransformer Fit(IDataView input, IDataView validationData = null, IPredictor initialPredictor = null)
173+
public NormalizerTransformer Fit(IDataView input)
174174
{
175175
_host.CheckValue(input, nameof(input));
176176
return NormalizerTransformer.Train(_host, input, _columns);

src/Microsoft.ML.Data/Transforms/TermEstimator.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ public TermEstimator(IHostEnvironment env, params TermTransform.ColumnInfo[] col
2323
_columns = columns;
2424
}
2525

26-
public TermTransform Fit(IDataView input, IDataView validationData = null, IPredictor initialPredictor = null) => new TermTransform(_host, input, _columns);
26+
public TermTransform Fit(IDataView input) => new TermTransform(_host, input, _columns);
2727

2828
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
2929
{

test/Microsoft.ML.StaticPipelineTesting/StaticPipeFakes.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ public override IEstimator<ITransformer> Reconcile(
5151

5252
private sealed class FakeEstimator : IEstimator<ITransformer>
5353
{
54-
public ITransformer Fit(IDataView input, IDataView validationData = null, IPredictor initialPredictor = null) => throw new NotImplementedException();
54+
public ITransformer Fit(IDataView input) => throw new NotImplementedException();
5555
public SchemaShape GetOutputSchema(SchemaShape inputSchema) => throw new NotImplementedException();
5656
}
5757

test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithInitialPredictor.cs

+4-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5+
using Microsoft.ML.Runtime;
56
using Microsoft.ML.Runtime.Data;
67
using Microsoft.ML.Runtime.Learners;
78
using Xunit;
@@ -38,7 +39,9 @@ public void New_TrainWithInitialPredictor()
3839

3940
// Train the second predictor on the same data.
4041
var secondTrainer = new AveragedPerceptronTrainer(env, new AveragedPerceptronTrainer.Arguments());
41-
var finalModel = secondTrainer.Fit(trainData, initialPredictor: firstModel.Model);
42+
43+
var trainRoles = new RoleMappedData(trainData, label: "Label", feature: "Features");
44+
var finalModel = secondTrainer.Train(new TrainContext(trainRoles, initialPredictor: firstModel.Model));
4245
}
4346
}
4447
}

test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Wrappers.cs

+5-5
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ protected TrainerBase(IHostEnvironment env, TrainerInfo trainerInfo, string feat
245245
Info = trainerInfo;
246246
}
247247

248-
public TTransformer Fit(IDataView input, IDataView validationData = null, IPredictor initialPredictor = null)
248+
public TTransformer Fit(IDataView input)
249249
{
250250
return TrainTransformer(input);
251251
}
@@ -311,7 +311,7 @@ public MyTextTransform(IHostEnvironment env, TextTransform.Arguments args)
311311
_args = args;
312312
}
313313

314-
public TransformWrapper Fit(IDataView input, IDataView validationData = null, IPredictor initialPredictor = null)
314+
public TransformWrapper Fit(IDataView input)
315315
{
316316
var xf = TextTransform.Create(_env, _args, input);
317317
var empty = new EmptyDataView(_env, input.Schema);
@@ -338,7 +338,7 @@ public MyConcatTransform(IHostEnvironment env, string name, params string[] sour
338338
_source = source;
339339
}
340340

341-
public TransformWrapper Fit(IDataView input, IDataView validationData = null, IPredictor initialPredictor = null)
341+
public TransformWrapper Fit(IDataView input)
342342
{
343343
var xf = new ConcatTransform(_env, input, _name, _source);
344344
var empty = new EmptyDataView(_env, input.Schema);
@@ -365,7 +365,7 @@ public MyKeyToValueTransform(IHostEnvironment env, string name, string source =
365365
_source = source;
366366
}
367367

368-
public TransformWrapper Fit(IDataView input, IDataView validationData = null, IPredictor initialPredictor = null)
368+
public TransformWrapper Fit(IDataView input)
369369
{
370370
var xf = new KeyToValueTransform(_env, input, _name, _source);
371371
var empty = new EmptyDataView(_env, input.Schema);
@@ -510,7 +510,7 @@ public MyLambdaTransform(IHostEnvironment env, Action<TSrc, TDst> action)
510510
_action = action;
511511
}
512512

513-
public TransformWrapper Fit(IDataView input, IDataView validationData = null, IPredictor initialPredictor = null)
513+
public TransformWrapper Fit(IDataView input)
514514
{
515515
var xf = LambdaTransform.CreateMap(_env, input, _action);
516516
var empty = new EmptyDataView(_env, input.Schema);

0 commit comments

Comments
 (0)