Skip to content

Conversion of ITrainer.Train returns predictor, accepts +TrainContext #522

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jul 18, 2018
155 changes: 41 additions & 114 deletions src/Microsoft.ML.Core/Prediction/ITrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.IO;
using Microsoft.ML.Runtime.Data;

namespace Microsoft.ML.Runtime
{
Expand All @@ -27,151 +26,79 @@ namespace Microsoft.ML.Runtime
public delegate void SignatureSequenceTrainer();
public delegate void SignatureMatrixRecommendingTrainer();

/// <summary>
/// Interface to provide extra information about a trainer.
/// </summary>
public interface ITrainerEx : ITrainer
{
// REVIEW: Ideally trainers should be able to communicate
// something about the type of data they are capable of being trained
// on, e.g., what ColumnKinds they want, how many of each, of what type,
// etc. This interface seems like the most natural conduit for that sort
// of extra information.

// REVIEW: Can we please have consistent naming here?
// 'Need' vs. 'Want' looks arbitrary to me, and it's grammatically more correct to
// be 'Needs' / 'Wants' anyway.

/// <summary>
/// Whether the trainer needs to see data in normalized form.
/// </summary>
bool NeedNormalization { get; }

/// <summary>
/// Whether the trainer needs calibration to produce probabilities.
/// </summary>
bool NeedCalibration { get; }

/// <summary>
/// Whether this trainer could benefit from a cached view of the data.
/// </summary>
bool WantCaching { get; }
}

public interface ITrainerHost
{
Random Rand { get; }
int Verbosity { get; }

TextWriter StdOut { get; }
TextWriter StdErr { get; }
}

// The Trainer (of Factory) can optionally implement this.
public interface IModelCombiner<TModel, TPredictor>
where TPredictor : IPredictor
{
TPredictor CombineModels(IEnumerable<TModel> models);
}

public delegate void SignatureModelCombiner(PredictionKind kind);

/// <summary>
/// Weakly typed interface for a trainer "session" that produces a predictor.
/// The base interface for a trainers. Implementors should not implement this interface directly,
/// but rather implement the more specific <see cref="ITrainer{TPredictor}"/>.
/// </summary>
public interface ITrainer
{
/// <summary>
/// Return the type of prediction task for the produced predictor.
/// Auxiliary information about the trainer in terms of its capabilities
/// and requirements.
/// </summary>
PredictionKind PredictionKind { get; }
TrainerInfo Info { get; }

/// <summary>
/// Returns the trained predictor.
/// REVIEW: Consider removing this.
/// Return the type of prediction task for the produced predictor.
/// </summary>
IPredictor CreatePredictor();
}

/// <summary>
/// Interface implemented by the MetalinearLearners base class.
/// Used to distinguish the MetaLinear Learners from the other learners
/// </summary>
public interface IMetaLinearTrainer
{

}
PredictionKind PredictionKind { get; }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think this could/should move into TrainerInfo? Or is the thought we are going to delete it all together? So no need to move it, just to delete it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather get rid of PredictionKind altogether, but not as part of this PR, since what we replace it with will require some discussion and consideration in a separate issue. That is, get rid of it from both IPredictor and ITrainer, and replace with something that uses the types instead. If, however, we were to keep it, there's something attractive about having IPredictor resemble ITrainer.


public interface ITrainer<in TDataSet> : ITrainer
{
/// <summary>
/// Trains a predictor using the specified dataset.
/// Trains a predictor.
/// </summary>
/// <param name="data"> Training dataset </param>
void Train(TDataSet data);
/// <param name="context">A context containing at least the training data</param>
/// <returns>The trained predictor</returns>
/// <seealso cref="ITrainer{TPredictor}.Train(TrainContext)"/>
IPredictor Train(TrainContext context);
}

/// <summary>
/// Strongly typed generic interface for a trainer. A trainer object takes
/// supervision data and produces a predictor.
/// Strongly typed generic interface for a trainer. A trainer object takes training data
/// and produces a predictor.
/// </summary>
/// <typeparam name="TDataSet"> Type of the training dataset</typeparam>
/// <typeparam name="TPredictor"> Type of predictor produced</typeparam>
public interface ITrainer<in TDataSet, out TPredictor> : ITrainer<TDataSet>
public interface ITrainer<out TPredictor> : ITrainer
where TPredictor : IPredictor
{
/// <summary>
/// Returns the trained predictor.
/// </summary>
/// <returns>Trained predictor ready to make predictions</returns>
new TPredictor CreatePredictor();
}

/// <summary>
/// Trainers that want data to do their own validation implement this interface.
/// </summary>
public interface IValidatingTrainer<in TDataSet> : ITrainer<TDataSet>
{
/// <summary>
/// Trains a predictor using the specified dataset.
/// Trains a predictor.
/// </summary>
/// <param name="data">Training dataset</param>
/// <param name="validData">Validation dataset</param>
void Train(TDataSet data, TDataSet validData);
/// <param name="context">A context containing at least the training data</param>
/// <returns>The trained predictor</returns>
new TPredictor Train(TrainContext context);
}

public interface IIncrementalTrainer<in TDataSet, in TPredictor> : ITrainer<TDataSet>
public static class TrainerExtensions
{
/// <summary>
/// Trains a predictor using the specified dataset and a trained predictor.
/// Convenience train extension for the case where one has only a training set with no auxiliary information.
/// Equivalent to calling <see cref="ITrainer.Train(TrainContext)"/>
/// on a <see cref="TrainContext"/> constructed with <paramref name="trainData"/>.
/// </summary>
/// <param name="data">Training dataset</param>
/// <param name="predictor">A trained predictor</param>
void Train(TDataSet data, TPredictor predictor);
}
/// <param name="trainer">The trainer</param>
/// <param name="trainData">The training data.</param>
/// <returns>The trained predictor</returns>
public static IPredictor Train(this ITrainer trainer, RoleMappedData trainData)
=> trainer.Train(new TrainContext(trainData));

public interface IIncrementalValidatingTrainer<in TDataSet, in TPredictor> : ITrainer<TDataSet>
{
/// <summary>
/// Trains a predictor using the specified dataset and a trained predictor.
/// Convenience train extension for the case where one has only a training set with no auxiliary information.
/// Equivalent to calling <see cref="ITrainer{TPredictor}.Train(TrainContext)"/>
/// on a <see cref="TrainContext"/> constructed with <paramref name="trainData"/>.
/// </summary>
/// <param name="data">Training dataset</param>
/// <param name="validData">Validation dataset</param>
/// <param name="predictor">A trained predictor</param>
void Train(TDataSet data, TDataSet validData, TPredictor predictor);
/// <param name="trainer">The trainer</param>
/// <param name="trainData">The training data.</param>
/// <returns>The trained predictor</returns>
public static TPredictor Train<TPredictor>(this ITrainer<TPredictor> trainer, RoleMappedData trainData) where TPredictor : IPredictor
=> trainer.Train(new TrainContext(trainData));
}

#if FUTURE
public interface IMultiTrainer<in TDataSet, in TFeatures, out TResult> :
IMultiTrainer<TDataSet, TDataSet, TFeatures, TResult>
{
}

public interface IMultiTrainer<in TDataSet, in TDataBatch, in TFeatures, out TResult> :
ITrainer<TDataSet, TFeatures, TResult>
// A trainer can optionally implement this to indicate it can combine multiple models into a single predictor.
public interface IModelCombiner<TModel, TPredictor>
where TPredictor : IPredictor
{
void UpdatePredictor(TDataBatch trainInstance);
IPredictor<TFeatures, TResult> GetCurrentPredictor();
TPredictor CombineModels(IEnumerable<TModel> models);
}
#endif
}
57 changes: 57 additions & 0 deletions src/Microsoft.ML.Core/Prediction/TrainContext.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.ML.Runtime.Data;

namespace Microsoft.ML.Runtime
{
/// <summary>
/// Holds information relevant to trainers. Instances of this class are meant to be constructed and passed
/// into <see cref="ITrainer{TPredictor}.Train(TrainContext)"/> or <see cref="ITrainer.Train(TrainContext)"/>.
/// This holds at least a training set, as well as optioonally a predictor.
/// </summary>
public sealed class TrainContext
{
/// <summary>
/// The training set. Cannot be <c>null</c>.
/// </summary>
public RoleMappedData TrainingSet { get; }

/// <summary>
/// The validation set. Can be <c>null</c>. Note that passing a non-<c>null</c> validation set into
/// a trainer that does not support validation sets should not be considered an error condition. It
Copy link
Contributor

@zeahmed zeahmed Jul 12, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

an error condition [](start = 85, length = 18)

Can we encourage trainer implementer to show at least warning? if it make sense. #Closed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer to not. Warnings are already printed by the tooling utilizing it as we see here for validation sets and here for initial predictors. To decentralize the checking, in addition to the problems of having to insert that boilerplate somehow, also runs somewhat contrary to the philosophy behind TrainContext which is that we can add new members to that without trainer implementors having to care. Under this proposal, suddenly they'd be encouraged to care, and I'd rather keep that object as simple as possible. Plus of course if we did that things that work through those aforementioned pathways would show two warnings.


In reply to: 202148448 [](ancestors = 202148448)

/// should simply be ignored in that case.
/// </summary>
public RoleMappedData ValidationSet { get; }

/// <summary>
/// The initial predictor, for incremental training. Note that if a <see cref="ITrainer"/> implementor
/// does not support incremental training, then it can ignore it similarly to how one would ignore
/// <see cref="ValidationSet"/>. However, if the trainer does support incremental training and there
/// is something wrong with a non-<c>null</c> value of this, then the trainer ought to throw an exception.
/// </summary>
public IPredictor InitialPredictor { get; }


/// <summary>
/// Constructor, given a training set and optional other arguments.
/// </summary>
/// <param name="trainingSet">Will set <see cref="TrainingSet"/> to this value. This must be specified</param>
/// <param name="validationSet">Will set <see cref="ValidationSet"/> to this value if specified</param>
/// <param name="initialPredictor">Will set <see cref="InitialPredictor"/> to this value if specified</param>
public TrainContext(RoleMappedData trainingSet, RoleMappedData validationSet = null, IPredictor initialPredictor = null)
{
Contracts.CheckValue(trainingSet, nameof(trainingSet));
Contracts.CheckValueOrNull(validationSet);
Contracts.CheckValueOrNull(initialPredictor);

// REVIEW: Should there be code here to ensure that the role mappings between the two are compatible?
// That is, all the role mappings are the same and the columns between them have identical types?

TrainingSet = trainingSet;
ValidationSet = validationSet;
InitialPredictor = initialPredictor;
}
}
}
71 changes: 71 additions & 0 deletions src/Microsoft.ML.Core/Prediction/TrainerInfo.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

namespace Microsoft.ML.Runtime
{
/// <summary>
/// Instances of this class posses information about trainers, in terms of their requirements and capabilities.
/// The intended usage is as the value for <see cref="ITrainer.Info"/>.
/// </summary>
public sealed class TrainerInfo
{
// REVIEW: Ideally trainers should be able to communicate
// something about the type of data they are capable of being trained
// on, e.g., what ColumnKinds they want, how many of each, of what type,
// etc. This interface seems like the most natural conduit for that sort
// of extra information.

/// <summary>
/// Whether the trainer needs to see data in normalized form. Only non-parametric learners will tend to produce
/// normalization here.
/// </summary>
public bool NeedNormalization { get; }

/// <summary>
/// Whether the trainer needs calibration to produce probabilities. As a general rule only trainers that produce
/// binary classifier predictors that also do not have a natural probabilistic interpretation should have a
/// <c>true</c> value here.
/// </summary>
public bool NeedCalibration { get; }

/// <summary>
/// Whether this trainer could benefit from a cached view of the data. Trainers that have few passes over the
/// data, or that need to build their own custom data structure over the data, will have a <c>false</c> here.
/// </summary>
public bool WantCaching { get; }

/// <summary>
/// Whether the trainer supports validation sets via <see cref="TrainContext.ValidationSet"/>. Not implementing
/// this interface and returning <c>true</c> from this property is an indication the trainer does not support
/// that.
/// </summary>
public bool SupportsValidation { get; }

/// <summary>
/// Whether the trainer can support incremental trainers via <see cref="TrainContext.InitialPredictor"/>. Not
/// implementing this interface and returning <c>true</c> from this property is an indication the trainer does
/// not support that.
/// </summary>
public bool SupportsIncrementalTraining { get; }

/// <summary>
/// Initializes with the given parameters. The parameters have default values for the most typical values
/// for most classical trainers.
/// </summary>
/// <param name="normalization">The value for the property <see cref="NeedNormalization"/></param>
/// <param name="calibration">The value for the property <see cref="NeedCalibration"/></param>
/// <param name="caching">The value for the property <see cref="WantCaching"/></param>
/// <param name="supportValid">The value for the property <see cref="SupportsValidation"/></param>
/// <param name="supportIncrementalTrain">The value for the property <see cref="SupportsIncrementalTraining"/></param>
public TrainerInfo(bool normalization = true, bool calibration = false, bool caching = true,
bool supportValid = false, bool supportIncrementalTrain = false)
{
NeedNormalization = normalization;
NeedCalibration = calibration;
WantCaching = caching;
SupportsValidation = supportValid;
SupportsIncrementalTraining = supportIncrementalTrain;
}
}
}
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Core/Utilities/ObjectPool.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public abstract class ObjectPoolBase<T>
public int Count => _pool.Count;
public int NumCreated { get { return _numCreated; } }

protected internal ObjectPoolBase()
private protected ObjectPoolBase()
{
_pool = new ConcurrentBag<T>();
}
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ private FoldResult RunFold(int fold)
if (_getValidationDataView != null)
{
ch.Assert(_applyTransformsToValidationData != null);
if (!TrainUtils.CanUseValidationData(trainer))
if (!trainer.Info.SupportsValidation)
ch.Warning("Trainer does not accept validation dataset.");
else
{
Expand Down
Loading