Skip to content

Commit f68a8bc

Browse files
committed
Commeeting the code
Breaking down the main file into 3 parts Expanding the range for the numerical value checks, if that range is smaller than 0.0001, to help with fluctuation of tests across OS
1 parent 7b6c140 commit f68a8bc

File tree

5 files changed

+356
-305
lines changed

5 files changed

+356
-305
lines changed

src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachine.cs renamed to src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs

Lines changed: 7 additions & 300 deletions
Original file line numberDiff line numberDiff line change
@@ -20,60 +20,16 @@
2020
new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer) }, FieldAwareFactorizationMachineTrainer.UserName, FieldAwareFactorizationMachineTrainer.LoadName,
2121
FieldAwareFactorizationMachineTrainer.ShortName, DocName = "trainer/FactorizationMachine.md")]
2222

23-
[assembly: LoadableClass(typeof(FieldAwareFactorizationMachinePredictor), null, typeof(SignatureLoadModel), "Field Aware Factorization Machine", FieldAwareFactorizationMachinePredictor.LoaderSignature)]
24-
2523
[assembly: LoadableClass(typeof(void), typeof(FieldAwareFactorizationMachineTrainer), null, typeof(SignatureEntryPointModule), FieldAwareFactorizationMachineTrainer.LoadName)]
2624

2725
namespace Microsoft.ML.Runtime.FactorizationMachine
2826
{
29-
internal sealed class FieldAwareFactorizationMachineUtils
30-
{
31-
internal static int GetAlignedVectorLength(int length)
32-
{
33-
int res = length % 4;
34-
if (res == 0)
35-
return length;
36-
else
37-
return length + (4 - res);
38-
}
39-
40-
internal static bool LoadOneExampleIntoBuffer(ValueGetter<VBuffer<float>>[] getters, VBuffer<float> featureBuffer, bool norm, ref int count,
41-
int[] fieldIndexBuffer, int[] featureIndexBuffer, float[] featureValueBuffer)
42-
{
43-
count = 0;
44-
float featureNorm = 0;
45-
int bias = 0;
46-
float annihilation = 0;
47-
for (int f = 0; f < getters.Length; f++)
48-
{
49-
getters[f](ref featureBuffer);
50-
foreach (var pair in featureBuffer.Items())
51-
{
52-
fieldIndexBuffer[count] = f;
53-
featureIndexBuffer[count] = bias + pair.Key;
54-
featureValueBuffer[count] = pair.Value;
55-
featureNorm += pair.Value * pair.Value;
56-
annihilation += pair.Value - pair.Value;
57-
count++;
58-
}
59-
bias += featureBuffer.Length;
60-
}
61-
featureNorm = MathUtils.Sqrt(featureNorm);
62-
if (norm)
63-
{
64-
for (int i = 0; i < count; i++)
65-
featureValueBuffer[i] /= featureNorm;
66-
}
67-
return FloatUtils.IsFinite(annihilation);
68-
}
69-
}
70-
7127
/// <summary>
7228
/// Train a field-aware factorization machine using ADAGRAD (an advanced stochastic gradient method). See references below
7329
/// for details. This trainer is essentially faster the one introduced in [2] because of some implemtation tricks[3].
7430
/// [1] http://jmlr.org/papers/volume12/duchi11a/duchi11a.pdf
7531
/// [2] http://www.csie.ntu.edu.tw/~cjlin/papers/ffm.pdf
76-
/// [3] fast-ffm.tex in FactorizationMachine project folder
32+
/// [3] https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf
7733
/// </summary>
7834
public sealed class FieldAwareFactorizationMachineTrainer : TrainerBase<RoleMappedData, FieldAwareFactorizationMachinePredictor>,
7935
IIncrementalTrainer<RoleMappedData, FieldAwareFactorizationMachinePredictor>, IValidatingTrainer<RoleMappedData>,
@@ -327,6 +283,8 @@ private void TrainCore(IChannel ch, IProgressChannel pch, RoleMappedData data, R
327283
Func<int, bool> pred = c => fieldColumnIndexes.Contains(c) || c == data.Schema.Label.Index || (data.Schema.Weight != null && c == data.Schema.Weight.Index);
328284
InitializeTrainingState(fieldCount, totalFeatureCount, predictor, out float[] linearWeights,
329285
out AlignedArray latentWeightsAligned, out float[] linearAccSqGrads, out AlignedArray latentAccSqGradsAligned);
286+
287+
// refer to Algorithm 3 in https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf
330288
while (iter++ < _numIterations)
331289
{
332290
using (var cursor = data.Data.GetRowCursor(pred, rng))
@@ -358,9 +316,13 @@ private void TrainCore(IChannel ch, IProgressChannel pch, RoleMappedData data, R
358316
badExampleCount++;
359317
continue;
360318
}
319+
320+
// refer to Algorithm 1 in [3] https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf
361321
FieldAwareFactorizationMachineInterface.CalculateIntermediateVariables(fieldCount, _latentDimAligned, count,
362322
featureFieldBuffer, featureIndexBuffer, featureValueBuffer, linearWeights, latentWeightsAligned, latentSum, ref modelResponse);
363323
var slope = CalculateLossSlope(label, modelResponse);
324+
325+
// refer to Algorithm 2 in [3] https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf
364326
FieldAwareFactorizationMachineInterface.CalculateGradientAndUpdate(_lambdaLinear, _lambdaLatent, _learningRate, fieldCount, _latentDimAligned, weight, count,
365327
featureFieldBuffer, featureIndexBuffer, featureValueBuffer, latentSum, slope, linearWeights, latentWeightsAligned, linearAccSqGrads, latentAccSqGradsAligned);
366328
loss += weight * CalculateLoss(label, modelResponse);
@@ -453,259 +415,4 @@ public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironm
453415
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn));
454416
}
455417
}
456-
457-
public sealed class FieldAwareFactorizationMachinePredictor : PredictorBase<float>, ISchemaBindableMapper, ICanSaveModel
458-
{
459-
public const string LoaderSignature = "FieldAwareFactMacPredict";
460-
public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
461-
private bool _norm;
462-
internal int FieldCount { get; }
463-
internal int FeatureCount { get; }
464-
internal int LatentDim { get; }
465-
internal int LatentDimAligned { get; }
466-
private readonly float[] _linearWeights;
467-
private readonly AlignedArray _latentWeightsAligned;
468-
469-
private static VersionInfo GetVersionInfo()
470-
{
471-
return new VersionInfo(
472-
modelSignature: "FAFAMAPD",
473-
verWrittenCur: 0x00010001,
474-
verReadableCur: 0x00010001,
475-
verWeCanReadBack: 0x00010001,
476-
loaderSignature: LoaderSignature);
477-
}
478-
479-
internal FieldAwareFactorizationMachinePredictor(IHostEnvironment env, bool norm, int fieldCount, int featureCount, int latentDim,
480-
float[] linearWeights, AlignedArray latentWeightsAligned) : base(env, LoaderSignature)
481-
{
482-
Host.Assert(fieldCount > 0);
483-
Host.Assert(featureCount > 0);
484-
Host.Assert(latentDim > 0);
485-
Host.Assert(Utils.Size(linearWeights) == featureCount);
486-
LatentDimAligned = FieldAwareFactorizationMachineUtils.GetAlignedVectorLength(latentDim);
487-
Host.Assert(latentWeightsAligned.Size == checked(featureCount * fieldCount * LatentDimAligned));
488-
489-
_norm = norm;
490-
FieldCount = fieldCount;
491-
FeatureCount = featureCount;
492-
LatentDim = latentDim;
493-
_linearWeights = linearWeights;
494-
_latentWeightsAligned = latentWeightsAligned;
495-
}
496-
497-
private FieldAwareFactorizationMachinePredictor(IHostEnvironment env, ModelLoadContext ctx) : base(env, LoaderSignature)
498-
{
499-
Host.AssertValue(ctx);
500-
501-
// *** Binary format ***
502-
// bool: whether to normalize feature vectors
503-
// int: number of fields
504-
// int: number of features
505-
// int: latent dimension
506-
// float[]: linear coefficients
507-
// float[]: latent representation of features
508-
509-
var norm = ctx.Reader.ReadBoolean();
510-
var fieldCount = ctx.Reader.ReadInt32();
511-
Host.CheckDecode(fieldCount > 0);
512-
var featureCount = ctx.Reader.ReadInt32();
513-
Host.CheckDecode(featureCount > 0);
514-
var latentDim = ctx.Reader.ReadInt32();
515-
Host.CheckDecode(latentDim > 0);
516-
LatentDimAligned = FieldAwareFactorizationMachineUtils.GetAlignedVectorLength(latentDim);
517-
Host.Check(checked(featureCount * fieldCount * LatentDimAligned) <= Utils.ArrayMaxSize, "Latent dimension too large");
518-
var linearWeights = ctx.Reader.ReadFloatArray();
519-
Host.CheckDecode(Utils.Size(linearWeights) == featureCount);
520-
var latentWeights = ctx.Reader.ReadFloatArray();
521-
Host.CheckDecode(Utils.Size(latentWeights) == featureCount * fieldCount * latentDim);
522-
523-
_norm = norm;
524-
FieldCount = fieldCount;
525-
FeatureCount = featureCount;
526-
LatentDim = latentDim;
527-
_linearWeights = linearWeights;
528-
_latentWeightsAligned = new AlignedArray(FeatureCount * FieldCount * LatentDimAligned, 16);
529-
for (int j = 0; j < FeatureCount; j++)
530-
{
531-
for (int f = 0; f < FieldCount; f++)
532-
{
533-
int vBias = j * FieldCount * LatentDim + f * LatentDim;
534-
int vBiasAligned = j * FieldCount * LatentDimAligned + f * LatentDimAligned;
535-
for (int k = 0; k < LatentDimAligned; k++)
536-
{
537-
if (k < LatentDim)
538-
_latentWeightsAligned[vBiasAligned + k] = latentWeights[vBias + k];
539-
else
540-
_latentWeightsAligned[vBiasAligned + k] = 0;
541-
}
542-
}
543-
}
544-
}
545-
546-
public static FieldAwareFactorizationMachinePredictor Create(IHostEnvironment env, ModelLoadContext ctx)
547-
{
548-
Contracts.CheckValue(env, nameof(env));
549-
env.CheckValue(ctx, nameof(ctx));
550-
ctx.CheckAtModel(GetVersionInfo());
551-
return new FieldAwareFactorizationMachinePredictor(env, ctx);
552-
}
553-
554-
protected override void SaveCore(ModelSaveContext ctx)
555-
{
556-
Host.AssertValue(ctx);
557-
ctx.SetVersionInfo(GetVersionInfo());
558-
559-
// *** Binary format ***
560-
// bool: whether to normalize feature vectors
561-
// int: number of fields
562-
// int: number of features
563-
// int: latent dimension
564-
// float[]: linear coefficients
565-
// float[]: latent representation of features
566-
567-
Host.Assert(FieldCount > 0);
568-
Host.Assert(FeatureCount > 0);
569-
Host.Assert(LatentDim > 0);
570-
Host.Assert(Utils.Size(_linearWeights) == FeatureCount);
571-
Host.Assert(_latentWeightsAligned.Size == FeatureCount * FieldCount * LatentDimAligned);
572-
573-
ctx.Writer.Write(_norm);
574-
ctx.Writer.Write(FieldCount);
575-
ctx.Writer.Write(FeatureCount);
576-
ctx.Writer.Write(LatentDim);
577-
ctx.Writer.WriteFloatArray(_linearWeights);
578-
float[] latentWeights = new float[FeatureCount * FieldCount * LatentDim];
579-
for (int j = 0; j < FeatureCount; j++)
580-
{
581-
for (int f = 0; f < FieldCount; f++)
582-
{
583-
int vBias = j * FieldCount * LatentDim + f * LatentDim;
584-
int vBiasAligned = j * FieldCount * LatentDimAligned + f * LatentDimAligned;
585-
for (int k = 0; k < LatentDim; k++)
586-
latentWeights[vBias + k] = _latentWeightsAligned[vBiasAligned + k];
587-
}
588-
}
589-
ctx.Writer.WriteFloatArray(latentWeights);
590-
}
591-
592-
internal float CalculateResponse(ValueGetter<VBuffer<float>>[] getters, VBuffer<float> featureBuffer,
593-
int[] featureFieldBuffer, int[] featureIndexBuffer, float[] featureValueBuffer, AlignedArray latentSum)
594-
{
595-
int count = 0;
596-
float modelResponse = 0;
597-
FieldAwareFactorizationMachineUtils.LoadOneExampleIntoBuffer(getters, featureBuffer, _norm, ref count,
598-
featureFieldBuffer, featureIndexBuffer, featureValueBuffer);
599-
FieldAwareFactorizationMachineInterface.CalculateIntermediateVariables(FieldCount, LatentDimAligned, count,
600-
featureFieldBuffer, featureIndexBuffer, featureValueBuffer, _linearWeights, _latentWeightsAligned, latentSum, ref modelResponse);
601-
return modelResponse;
602-
}
603-
604-
public ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema)
605-
{
606-
return new FieldAwareFactorizationMachineScalarRowMapper(env, schema, new BinaryClassifierSchema(), this);
607-
}
608-
609-
internal void CopyLinearWeightsTo(float[] linearWeights)
610-
{
611-
Host.AssertValue(_linearWeights);
612-
Host.AssertValue(linearWeights);
613-
Array.Copy(_linearWeights, linearWeights, _linearWeights.Length);
614-
}
615-
616-
internal void CopyLatentWeightsTo(AlignedArray latentWeights)
617-
{
618-
Host.AssertValue(_latentWeightsAligned);
619-
Host.AssertValue(latentWeights);
620-
latentWeights.CopyFrom(_latentWeightsAligned);
621-
}
622-
}
623-
624-
internal sealed class FieldAwareFactorizationMachineScalarRowMapper : ISchemaBoundRowMapper
625-
{
626-
private readonly FieldAwareFactorizationMachinePredictor _pred;
627-
628-
public RoleMappedSchema InputSchema { get; }
629-
630-
public ISchema OutputSchema { get; }
631-
632-
public ISchemaBindableMapper Bindable => _pred;
633-
634-
private readonly ColumnInfo[] _columns;
635-
private readonly List<int> _inputColumnIndexes;
636-
private readonly IHostEnvironment _env;
637-
638-
public FieldAwareFactorizationMachineScalarRowMapper(IHostEnvironment env, RoleMappedSchema schema,
639-
ISchema outputSchema, FieldAwareFactorizationMachinePredictor pred)
640-
{
641-
Contracts.AssertValue(env);
642-
Contracts.AssertValue(schema);
643-
Contracts.CheckParam(outputSchema.ColumnCount == 2, nameof(outputSchema));
644-
Contracts.CheckParam(outputSchema.GetColumnType(0).IsNumber, nameof(outputSchema));
645-
Contracts.CheckParam(outputSchema.GetColumnType(1).IsNumber, nameof(outputSchema));
646-
Contracts.AssertValue(pred);
647-
648-
_env = env;
649-
_columns = schema.GetColumns(RoleMappedSchema.ColumnRole.Feature).ToArray();
650-
_pred = pred;
651-
652-
var inputFeatureColumns = _columns.Select(c => new KeyValuePair<RoleMappedSchema.ColumnRole, string>(RoleMappedSchema.ColumnRole.Feature, c.Name)).ToList();
653-
InputSchema = RoleMappedSchema.Create(schema.Schema, inputFeatureColumns);
654-
OutputSchema = outputSchema;
655-
656-
_inputColumnIndexes = new List<int>();
657-
foreach (var kvp in inputFeatureColumns)
658-
{
659-
if (schema.Schema.TryGetColumnIndex(kvp.Value, out int index))
660-
_inputColumnIndexes.Add(index);
661-
}
662-
}
663-
664-
public IRow GetOutputRow(IRow input, Func<int, bool> predicate, out Action action)
665-
{
666-
var latentSum = new AlignedArray(_pred.FieldCount * _pred.FieldCount * _pred.LatentDimAligned, 16);
667-
var featureBuffer = new VBuffer<float>();
668-
var featureFieldBuffer = new int[_pred.FeatureCount];
669-
var featureIndexBuffer = new int[_pred.FeatureCount];
670-
var featureValueBuffer = new float[_pred.FeatureCount];
671-
var inputGetters = new ValueGetter<VBuffer<float>>[_pred.FieldCount];
672-
for (int f = 0; f < _pred.FieldCount; f++)
673-
inputGetters[f] = input.GetGetter<VBuffer<float>>(_inputColumnIndexes[f]);
674-
675-
action = null;
676-
var getters = new Delegate[2];
677-
if (predicate(0))
678-
{
679-
ValueGetter<float> responseGetter = (ref float value) =>
680-
{
681-
value = _pred.CalculateResponse(inputGetters, featureBuffer, featureFieldBuffer, featureIndexBuffer, featureValueBuffer, latentSum);
682-
};
683-
getters[0] = responseGetter;
684-
}
685-
if (predicate(1))
686-
{
687-
ValueGetter<float> probGetter = (ref float value) =>
688-
{
689-
value = _pred.CalculateResponse(inputGetters, featureBuffer, featureFieldBuffer, featureIndexBuffer, featureValueBuffer, latentSum);
690-
value = MathUtils.SigmoidSlow(value);
691-
};
692-
getters[1] = probGetter;
693-
}
694-
695-
return new SimpleRow(OutputSchema, input, getters);
696-
}
697-
698-
public Func<int, bool> GetDependencies(Func<int, bool> predicate)
699-
{
700-
if (Enumerable.Range(0, OutputSchema.ColumnCount).Any(predicate))
701-
return index => _inputColumnIndexes.Any(c => c == index);
702-
else
703-
return index => false;
704-
}
705-
706-
public IEnumerable<KeyValuePair<RoleMappedSchema.ColumnRole, string>> GetInputColumnRoles()
707-
{
708-
return InputSchema.GetColumnRoles().Select(kvp => new KeyValuePair<RoleMappedSchema.ColumnRole, string>(kvp.Key, kvp.Value.Name));
709-
}
710-
}
711418
}

0 commit comments

Comments
 (0)