-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Adding Factorization Machines #383
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using System.Reflection; | ||
using System.Runtime.CompilerServices; | ||
using System.Runtime.InteropServices; | ||
|
||
[assembly: InternalsVisibleTo("Microsoft.ML.StandardLearners, PublicKey=00240000048000009400000006020000002400005253413100040000010001004b86c4cb78549b34bab61a3b1800e23bfeb5b3ec390074041536a7e3cbd97f5f04cf0f857155a8928eaa29ebfd11cfbbad3ba70efea7bda3226c6a8d370a4cd303f714486b6ebc225985a638471e6ef571cc92a4613c00b8fa65d61ccee0cbe5f36330c9a01f4183559f1bef24cc2917c6d913e3a541333a1d05d9bed22b38cb")] | ||
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
@@ -0,0 +1,94 @@ | ||||
// Licensed to the .NET Foundation under one or more agreements. | ||||
// The .NET Foundation licenses this file to you under the MIT license. | ||||
// See the LICENSE file in the project root for more information. | ||||
|
||||
using Microsoft.ML.Runtime.Internal.CpuMath; | ||||
using Microsoft.ML.Runtime.Internal.Utilities; | ||||
using System.Runtime.InteropServices; | ||||
|
||||
using System.Security; | ||||
|
||||
namespace Microsoft.ML.Runtime.FactorizationMachine | ||||
{ | ||||
internal unsafe static class FieldAwareFactorizationMachineInterface | ||||
{ | ||||
internal const string NativePath = "FactorizationMachineNative"; | ||||
public const int CbAlign = 16; | ||||
|
||||
private static bool Compat(AlignedArray a) | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @sfilipi and @wschin , could I ask, was the usage of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In a few small benchmark performance tests I've run, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FFM's core computation is done by SSE code, which requires the memory blocks to be aligned. The main computation doesn't call any member functions of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From my investigation, I concluded that the performance penalty being paid was where we are moving the array elements around in memory to manually align it.
In my experience, doing this copying is worse performance than just using unaligned reads. Another way to fix this issue is to use both a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||
{ | ||||
Contracts.AssertValue(a); | ||||
Contracts.Assert(a.Size > 0); | ||||
return a.CbAlign == CbAlign; | ||||
} | ||||
|
||||
private unsafe static float* Ptr(AlignedArray a, float* p) | ||||
{ | ||||
Contracts.AssertValue(a); | ||||
float* q = p + a.GetBase((long)p); | ||||
Contracts.Assert(((long)q & (CbAlign - 1)) == 0); | ||||
return q; | ||||
} | ||||
|
||||
[DllImport(NativePath), SuppressUnmanagedCodeSecurity] | ||||
public static extern void CalculateIntermediateVariablesNative(int fieldCount, int latentDim, int count, int* /*const*/ fieldIndices, int* /*const*/ featureIndices, | ||||
float* /*const*/ featureValues, float* /*const*/ linearWeights, float* /*const*/ latentWeights, float* latentSum, float* response); | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
are this comments here for a reason? #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added those comments There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||
|
||||
[DllImport(NativePath), SuppressUnmanagedCodeSecurity] | ||||
public static extern void CalculateGradientAndUpdateNative(float lambdaLinear, float lambdaLatent, float learningRate, int fieldCount, int latentDim, float weight, | ||||
int count, int* /*const*/ fieldIndices, int* /*const*/ featureIndices, float* /*const*/ featureValues, float* /*const*/ latentSum, float slope, | ||||
float* linearWeights, float* latentWeights, float* linearAccumulatedSquaredGrads, float* latentAccumulatedSquaredGrads); | ||||
|
||||
public static void CalculateIntermediateVariables(int fieldCount, int latentDim, int count, int[] fieldIndices, int[] featureIndices, float[] featureValues, | ||||
float[] linearWeights, AlignedArray latentWeights, AlignedArray latentSum, ref float response) | ||||
{ | ||||
Contracts.AssertNonEmpty(fieldIndices); | ||||
Contracts.AssertNonEmpty(featureValues); | ||||
Contracts.AssertNonEmpty(featureIndices); | ||||
Contracts.AssertNonEmpty(linearWeights); | ||||
Contracts.Assert(Compat(latentWeights)); | ||||
Contracts.Assert(Compat(latentSum)); | ||||
|
||||
unsafe | ||||
{ | ||||
fixed (int* pf = &fieldIndices[0]) | ||||
fixed (int* pi = &featureIndices[0]) | ||||
fixed (float* px = &featureValues[0]) | ||||
fixed (float* pw = &linearWeights[0]) | ||||
fixed (float* pv = &latentWeights.Items[0]) | ||||
fixed (float* pq = &latentSum.Items[0]) | ||||
fixed (float* pr = &response) | ||||
CalculateIntermediateVariablesNative(fieldCount, latentDim, count, pf, pi, px, pw, Ptr(latentWeights, pv), Ptr(latentSum, pq), pr); | ||||
} | ||||
} | ||||
|
||||
public static void CalculateGradientAndUpdate(float lambdaLinear, float lambdaLatent, float learningRate, int fieldCount, int latentDim, | ||||
float weight, int count, int[] fieldIndices, int[] featureIndices, float[] featureValues, AlignedArray latentSum, float slope, | ||||
float[] linearWeights, AlignedArray latentWeights, float[] linearAccumulatedSquaredGrads, AlignedArray latentAccumulatedSquaredGrads) | ||||
{ | ||||
Contracts.AssertNonEmpty(fieldIndices); | ||||
Contracts.AssertNonEmpty(featureIndices); | ||||
Contracts.AssertNonEmpty(featureValues); | ||||
Contracts.Assert(Compat(latentSum)); | ||||
Contracts.AssertNonEmpty(linearWeights); | ||||
Contracts.Assert(Compat(latentWeights)); | ||||
Contracts.AssertNonEmpty(linearAccumulatedSquaredGrads); | ||||
Contracts.Assert(Compat(latentAccumulatedSquaredGrads)); | ||||
|
||||
unsafe | ||||
{ | ||||
fixed (int* pf = &fieldIndices[0]) | ||||
fixed (int* pi = &featureIndices[0]) | ||||
fixed (float* px = &featureValues[0]) | ||||
fixed (float* pq = &latentSum.Items[0]) | ||||
fixed (float* pw = &linearWeights[0]) | ||||
fixed (float* pv = &latentWeights.Items[0]) | ||||
fixed (float* phw = &linearAccumulatedSquaredGrads[0]) | ||||
fixed (float* phv = &latentAccumulatedSquaredGrads.Items[0]) | ||||
CalculateGradientAndUpdateNative(lambdaLinear, lambdaLatent, learningRate, fieldCount, latentDim, weight, count, pf, pi, px, | ||||
Ptr(latentSum, pq), slope, pw, Ptr(latentWeights, pv), phw, Ptr(latentAccumulatedSquaredGrads, phv)); | ||||
} | ||||
|
||||
} | ||||
} | ||||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this necessary? #Pending
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AlignedArray.Items is internal, but gets accessed in FactorizationMachines.
Will get rid of it we move off AlignedArray.
In reply to: 198213276 [](ancestors = 198213276)