Skip to content

Commit e03104b

Browse files
committed
Resolved merged conflicts.
2 parents 7c5b324 + e023ab8 commit e03104b

File tree

93 files changed

+1201
-1102
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

93 files changed

+1201
-1102
lines changed

docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/AveragedPerceptron.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ public static void Example()
1919
var trainTestData = mlContext.Data.TrainTestSplit(data, testFraction: 0.1);
2020

2121
// Create data training pipeline.
22-
var pipeline = mlContext.BinaryClassification.Trainers.AveragedPerceptron(numIterations: 10);
22+
var pipeline = mlContext.BinaryClassification.Trainers.AveragedPerceptron(numberOfIterations: 10);
2323

2424
// Fit this pipeline to the training data.
2525
var model = pipeline.Fit(trainTestData.TrainSet);

docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/AveragedPerceptronWithOptions.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ public static void Example()
2323
// Define the trainer options.
2424
var options = new AveragedPerceptronTrainer.Options()
2525
{
26-
LossFunction = new SmoothedHingeLoss.Options(),
26+
LossFunction = new SmoothedHingeLoss(),
2727
LearningRate = 0.1f,
28-
DoLazyUpdates = false,
28+
LazyUpdate = false,
2929
RecencyGain = 0.1f,
3030
NumberOfIterations = 10
3131
};

docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/StochasticDualCoordinateAscentWithOptions.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ public static void Example()
2929
var options = new SdcaMultiClassTrainer.Options
3030
{
3131
// Add custom loss
32-
LossFunction = new HingeLoss.Options(),
32+
LossFunction = new HingeLoss(),
3333
// Make the convergence tolerance tighter.
3434
ConvergenceTolerance = 0.05f,
3535
// Increase the maximum number of passes over training data.

docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/ImageAnalytics/DnnFeaturizeImage.cs

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
using System;
2-
using System.IO;
1+
using System.IO;
32
using System.Linq;
43
using Microsoft.ML.Data;
54
using Microsoft.ML.Transforms;

docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/ImageAnalytics/ExtractPixels.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ public static void Example()
3838
var imagesFolder = Path.GetDirectoryName(imagesDataFile);
3939
// Image loading pipeline.
4040
var pipeline = mlContext.Transforms.LoadImages(imagesFolder, ("ImageObject", "ImagePath"))
41-
.Append(mlContext.Transforms.ResizeImages("ImageObject",imageWidth: 100 , imageHeight: 100 ))
41+
.Append(mlContext.Transforms.ResizeImages("ImageObject", imageWidth: 100, imageHeight: 100 ))
4242
.Append(mlContext.Transforms.ExtractPixels("Pixels", "ImageObject"));
4343

4444

src/Microsoft.ML.Core/Data/AnnotationUtils.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ public static void GetSlotNames(RoleMappedSchema schema, RoleMappedSchema.Column
308308
schema.Schema[list[0].Index].Annotations.GetValue(Kinds.SlotNames, ref slotNames);
309309
}
310310

311-
public static bool HasKeyValues(this SchemaShape.Column col)
311+
public static bool NeedsSlotNames(this SchemaShape.Column col)
312312
{
313313
return col.Annotations.TryFindColumn(Kinds.KeyValues, out var metaCol)
314314
&& metaCol.Kind == SchemaShape.Column.VectorKind.Vector
@@ -442,7 +442,7 @@ public static bool TryGetCategoricalFeatureIndices(DataViewSchema schema, int co
442442
public static IEnumerable<SchemaShape.Column> AnnotationsForMulticlassScoreColumn(SchemaShape.Column? labelColumn = null)
443443
{
444444
var cols = new List<SchemaShape.Column>();
445-
if (labelColumn != null && labelColumn.Value.IsKey && HasKeyValues(labelColumn.Value))
445+
if (labelColumn != null && labelColumn.Value.IsKey && NeedsSlotNames(labelColumn.Value))
446446
cols.Add(new SchemaShape.Column(Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextDataViewType.Instance, false));
447447
cols.AddRange(GetTrainerOutputAnnotation());
448448
return cols;

src/Microsoft.ML.Data/Data/Conversion.cs

+25-2
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ private Conversions()
111111
AddStd<I1, R4>(Convert);
112112
AddStd<I1, R8>(Convert);
113113
AddAux<I1, SB>(Convert);
114+
AddStd<I1, BL>(Convert);
114115

115116
AddStd<I2, I1>(Convert);
116117
AddStd<I2, I2>(Convert);
@@ -119,6 +120,7 @@ private Conversions()
119120
AddStd<I2, R4>(Convert);
120121
AddStd<I2, R8>(Convert);
121122
AddAux<I2, SB>(Convert);
123+
AddStd<I2, BL>(Convert);
122124

123125
AddStd<I4, I1>(Convert);
124126
AddStd<I4, I2>(Convert);
@@ -127,6 +129,7 @@ private Conversions()
127129
AddStd<I4, R4>(Convert);
128130
AddStd<I4, R8>(Convert);
129131
AddAux<I4, SB>(Convert);
132+
AddStd<I4, BL>(Convert);
130133

131134
AddStd<I8, I1>(Convert);
132135
AddStd<I8, I2>(Convert);
@@ -135,6 +138,7 @@ private Conversions()
135138
AddStd<I8, R4>(Convert);
136139
AddStd<I8, R8>(Convert);
137140
AddAux<I8, SB>(Convert);
141+
AddStd<I8, BL>(Convert);
138142

139143
AddStd<U1, U1>(Convert);
140144
AddStd<U1, U2>(Convert);
@@ -144,6 +148,7 @@ private Conversions()
144148
AddStd<U1, R4>(Convert);
145149
AddStd<U1, R8>(Convert);
146150
AddAux<U1, SB>(Convert);
151+
AddStd<U1, BL>(Convert);
147152

148153
AddStd<U2, U1>(Convert);
149154
AddStd<U2, U2>(Convert);
@@ -153,6 +158,7 @@ private Conversions()
153158
AddStd<U2, R4>(Convert);
154159
AddStd<U2, R8>(Convert);
155160
AddAux<U2, SB>(Convert);
161+
AddStd<U2, BL>(Convert);
156162

157163
AddStd<U4, U1>(Convert);
158164
AddStd<U4, U2>(Convert);
@@ -162,6 +168,7 @@ private Conversions()
162168
AddStd<U4, R4>(Convert);
163169
AddStd<U4, R8>(Convert);
164170
AddAux<U4, SB>(Convert);
171+
AddStd<U4, BL>(Convert);
165172

166173
AddStd<U8, U1>(Convert);
167174
AddStd<U8, U2>(Convert);
@@ -171,6 +178,7 @@ private Conversions()
171178
AddStd<U8, R4>(Convert);
172179
AddStd<U8, R8>(Convert);
173180
AddAux<U8, SB>(Convert);
181+
AddStd<U8, BL>(Convert);
174182

175183
AddStd<UG, U1>(Convert);
176184
AddStd<UG, U2>(Convert);
@@ -180,11 +188,13 @@ private Conversions()
180188
AddAux<UG, SB>(Convert);
181189

182190
AddStd<R4, R4>(Convert);
191+
AddStd<R4, BL>(Convert);
183192
AddStd<R4, R8>(Convert);
184193
AddAux<R4, SB>(Convert);
185194

186195
AddStd<R8, R4>(Convert);
187196
AddStd<R8, R8>(Convert);
197+
AddStd<R8, BL>(Convert);
188198
AddAux<R8, SB>(Convert);
189199

190200
AddStd<TX, I1>(Convert);
@@ -901,6 +911,19 @@ public void Convert(in BL src, ref SB dst)
901911
public void Convert(in DZ src, ref SB dst) { ClearDst(ref dst); dst.AppendFormat("{0:o}", src); }
902912
#endregion ToStringBuilder
903913

914+
#region ToBL
915+
public void Convert(in R8 src, ref BL dst) => dst = System.Convert.ToBoolean(src);
916+
public void Convert(in R4 src, ref BL dst) => dst = System.Convert.ToBoolean(src);
917+
public void Convert(in I1 src, ref BL dst) => dst = System.Convert.ToBoolean(src);
918+
public void Convert(in I2 src, ref BL dst) => dst = System.Convert.ToBoolean(src);
919+
public void Convert(in I4 src, ref BL dst) => dst = System.Convert.ToBoolean(src);
920+
public void Convert(in I8 src, ref BL dst) => dst = System.Convert.ToBoolean(src);
921+
public void Convert(in U1 src, ref BL dst) => dst = System.Convert.ToBoolean(src);
922+
public void Convert(in U2 src, ref BL dst) => dst = System.Convert.ToBoolean(src);
923+
public void Convert(in U4 src, ref BL dst) => dst = System.Convert.ToBoolean(src);
924+
public void Convert(in U8 src, ref BL dst) => dst = System.Convert.ToBoolean(src);
925+
#endregion
926+
904927
#region FromR4
905928
public void Convert(in R4 src, ref R4 dst) => dst = src;
906929
public void Convert(in R4 src, ref R8 dst) => dst = src;
@@ -1139,7 +1162,7 @@ private bool TryParseCore(ReadOnlySpan<char> span, out ulong dst)
11391162
dst = res;
11401163
return true;
11411164

1142-
LFail:
1165+
LFail:
11431166
dst = 0;
11441167
return false;
11451168
}
@@ -1246,7 +1269,7 @@ private bool TryParseNonNegative(ReadOnlySpan<char> span, out long result)
12461269
result = res;
12471270
return true;
12481271

1249-
LFail:
1272+
LFail:
12501273
result = 0;
12511274
return false;
12521275
}

src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs

+2-1
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,8 @@ internal TransformerChain(IHostEnvironment env, ModelLoadContext ctx)
200200
LastTransformer = null;
201201
}
202202

203-
public void SaveTo(IHostEnvironment env, Stream outputStream)
203+
[BestFriend]
204+
internal void SaveTo(IHostEnvironment env, Stream outputStream)
204205
{
205206
using (var ch = env.Start("Saving pipeline"))
206207
{

src/Microsoft.ML.Data/Dirty/ILoss.cs

+7-5
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ public interface ILossFunction<in TOutput, in TLabel>
1717
Double Loss(TOutput output, TLabel label);
1818
}
1919

20-
public interface IScalarOutputLoss : ILossFunction<float, float>
20+
public interface IScalarLoss : ILossFunction<float, float>
2121
{
2222
/// <summary>
2323
/// Derivative of the loss function with respect to output
@@ -26,20 +26,22 @@ public interface IScalarOutputLoss : ILossFunction<float, float>
2626
}
2727

2828
[TlcModule.ComponentKind("RegressionLossFunction")]
29-
public interface ISupportRegressionLossFactory : IComponentFactory<IRegressionLoss>
29+
[BestFriend]
30+
internal interface ISupportRegressionLossFactory : IComponentFactory<IRegressionLoss>
3031
{
3132
}
3233

33-
public interface IRegressionLoss : IScalarOutputLoss
34+
public interface IRegressionLoss : IScalarLoss
3435
{
3536
}
3637

3738
[TlcModule.ComponentKind("ClassificationLossFunction")]
38-
public interface ISupportClassificationLossFactory : IComponentFactory<IClassificationLoss>
39+
[BestFriend]
40+
internal interface ISupportClassificationLossFactory : IComponentFactory<IClassificationLoss>
3941
{
4042
}
4143

42-
public interface IClassificationLoss : IScalarOutputLoss
44+
public interface IClassificationLoss : IScalarLoss
4345
{
4446
}
4547

src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs

-16
Original file line numberDiff line numberDiff line change
@@ -53,21 +53,5 @@ internal ExplainabilityTransforms(ModelOperationsCatalog owner)
5353
_env = owner._env;
5454
}
5555
}
56-
57-
/// <summary>
58-
/// Create a prediction engine for one-time prediction.
59-
/// </summary>
60-
/// <typeparam name="TSrc">The class that defines the input data.</typeparam>
61-
/// <typeparam name="TDst">The class that defines the output data.</typeparam>
62-
/// <param name="transformer">The transformer to use for prediction.</param>
63-
/// <param name="inputSchemaDefinition">Additional settings of the input schema.</param>
64-
/// <param name="outputSchemaDefinition">Additional settings of the output schema.</param>
65-
public PredictionEngine<TSrc, TDst> CreatePredictionEngine<TSrc, TDst>(ITransformer transformer,
66-
SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
67-
where TSrc : class
68-
where TDst : class, new()
69-
{
70-
return new PredictionEngine<TSrc, TDst>(_env, transformer, false, inputSchemaDefinition, outputSchemaDefinition);
71-
}
7256
}
7357
}

src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ internal static bool CanWrap(ISchemaBoundMapper mapper, DataViewType labelNameTy
429429
var scoreType = outSchema[scoreIdx].Type;
430430

431431
// Check that the type is vector, and is of compatible size with the score output.
432-
return labelNameType is VectorType vectorType && vectorType.Size == scoreType.GetVectorSize();
432+
return labelNameType is VectorType vectorType && vectorType.Size == scoreType.GetVectorSize() && vectorType.ItemType == TextDataViewType.Instance;
433433
}
434434

435435
internal static ISchemaBoundMapper WrapCore<T>(IHostEnvironment env, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema)

src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs

+8-1
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,15 @@ private protected virtual void CheckLabelCompatible(SchemaShape.Column labelCol)
145145
private protected TTransformer TrainTransformer(IDataView trainSet,
146146
IDataView validationSet = null, IPredictor initPredictor = null)
147147
{
148+
CheckInputSchema(SchemaShape.Create(trainSet.Schema));
148149
var trainRoleMapped = MakeRoles(trainSet);
149-
var validRoleMapped = validationSet == null ? null : MakeRoles(validationSet);
150+
RoleMappedData validRoleMapped = null;
151+
152+
if (validationSet != null)
153+
{
154+
CheckInputSchema(SchemaShape.Create(validationSet.Schema));
155+
validRoleMapped = MakeRoles(validationSet);
156+
}
150157

151158
var pred = TrainModelCore(new TrainContext(trainRoleMapped, validRoleMapped, null, initPredictor));
152159
return MakeTransformer(pred, trainSet.Schema);

0 commit comments

Comments
 (0)