Skip to content

Commit 6812cb5

Browse files
sfilipiTomFinley
authored andcommitted
Enabling DI framework to scan the constructors with non-public visibility
* enabling scanning the constructors with non-public visibility, and reducing the visibility of some of them to avoid confusing the users.
1 parent 044a6d3 commit 6812cb5

File tree

17 files changed

+18
-18
lines changed

17 files changed

+18
-18
lines changed

src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -454,9 +454,9 @@ private static bool TryGetIniters(Type instType, Type loaderType, Type[] parmTyp
454454
var parmTypesWithEnv = Utils.Concat(new Type[1] { typeof(IHostEnvironment) }, parmTypes);
455455
if (Utils.Size(parmTypes) == 0 && (getter = FindInstanceGetter(instType, loaderType)) != null)
456456
return true;
457-
if (instType.IsAssignableFrom(loaderType) && (ctor = loaderType.GetConstructor(parmTypes ?? Type.EmptyTypes)) != null)
457+
if (instType.IsAssignableFrom(loaderType) && (ctor = loaderType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, parmTypes ?? Type.EmptyTypes, null)) != null)
458458
return true;
459-
if (instType.IsAssignableFrom(loaderType) && (ctor = loaderType.GetConstructor(parmTypesWithEnv ?? Type.EmptyTypes)) != null)
459+
if (instType.IsAssignableFrom(loaderType) && (ctor = loaderType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, parmTypesWithEnv ?? Type.EmptyTypes, null)) != null)
460460
{
461461
requireEnvironment = true;
462462
return true;

src/Microsoft.ML.FastTree/FastTreeClassification.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ public FastTreeBinaryClassificationTrainer(IHostEnvironment env, string labelCol
142142
/// <summary>
143143
/// Initializes a new instance of <see cref="FastTreeBinaryClassificationTrainer"/> by using the legacy <see cref="Arguments"/> class.
144144
/// </summary>
145-
public FastTreeBinaryClassificationTrainer(IHostEnvironment env, Arguments args)
145+
internal FastTreeBinaryClassificationTrainer(IHostEnvironment env, Arguments args)
146146
: base(env, args, MakeLabelColumn(args.LabelColumn))
147147
{
148148
_outputColumns = new[]

src/Microsoft.ML.FastTree/FastTreeRanking.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ public FastTreeRankingTrainer(IHostEnvironment env, string labelColumn, string f
8282
/// <summary>
8383
/// Initializes a new instance of <see cref="FastTreeRankingTrainer"/> by using the legacy <see cref="Arguments"/> class.
8484
/// </summary>
85-
public FastTreeRankingTrainer(IHostEnvironment env, Arguments args)
85+
internal FastTreeRankingTrainer(IHostEnvironment env, Arguments args)
8686
: base(env, args, MakeLabelColumn(args.LabelColumn))
8787
{
8888
_outputColumns = new[]

src/Microsoft.ML.FastTree/FastTreeRegression.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ public FastTreeRegressionTrainer(IHostEnvironment env, string labelColumn, strin
7575
/// <summary>
7676
/// Initializes a new instance of <see cref="FastTreeRegressionTrainer"/> by using the legacy <see cref="Arguments"/> class.
7777
/// </summary>
78-
public FastTreeRegressionTrainer(IHostEnvironment env, Arguments args)
78+
internal FastTreeRegressionTrainer(IHostEnvironment env, Arguments args)
7979
: base(env, args, MakeLabelColumn(args.LabelColumn))
8080
{
8181
_outputColumns = new[]

src/Microsoft.ML.FastTree/FastTreeTweedie.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ public FastTreeTweedieTrainer(IHostEnvironment env, string labelColumn, string f
6767
/// <summary>
6868
/// Initializes a new instance of <see cref="FastTreeTweedieTrainer"/> by using the legacy <see cref="Arguments"/> class.
6969
/// </summary>
70-
public FastTreeTweedieTrainer(IHostEnvironment env, Arguments args)
70+
internal FastTreeTweedieTrainer(IHostEnvironment env, Arguments args)
7171
: base(env, args, MakeLabelColumn(args.LabelColumn))
7272
{
7373
Initialize();

src/Microsoft.ML.FastTree/GamClassification.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ public sealed class Arguments : ArgumentsBase
4646
public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
4747
private protected override bool NeedCalibration => true;
4848

49-
public BinaryClassificationGamTrainer(IHostEnvironment env, Arguments args)
49+
internal BinaryClassificationGamTrainer(IHostEnvironment env, Arguments args)
5050
: base(env, args)
5151
{
5252
_sigmoidParameter = 1;

src/Microsoft.ML.FastTree/GamRegression.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ public partial class Arguments : ArgumentsBase
4040

4141
public override PredictionKind PredictionKind => PredictionKind.Regression;
4242

43-
public RegressionGamTrainer(IHostEnvironment env, Arguments args)
43+
internal RegressionGamTrainer(IHostEnvironment env, Arguments args)
4444
: base(env, args) { }
4545

4646
internal override void CheckLabel(RoleMappedData data)

src/Microsoft.ML.FastTree/RandomForestClassification.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ public FastForestClassification(IHostEnvironment env, string labelColumn, string
158158
/// <summary>
159159
/// Initializes a new instance of <see cref="FastForestClassification"/> by using the legacy <see cref="Arguments"/> class.
160160
/// </summary>
161-
public FastForestClassification(IHostEnvironment env, Arguments args)
161+
internal FastForestClassification(IHostEnvironment env, Arguments args)
162162
: base(env, args, MakeLabelColumn(args.LabelColumn))
163163
{
164164
_outputColumns = new[]

src/Microsoft.ML.FastTree/RandomForestRegression.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ public FastForestRegression(IHostEnvironment env, string labelColumn, string fea
178178
/// <summary>
179179
/// Initializes a new instance of <see cref="FastForestRegression"/> by using the legacy <see cref="Arguments"/> class.
180180
/// </summary>
181-
public FastForestRegression(IHostEnvironment env, Arguments args)
181+
internal FastForestRegression(IHostEnvironment env, Arguments args)
182182
: base(env, args, MakeLabelColumn(args.LabelColumn), true)
183183
{
184184
_outputColumns = new[]

src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ public sealed class Arguments : ArgumentsBase
7373
/// Developers should instantiate <see cref="Pkpd"/> by supplying the trainer argument directly to the <see cref="Pkpd"/> constructor
7474
/// using the other public constructor.
7575
/// </summary>
76-
public Pkpd(IHostEnvironment env, Arguments args)
76+
internal Pkpd(IHostEnvironment env, Arguments args)
7777
: base(env, args, LoadNameValue)
7878
{
7979
}

src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSc
7070
};
7171
}
7272

73-
public SdcaMultiClassTrainer(IHostEnvironment env, Arguments args)
73+
internal SdcaMultiClassTrainer(IHostEnvironment env, Arguments args)
7474
: this(env, args, args.FeatureColumn, args.LabelColumn)
7575
{
7676
}

src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ public SdcaRegressionTrainer(IHostEnvironment env, Arguments args, string featur
6767
};
6868
}
6969

70-
public SdcaRegressionTrainer(IHostEnvironment env, Arguments args)
70+
internal SdcaRegressionTrainer(IHostEnvironment env, Arguments args)
7171
: this(env, args, args.FeatureColumn, args.LabelColumn)
7272
{
7373
}

test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ public void TrainSentiment()
122122
}, text);
123123

124124
// Train
125-
var trainer = new SdcaMultiClassTrainer(env, new SdcaMultiClassTrainer.Arguments() { MaxIterations = 20 });
125+
var trainer = new SdcaMultiClassTrainer(env, new SdcaMultiClassTrainer.Arguments() { MaxIterations = 20 }, "Features", "Label");
126126
var trainRoles = new RoleMappedData(trans, label: "Label", feature: "Features");
127127

128128
var predicted = trainer.Train(trainRoles);

test/Microsoft.ML.Tests/Scenarios/Api/DecomposableTrainAndPredict.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ void DecomposableTrainAndPredict()
3030
var loader = TextLoader.ReadFile(env, MakeIrisTextLoaderArgs(), new MultiFileSource(GetDataPath(TestDatasets.irisData.trainFilename)));
3131
var term = TermTransform.Create(env, loader, "Label");
3232
var concat = new ConcatTransform(env, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth").Transform(term);
33-
var trainer = new SdcaMultiClassTrainer(env, new SdcaMultiClassTrainer.Arguments { MaxIterations = 100, Shuffle = true, NumThreads = 1 });
33+
var trainer = new SdcaMultiClassTrainer(env, new SdcaMultiClassTrainer.Arguments { MaxIterations = 100, Shuffle = true, NumThreads = 1 }, "Features", "Label");
3434

3535
IDataView trainData = trainer.Info.WantCaching ? (IDataView)new CacheDataView(env, concat, prefetch: null) : concat;
3636
var trainRoles = new RoleMappedData(trainData, label: "Label", feature: "Features");

test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DecomposableTrainAndPredict.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ void New_DecomposableTrainAndPredict()
3030
var loader = TextLoader.ReadFile(env, MakeIrisTextLoaderArgs(), new MultiFileSource(dataPath));
3131
var term = TermTransform.Create(env, loader, "Label");
3232
var concat = new ConcatTransform(env, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth").Transform(term);
33-
var trainer = new SdcaMultiClassTrainer(env, new SdcaMultiClassTrainer.Arguments { MaxIterations = 100, Shuffle = true, NumThreads = 1 });
33+
var trainer = new SdcaMultiClassTrainer(env, new SdcaMultiClassTrainer.Arguments { MaxIterations = 100, Shuffle = true, NumThreads = 1 }, "Features", "Label");
3434

3535
IDataView trainData = trainer.Info.WantCaching ? (IDataView)new CacheDataView(env, concat, prefetch: null) : concat;
3636
var trainRoles = new RoleMappedData(trainData, label: "Label", feature: "Features");

test/Microsoft.ML.Tests/Scenarios/Api/Extensibility.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ void Extensibility()
3535
var concat = new ConcatTransform(env, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth")
3636
.Transform(term);
3737

38-
var trainer = new SdcaMultiClassTrainer(env, new SdcaMultiClassTrainer.Arguments { MaxIterations = 100, Shuffle = true, NumThreads = 1 });
38+
var trainer = new SdcaMultiClassTrainer(env, new SdcaMultiClassTrainer.Arguments { MaxIterations = 100, Shuffle = true, NumThreads = 1 }, "Features", "Label");
3939

4040
IDataView trainData = trainer.Info.WantCaching ? (IDataView)new CacheDataView(env, concat, prefetch: null) : concat;
4141
var trainRoles = new RoleMappedData(trainData, label: "Label", feature: "Features");

test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/IrisPlantClassificationTests.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ public void TrainAndPredictIrisModelUsingDirectInstantiationTest()
4646
pipeline = NormalizeTransform.CreateMinMaxNormalizer(env, pipeline, "Features");
4747

4848
// Train
49-
var trainer = new SdcaMultiClassTrainer(env, new SdcaMultiClassTrainer.Arguments() { NumThreads = 1 } );
49+
var trainer = new SdcaMultiClassTrainer(env, new SdcaMultiClassTrainer.Arguments() { NumThreads = 1 }, "Features", "Label");
5050

5151
// Explicity adding CacheDataView since caching is not working though trainer has 'Caching' On/Auto
5252
var cached = new CacheDataView(env, pipeline, prefetch: null);

0 commit comments

Comments
 (0)