Skip to content

Move Learner* input base and Transform* input base out of Entrypoints… #2748

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Feb 28, 2019
Merged
1 change: 1 addition & 0 deletions src/Microsoft.ML.Data/EntryPoints/Cache.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using Microsoft.ML.Data;
using Microsoft.ML.Data.IO;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Transforms;

[assembly: LoadableClass(typeof(void), typeof(Cache), null, typeof(SignatureEntryPointModule), "Cache")]
namespace Microsoft.ML.EntryPoints
Expand Down
202 changes: 2 additions & 200 deletions src/Microsoft.ML.Data/EntryPoints/InputBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using Microsoft.Data.DataView;
using Microsoft.ML.Calibrators;
using Microsoft.ML.CommandLine;
Expand All @@ -12,21 +10,6 @@

namespace Microsoft.ML.EntryPoints
{
/// <summary>
/// The base class for all transform inputs.
/// </summary>
[TlcModule.EntryPointKind(typeof(CommonInputs.ITransformInput))]
public abstract class TransformInputBase
{
/// <summary>
/// The input dataset. Used only in entry-point methods, since the normal API mechanism for feeding in a dataset to
/// create an <see cref="ITransformer"/> is to use the <see cref="IEstimator{TTransformer}.Fit(IDataView)"/> method.
/// </summary>
[BestFriend]
[Argument(ArgumentType.Required, HelpText = "Input dataset", Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly, SortOrder = 1)]
internal IDataView Data;
}

[BestFriend]
internal enum CachingOptions
{
Expand All @@ -36,89 +19,12 @@ internal enum CachingOptions
None
}

/// <summary>
/// The base class for all learner inputs.
/// </summary>
[TlcModule.EntryPointKind(typeof(CommonInputs.ITrainerInput))]
public abstract class LearnerInputBase
{
/// <summary>
/// The data to be used for training. Used only in entry-points, since in the API the expected mechanism is
/// that the user will use the <see cref="IEstimator{TTransformer}.Fit(IDataView)"/> or some other train
/// method.
/// </summary>
[BestFriend]
[Argument(ArgumentType.Required, ShortName = "data", HelpText = "The data to be used for training", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
internal IDataView TrainingData;

/// <summary>
/// Column to use for features.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for features", ShortName = "feat", SortOrder = 2, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
public string FeatureColumn = DefaultColumnNames.Features;

/// <summary>
/// Normalize option for the feature column. Used only in entry-points, since in the API the user is expected to do this themselves.
/// </summary>
[BestFriend]
[Argument(ArgumentType.AtMostOnce, HelpText = "Normalize option for the feature column", ShortName = "norm", SortOrder = 5, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
internal NormalizeOption NormalizeFeatures = NormalizeOption.Auto;

/// <summary>
/// Whether learner should cache input training data. Used only in entry-points, since the intended API mechanism
/// is that the user will use the <see cref="DataOperationsCatalog.Cache(IDataView, string[])"/> or other method
/// like <see cref="EstimatorChain{TLastTransformer}.AppendCacheCheckpoint(IHostEnvironment)"/>.
/// </summary>
[BestFriend]
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Whether learner should cache input training data", ShortName = "cache", SortOrder = 6, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
internal CachingOptions Caching = CachingOptions.Auto;
}

/// <summary>
/// The base class for all learner inputs that support a Label column.
/// </summary>
[TlcModule.EntryPointKind(typeof(CommonInputs.ITrainerInputWithLabel))]
public abstract class LearnerInputBaseWithLabel : LearnerInputBase
{
/// <summary>
/// Column to use for labels.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for labels", ShortName = "lab", SortOrder = 3, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
public string LabelColumn = DefaultColumnNames.Label;
}

// REVIEW: This is a known antipattern, but the solution involves the decorator pattern which can't be used in this case.
/// <summary>
/// The base class for all learner inputs that support a weight column.
/// </summary>
[TlcModule.EntryPointKind(typeof(CommonInputs.ITrainerInputWithWeight))]
public abstract class LearnerInputBaseWithWeight : LearnerInputBaseWithLabel
{
/// <summary>
/// The name of the example weight column.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for example weight", ShortName = "weight", SortOrder = 4, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
public string WeightColumn = null;
}

/// <summary>
/// The base class for all unsupervised learner inputs that support a weight column.
/// </summary>
[TlcModule.EntryPointKind(typeof(CommonInputs.IUnsupervisedTrainerWithWeight))]
public abstract class UnsupervisedLearnerInputBaseWithWeight : LearnerInputBase
{
/// <summary>
/// Column to use for example weight.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for example weight", ShortName = "weight", SortOrder = 4, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
public string WeightColumn = null;
}

/// <summary>
/// The base class for all evaluators inputs.
/// </summary>
[TlcModule.EntryPointKind(typeof(CommonInputs.IEvaluatorInput))]
public abstract class EvaluateInputBase
[BestFriend]
internal abstract class EvaluateInputBase
{
[Argument(ArgumentType.Required, ShortName = "data", HelpText = "The data to be used for evaluation.", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
public IDataView Data;
Expand All @@ -127,110 +33,6 @@ public abstract class EvaluateInputBase
public string NameColumn = DefaultColumnNames.Name;
}

[TlcModule.EntryPointKind(typeof(CommonInputs.ITrainerInputWithGroupId))]
public abstract class LearnerInputBaseWithGroupId : LearnerInputBaseWithWeight
{
/// <summary>
/// Column to use for example groupId.
/// </summary>
[Argument(ArgumentType.AtMostOnce, Name = "GroupIdColumn", HelpText = "Column to use for example groupId", ShortName = "groupId", SortOrder = 5, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
public string GroupIdColumn = null;
}

[BestFriend]
internal static class LearnerEntryPointsUtils
{
public static string FindColumn(IExceptionContext ectx, DataViewSchema schema, Optional<string> value)
{
Contracts.CheckValueOrNull(ectx);
ectx.CheckValue(schema, nameof(schema));
ectx.CheckValue(value, nameof(value));

if (string.IsNullOrEmpty(value?.Value))
return null;
if (!schema.TryGetColumnIndex(value, out int col))
{
if (value.IsExplicit)
throw ectx.Except("Column '{0}' not found", value);
return null;
}
return value;
}

public static TOut Train<TArg, TOut>(IHost host, TArg input,
Func<ITrainer> createTrainer,
Func<string> getLabel = null,
Func<string> getWeight = null,
Func<string> getGroup = null,
Func<string> getName = null,
Func<IEnumerable<KeyValuePair<RoleMappedSchema.ColumnRole, string>>> getCustom = null,
ICalibratorTrainerFactory calibrator = null,
int maxCalibrationExamples = 0)
where TArg : LearnerInputBase
where TOut : CommonOutputs.TrainerOutput, new()
{
using (var ch = host.Start("Training"))
{
var schema = input.TrainingData.Schema;
var feature = FindColumn(ch, schema, input.FeatureColumn);
var label = getLabel?.Invoke();
var weight = getWeight?.Invoke();
var group = getGroup?.Invoke();
var name = getName?.Invoke();
var custom = getCustom?.Invoke();

var trainer = createTrainer();

IDataView view = input.TrainingData;
TrainUtils.AddNormalizerIfNeeded(host, ch, trainer, ref view, feature, input.NormalizeFeatures);

ch.Trace("Binding columns");
var roleMappedData = new RoleMappedData(view, label, feature, group, weight, name, custom);

RoleMappedData cachedRoleMappedData = roleMappedData;
Cache.CachingType? cachingType = null;
switch (input.Caching)
{
case CachingOptions.Memory:
{
cachingType = Cache.CachingType.Memory;
break;
}
case CachingOptions.Disk:
{
cachingType = Cache.CachingType.Disk;
break;
}
case CachingOptions.Auto:
{
// REVIEW: we should switch to hybrid caching in future.
if (!(input.TrainingData is BinaryLoader) && trainer.Info.WantCaching)
// default to Memory so mml is on par with maml
cachingType = Cache.CachingType.Memory;
break;
}
case CachingOptions.None:
break;
default:
throw ch.ExceptParam(nameof(input.Caching), "Unknown option for caching: '{0}'", input.Caching);
}

if (cachingType.HasValue)
{
var cacheView = Cache.CacheData(host, new Cache.CacheInput()
{
Data = roleMappedData.Data,
Caching = cachingType.Value
}).OutputData;
cachedRoleMappedData = new RoleMappedData(cacheView, roleMappedData.Schema.GetColumnRoleNames());
}

var predictor = TrainUtils.Train(host, ch, cachedRoleMappedData, trainer, calibrator, maxCalibrationExamples);
return new TOut() { PredictorModel = new PredictorModelImpl(host, roleMappedData, input.TrainingData, predictor) };
}
}
}

/// <summary>
/// Common input interfaces for TLC components.
/// </summary>
Expand Down
1 change: 1 addition & 0 deletions src/Microsoft.ML.Data/Prediction/Calibrator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
using Microsoft.ML.Model;
using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.Model.Pfa;
using Microsoft.ML.Transforms;
using Newtonsoft.Json.Linq;

[assembly: LoadableClass(PlattCalibratorTrainer.Summary, typeof(PlattCalibratorTrainer), null, typeof(SignatureCalibrator),
Expand Down
Loading