Skip to content

PrivateAI Demo. #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
248 changes: 248 additions & 0 deletions Microsoft.ML.sln

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions PrivateAI/App.config
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
<?xml version="1.0" encoding="utf-8" ?>
<configuration>
<startup>
<supportedRuntime version="v4.0" sku=".NETFramework,Version=v4.6.2" />
</startup>
</configuration>
79 changes: 79 additions & 0 deletions PrivateAI/EncryptionContext.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
using Microsoft.Research.SEAL;

namespace PrivateAI
{
public class EncryptionContext
{
public Evaluator Evaluator { get; }
public Encryptor Encryptor { get; }
public Decryptor Decryptor { get; }
public FractionalEncoder Encoder { get; }

// These keys should **NOT** be saved here
// Its just for demo purpose
public PublicKey PublicKey { get; }

public SecretKey SecretKey { get; }

public EncryptionContext()
{
var context = CreateSEALContext();
var keygen = new KeyGenerator(context);
PublicKey = keygen.PublicKey;
SecretKey = keygen.SecretKey;

Encryptor = new Encryptor(context, PublicKey);
Evaluator = new Evaluator(context);
Decryptor = new Decryptor(context, SecretKey);

Encoder = new FractionalEncoder(context.PlainModulus, context.PolyModulus, 64, 32, 3);
}

public EncryptionContext(PublicKey publicKey, SecretKey secretKey)
{
var context = CreateSEALContext();
PublicKey = publicKey;
SecretKey = secretKey;

Encryptor = new Encryptor(context, publicKey);
Evaluator = new Evaluator(context);
Decryptor = new Decryptor(context, secretKey);

Encoder = new FractionalEncoder(context.PlainModulus, context.PolyModulus, 64, 32, 3);
}

public EncryptionContext(PublicKey publicKey)
{
var context = CreateSEALContext();
PublicKey = publicKey;

Encryptor = new Encryptor(context, publicKey);
Evaluator = new Evaluator(context);
Decryptor = null;

Encoder = new FractionalEncoder(context.PlainModulus, context.PolyModulus, 64, 32, 3);
}

public EncryptionContext(SecretKey secretKey)
{
var context = CreateSEALContext();
SecretKey = secretKey;

Encryptor = null;
Evaluator = new Evaluator(context);
Decryptor = new Decryptor(context, secretKey);

Encoder = new FractionalEncoder(context.PlainModulus, context.PolyModulus, 64, 32, 3);
}

private static SEALContext CreateSEALContext()
{
EncryptionParameters encryptionParams = new EncryptionParameters();
encryptionParams.PolyModulus = "1x^2048 + 1";
encryptionParams.CoeffModulus = DefaultParams.CoeffModulus128(2048);
encryptionParams.PlainModulus = 1 << 8;

return new SEALContext(encryptionParams);
}
}
}
105 changes: 105 additions & 0 deletions PrivateAI/PrivateAI.csproj
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
<?xml version="1.0" encoding="utf-8"?>
<Project ToolsVersion="15.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<Import Project="$(MSBuildExtensionsPath)\$(MSBuildToolsVersion)\Microsoft.Common.props" Condition="Exists('$(MSBuildExtensionsPath)\$(MSBuildToolsVersion)\Microsoft.Common.props')" />
<PropertyGroup>
<Configuration Condition=" '$(Configuration)' == '' ">Debug</Configuration>
<Platform Condition=" '$(Platform)' == '' ">AnyCPU</Platform>
<ProjectGuid>{BEF8FCD5-5872-4F92-91FC-48DCA00118EC}</ProjectGuid>
<OutputType>Exe</OutputType>
<RootNamespace>PrivateAI</RootNamespace>
<AssemblyName>PrivateAI</AssemblyName>
<TargetFrameworkVersion>v4.6.2</TargetFrameworkVersion>
<FileAlignment>512</FileAlignment>
<AutoGenerateBindingRedirects>true</AutoGenerateBindingRedirects>
</PropertyGroup>
<PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Debug|AnyCPU' ">
<PlatformTarget>AnyCPU</PlatformTarget>
<DebugSymbols>true</DebugSymbols>
<DebugType>full</DebugType>
<Optimize>false</Optimize>
<OutputPath>bin\Debug\</OutputPath>
<DefineConstants>DEBUG;TRACE</DefineConstants>
<ErrorReport>prompt</ErrorReport>
<WarningLevel>4</WarningLevel>
</PropertyGroup>
<PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Release|AnyCPU' ">
<PlatformTarget>AnyCPU</PlatformTarget>
<DebugType>pdbonly</DebugType>
<Optimize>true</Optimize>
<OutputPath>bin\Release\</OutputPath>
<DefineConstants>TRACE</DefineConstants>
<ErrorReport>prompt</ErrorReport>
<WarningLevel>4</WarningLevel>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)' == 'Debug|x64'">
<DebugSymbols>true</DebugSymbols>
<OutputPath>bin\x64\Debug\</OutputPath>
<DefineConstants>DEBUG;TRACE</DefineConstants>
<TreatWarningsAsErrors>true</TreatWarningsAsErrors>
<DebugType>full</DebugType>
<PlatformTarget>x64</PlatformTarget>
<LangVersion>latest</LangVersion>
<ErrorReport>prompt</ErrorReport>
<CodeAnalysisRuleSet>MinimumRecommendedRules.ruleset</CodeAnalysisRuleSet>
<Prefer32Bit>true</Prefer32Bit>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)' == 'Release|x64'">
<OutputPath>bin\x64\Release\</OutputPath>
<DefineConstants>TRACE</DefineConstants>
<Optimize>true</Optimize>
<TreatWarningsAsErrors>true</TreatWarningsAsErrors>
<DebugType>pdbonly</DebugType>
<PlatformTarget>x64</PlatformTarget>
<LangVersion>latest</LangVersion>
<ErrorReport>prompt</ErrorReport>
<CodeAnalysisRuleSet>MinimumRecommendedRules.ruleset</CodeAnalysisRuleSet>
<Prefer32Bit>true</Prefer32Bit>
</PropertyGroup>
<ItemGroup>
<Reference Include="SEALNET">
<HintPath>..\..\..\SEAL\SEAL\bin\x64\Release\SEALNET.dll</HintPath>
</Reference>
<Reference Include="System" />
<Reference Include="System.Core" />
<Reference Include="System.Xml.Linq" />
<Reference Include="System.Data.DataSetExtensions" />
<Reference Include="Microsoft.CSharp" />
<Reference Include="System.Data" />
<Reference Include="System.Net.Http" />
<Reference Include="System.Xml" />
</ItemGroup>
<ItemGroup>
<Compile Include="EncryptionContext.cs" />
<Compile Include="PrivateAIUtils.cs" />
<Compile Include="Program.cs" />
</ItemGroup>
<ItemGroup>
<None Include="App.config" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\src\Microsoft.ML.Core\Microsoft.ML.Core.csproj">
<Project>{a6ca6cc6-5d7c-4d7f-a0f5-35e14b383b0a}</Project>
<Name>Microsoft.ML.Core</Name>
</ProjectReference>
<ProjectReference Include="..\src\Microsoft.ML.CpuMath\Microsoft.ML.CpuMath.csproj">
<Project>{46f2f967-c23f-4076-858d-33f7da9bd2da}</Project>
<Name>Microsoft.ML.CpuMath</Name>
</ProjectReference>
<ProjectReference Include="..\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj">
<Project>{ad92d96b-0e96-4f22-8dce-892e13b1f282}</Project>
<Name>Microsoft.ML.Data</Name>
</ProjectReference>
<ProjectReference Include="..\src\Microsoft.ML.StandardLearners\Microsoft.ML.StandardLearners.csproj">
<Project>{707bb22c-7e5f-497a-8c2f-74578f675705}</Project>
<Name>Microsoft.ML.StandardLearners</Name>
</ProjectReference>
<ProjectReference Include="..\src\Microsoft.ML\Microsoft.ML.csproj">
<Project>{7288c084-11c0-43be-ac7f-45dcfeaeebf6}</Project>
<Name>Microsoft.ML</Name>
</ProjectReference>
</ItemGroup>
<ItemGroup>
<Folder Include="Properties\" />
</ItemGroup>
<Import Project="$(MSBuildToolsPath)\Microsoft.CSharp.targets" />
</Project>
183 changes: 183 additions & 0 deletions PrivateAI/PrivateAIUtils.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Learners;
using Microsoft.ML.Runtime.Model;
using Microsoft.ML.Runtime.Training;
using Microsoft.Research.SEAL;
using System;
using System.IO;

namespace PrivateAI
{
public class PrivateAIUtils
{
public static void EncryptModel(string modelFile, string publicKeyFile)
{
PublicKey key = new PublicKey();
using (BinaryReader reader = new BinaryReader(new FileStream(publicKeyFile, FileMode.Open)))
{
key.Load(reader.BaseStream);
}

EncryptionContext context = new EncryptionContext(key);

using (var env = new TlcEnvironment(seed: 1, conc: 1))
{
LinearPredictor pred = (LinearPredictor)LoadModel(env, modelFile);
pred.Evaluator = context.Evaluator;

// Now encrypt the model
pred.EncryptModel(context.Encryptor, context.Encoder);

// Save model
var trainRoles = LoadRoleMapping(env, modelFile);
SaveModel(env, pred, trainRoles, modelFile + ".encrypted");
}
}

public static void EncryptData(string dataPath, string modelFile, string publicKeyFile)
{
PublicKey key = new PublicKey();
using (BinaryReader reader = new BinaryReader(new FileStream(publicKeyFile, FileMode.Open)))
{
key.Load(reader.BaseStream);
}

EncryptionContext context = new EncryptionContext(key);

using (var env = new TlcEnvironment(seed: 1, conc: 1))
{
IDataView testData = GetTestPipeline(env, dataPath, modelFile);
LinearPredictor pred = (LinearPredictor)LoadModel(env, modelFile);
pred.Evaluator = context.Evaluator;

// Get the valuemapper methods. Both for normal and encrypted case.
// We will use these mappers to score the feature vector before and after encryption.
// Since non of ML.Net transforms are encryption aware, feature vector is featurized here.
// Featurized vector is then ecrypted and passed on to model for scoring.
var valueMapperEncrypted = pred.GetEncryptedMapper<VBuffer<Ciphertext>, Ciphertext>();
var valueMapper = pred.GetMapper<VBuffer<Single>, Single>();

BinaryWriter writer = new BinaryWriter(new FileStream(dataPath + ".encrypted", FileMode.Create));
// Prepare for iteration over the data pipeline.
var cursorFactory = new FloatLabelCursor.Factory(new RoleMappedData(testData, DefaultColumnNames.Label, DefaultColumnNames.Features)
, CursOpt.Label | CursOpt.Features);
using (var cursor = cursorFactory.Create())
{
int sampleCount = 0;
// Iterate over the data and match encrypted and non-encrypted score.
while (cursor.MoveNext())
{
sampleCount++;
// Predict on Encrypted Data
var vBufferencryptedFeatures = EncryptData(context, ref cursor.Features);
Ciphertext encryptedResult = new Ciphertext();
var watch = System.Diagnostics.Stopwatch.StartNew();
valueMapperEncrypted(ref vBufferencryptedFeatures, ref encryptedResult);


WriteData(writer, vBufferencryptedFeatures);
}
}
}
}

public static void DecryptData(string dataPath, string privateKeyFile)
{
SecretKey key = new SecretKey();
using (BinaryReader reader = new BinaryReader(new FileStream(privateKeyFile, FileMode.Open)))
{
key.Load(reader.BaseStream);
}

EncryptionContext context = new EncryptionContext(key);

using (BinaryReader reader = new BinaryReader(new FileStream(dataPath, FileMode.Open)))
{
while (reader.BaseStream.Position != reader.BaseStream.Length)
{
var ciphertext = new Ciphertext();
ciphertext.Load(reader.BaseStream);

var plainResult = new Plaintext();
context.Decryptor.Decrypt(ciphertext, plainResult);
var predictionEncrypted = (float)context.Encoder.Decode(plainResult);
Console.WriteLine(predictionEncrypted);
}
}
}

public static void WriteData(BinaryWriter writer, VBuffer<Ciphertext> features)
{
if (features.Indices == null)
{
writer.Write(false);
writer.Write(features.Values.Length);
for (int i = 0; i < features.Values.Length; i++)
{
features.Values[i].Save(writer.BaseStream);
}
}
else
{
writer.Write(true);
writer.Write(features.Length);
writer.Write(features.Values.Length);
for (int i = 0; i < features.Values.Length; i++)
{
writer.Write(features.Indices[i]);
features.Values[i].Save(writer.BaseStream);
}
}
}

public static VBuffer<Ciphertext> EncryptData(EncryptionContext EncryContext, ref VBuffer<Single> features)
{
Ciphertext[] encryptedFeatures = new Ciphertext[features.Values.Length];

for (int i = 0; i < features.Values.Length; i++)
{
encryptedFeatures[i] = new Ciphertext();
EncryContext.Encryptor.Encrypt(EncryContext.Encoder.Encode(features.Values[i]), encryptedFeatures[i]);
}
return new VBuffer<Ciphertext>(features.Length, features.Count, encryptedFeatures, features.Indices);
}

public static IDataView GetTestPipeline(IHostEnvironment env, string testDataPath, string modelFile)
{
using (var stream = new FileStream(modelFile, FileMode.Open))
{
return ModelFileUtils.LoadPipeline(env, stream, new MultiFileSource(testDataPath), true);
}
}

public static void SaveModel(IHostEnvironment env, IPredictor pred, RoleMappedData trainRoles, string modelFile)
{
using (var ch = env.Start("Saving model"))
using (var filestream = new FileStream(modelFile, FileMode.Create))
{
// Model cannot be saved with CacheDataView
TrainUtils.SaveModel(env, ch, filestream, pred, trainRoles);
}
}

public static IPredictor LoadModel(IHostEnvironment env, string modelFile)
{
using (var filestream = new FileStream(modelFile, FileMode.Open))
{
// Model cannot be saved with CacheDataView
return ModelFileUtils.LoadPredictorOrNull(env, filestream);
}
}

public static RoleMappedData LoadRoleMapping(IHostEnvironment env, string modelFile)
{
using (var filestream = new FileStream(modelFile, FileMode.Open))
{
var dataview = ModelFileUtils.LoadPipeline(env, filestream, new MultiFileSource(null), true);
// Model cannot be saved with CacheDataView
return new RoleMappedData(dataview, ModelFileUtils.LoadRoleMappingsOrNull(env, filestream));
}
}
}
}
Loading