Skip to content

Commit 37a6aa2

Browse files
authored
Harden user PII protection logic and extend TrainingAnalytics to expose detailed configuration parameters. (#5512)
* Hash128 is not a cryptographic hash, replace with HMAC-SHA256. * Extend TrainingAnalytics side channel to expose configuration details * Change member function scopes and hash demo_paths * Extract tbiEvent hashing method and add test coverage
1 parent f2e0cb8 commit 37a6aa2

File tree

11 files changed

+249
-35
lines changed

11 files changed

+249
-35
lines changed

com.unity.ml-agents/Runtime/Analytics/AnalyticsUtils.cs

+31-5
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,45 @@
11
using System;
2+
using System.Text;
3+
using System.Security.Cryptography;
24
using UnityEngine;
35

46
namespace Unity.MLAgents.Analytics
57
{
8+
69
internal static class AnalyticsUtils
710
{
11+
/// <summary>
12+
/// Conversion function from byte array to hex string
13+
/// </summary>
14+
/// <param name="array"></param>
15+
/// <returns>A byte array to be hex encoded.</returns>
16+
private static string ToHexString(byte[] array)
17+
{
18+
StringBuilder hex = new StringBuilder(array.Length * 2);
19+
foreach (byte b in array)
20+
{
21+
hex.AppendFormat("{0:x2}", b);
22+
}
23+
return hex.ToString();
24+
}
25+
826
/// <summary>
927
/// Hash a string to remove PII or secret info before sending to analytics
1028
/// </summary>
11-
/// <param name="s"></param>
12-
/// <returns>A string containing the Hash128 of the input string.</returns>
13-
public static string Hash(string s)
29+
/// <param name="key"></param>
30+
/// <returns>A string containing the key to be used for HMAC encoding.</returns>
31+
/// <param name="value"></param>
32+
/// <returns>A string containing the value to be encoded.</returns>
33+
public static string Hash(string key, string value)
1434
{
15-
var behaviorNameHash = Hash128.Compute(s);
16-
return behaviorNameHash.ToString();
35+
string hash;
36+
UTF8Encoding encoder = new UTF8Encoding();
37+
using (HMACSHA256 hmac = new HMACSHA256(encoder.GetBytes(key)))
38+
{
39+
Byte[] hmBytes = hmac.ComputeHash(encoder.GetBytes(value));
40+
hash = ToHexString(hmBytes);
41+
}
42+
return hash;
1743
}
1844

1945
internal static bool s_SendEditorAnalytics = true;

com.unity.ml-agents/Runtime/Analytics/Events.cs

+2
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ internal struct TrainingEnvironmentInitializedEvent
156156
public string TorchDeviceType;
157157
public int NumEnvironments;
158158
public int NumEnvironmentParameters;
159+
public string RunOptions;
159160
}
160161

161162
[Flags]
@@ -188,5 +189,6 @@ internal struct TrainingBehaviorInitializedEvent
188189
public string VisualEncoder;
189190
public int NumNetworkLayers;
190191
public int NumNetworkHiddenUnits;
192+
public string Config;
191193
}
192194
}

com.unity.ml-agents/Runtime/Analytics/InferenceAnalytics.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ IList<IActuator> actuators
156156
var inferenceEvent = new InferenceEvent();
157157

158158
// Hash the behavior name so that there's no concern about PII or "secret" data being leaked.
159-
inferenceEvent.BehaviorName = AnalyticsUtils.Hash(behaviorName);
159+
inferenceEvent.BehaviorName = AnalyticsUtils.Hash(k_VendorKey, behaviorName);
160160

161161
inferenceEvent.BarracudaModelSource = barracudaModel.IrSource;
162162
inferenceEvent.BarracudaModelVersion = barracudaModel.IrVersion;

com.unity.ml-agents/Runtime/Analytics/TrainingAnalytics.cs

+16-4
Original file line numberDiff line numberDiff line change
@@ -192,8 +192,21 @@ internal static string ParseBehaviorName(string fullyQualifiedBehaviorName)
192192
return fullyQualifiedBehaviorName.Substring(0, lastQuestionIndex);
193193
}
194194

195+
internal static TrainingBehaviorInitializedEvent SanitizeTrainingBehaviorInitializedEvent(TrainingBehaviorInitializedEvent tbiEvent)
196+
{
197+
// Hash the behavior name if the message version is from an older version of ml-agents that doesn't do trainer-side hashing.
198+
// We'll also, for extra safety, verify that the BehaviorName is the size of the expected SHA256 hash.
199+
// Context: The config field was added at the same time as trainer side hashing, so messages including it should already be hashed.
200+
if (tbiEvent.Config.Length == 0 || tbiEvent.BehaviorName.Length != 64)
201+
{
202+
tbiEvent.BehaviorName = AnalyticsUtils.Hash(k_VendorKey, tbiEvent.BehaviorName);
203+
}
204+
205+
return tbiEvent;
206+
}
207+
195208
[Conditional("MLA_UNITY_ANALYTICS_MODULE")]
196-
public static void TrainingBehaviorInitialized(TrainingBehaviorInitializedEvent tbiEvent)
209+
public static void TrainingBehaviorInitialized(TrainingBehaviorInitializedEvent rawTbiEvent)
197210
{
198211
#if UNITY_EDITOR && MLA_UNITY_ANALYTICS_MODULE
199212
if (!IsAnalyticsEnabled())
@@ -202,6 +215,7 @@ public static void TrainingBehaviorInitialized(TrainingBehaviorInitializedEvent
202215
if (!EnableAnalytics())
203216
return;
204217

218+
var tbiEvent = SanitizeTrainingBehaviorInitializedEvent(rawTbiEvent);
205219
var behaviorName = tbiEvent.BehaviorName;
206220
var added = s_SentTrainingBehaviorInitialized.Add(behaviorName);
207221

@@ -211,9 +225,7 @@ public static void TrainingBehaviorInitialized(TrainingBehaviorInitializedEvent
211225
return;
212226
}
213227

214-
// Hash the behavior name so that there's no concern about PII or "secret" data being leaked.
215228
tbiEvent.TrainingSessionGuid = s_TrainingSessionGuid.ToString();
216-
tbiEvent.BehaviorName = AnalyticsUtils.Hash(tbiEvent.BehaviorName);
217229

218230
// Note - to debug, use JsonUtility.ToJson on the event.
219231
// Debug.Log(
@@ -236,7 +248,7 @@ IList<IActuator> actuators
236248
var remotePolicyEvent = new RemotePolicyInitializedEvent();
237249

238250
// Hash the behavior name so that there's no concern about PII or "secret" data being leaked.
239-
remotePolicyEvent.BehaviorName = AnalyticsUtils.Hash(behaviorName);
251+
remotePolicyEvent.BehaviorName = AnalyticsUtils.Hash(k_VendorKey, behaviorName);
240252

241253
remotePolicyEvent.TrainingSessionGuid = s_TrainingSessionGuid.ToString();
242254
remotePolicyEvent.ActionSpec = EventActionSpec.FromActionSpec(actionSpec);

com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs

+2
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,7 @@ internal static TrainingEnvironmentInitializedEvent ToTrainingEnvironmentInitial
501501
TorchDeviceType = inputProto.TorchDeviceType,
502502
NumEnvironments = inputProto.NumEnvs,
503503
NumEnvironmentParameters = inputProto.NumEnvironmentParameters,
504+
RunOptions = inputProto.RunOptions,
504505
};
505506
}
506507

@@ -530,6 +531,7 @@ internal static TrainingBehaviorInitializedEvent ToTrainingBehaviorInitializedEv
530531
VisualEncoder = inputProto.VisualEncoder,
531532
NumNetworkLayers = inputProto.NumNetworkLayers,
532533
NumNetworkHiddenUnits = inputProto.NumNetworkHiddenUnits,
534+
Config = inputProto.Config,
533535
};
534536
}
535537

com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/TrainingAnalytics.cs

+72-15
Original file line numberDiff line numberDiff line change
@@ -25,28 +25,29 @@ static TrainingAnalyticsReflection() {
2525
byte[] descriptorData = global::System.Convert.FromBase64String(
2626
string.Concat(
2727
"CjttbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL3RyYWluaW5n",
28-
"X2FuYWx5dGljcy5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMi2QEKHlRy",
28+
"X2FuYWx5dGljcy5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMi7gEKHlRy",
2929
"YWluaW5nRW52aXJvbm1lbnRJbml0aWFsaXplZBIYChBtbGFnZW50c192ZXJz",
3030
"aW9uGAEgASgJEh0KFW1sYWdlbnRzX2VudnNfdmVyc2lvbhgCIAEoCRIWCg5w",
3131
"eXRob25fdmVyc2lvbhgDIAEoCRIVCg10b3JjaF92ZXJzaW9uGAQgASgJEhkK",
3232
"EXRvcmNoX2RldmljZV90eXBlGAUgASgJEhAKCG51bV9lbnZzGAYgASgFEiIK",
33-
"Gm51bV9lbnZpcm9ubWVudF9wYXJhbWV0ZXJzGAcgASgFIq0DChtUcmFpbmlu",
34-
"Z0JlaGF2aW9ySW5pdGlhbGl6ZWQSFQoNYmVoYXZpb3JfbmFtZRgBIAEoCRIU",
35-
"Cgx0cmFpbmVyX3R5cGUYAiABKAkSIAoYZXh0cmluc2ljX3Jld2FyZF9lbmFi",
36-
"bGVkGAMgASgIEhsKE2dhaWxfcmV3YXJkX2VuYWJsZWQYBCABKAgSIAoYY3Vy",
37-
"aW9zaXR5X3Jld2FyZF9lbmFibGVkGAUgASgIEhoKEnJuZF9yZXdhcmRfZW5h",
38-
"YmxlZBgGIAEoCBIiChpiZWhhdmlvcmFsX2Nsb25pbmdfZW5hYmxlZBgHIAEo",
39-
"CBIZChFyZWN1cnJlbnRfZW5hYmxlZBgIIAEoCBIWCg52aXN1YWxfZW5jb2Rl",
40-
"chgJIAEoCRIaChJudW1fbmV0d29ya19sYXllcnMYCiABKAUSIAoYbnVtX25l",
41-
"dHdvcmtfaGlkZGVuX3VuaXRzGAsgASgFEhgKEHRyYWluZXJfdGhyZWFkZWQY",
42-
"DCABKAgSGQoRc2VsZl9wbGF5X2VuYWJsZWQYDSABKAgSGgoSY3VycmljdWx1",
43-
"bV9lbmFibGVkGA4gASgIQiWqAiJVbml0eS5NTEFnZW50cy5Db21tdW5pY2F0",
44-
"b3JPYmplY3RzYgZwcm90bzM="));
33+
"Gm51bV9lbnZpcm9ubWVudF9wYXJhbWV0ZXJzGAcgASgFEhMKC3J1bl9vcHRp",
34+
"b25zGAggASgJIr0DChtUcmFpbmluZ0JlaGF2aW9ySW5pdGlhbGl6ZWQSFQoN",
35+
"YmVoYXZpb3JfbmFtZRgBIAEoCRIUCgx0cmFpbmVyX3R5cGUYAiABKAkSIAoY",
36+
"ZXh0cmluc2ljX3Jld2FyZF9lbmFibGVkGAMgASgIEhsKE2dhaWxfcmV3YXJk",
37+
"X2VuYWJsZWQYBCABKAgSIAoYY3VyaW9zaXR5X3Jld2FyZF9lbmFibGVkGAUg",
38+
"ASgIEhoKEnJuZF9yZXdhcmRfZW5hYmxlZBgGIAEoCBIiChpiZWhhdmlvcmFs",
39+
"X2Nsb25pbmdfZW5hYmxlZBgHIAEoCBIZChFyZWN1cnJlbnRfZW5hYmxlZBgI",
40+
"IAEoCBIWCg52aXN1YWxfZW5jb2RlchgJIAEoCRIaChJudW1fbmV0d29ya19s",
41+
"YXllcnMYCiABKAUSIAoYbnVtX25ldHdvcmtfaGlkZGVuX3VuaXRzGAsgASgF",
42+
"EhgKEHRyYWluZXJfdGhyZWFkZWQYDCABKAgSGQoRc2VsZl9wbGF5X2VuYWJs",
43+
"ZWQYDSABKAgSGgoSY3VycmljdWx1bV9lbmFibGVkGA4gASgIEg4KBmNvbmZp",
44+
"ZxgPIAEoCUIlqgIiVW5pdHkuTUxBZ2VudHMuQ29tbXVuaWNhdG9yT2JqZWN0",
45+
"c2IGcHJvdG8z"));
4546
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
4647
new pbr::FileDescriptor[] { },
4748
new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
48-
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.TrainingEnvironmentInitialized), global::Unity.MLAgents.CommunicatorObjects.TrainingEnvironmentInitialized.Parser, new[]{ "MlagentsVersion", "MlagentsEnvsVersion", "PythonVersion", "TorchVersion", "TorchDeviceType", "NumEnvs", "NumEnvironmentParameters" }, null, null, null),
49-
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.TrainingBehaviorInitialized), global::Unity.MLAgents.CommunicatorObjects.TrainingBehaviorInitialized.Parser, new[]{ "BehaviorName", "TrainerType", "ExtrinsicRewardEnabled", "GailRewardEnabled", "CuriosityRewardEnabled", "RndRewardEnabled", "BehavioralCloningEnabled", "RecurrentEnabled", "VisualEncoder", "NumNetworkLayers", "NumNetworkHiddenUnits", "TrainerThreaded", "SelfPlayEnabled", "CurriculumEnabled" }, null, null, null)
49+
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.TrainingEnvironmentInitialized), global::Unity.MLAgents.CommunicatorObjects.TrainingEnvironmentInitialized.Parser, new[]{ "MlagentsVersion", "MlagentsEnvsVersion", "PythonVersion", "TorchVersion", "TorchDeviceType", "NumEnvs", "NumEnvironmentParameters", "RunOptions" }, null, null, null),
50+
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.TrainingBehaviorInitialized), global::Unity.MLAgents.CommunicatorObjects.TrainingBehaviorInitialized.Parser, new[]{ "BehaviorName", "TrainerType", "ExtrinsicRewardEnabled", "GailRewardEnabled", "CuriosityRewardEnabled", "RndRewardEnabled", "BehavioralCloningEnabled", "RecurrentEnabled", "VisualEncoder", "NumNetworkLayers", "NumNetworkHiddenUnits", "TrainerThreaded", "SelfPlayEnabled", "CurriculumEnabled", "Config" }, null, null, null)
5051
}));
5152
}
5253
#endregion
@@ -85,6 +86,7 @@ public TrainingEnvironmentInitialized(TrainingEnvironmentInitialized other) : th
8586
torchDeviceType_ = other.torchDeviceType_;
8687
numEnvs_ = other.numEnvs_;
8788
numEnvironmentParameters_ = other.numEnvironmentParameters_;
89+
runOptions_ = other.runOptions_;
8890
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
8991
}
9092

@@ -170,6 +172,17 @@ public int NumEnvironmentParameters {
170172
}
171173
}
172174

175+
/// <summary>Field number for the "run_options" field.</summary>
176+
public const int RunOptionsFieldNumber = 8;
177+
private string runOptions_ = "";
178+
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
179+
public string RunOptions {
180+
get { return runOptions_; }
181+
set {
182+
runOptions_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
183+
}
184+
}
185+
173186
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
174187
public override bool Equals(object other) {
175188
return Equals(other as TrainingEnvironmentInitialized);
@@ -190,6 +203,7 @@ public bool Equals(TrainingEnvironmentInitialized other) {
190203
if (TorchDeviceType != other.TorchDeviceType) return false;
191204
if (NumEnvs != other.NumEnvs) return false;
192205
if (NumEnvironmentParameters != other.NumEnvironmentParameters) return false;
206+
if (RunOptions != other.RunOptions) return false;
193207
return Equals(_unknownFields, other._unknownFields);
194208
}
195209

@@ -203,6 +217,7 @@ public override int GetHashCode() {
203217
if (TorchDeviceType.Length != 0) hash ^= TorchDeviceType.GetHashCode();
204218
if (NumEnvs != 0) hash ^= NumEnvs.GetHashCode();
205219
if (NumEnvironmentParameters != 0) hash ^= NumEnvironmentParameters.GetHashCode();
220+
if (RunOptions.Length != 0) hash ^= RunOptions.GetHashCode();
206221
if (_unknownFields != null) {
207222
hash ^= _unknownFields.GetHashCode();
208223
}
@@ -244,6 +259,10 @@ public void WriteTo(pb::CodedOutputStream output) {
244259
output.WriteRawTag(56);
245260
output.WriteInt32(NumEnvironmentParameters);
246261
}
262+
if (RunOptions.Length != 0) {
263+
output.WriteRawTag(66);
264+
output.WriteString(RunOptions);
265+
}
247266
if (_unknownFields != null) {
248267
_unknownFields.WriteTo(output);
249268
}
@@ -273,6 +292,9 @@ public int CalculateSize() {
273292
if (NumEnvironmentParameters != 0) {
274293
size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumEnvironmentParameters);
275294
}
295+
if (RunOptions.Length != 0) {
296+
size += 1 + pb::CodedOutputStream.ComputeStringSize(RunOptions);
297+
}
276298
if (_unknownFields != null) {
277299
size += _unknownFields.CalculateSize();
278300
}
@@ -305,6 +327,9 @@ public void MergeFrom(TrainingEnvironmentInitialized other) {
305327
if (other.NumEnvironmentParameters != 0) {
306328
NumEnvironmentParameters = other.NumEnvironmentParameters;
307329
}
330+
if (other.RunOptions.Length != 0) {
331+
RunOptions = other.RunOptions;
332+
}
308333
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
309334
}
310335

@@ -344,6 +369,10 @@ public void MergeFrom(pb::CodedInputStream input) {
344369
NumEnvironmentParameters = input.ReadInt32();
345370
break;
346371
}
372+
case 66: {
373+
RunOptions = input.ReadString();
374+
break;
375+
}
347376
}
348377
}
349378
}
@@ -389,6 +418,7 @@ public TrainingBehaviorInitialized(TrainingBehaviorInitialized other) : this() {
389418
trainerThreaded_ = other.trainerThreaded_;
390419
selfPlayEnabled_ = other.selfPlayEnabled_;
391420
curriculumEnabled_ = other.curriculumEnabled_;
421+
config_ = other.config_;
392422
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
393423
}
394424

@@ -551,6 +581,17 @@ public bool CurriculumEnabled {
551581
}
552582
}
553583

584+
/// <summary>Field number for the "config" field.</summary>
585+
public const int ConfigFieldNumber = 15;
586+
private string config_ = "";
587+
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
588+
public string Config {
589+
get { return config_; }
590+
set {
591+
config_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
592+
}
593+
}
594+
554595
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
555596
public override bool Equals(object other) {
556597
return Equals(other as TrainingBehaviorInitialized);
@@ -578,6 +619,7 @@ public bool Equals(TrainingBehaviorInitialized other) {
578619
if (TrainerThreaded != other.TrainerThreaded) return false;
579620
if (SelfPlayEnabled != other.SelfPlayEnabled) return false;
580621
if (CurriculumEnabled != other.CurriculumEnabled) return false;
622+
if (Config != other.Config) return false;
581623
return Equals(_unknownFields, other._unknownFields);
582624
}
583625

@@ -598,6 +640,7 @@ public override int GetHashCode() {
598640
if (TrainerThreaded != false) hash ^= TrainerThreaded.GetHashCode();
599641
if (SelfPlayEnabled != false) hash ^= SelfPlayEnabled.GetHashCode();
600642
if (CurriculumEnabled != false) hash ^= CurriculumEnabled.GetHashCode();
643+
if (Config.Length != 0) hash ^= Config.GetHashCode();
601644
if (_unknownFields != null) {
602645
hash ^= _unknownFields.GetHashCode();
603646
}
@@ -667,6 +710,10 @@ public void WriteTo(pb::CodedOutputStream output) {
667710
output.WriteRawTag(112);
668711
output.WriteBool(CurriculumEnabled);
669712
}
713+
if (Config.Length != 0) {
714+
output.WriteRawTag(122);
715+
output.WriteString(Config);
716+
}
670717
if (_unknownFields != null) {
671718
_unknownFields.WriteTo(output);
672719
}
@@ -717,6 +764,9 @@ public int CalculateSize() {
717764
if (CurriculumEnabled != false) {
718765
size += 1 + 1;
719766
}
767+
if (Config.Length != 0) {
768+
size += 1 + pb::CodedOutputStream.ComputeStringSize(Config);
769+
}
720770
if (_unknownFields != null) {
721771
size += _unknownFields.CalculateSize();
722772
}
@@ -770,6 +820,9 @@ public void MergeFrom(TrainingBehaviorInitialized other) {
770820
if (other.CurriculumEnabled != false) {
771821
CurriculumEnabled = other.CurriculumEnabled;
772822
}
823+
if (other.Config.Length != 0) {
824+
Config = other.Config;
825+
}
773826
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
774827
}
775828

@@ -837,6 +890,10 @@ public void MergeFrom(pb::CodedInputStream input) {
837890
CurriculumEnabled = input.ReadBool();
838891
break;
839892
}
893+
case 122: {
894+
Config = input.ReadString();
895+
break;
896+
}
840897
}
841898
}
842899
}

com.unity.ml-agents/Tests/Editor/Analytics/TrainingAnalyticsTest.cs

+13
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,19 @@ public void TestRemotePolicy()
7070
Academy.Instance.Dispose();
7171
}
7272

73+
[TestCase("a name we expect to hash", ExpectedResult = "d084a8b6da6a6a1c097cdc9ffea95e1546da4647352113ed77cbe7b4192e6d73")]
74+
[TestCase("another_name", ExpectedResult = "0b74613c872e79aba11e06eda3538f2b646eb2b459e75087829ea500bd703d0b")]
75+
[TestCase("0b74613c872e79aba11e06eda3538f2b646eb2b459e75087829ea500bd703d0b", ExpectedResult = "0b74613c872e79aba11e06eda3538f2b646eb2b459e75087829ea500bd703d0b")]
76+
public string TestTrainingBehaviorInitialized(string stringToMaybeHash)
77+
{
78+
var tbiEvent = new TrainingBehaviorInitializedEvent();
79+
tbiEvent.BehaviorName = stringToMaybeHash;
80+
tbiEvent.Config = "{}";
81+
82+
var sanitizedEvent = TrainingAnalytics.SanitizeTrainingBehaviorInitializedEvent(tbiEvent);
83+
return sanitizedEvent.BehaviorName;
84+
}
85+
7386
[Test]
7487
public void TestEnableAnalytics()
7588
{

0 commit comments

Comments
 (0)