Skip to content

Commit 9b7768a

Browse files
sfilipieerhardt
authored andcommitted
Adding Factorization Machines (dotnet#383)
* Adding Factorization Machines
1 parent 1d88e46 commit 9b7768a

26 files changed

+4498
-4
lines changed
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System.Reflection;
6+
using System.Runtime.CompilerServices;
7+
using System.Runtime.InteropServices;
8+
9+
[assembly: InternalsVisibleTo("Microsoft.ML.StandardLearners, PublicKey=00240000048000009400000006020000002400005253413100040000010001004b86c4cb78549b34bab61a3b1800e23bfeb5b3ec390074041536a7e3cbd97f5f04cf0f857155a8928eaa29ebfd11cfbbad3ba70efea7bda3226c6a8d370a4cd303f714486b6ebc225985a638471e6ef571cc92a4613c00b8fa65d61ccee0cbe5f36330c9a01f4183559f1bef24cc2917c6d913e3a541333a1d05d9bed22b38cb")]
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Runtime.Internal.CpuMath;
6+
using Microsoft.ML.Runtime.Internal.Utilities;
7+
using System.Runtime.InteropServices;
8+
9+
using System.Security;
10+
11+
namespace Microsoft.ML.Runtime.FactorizationMachine
12+
{
13+
internal unsafe static class FieldAwareFactorizationMachineInterface
14+
{
15+
internal const string NativePath = "FactorizationMachineNative";
16+
public const int CbAlign = 16;
17+
18+
private static bool Compat(AlignedArray a)
19+
{
20+
Contracts.AssertValue(a);
21+
Contracts.Assert(a.Size > 0);
22+
return a.CbAlign == CbAlign;
23+
}
24+
25+
private unsafe static float* Ptr(AlignedArray a, float* p)
26+
{
27+
Contracts.AssertValue(a);
28+
float* q = p + a.GetBase((long)p);
29+
Contracts.Assert(((long)q & (CbAlign - 1)) == 0);
30+
return q;
31+
}
32+
33+
[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
34+
public static extern void CalculateIntermediateVariablesNative(int fieldCount, int latentDim, int count, int* /*const*/ fieldIndices, int* /*const*/ featureIndices,
35+
float* /*const*/ featureValues, float* /*const*/ linearWeights, float* /*const*/ latentWeights, float* latentSum, float* response);
36+
37+
[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
38+
public static extern void CalculateGradientAndUpdateNative(float lambdaLinear, float lambdaLatent, float learningRate, int fieldCount, int latentDim, float weight,
39+
int count, int* /*const*/ fieldIndices, int* /*const*/ featureIndices, float* /*const*/ featureValues, float* /*const*/ latentSum, float slope,
40+
float* linearWeights, float* latentWeights, float* linearAccumulatedSquaredGrads, float* latentAccumulatedSquaredGrads);
41+
42+
public static void CalculateIntermediateVariables(int fieldCount, int latentDim, int count, int[] fieldIndices, int[] featureIndices, float[] featureValues,
43+
float[] linearWeights, AlignedArray latentWeights, AlignedArray latentSum, ref float response)
44+
{
45+
Contracts.AssertNonEmpty(fieldIndices);
46+
Contracts.AssertNonEmpty(featureValues);
47+
Contracts.AssertNonEmpty(featureIndices);
48+
Contracts.AssertNonEmpty(linearWeights);
49+
Contracts.Assert(Compat(latentWeights));
50+
Contracts.Assert(Compat(latentSum));
51+
52+
unsafe
53+
{
54+
fixed (int* pf = &fieldIndices[0])
55+
fixed (int* pi = &featureIndices[0])
56+
fixed (float* px = &featureValues[0])
57+
fixed (float* pw = &linearWeights[0])
58+
fixed (float* pv = &latentWeights.Items[0])
59+
fixed (float* pq = &latentSum.Items[0])
60+
fixed (float* pr = &response)
61+
CalculateIntermediateVariablesNative(fieldCount, latentDim, count, pf, pi, px, pw, Ptr(latentWeights, pv), Ptr(latentSum, pq), pr);
62+
}
63+
}
64+
65+
public static void CalculateGradientAndUpdate(float lambdaLinear, float lambdaLatent, float learningRate, int fieldCount, int latentDim,
66+
float weight, int count, int[] fieldIndices, int[] featureIndices, float[] featureValues, AlignedArray latentSum, float slope,
67+
float[] linearWeights, AlignedArray latentWeights, float[] linearAccumulatedSquaredGrads, AlignedArray latentAccumulatedSquaredGrads)
68+
{
69+
Contracts.AssertNonEmpty(fieldIndices);
70+
Contracts.AssertNonEmpty(featureIndices);
71+
Contracts.AssertNonEmpty(featureValues);
72+
Contracts.Assert(Compat(latentSum));
73+
Contracts.AssertNonEmpty(linearWeights);
74+
Contracts.Assert(Compat(latentWeights));
75+
Contracts.AssertNonEmpty(linearAccumulatedSquaredGrads);
76+
Contracts.Assert(Compat(latentAccumulatedSquaredGrads));
77+
78+
unsafe
79+
{
80+
fixed (int* pf = &fieldIndices[0])
81+
fixed (int* pi = &featureIndices[0])
82+
fixed (float* px = &featureValues[0])
83+
fixed (float* pq = &latentSum.Items[0])
84+
fixed (float* pw = &linearWeights[0])
85+
fixed (float* pv = &latentWeights.Items[0])
86+
fixed (float* phw = &linearAccumulatedSquaredGrads[0])
87+
fixed (float* phv = &latentAccumulatedSquaredGrads.Items[0])
88+
CalculateGradientAndUpdateNative(lambdaLinear, lambdaLatent, learningRate, fieldCount, latentDim, weight, count, pf, pi, px,
89+
Ptr(latentSum, pq), slope, pw, Ptr(latentWeights, pv), phw, Ptr(latentAccumulatedSquaredGrads, phv));
90+
}
91+
92+
}
93+
}
94+
}

0 commit comments

Comments
 (0)