|
20 | 20 | new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer) }, FieldAwareFactorizationMachineTrainer.UserName, FieldAwareFactorizationMachineTrainer.LoadName,
|
21 | 21 | FieldAwareFactorizationMachineTrainer.ShortName, DocName = "trainer/FactorizationMachine.md")]
|
22 | 22 |
|
23 |
| -[assembly: LoadableClass(typeof(FieldAwareFactorizationMachinePredictor), null, typeof(SignatureLoadModel), "Field Aware Factorization Machine", FieldAwareFactorizationMachinePredictor.LoaderSignature)] |
24 |
| - |
25 | 23 | [assembly: LoadableClass(typeof(void), typeof(FieldAwareFactorizationMachineTrainer), null, typeof(SignatureEntryPointModule), FieldAwareFactorizationMachineTrainer.LoadName)]
|
26 | 24 |
|
27 | 25 | namespace Microsoft.ML.Runtime.FactorizationMachine
|
28 | 26 | {
|
29 |
| - internal sealed class FieldAwareFactorizationMachineUtils |
30 |
| - { |
31 |
| - internal static int GetAlignedVectorLength(int length) |
32 |
| - { |
33 |
| - int res = length % 4; |
34 |
| - if (res == 0) |
35 |
| - return length; |
36 |
| - else |
37 |
| - return length + (4 - res); |
38 |
| - } |
39 |
| - |
40 |
| - internal static bool LoadOneExampleIntoBuffer(ValueGetter<VBuffer<float>>[] getters, VBuffer<float> featureBuffer, bool norm, ref int count, |
41 |
| - int[] fieldIndexBuffer, int[] featureIndexBuffer, float[] featureValueBuffer) |
42 |
| - { |
43 |
| - count = 0; |
44 |
| - float featureNorm = 0; |
45 |
| - int bias = 0; |
46 |
| - float annihilation = 0; |
47 |
| - for (int f = 0; f < getters.Length; f++) |
48 |
| - { |
49 |
| - getters[f](ref featureBuffer); |
50 |
| - foreach (var pair in featureBuffer.Items()) |
51 |
| - { |
52 |
| - fieldIndexBuffer[count] = f; |
53 |
| - featureIndexBuffer[count] = bias + pair.Key; |
54 |
| - featureValueBuffer[count] = pair.Value; |
55 |
| - featureNorm += pair.Value * pair.Value; |
56 |
| - annihilation += pair.Value - pair.Value; |
57 |
| - count++; |
58 |
| - } |
59 |
| - bias += featureBuffer.Length; |
60 |
| - } |
61 |
| - featureNorm = MathUtils.Sqrt(featureNorm); |
62 |
| - if (norm) |
63 |
| - { |
64 |
| - for (int i = 0; i < count; i++) |
65 |
| - featureValueBuffer[i] /= featureNorm; |
66 |
| - } |
67 |
| - return FloatUtils.IsFinite(annihilation); |
68 |
| - } |
69 |
| - } |
70 |
| - |
71 | 27 | /// <summary>
|
72 | 28 | /// Train a field-aware factorization machine using ADAGRAD (an advanced stochastic gradient method). See references below
|
73 | 29 | /// for details. This trainer is essentially faster the one introduced in [2] because of some implemtation tricks[3].
|
74 | 30 | /// [1] http://jmlr.org/papers/volume12/duchi11a/duchi11a.pdf
|
75 | 31 | /// [2] http://www.csie.ntu.edu.tw/~cjlin/papers/ffm.pdf
|
76 |
| - /// [3] fast-ffm.tex in FactorizationMachine project folder |
| 32 | + /// [3] https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf |
77 | 33 | /// </summary>
|
78 | 34 | public sealed class FieldAwareFactorizationMachineTrainer : TrainerBase<RoleMappedData, FieldAwareFactorizationMachinePredictor>,
|
79 | 35 | IIncrementalTrainer<RoleMappedData, FieldAwareFactorizationMachinePredictor>, IValidatingTrainer<RoleMappedData>,
|
@@ -327,6 +283,8 @@ private void TrainCore(IChannel ch, IProgressChannel pch, RoleMappedData data, R
|
327 | 283 | Func<int, bool> pred = c => fieldColumnIndexes.Contains(c) || c == data.Schema.Label.Index || (data.Schema.Weight != null && c == data.Schema.Weight.Index);
|
328 | 284 | InitializeTrainingState(fieldCount, totalFeatureCount, predictor, out float[] linearWeights,
|
329 | 285 | out AlignedArray latentWeightsAligned, out float[] linearAccSqGrads, out AlignedArray latentAccSqGradsAligned);
|
| 286 | + |
| 287 | + // refer to Algorithm 3 in https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf |
330 | 288 | while (iter++ < _numIterations)
|
331 | 289 | {
|
332 | 290 | using (var cursor = data.Data.GetRowCursor(pred, rng))
|
@@ -358,9 +316,13 @@ private void TrainCore(IChannel ch, IProgressChannel pch, RoleMappedData data, R
|
358 | 316 | badExampleCount++;
|
359 | 317 | continue;
|
360 | 318 | }
|
| 319 | + |
| 320 | + // refer to Algorithm 1 in [3] https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf |
361 | 321 | FieldAwareFactorizationMachineInterface.CalculateIntermediateVariables(fieldCount, _latentDimAligned, count,
|
362 | 322 | featureFieldBuffer, featureIndexBuffer, featureValueBuffer, linearWeights, latentWeightsAligned, latentSum, ref modelResponse);
|
363 | 323 | var slope = CalculateLossSlope(label, modelResponse);
|
| 324 | + |
| 325 | + // refer to Algorithm 2 in [3] https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf |
364 | 326 | FieldAwareFactorizationMachineInterface.CalculateGradientAndUpdate(_lambdaLinear, _lambdaLatent, _learningRate, fieldCount, _latentDimAligned, weight, count,
|
365 | 327 | featureFieldBuffer, featureIndexBuffer, featureValueBuffer, latentSum, slope, linearWeights, latentWeightsAligned, linearAccSqGrads, latentAccSqGradsAligned);
|
366 | 328 | loss += weight * CalculateLoss(label, modelResponse);
|
@@ -453,259 +415,4 @@ public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironm
|
453 | 415 | () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn));
|
454 | 416 | }
|
455 | 417 | }
|
456 |
| - |
457 |
| - public sealed class FieldAwareFactorizationMachinePredictor : PredictorBase<float>, ISchemaBindableMapper, ICanSaveModel |
458 |
| - { |
459 |
| - public const string LoaderSignature = "FieldAwareFactMacPredict"; |
460 |
| - public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; |
461 |
| - private bool _norm; |
462 |
| - internal int FieldCount { get; } |
463 |
| - internal int FeatureCount { get; } |
464 |
| - internal int LatentDim { get; } |
465 |
| - internal int LatentDimAligned { get; } |
466 |
| - private readonly float[] _linearWeights; |
467 |
| - private readonly AlignedArray _latentWeightsAligned; |
468 |
| - |
469 |
| - private static VersionInfo GetVersionInfo() |
470 |
| - { |
471 |
| - return new VersionInfo( |
472 |
| - modelSignature: "FAFAMAPD", |
473 |
| - verWrittenCur: 0x00010001, |
474 |
| - verReadableCur: 0x00010001, |
475 |
| - verWeCanReadBack: 0x00010001, |
476 |
| - loaderSignature: LoaderSignature); |
477 |
| - } |
478 |
| - |
479 |
| - internal FieldAwareFactorizationMachinePredictor(IHostEnvironment env, bool norm, int fieldCount, int featureCount, int latentDim, |
480 |
| - float[] linearWeights, AlignedArray latentWeightsAligned) : base(env, LoaderSignature) |
481 |
| - { |
482 |
| - Host.Assert(fieldCount > 0); |
483 |
| - Host.Assert(featureCount > 0); |
484 |
| - Host.Assert(latentDim > 0); |
485 |
| - Host.Assert(Utils.Size(linearWeights) == featureCount); |
486 |
| - LatentDimAligned = FieldAwareFactorizationMachineUtils.GetAlignedVectorLength(latentDim); |
487 |
| - Host.Assert(latentWeightsAligned.Size == checked(featureCount * fieldCount * LatentDimAligned)); |
488 |
| - |
489 |
| - _norm = norm; |
490 |
| - FieldCount = fieldCount; |
491 |
| - FeatureCount = featureCount; |
492 |
| - LatentDim = latentDim; |
493 |
| - _linearWeights = linearWeights; |
494 |
| - _latentWeightsAligned = latentWeightsAligned; |
495 |
| - } |
496 |
| - |
497 |
| - private FieldAwareFactorizationMachinePredictor(IHostEnvironment env, ModelLoadContext ctx) : base(env, LoaderSignature) |
498 |
| - { |
499 |
| - Host.AssertValue(ctx); |
500 |
| - |
501 |
| - // *** Binary format *** |
502 |
| - // bool: whether to normalize feature vectors |
503 |
| - // int: number of fields |
504 |
| - // int: number of features |
505 |
| - // int: latent dimension |
506 |
| - // float[]: linear coefficients |
507 |
| - // float[]: latent representation of features |
508 |
| - |
509 |
| - var norm = ctx.Reader.ReadBoolean(); |
510 |
| - var fieldCount = ctx.Reader.ReadInt32(); |
511 |
| - Host.CheckDecode(fieldCount > 0); |
512 |
| - var featureCount = ctx.Reader.ReadInt32(); |
513 |
| - Host.CheckDecode(featureCount > 0); |
514 |
| - var latentDim = ctx.Reader.ReadInt32(); |
515 |
| - Host.CheckDecode(latentDim > 0); |
516 |
| - LatentDimAligned = FieldAwareFactorizationMachineUtils.GetAlignedVectorLength(latentDim); |
517 |
| - Host.Check(checked(featureCount * fieldCount * LatentDimAligned) <= Utils.ArrayMaxSize, "Latent dimension too large"); |
518 |
| - var linearWeights = ctx.Reader.ReadFloatArray(); |
519 |
| - Host.CheckDecode(Utils.Size(linearWeights) == featureCount); |
520 |
| - var latentWeights = ctx.Reader.ReadFloatArray(); |
521 |
| - Host.CheckDecode(Utils.Size(latentWeights) == featureCount * fieldCount * latentDim); |
522 |
| - |
523 |
| - _norm = norm; |
524 |
| - FieldCount = fieldCount; |
525 |
| - FeatureCount = featureCount; |
526 |
| - LatentDim = latentDim; |
527 |
| - _linearWeights = linearWeights; |
528 |
| - _latentWeightsAligned = new AlignedArray(FeatureCount * FieldCount * LatentDimAligned, 16); |
529 |
| - for (int j = 0; j < FeatureCount; j++) |
530 |
| - { |
531 |
| - for (int f = 0; f < FieldCount; f++) |
532 |
| - { |
533 |
| - int vBias = j * FieldCount * LatentDim + f * LatentDim; |
534 |
| - int vBiasAligned = j * FieldCount * LatentDimAligned + f * LatentDimAligned; |
535 |
| - for (int k = 0; k < LatentDimAligned; k++) |
536 |
| - { |
537 |
| - if (k < LatentDim) |
538 |
| - _latentWeightsAligned[vBiasAligned + k] = latentWeights[vBias + k]; |
539 |
| - else |
540 |
| - _latentWeightsAligned[vBiasAligned + k] = 0; |
541 |
| - } |
542 |
| - } |
543 |
| - } |
544 |
| - } |
545 |
| - |
546 |
| - public static FieldAwareFactorizationMachinePredictor Create(IHostEnvironment env, ModelLoadContext ctx) |
547 |
| - { |
548 |
| - Contracts.CheckValue(env, nameof(env)); |
549 |
| - env.CheckValue(ctx, nameof(ctx)); |
550 |
| - ctx.CheckAtModel(GetVersionInfo()); |
551 |
| - return new FieldAwareFactorizationMachinePredictor(env, ctx); |
552 |
| - } |
553 |
| - |
554 |
| - protected override void SaveCore(ModelSaveContext ctx) |
555 |
| - { |
556 |
| - Host.AssertValue(ctx); |
557 |
| - ctx.SetVersionInfo(GetVersionInfo()); |
558 |
| - |
559 |
| - // *** Binary format *** |
560 |
| - // bool: whether to normalize feature vectors |
561 |
| - // int: number of fields |
562 |
| - // int: number of features |
563 |
| - // int: latent dimension |
564 |
| - // float[]: linear coefficients |
565 |
| - // float[]: latent representation of features |
566 |
| - |
567 |
| - Host.Assert(FieldCount > 0); |
568 |
| - Host.Assert(FeatureCount > 0); |
569 |
| - Host.Assert(LatentDim > 0); |
570 |
| - Host.Assert(Utils.Size(_linearWeights) == FeatureCount); |
571 |
| - Host.Assert(_latentWeightsAligned.Size == FeatureCount * FieldCount * LatentDimAligned); |
572 |
| - |
573 |
| - ctx.Writer.Write(_norm); |
574 |
| - ctx.Writer.Write(FieldCount); |
575 |
| - ctx.Writer.Write(FeatureCount); |
576 |
| - ctx.Writer.Write(LatentDim); |
577 |
| - ctx.Writer.WriteFloatArray(_linearWeights); |
578 |
| - float[] latentWeights = new float[FeatureCount * FieldCount * LatentDim]; |
579 |
| - for (int j = 0; j < FeatureCount; j++) |
580 |
| - { |
581 |
| - for (int f = 0; f < FieldCount; f++) |
582 |
| - { |
583 |
| - int vBias = j * FieldCount * LatentDim + f * LatentDim; |
584 |
| - int vBiasAligned = j * FieldCount * LatentDimAligned + f * LatentDimAligned; |
585 |
| - for (int k = 0; k < LatentDim; k++) |
586 |
| - latentWeights[vBias + k] = _latentWeightsAligned[vBiasAligned + k]; |
587 |
| - } |
588 |
| - } |
589 |
| - ctx.Writer.WriteFloatArray(latentWeights); |
590 |
| - } |
591 |
| - |
592 |
| - internal float CalculateResponse(ValueGetter<VBuffer<float>>[] getters, VBuffer<float> featureBuffer, |
593 |
| - int[] featureFieldBuffer, int[] featureIndexBuffer, float[] featureValueBuffer, AlignedArray latentSum) |
594 |
| - { |
595 |
| - int count = 0; |
596 |
| - float modelResponse = 0; |
597 |
| - FieldAwareFactorizationMachineUtils.LoadOneExampleIntoBuffer(getters, featureBuffer, _norm, ref count, |
598 |
| - featureFieldBuffer, featureIndexBuffer, featureValueBuffer); |
599 |
| - FieldAwareFactorizationMachineInterface.CalculateIntermediateVariables(FieldCount, LatentDimAligned, count, |
600 |
| - featureFieldBuffer, featureIndexBuffer, featureValueBuffer, _linearWeights, _latentWeightsAligned, latentSum, ref modelResponse); |
601 |
| - return modelResponse; |
602 |
| - } |
603 |
| - |
604 |
| - public ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema) |
605 |
| - { |
606 |
| - return new FieldAwareFactorizationMachineScalarRowMapper(env, schema, new BinaryClassifierSchema(), this); |
607 |
| - } |
608 |
| - |
609 |
| - internal void CopyLinearWeightsTo(float[] linearWeights) |
610 |
| - { |
611 |
| - Host.AssertValue(_linearWeights); |
612 |
| - Host.AssertValue(linearWeights); |
613 |
| - Array.Copy(_linearWeights, linearWeights, _linearWeights.Length); |
614 |
| - } |
615 |
| - |
616 |
| - internal void CopyLatentWeightsTo(AlignedArray latentWeights) |
617 |
| - { |
618 |
| - Host.AssertValue(_latentWeightsAligned); |
619 |
| - Host.AssertValue(latentWeights); |
620 |
| - latentWeights.CopyFrom(_latentWeightsAligned); |
621 |
| - } |
622 |
| - } |
623 |
| - |
624 |
| - internal sealed class FieldAwareFactorizationMachineScalarRowMapper : ISchemaBoundRowMapper |
625 |
| - { |
626 |
| - private readonly FieldAwareFactorizationMachinePredictor _pred; |
627 |
| - |
628 |
| - public RoleMappedSchema InputSchema { get; } |
629 |
| - |
630 |
| - public ISchema OutputSchema { get; } |
631 |
| - |
632 |
| - public ISchemaBindableMapper Bindable => _pred; |
633 |
| - |
634 |
| - private readonly ColumnInfo[] _columns; |
635 |
| - private readonly List<int> _inputColumnIndexes; |
636 |
| - private readonly IHostEnvironment _env; |
637 |
| - |
638 |
| - public FieldAwareFactorizationMachineScalarRowMapper(IHostEnvironment env, RoleMappedSchema schema, |
639 |
| - ISchema outputSchema, FieldAwareFactorizationMachinePredictor pred) |
640 |
| - { |
641 |
| - Contracts.AssertValue(env); |
642 |
| - Contracts.AssertValue(schema); |
643 |
| - Contracts.CheckParam(outputSchema.ColumnCount == 2, nameof(outputSchema)); |
644 |
| - Contracts.CheckParam(outputSchema.GetColumnType(0).IsNumber, nameof(outputSchema)); |
645 |
| - Contracts.CheckParam(outputSchema.GetColumnType(1).IsNumber, nameof(outputSchema)); |
646 |
| - Contracts.AssertValue(pred); |
647 |
| - |
648 |
| - _env = env; |
649 |
| - _columns = schema.GetColumns(RoleMappedSchema.ColumnRole.Feature).ToArray(); |
650 |
| - _pred = pred; |
651 |
| - |
652 |
| - var inputFeatureColumns = _columns.Select(c => new KeyValuePair<RoleMappedSchema.ColumnRole, string>(RoleMappedSchema.ColumnRole.Feature, c.Name)).ToList(); |
653 |
| - InputSchema = RoleMappedSchema.Create(schema.Schema, inputFeatureColumns); |
654 |
| - OutputSchema = outputSchema; |
655 |
| - |
656 |
| - _inputColumnIndexes = new List<int>(); |
657 |
| - foreach (var kvp in inputFeatureColumns) |
658 |
| - { |
659 |
| - if (schema.Schema.TryGetColumnIndex(kvp.Value, out int index)) |
660 |
| - _inputColumnIndexes.Add(index); |
661 |
| - } |
662 |
| - } |
663 |
| - |
664 |
| - public IRow GetOutputRow(IRow input, Func<int, bool> predicate, out Action action) |
665 |
| - { |
666 |
| - var latentSum = new AlignedArray(_pred.FieldCount * _pred.FieldCount * _pred.LatentDimAligned, 16); |
667 |
| - var featureBuffer = new VBuffer<float>(); |
668 |
| - var featureFieldBuffer = new int[_pred.FeatureCount]; |
669 |
| - var featureIndexBuffer = new int[_pred.FeatureCount]; |
670 |
| - var featureValueBuffer = new float[_pred.FeatureCount]; |
671 |
| - var inputGetters = new ValueGetter<VBuffer<float>>[_pred.FieldCount]; |
672 |
| - for (int f = 0; f < _pred.FieldCount; f++) |
673 |
| - inputGetters[f] = input.GetGetter<VBuffer<float>>(_inputColumnIndexes[f]); |
674 |
| - |
675 |
| - action = null; |
676 |
| - var getters = new Delegate[2]; |
677 |
| - if (predicate(0)) |
678 |
| - { |
679 |
| - ValueGetter<float> responseGetter = (ref float value) => |
680 |
| - { |
681 |
| - value = _pred.CalculateResponse(inputGetters, featureBuffer, featureFieldBuffer, featureIndexBuffer, featureValueBuffer, latentSum); |
682 |
| - }; |
683 |
| - getters[0] = responseGetter; |
684 |
| - } |
685 |
| - if (predicate(1)) |
686 |
| - { |
687 |
| - ValueGetter<float> probGetter = (ref float value) => |
688 |
| - { |
689 |
| - value = _pred.CalculateResponse(inputGetters, featureBuffer, featureFieldBuffer, featureIndexBuffer, featureValueBuffer, latentSum); |
690 |
| - value = MathUtils.SigmoidSlow(value); |
691 |
| - }; |
692 |
| - getters[1] = probGetter; |
693 |
| - } |
694 |
| - |
695 |
| - return new SimpleRow(OutputSchema, input, getters); |
696 |
| - } |
697 |
| - |
698 |
| - public Func<int, bool> GetDependencies(Func<int, bool> predicate) |
699 |
| - { |
700 |
| - if (Enumerable.Range(0, OutputSchema.ColumnCount).Any(predicate)) |
701 |
| - return index => _inputColumnIndexes.Any(c => c == index); |
702 |
| - else |
703 |
| - return index => false; |
704 |
| - } |
705 |
| - |
706 |
| - public IEnumerable<KeyValuePair<RoleMappedSchema.ColumnRole, string>> GetInputColumnRoles() |
707 |
| - { |
708 |
| - return InputSchema.GetColumnRoles().Select(kvp => new KeyValuePair<RoleMappedSchema.ColumnRole, string>(kvp.Key, kvp.Value.Name)); |
709 |
| - } |
710 |
| - } |
711 | 418 | }
|
0 commit comments