Skip to content

Refactoring of Options for ImagePixelExtractingEstimator #3033

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Data/DataView/LambdaColumnMapper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ private sealed class Column : OneToOneColumn
{
public Column(string src, string dst)
{
Name = dst;
Source = src;
OutputName = dst;
InputName = src;
}
}

Expand Down
50 changes: 24 additions & 26 deletions src/Microsoft.ML.Data/Transforms/ColumnBindingsBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,21 @@

namespace Microsoft.ML.Data
{
internal abstract class SourceNameColumnBase
/// <summary>
/// Specifies input and output column names for a transformation
/// </summary>
public abstract class OneToOneColumn
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Name of the new column", ShortName = "name")]
public string Name;
/// <summary>Name of the column resulting from the transformation of <see cref="InputName"/>.</summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Name of the new column", Name = "Name", ShortName = "name")]
public string OutputName;

[Argument(ArgumentType.AtMostOnce, HelpText = "Name of the source column", ShortName = "src")]
public string Source;
/// <summary>Name of the column to transform. If set to <see langword="null"/>, the value of the <see cref= "OutputName"/> will be used as source.</summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Name of the source column", Name = "Source", ShortName = "src")]
public string InputName;

[BestFriend]
private protected SourceNameColumnBase() { }
private protected OneToOneColumn() { }

/// <summary>
/// For parsing from a string. This supports "name" and "name:source".
Expand All @@ -35,7 +40,7 @@ private protected SourceNameColumnBase() { }
private protected virtual bool TryParse(string str)
{
Contracts.AssertNonEmpty(str);
return ColumnParsingUtils.TryParse(str, out Name, out Source);
return ColumnParsingUtils.TryParse(str, out OutputName, out InputName);
}

/// <summary>
Expand All @@ -49,7 +54,7 @@ private protected virtual bool TryParse(string str)
private protected bool TryParse(string str, out string extra)
{
Contracts.AssertNonEmpty(str);
return ColumnParsingUtils.TryParse(str, out Name, out Source, out extra);
return ColumnParsingUtils.TryParse(str, out OutputName, out InputName, out extra);
}

/// <summary>
Expand All @@ -62,12 +67,12 @@ private protected virtual bool TryUnparseCore(StringBuilder sb)

if (!TrySanitize())
return false;
if (CmdQuoter.NeedsQuoting(Name) || CmdQuoter.NeedsQuoting(Source))
if (CmdQuoter.NeedsQuoting(OutputName) || CmdQuoter.NeedsQuoting(InputName))
return false;

sb.Append(Name);
if (Source != Name)
sb.Append(':').Append(Source);
sb.Append(OutputName);
if (InputName != OutputName)
sb.Append(':').Append(InputName);
return true;
}

Expand All @@ -82,10 +87,10 @@ private protected virtual bool TryUnparseCore(StringBuilder sb, string extra)

if (!TrySanitize())
return false;
if (CmdQuoter.NeedsQuoting(Name) || CmdQuoter.NeedsQuoting(Source))
if (CmdQuoter.NeedsQuoting(OutputName) || CmdQuoter.NeedsQuoting(InputName))
return false;

sb.Append(Name).Append(':').Append(extra).Append(':').Append(Source);
sb.Append(OutputName).Append(':').Append(extra).Append(':').Append(InputName);
return true;
}

Expand All @@ -95,21 +100,14 @@ private protected virtual bool TryUnparseCore(StringBuilder sb, string extra)
/// </summary>
public bool TrySanitize()
{
if (string.IsNullOrWhiteSpace(Name))
Name = Source;
else if (string.IsNullOrWhiteSpace(Source))
Source = Name;
return !string.IsNullOrWhiteSpace(Name);
if (string.IsNullOrWhiteSpace(OutputName))
OutputName = InputName;
else if (string.IsNullOrWhiteSpace(InputName))
InputName = OutputName;
return !string.IsNullOrWhiteSpace(OutputName);
}
}

[BestFriend]
internal abstract class OneToOneColumn : SourceNameColumnBase
{
[BestFriend]
private protected OneToOneColumn() { }
}

[BestFriend]
internal abstract class ManyToOneColumn
{
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Transforms/ColumnCopying.cs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ internal static IDataTransform Create(IHostEnvironment env, Options options, IDa
Contracts.CheckValue(env, nameof(env));
env.CheckValue(options, nameof(options));

var transformer = new ColumnCopyingTransformer(env, options.Columns.Select(x => (x.Name, x.Source)).ToArray());
var transformer = new ColumnCopyingTransformer(env, options.Columns.Select(x => (x.OutputName, x.InputName)).ToArray());
return transformer.MakeDataTransform(input);
}

Expand Down
12 changes: 4 additions & 8 deletions src/Microsoft.ML.Data/Transforms/ExtensionsCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,17 @@ namespace Microsoft.ML
/// <summary>
/// Specifies input and output column names for a transformation.
/// </summary>
[BestFriend]
internal sealed class ColumnOptions
public sealed class ColumnOptions : OneToOneColumn
{
private readonly string _outputColumnName;
private readonly string _inputColumnName;

/// <summary>
/// Specifies input and output column names for a transformation.
/// </summary>
/// <param name="outputColumnName">Name of the column resulting from the transformation of <paramref name="inputColumnName"/>.</param>
/// <param name="inputColumnName">Name of the column to transform. If set to <see langword="null"/>, the value of the <paramref name="outputColumnName"/> will be used as source.</param>
public ColumnOptions(string outputColumnName, string inputColumnName = null)
{
_outputColumnName = outputColumnName;
_inputColumnName = inputColumnName ?? outputColumnName;
OutputName = outputColumnName;
InputName = inputColumnName ?? outputColumnName;
}

/// <summary>
Expand All @@ -39,7 +35,7 @@ public static implicit operator ColumnOptions((string outputColumnName, string i
[BestFriend]
internal static (string outputColumnName, string inputColumnName)[] ConvertToValueTuples(ColumnOptions[] infos)
{
return infos.Select(info => (info._outputColumnName, info._inputColumnName)).ToArray();
return infos.Select(info => (info.OutputName, info.InputName)).ToArray();
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Data/Transforms/Hashing.cs
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,8 @@ private static IDataTransform Create(IHostEnvironment env, Options options, IDat
var item = options.Columns[i];
var kind = item.MaximumNumberOfInverts ?? options.MaximumNumberOfInverts;
cols[i] = new HashingEstimator.ColumnOptions(
item.Name,
item.Source ?? item.Name,
item.OutputName,
item.InputName ?? item.OutputName,
item.NumberOfBits ?? options.NumberOfBits,
item.Seed ?? options.Seed,
item.Ordered ?? options.Ordered,
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Transforms/KeyToValue.cs
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ internal static IDataTransform Create(IHostEnvironment env, Options options, IDa
env.CheckValue(input, nameof(input));
env.CheckNonEmpty(options.Columns, nameof(options.Columns));

var transformer = new KeyToValueMappingTransformer(env, options.Columns.Select(c => (c.Name, c.Source ?? c.Name)).ToArray());
var transformer = new KeyToValueMappingTransformer(env, options.Columns.Select(c => (c.OutputName, c.InputName ?? c.OutputName)).ToArray());
return transformer.MakeDataTransform(input);
}

Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Data/Transforms/KeyToVector.cs
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,8 @@ private static IDataTransform Create(IHostEnvironment env, Options options, IDat
var item = options.Columns[i];

cols[i] = new KeyToVectorMappingEstimator.ColumnOptions(
item.Name,
item.Source ?? item.Name,
item.OutputName,
item.InputName ?? item.OutputName,
item.Bag ?? options.Bag);
};
return new KeyToVectorMappingTransformer(env, cols).MakeDataTransform(input);
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Transforms/LabelConvertTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ private static VersionInfo GetVersionInfo()
/// <param name="outputColumnName">Name of the output column.</param>
/// <param name="inputColumnName">Name of the input column. If this is null '<paramref name="outputColumnName"/>' will be used.</param>
public LabelConvertTransform(IHostEnvironment env, IDataView input, string outputColumnName, string inputColumnName = null)
: this(env, new Arguments() { Columns = new[] { new Column() { Source = inputColumnName ?? outputColumnName, Name = outputColumnName } } }, input)
: this(env, new Arguments() { Columns = new[] { new Column() { InputName = inputColumnName ?? outputColumnName, OutputName = outputColumnName } } }, input)
{
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ public LabelIndicatorTransform(IHostEnvironment env,
int classIndex,
string name,
string source = null)
: this(env, new Options() { Columns = new[] { new Column() { Source = source ?? name, Name = name } }, ClassIndex = classIndex }, input)
: this(env, new Options() { Columns = new[] { new Column() { InputName = source ?? name, OutputName = name } }, ClassIndex = classIndex }, input)
{
}

Expand Down
36 changes: 18 additions & 18 deletions src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,8 @@ internal static IDataTransform Create(IHostEnvironment env, MinMaxArguments args

var columns = args.Columns
.Select(col => new NormalizingEstimator.MinMaxColumnOptions(
col.Name,
col.Source ?? col.Name,
col.OutputName,
col.InputName ?? col.OutputName,
col.MaximumExampleCount ?? args.MaximumExampleCount,
col.EnsureZeroUntouched ?? args.EnsureZeroUntouched))
.ToArray();
Expand All @@ -308,8 +308,8 @@ internal static IDataTransform Create(IHostEnvironment env, MeanVarArguments arg

var columns = args.Columns
.Select(col => new NormalizingEstimator.MeanVarianceColumnOptions(
col.Name,
col.Source ?? col.Name,
col.OutputName,
col.InputName ?? col.OutputName,
col.MaximumExampleCount ?? args.MaximumExampleCount,
col.EnsureZeroUntouched ?? args.EnsureZeroUntouched))
.ToArray();
Expand All @@ -328,8 +328,8 @@ internal static IDataTransform Create(IHostEnvironment env, LogMeanVarArguments

var columns = args.Columns
.Select(col => new NormalizingEstimator.LogMeanVarianceColumnOptions(
col.Name,
col.Source ?? col.Name,
col.OutputName,
col.InputName ?? col.OutputName,
col.MaximumExampleCount ?? args.MaximumExampleCount,
args.UseCdf))
.ToArray();
Expand All @@ -348,8 +348,8 @@ internal static IDataTransform Create(IHostEnvironment env, BinArguments args, I

var columns = args.Columns
.Select(col => new NormalizingEstimator.BinningColumnOptions(
col.Name,
col.Source ?? col.Name,
col.OutputName,
col.InputName ?? col.OutputName,
col.MaximumExampleCount ?? args.MaximumExampleCount,
col.EnsureZeroUntouched ?? args.EnsureZeroUntouched,
col.NumBins ?? args.NumBins))
Expand Down Expand Up @@ -926,8 +926,8 @@ public static IColumnFunctionBuilder CreateBuilder(MinMaxArguments args, IHost h
host.AssertValue(args);

return CreateBuilder(new NormalizingEstimator.MinMaxColumnOptions(
args.Columns[icol].Name,
args.Columns[icol].Source ?? args.Columns[icol].Name,
args.Columns[icol].OutputName,
args.Columns[icol].InputName ?? args.Columns[icol].OutputName,
args.Columns[icol].MaximumExampleCount ?? args.MaximumExampleCount,
args.Columns[icol].EnsureZeroUntouched ?? args.EnsureZeroUntouched), host, srcIndex, srcType, cursor);
}
Expand Down Expand Up @@ -963,8 +963,8 @@ public static IColumnFunctionBuilder CreateBuilder(MeanVarArguments args, IHost
host.AssertValue(args);

return CreateBuilder(new NormalizingEstimator.MeanVarianceColumnOptions(
args.Columns[icol].Name,
args.Columns[icol].Source ?? args.Columns[icol].Name,
args.Columns[icol].OutputName,
args.Columns[icol].InputName ?? args.Columns[icol].OutputName,
args.Columns[icol].MaximumExampleCount ?? args.MaximumExampleCount,
args.Columns[icol].EnsureZeroUntouched ?? args.EnsureZeroUntouched,
args.UseCdf), host, srcIndex, srcType, cursor);
Expand Down Expand Up @@ -1003,8 +1003,8 @@ public static IColumnFunctionBuilder CreateBuilder(LogMeanVarArguments args, IHo
host.AssertValue(args);

return CreateBuilder(new NormalizingEstimator.LogMeanVarianceColumnOptions(
args.Columns[icol].Name,
args.Columns[icol].Source ?? args.Columns[icol].Name,
args.Columns[icol].OutputName,
args.Columns[icol].InputName ?? args.Columns[icol].OutputName,
args.Columns[icol].MaximumExampleCount ?? args.MaximumExampleCount,
args.UseCdf), host, srcIndex, srcType, cursor);
}
Expand Down Expand Up @@ -1043,8 +1043,8 @@ public static IColumnFunctionBuilder CreateBuilder(BinArguments args, IHost host
host.AssertValue(args);

return CreateBuilder(new NormalizingEstimator.BinningColumnOptions(
args.Columns[icol].Name,
args.Columns[icol].Source ?? args.Columns[icol].Name,
args.Columns[icol].OutputName,
args.Columns[icol].InputName ?? args.Columns[icol].OutputName,
args.Columns[icol].MaximumExampleCount ?? args.MaximumExampleCount,
args.Columns[icol].EnsureZeroUntouched ?? args.EnsureZeroUntouched,
args.Columns[icol].NumBins ?? args.NumBins), host, srcIndex, srcType, cursor);
Expand Down Expand Up @@ -1093,8 +1093,8 @@ public static IColumnFunctionBuilder CreateBuilder(SupervisedBinArguments args,

return CreateBuilder(
new NormalizingEstimator.SupervisedBinningColumOptions(
args.Columns[icol].Name,
args.Columns[icol].Source ?? args.Columns[icol].Name,
args.Columns[icol].OutputName,
args.Columns[icol].InputName ?? args.Columns[icol].OutputName,
args.LabelColumn ?? DefaultColumnNames.Label,
args.Columns[icol].MaximumExampleCount ?? args.MaximumExampleCount,
args.Columns[icol].EnsureZeroUntouched ?? args.EnsureZeroUntouched,
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Data/Transforms/NormalizeUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,8 @@ public static CommonOutputs.TransformOutput Bin(IHostEnvironment env, NormalizeT
var columnsToNormalize = new List<NormalizeTransform.AffineColumn>();
foreach (var column in input.Columns)
{
if (!schema.TryGetColumnIndex(column.Source, out int col))
throw env.ExceptUserArg(nameof(input.Columns), $"Column '{column.Source}' does not exist.");
if (!schema.TryGetColumnIndex(column.InputName, out int col))
throw env.ExceptUserArg(nameof(input.Columns), $"Column '{column.InputName}' does not exist.");
if (!schema[col].IsNormalized())
columnsToNormalize.Add(column);
}
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Data/Transforms/SlotsDroppingTransformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -221,9 +221,9 @@ public ColumnOptions(string name, string inputColumnName = null, params (int min

internal ColumnOptions(Column column)
{
Name = column.Name;
Name = column.OutputName;
Contracts.CheckValue(Name, nameof(Name));
InputColumnName = column.Source ?? column.Name;
InputColumnName = column.InputName ?? column.OutputName;
Contracts.CheckValue(InputColumnName, nameof(InputColumnName));
Slots = column.Slots.Select(range => (range.Min, range.Max)).ToArray();
foreach (var (min, max) in Slots)
Expand Down
14 changes: 7 additions & 7 deletions src/Microsoft.ML.Data/Transforms/TransformBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -333,19 +333,19 @@ public static Bindings Create(OneToOneTransformBase parent, OneToOneColumn[] col
for (int i = 0; i < names.Length; i++)
{
var item = column[i];
host.CheckUserArg(item.TrySanitize(), nameof(OneToOneColumn.Name), "Invalid new column name");
names[i] = item.Name;
host.CheckUserArg(item.TrySanitize(), nameof(OneToOneColumn.OutputName), "Invalid new column name");
names[i] = item.OutputName;

int colSrc;
if (!inputSchema.TryGetColumnIndex(item.Source, out colSrc))
throw host.ExceptUserArg(nameof(OneToOneColumn.Source), "Source column '{0}' not found", item.Source);
if (!inputSchema.TryGetColumnIndex(item.InputName, out colSrc))
throw host.ExceptUserArg(nameof(OneToOneColumn.InputName), "Source column '{0}' not found", item.InputName);

var type = inputSchema[colSrc].Type;
if (testType != null)
{
string reason = testType(type);
if (reason != null)
throw host.ExceptUserArg(nameof(OneToOneColumn.Source), InvalidTypeErrorFormat, item.Source, type, reason);
throw host.ExceptUserArg(nameof(OneToOneColumn.InputName), InvalidTypeErrorFormat, item.InputName, type, reason);
}

var slotType = transposedInput?.GetSlotType(i);
Expand Down Expand Up @@ -541,8 +541,8 @@ private protected OneToOneTransformBase(IHostEnvironment env, string name, OneTo
OneToOneColumn[] map = transform.Infos
.Select(x => new ColumnTmp
{
Name = x.Name,
Source = transform.Source.Schema[x.Source].Name,
OutputName = x.Name,
InputName = transform.Source.Schema[x.Source].Name,
})
.ToArray();

Expand Down
Loading