Skip to content

Converted normalizers to be estimators #797

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Sep 4, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
385 changes: 192 additions & 193 deletions src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs

Large diffs are not rendered by default.

84 changes: 38 additions & 46 deletions src/Microsoft.ML.Data/Transforms/NormalizeColumnDbl.cs
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ private void Update(int j, TFloat origVal)

public sealed partial class NormalizeTransform
{
public abstract partial class AffineColumnFunction
internal abstract partial class AffineColumnFunction
{
public static IColumnFunction Create(IHost host, TFloat scale, TFloat offset)
{
Expand Down Expand Up @@ -846,7 +846,7 @@ private static void FillValues(ref VBuffer<TFloat> input, BufferBuilder<TFloat>
}
}

public abstract partial class CdfColumnFunction
internal abstract partial class CdfColumnFunction
{
public static IColumnFunction Create(IHost host, TFloat mean, TFloat stddev, bool useLog)
{
Expand Down Expand Up @@ -1021,7 +1021,7 @@ private static void FillValues(ref VBuffer<TFloat> input, BufferBuilder<TFloat>
}
}

public abstract partial class BinColumnFunction
internal abstract partial class BinColumnFunction
{
public static IColumnFunction Create(IHost host, TFloat[] binUpperBounds, bool fixZero)
{
Expand Down Expand Up @@ -1255,7 +1255,7 @@ private void GetResult(ref VBuffer<TFloat> input, ref VBuffer<TFloat> value, Buf
}
}

private static partial class MinMaxUtils
internal static partial class MinMaxUtils
{
public static void ComputeScaleAndOffset(bool fixZero, TFloat max, TFloat min, out TFloat scale, out TFloat offset)
{
Expand Down Expand Up @@ -1312,7 +1312,7 @@ private static void ComputeScaleAndOffsetFixZero(TFloat max, TFloat min, out TFl
}
}

private static partial class MeanVarUtils
internal static partial class MeanVarUtils
{
public static void ComputeScaleAndOffset(Double mean, Double stddev, out TFloat scale, out TFloat offset)
{
Expand Down Expand Up @@ -1364,7 +1364,7 @@ public static TFloat Cdf(TFloat input, TFloat mean, TFloat stddev)
}
}

private static partial class BinUtils
internal static partial class BinUtils
{
public static TFloat GetValue(ref TFloat input, TFloat[] binUpperBounds, TFloat den, TFloat offset)
{
Expand Down Expand Up @@ -1422,13 +1422,11 @@ private MinMaxOneColumnFunctionBuilder(IHost host, long lim, bool fix, ValueGett
{
}

public static IColumnFunctionBuilder Create(MinMaxArguments args, IHost host, int icol, ColumnType srcType,
public static IColumnFunctionBuilder Create(Normalizer.MinMaxColumn column, IHost host, ColumnType srcType,
ValueGetter<TFloat> getter)
{
var lim = args.Column[icol].MaxTrainingExamples ?? args.MaxTrainingExamples;
host.CheckUserArg(lim > 1, nameof(args.MaxTrainingExamples), "Must be greater than 1");
bool fix = args.Column[icol].FixZero ?? args.FixZero;
return new MinMaxOneColumnFunctionBuilder(host, lim, fix, getter);
host.CheckUserArg(column.MaxTrainingExamples > 1, nameof(column.MaxTrainingExamples), "Must be greater than 1");
return new MinMaxOneColumnFunctionBuilder(host, column.MaxTrainingExamples, column.FixZero, getter);
}

public override IColumnFunction CreateColumnFunction()
Expand Down Expand Up @@ -1474,14 +1472,12 @@ private MinMaxVecColumnFunctionBuilder(IHost host, int cv, long lim, bool fix,
{
}

public static IColumnFunctionBuilder Create(MinMaxArguments args, IHost host, int icol, ColumnType srcType,
public static IColumnFunctionBuilder Create(Normalizer.MinMaxColumn column, IHost host, ColumnType srcType,
ValueGetter<VBuffer<TFloat>> getter)
{
var lim = args.Column[icol].MaxTrainingExamples ?? args.MaxTrainingExamples;
host.CheckUserArg(lim > 1, nameof(args.MaxTrainingExamples), "Must be greater than 1");
bool fix = args.Column[icol].FixZero ?? args.FixZero;
host.CheckUserArg(column.MaxTrainingExamples > 1, nameof(column.MaxTrainingExamples), "Must be greater than 1");
var cv = srcType.ValueCount;
return new MinMaxVecColumnFunctionBuilder(host, cv, lim, fix, getter);
return new MinMaxVecColumnFunctionBuilder(host, cv, column.MaxTrainingExamples, column.FixZero, getter);
}

public override IColumnFunction CreateColumnFunction()
Expand Down Expand Up @@ -1538,21 +1534,19 @@ private MeanVarOneColumnFunctionBuilder(IHost host, long lim, bool fix, ValueGet
_buffer = new VBuffer<TFloat>(1, new TFloat[1]);
}

public static IColumnFunctionBuilder Create(MeanVarArguments args, IHost host, int icol, ColumnType srcType,
public static IColumnFunctionBuilder Create(Normalizer.MeanVarColumn column, IHost host, ColumnType srcType,
ValueGetter<TFloat> getter)
{
var lim = args.Column[icol].MaxTrainingExamples ?? args.MaxTrainingExamples;
host.CheckUserArg(lim > 1, nameof(args.MaxTrainingExamples), "Must be greater than 1");
bool fix = args.Column[icol].FixZero ?? args.FixZero;
return new MeanVarOneColumnFunctionBuilder(host, lim, fix, getter, false, args.UseCdf);
host.CheckUserArg(column.MaxTrainingExamples > 1, nameof(column.MaxTrainingExamples), "Must be greater than 1");
return new MeanVarOneColumnFunctionBuilder(host, column.MaxTrainingExamples, column.FixZero, getter, false, column.UseCdf);
}

public static IColumnFunctionBuilder Create(LogMeanVarArguments args, IHost host, int icol, ColumnType srcType,
public static IColumnFunctionBuilder Create(Normalizer.LogMeanVarColumn column, IHost host, ColumnType srcType,
ValueGetter<TFloat> getter)
{
var lim = args.Column[icol].MaxTrainingExamples ?? args.MaxTrainingExamples;
host.CheckUserArg(lim > 1, nameof(args.MaxTrainingExamples), "Must be greater than 1");
return new MeanVarOneColumnFunctionBuilder(host, lim, false, getter, true, args.UseCdf);
var lim = column.MaxTrainingExamples;
host.CheckUserArg(lim > 1, nameof(column.MaxTrainingExamples), "Must be greater than 1");
return new MeanVarOneColumnFunctionBuilder(host, lim, false, getter, true, column.UseCdf);
}

protected override bool ProcessValue(ref TFloat origVal)
Expand Down Expand Up @@ -1614,23 +1608,21 @@ private MeanVarVecColumnFunctionBuilder(IHost host, int cv, long lim, bool fix,
_useCdf = useCdf;
}

public static IColumnFunctionBuilder Create(MeanVarArguments args, IHost host, int icol, ColumnType srcType,
public static IColumnFunctionBuilder Create(Normalizer.MeanVarColumn column, IHost host, ColumnType srcType,
ValueGetter<VBuffer<TFloat>> getter)
{
var lim = args.Column[icol].MaxTrainingExamples ?? args.MaxTrainingExamples;
host.CheckUserArg(lim > 1, nameof(args.MaxTrainingExamples), "Must be greater than 1");
bool fix = args.Column[icol].FixZero ?? args.FixZero;
host.CheckUserArg(column.MaxTrainingExamples > 1, nameof(column.MaxTrainingExamples), "Must be greater than 1");
var cv = srcType.ValueCount;
return new MeanVarVecColumnFunctionBuilder(host, cv, lim, fix, getter, false, args.UseCdf);
return new MeanVarVecColumnFunctionBuilder(host, cv, column.MaxTrainingExamples, column.FixZero, getter, false, column.UseCdf);
}

public static IColumnFunctionBuilder Create(LogMeanVarArguments args, IHost host, int icol, ColumnType srcType,
public static IColumnFunctionBuilder Create(Normalizer.LogMeanVarColumn column, IHost host, ColumnType srcType,
ValueGetter<VBuffer<TFloat>> getter)
{
var lim = args.Column[icol].MaxTrainingExamples ?? args.MaxTrainingExamples;
host.CheckUserArg(lim > 1, nameof(args.MaxTrainingExamples), "Must be greater than 1");
var lim = column.MaxTrainingExamples;
host.CheckUserArg(lim > 1, nameof(column.MaxTrainingExamples), "Must be greater than 1");
var cv = srcType.ValueCount;
return new MeanVarVecColumnFunctionBuilder(host, cv, lim, false, getter, true, args.UseCdf);
return new MeanVarVecColumnFunctionBuilder(host, cv, lim, false, getter, true, column.UseCdf);
}

protected override bool ProcessValue(ref VBuffer<TFloat> buffer)
Expand Down Expand Up @@ -1732,14 +1724,14 @@ private BinOneColumnFunctionBuilder(IHost host, long lim, bool fix, int numBins,
_values = new List<TFloat>();
}

public static IColumnFunctionBuilder Create(BinArguments args, IHost host, int icol, ColumnType srcType,
public static IColumnFunctionBuilder Create(Normalizer.BinningColumn column, IHost host, ColumnType srcType,
ValueGetter<TFloat> getter)
{
var lim = args.Column[icol].MaxTrainingExamples ?? args.MaxTrainingExamples;
host.CheckUserArg(lim > 1, nameof(args.MaxTrainingExamples), "Must be greater than 1");
bool fix = args.Column[icol].FixZero ?? args.FixZero;
var numBins = args.Column[icol].NumBins ?? args.NumBins;
host.CheckUserArg(numBins > 1, nameof(args.NumBins), "Must be greater than 1");
var lim = column.MaxTrainingExamples;
host.CheckUserArg(lim > 1, nameof(column.MaxTrainingExamples), "Must be greater than 1");
bool fix = column.FixZero;
var numBins = column.NumBins;
host.CheckUserArg(numBins > 1, nameof(column.NumBins), "Must be greater than 1");
return new BinOneColumnFunctionBuilder(host, lim, fix, numBins, getter);
}

Expand Down Expand Up @@ -1781,14 +1773,14 @@ private BinVecColumnFunctionBuilder(IHost host, int cv, long lim, bool fix, int
}
}

public static IColumnFunctionBuilder Create(BinArguments args, IHost host, int icol, ColumnType srcType,
public static IColumnFunctionBuilder Create(Normalizer.BinningColumn column, IHost host, ColumnType srcType,
ValueGetter<VBuffer<TFloat>> getter)
{
var lim = args.Column[icol].MaxTrainingExamples ?? args.MaxTrainingExamples;
host.CheckUserArg(lim > 1, nameof(args.MaxTrainingExamples), "Must be greater than 1");
bool fix = args.Column[icol].FixZero ?? args.FixZero;
var numBins = args.Column[icol].NumBins ?? args.NumBins;
host.CheckUserArg(numBins > 1, nameof(args.NumBins), "Must be greater than 1");
var lim = column.MaxTrainingExamples;
host.CheckUserArg(lim > 1, nameof(column.MaxTrainingExamples), "Must be greater than 1");
bool fix = column.FixZero;
var numBins = column.NumBins;
host.CheckUserArg(numBins > 1, nameof(column.NumBins), "numBins must be greater than 1");
var cv = srcType.ValueCount;
return new BinVecColumnFunctionBuilder(host, cv, lim, fix, numBins, getter);
}
Expand Down
Loading