Skip to content

Commit 4acf5aa

Browse files
authored
Explicit implementation for IsRowToRowMapper and GetRowToRowMapper (#2673)
1 parent 7cc208c commit 4acf5aa

20 files changed

+69
-59
lines changed

src/Microsoft.ML.Data/DataLoadSave/TransformWrapper.cs

+5-4
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ public sealed class TransformWrapper : ITransformer
2525
private readonly IHost _host;
2626
private readonly IDataView _xf;
2727
private readonly bool _allowSave;
28+
private readonly bool _isRowToRowMapper;
2829

2930
public TransformWrapper(IHostEnvironment env, IDataView xf, bool allowSave = false)
3031
{
@@ -33,7 +34,7 @@ public TransformWrapper(IHostEnvironment env, IDataView xf, bool allowSave = fal
3334
_host.CheckValue(xf, nameof(xf));
3435
_xf = xf;
3536
_allowSave = allowSave;
36-
IsRowToRowMapper = IsChainRowToRowMapper(_xf);
37+
_isRowToRowMapper = IsChainRowToRowMapper(_xf);
3738
}
3839

3940
public DataViewSchema GetOutputSchema(DataViewSchema inputSchema)
@@ -108,7 +109,7 @@ private TransformWrapper(IHostEnvironment env, ModelLoadContext ctx)
108109
}
109110

110111
_xf = data;
111-
IsRowToRowMapper = IsChainRowToRowMapper(_xf);
112+
_isRowToRowMapper = IsChainRowToRowMapper(_xf);
112113
}
113114

114115
public IDataView Transform(IDataView input) => ApplyTransformUtils.ApplyAllTransformsToData(_host, _xf, input);
@@ -123,9 +124,9 @@ private static bool IsChainRowToRowMapper(IDataView view)
123124
return true;
124125
}
125126

126-
public bool IsRowToRowMapper { get; }
127+
bool ITransformer.IsRowToRowMapper => _isRowToRowMapper;
127128

128-
public IRowToRowMapper GetRowToRowMapper(DataViewSchema inputSchema)
129+
IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema)
129130
{
130131
_host.CheckValue(inputSchema, nameof(inputSchema));
131132
var input = new EmptyDataView(_host, inputSchema);

src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs

+4-3
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ public sealed class TransformerChain<TLastTransformer> : ITransformer, IEnumerab
5959

6060
private const string TransformDirTemplate = "Transform_{0:000}";
6161

62-
public bool IsRowToRowMapper => _transformers.All(t => t.IsRowToRowMapper);
62+
bool ITransformer.IsRowToRowMapper => _transformers.All(t => t.IsRowToRowMapper);
6363

6464
ITransformer[] ITransformerChainAccessor.Transformers => _transformers;
6565

@@ -216,10 +216,11 @@ public void SaveTo(IHostEnvironment env, Stream outputStream)
216216

217217
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
218218

219-
public IRowToRowMapper GetRowToRowMapper(DataViewSchema inputSchema)
219+
IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema)
220220
{
221221
Contracts.CheckValue(inputSchema, nameof(inputSchema));
222-
Contracts.Check(IsRowToRowMapper, nameof(GetRowToRowMapper) + " method called despite " + nameof(IsRowToRowMapper) + " being false.");
222+
Contracts.Check(((ITransformer)this).IsRowToRowMapper, nameof(ITransformer.GetRowToRowMapper) + " method called despite " +
223+
nameof(ITransformer.IsRowToRowMapper) + " being false.");
223224

224225
IRowToRowMapper[] mappers = new IRowToRowMapper[_transformers.Length];
225226
DataViewSchema schema = inputSchema;

src/Microsoft.ML.Data/Prediction/PredictionEngine.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,8 @@ private static Func<DataViewSchema, IRowToRowMapper> StreamChecker(IHostEnvironm
116116
{
117117
var pipe = DataViewConstructionUtils.LoadPipeWithPredictor(env, modelStream, new EmptyDataView(env, schema));
118118
var transformer = new TransformWrapper(env, pipe);
119-
env.CheckParam(transformer.IsRowToRowMapper, nameof(transformer), "Must be a row to row mapper");
120-
return transformer.GetRowToRowMapper(schema);
119+
env.CheckParam(((ITransformer)transformer).IsRowToRowMapper, nameof(transformer), "Must be a row to row mapper");
120+
return ((ITransformer)transformer).GetRowToRowMapper(schema);
121121
};
122122
}
123123

src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,10 @@ public abstract class PredictionTransformerBase<TModel> : IPredictionTransformer
5555
protected DataViewSchema TrainSchema;
5656

5757
/// <summary>
58-
/// Whether a call to <see cref="GetRowToRowMapper(DataViewSchema)"/> should succeed, on an
58+
/// Whether a call to <see cref="ITransformer.GetRowToRowMapper(DataViewSchema)"/> should succeed, on an
5959
/// appropriate schema.
6060
/// </summary>
61-
public bool IsRowToRowMapper => true;
61+
bool ITransformer.IsRowToRowMapper => true;
6262

6363
/// <summary>
6464
/// This class is more or less a thin wrapper over the <see cref="IDataScorerTransform"/> implementing
@@ -132,7 +132,7 @@ public IDataView Transform(IDataView input)
132132
/// </summary>
133133
/// <param name="inputSchema"></param>
134134
/// <returns></returns>
135-
public IRowToRowMapper GetRowToRowMapper(DataViewSchema inputSchema)
135+
IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema)
136136
{
137137
Host.CheckValue(inputSchema, nameof(inputSchema));
138138
return (IRowToRowMapper)Scorer.ApplyToData(Host, new EmptyDataView(Host, inputSchema));

src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ public sealed class ColumnSelectingTransformer : ITransformer
138138
private readonly IHost _host;
139139
private string[] _selectedColumns;
140140

141-
public bool IsRowToRowMapper => true;
141+
bool ITransformer.IsRowToRowMapper => true;
142142

143143
public IEnumerable<string> SelectColumns => _selectedColumns.AsReadOnly();
144144

@@ -458,13 +458,13 @@ public DataViewSchema GetOutputSchema(DataViewSchema inputSchema)
458458
}
459459

460460
/// <summary>
461-
/// Constructs a row-to-row mapper based on an input schema. If <see cref="IsRowToRowMapper"/>
461+
/// Constructs a row-to-row mapper based on an input schema. If <see cref="ITransformer.IsRowToRowMapper"/>
462462
/// is <c>false</c>, then an exception is thrown. If the input schema is in any way
463463
/// unsuitable for constructing the mapper, an exception should likewise be thrown.
464464
/// </summary>
465465
/// <param name="inputSchema">The input schema for which we should get the mapper.</param>
466466
/// <returns>The row to row mapper.</returns>
467-
public IRowToRowMapper GetRowToRowMapper(DataViewSchema inputSchema)
467+
IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema)
468468
{
469469
_host.CheckValue(inputSchema, nameof(inputSchema));
470470
if (!IgnoreMissing && !IsSchemaValid(inputSchema.Select(x => x.Name),

src/Microsoft.ML.Data/Transforms/RowToRowTransformerBase.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ protected RowToRowTransformerBase(IHost host)
2626

2727
private protected abstract void SaveModel(ModelSaveContext ctx);
2828

29-
public bool IsRowToRowMapper => true;
29+
bool ITransformer.IsRowToRowMapper => true;
3030

31-
public IRowToRowMapper GetRowToRowMapper(DataViewSchema inputSchema)
31+
IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema)
3232
{
3333
Host.CheckValue(inputSchema, nameof(inputSchema));
3434
return new RowToRowMapperTransform(Host, new EmptyDataView(Host, inputSchema), MakeRowMapper(inputSchema), MakeRowMapper);

src/Microsoft.ML.TimeSeries/IidAnomalyDetectionBase.cs

+11-7
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@ namespace Microsoft.ML.Transforms.TimeSeries
1818
public class IidAnomalyDetectionBaseWrapper : IStatefulTransformer, ICanSaveModel
1919
{
2020
/// <summary>
21-
/// Whether a call to <see cref="GetRowToRowMapper(DataViewSchema)"/> should succeed, on an
21+
/// Whether a call to <see cref="ITransformer.GetRowToRowMapper(DataViewSchema)"/> should succeed, on an
2222
/// appropriate schema.
2323
/// </summary>
24-
public bool IsRowToRowMapper => InternalTransform.IsRowToRowMapper;
24+
bool ITransformer.IsRowToRowMapper => ((ITransformer)InternalTransform).IsRowToRowMapper;
2525

2626
/// <summary>
2727
/// Creates a clone of the transfomer. Used for taking the snapshot of the state.
@@ -36,20 +36,22 @@ public class IidAnomalyDetectionBaseWrapper : IStatefulTransformer, ICanSaveMode
3636
public DataViewSchema GetOutputSchema(DataViewSchema inputSchema) => InternalTransform.GetOutputSchema(inputSchema);
3737

3838
/// <summary>
39-
/// Constructs a row-to-row mapper based on an input schema. If <see cref="IsRowToRowMapper"/>
39+
/// Constructs a row-to-row mapper based on an input schema. If <see cref="ITransformer.IsRowToRowMapper"/>
4040
/// is <c>false</c>, then an exception should be thrown. If the input schema is in any way
4141
/// unsuitable for constructing the mapper, an exception should likewise be thrown.
4242
/// </summary>
4343
/// <param name="inputSchema">The input schema for which we should get the mapper.</param>
4444
/// <returns>The row to row mapper.</returns>
45-
public IRowToRowMapper GetRowToRowMapper(DataViewSchema inputSchema) => InternalTransform.GetRowToRowMapper(inputSchema);
45+
IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema)
46+
=> ((ITransformer)InternalTransform).GetRowToRowMapper(inputSchema);
4647

4748
/// <summary>
4849
/// Same as <see cref="ITransformer.GetRowToRowMapper(DataViewSchema)"/> but also supports mechanism to save the state.
4950
/// </summary>
5051
/// <param name="inputSchema">The input schema for which we should get the mapper.</param>
5152
/// <returns>The row to row mapper.</returns>
52-
public IRowToRowMapper GetStatefulRowToRowMapper(DataViewSchema inputSchema) => ((IStatefulTransformer)InternalTransform).GetStatefulRowToRowMapper(inputSchema);
53+
public IRowToRowMapper GetStatefulRowToRowMapper(DataViewSchema inputSchema)
54+
=> ((IStatefulTransformer)InternalTransform).GetStatefulRowToRowMapper(inputSchema);
5355

5456
/// <summary>
5557
/// Take the data in, make transformations, output the data.
@@ -60,7 +62,9 @@ public class IidAnomalyDetectionBaseWrapper : IStatefulTransformer, ICanSaveMode
6062
/// <summary>
6163
/// For saving a model into a repository.
6264
/// </summary>
63-
public virtual void Save(ModelSaveContext ctx)
65+
void ICanSaveModel.Save(ModelSaveContext ctx) => SaveModel(ctx);
66+
67+
private protected virtual void SaveModel(ModelSaveContext ctx)
6468
{
6569
InternalTransform.SaveThis(ctx);
6670
}
@@ -129,7 +133,7 @@ public override DataViewSchema GetOutputSchema(DataViewSchema inputSchema)
129133

130134
private protected override void SaveModel(ModelSaveContext ctx)
131135
{
132-
Parent.Save(ctx);
136+
((ICanSaveModel)Parent).Save(ctx);
133137
}
134138

135139
internal void SaveThis(ModelSaveContext ctx)

src/Microsoft.ML.TimeSeries/IidChangePointDetector.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ private IidChangePointDetector(IHostEnvironment env, IidChangePointDetector tran
172172
{
173173
}
174174

175-
public override void Save(ModelSaveContext ctx)
175+
private protected override void SaveModel(ModelSaveContext ctx)
176176
{
177177
InternalTransform.Host.CheckValue(ctx, nameof(ctx));
178178
ctx.CheckAtModel();
@@ -184,7 +184,7 @@ public override void Save(ModelSaveContext ctx)
184184
// *** Binary format ***
185185
// <base>
186186

187-
base.Save(ctx);
187+
base.SaveModel(ctx);
188188
}
189189

190190
// Factory method for SignatureLoadRowMapper.

src/Microsoft.ML.TimeSeries/IidSpikeDetector.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ private IidSpikeDetector(IHostEnvironment env, IidSpikeDetector transform)
153153
{
154154
}
155155

156-
public override void Save(ModelSaveContext ctx)
156+
private protected override void SaveModel(ModelSaveContext ctx)
157157
{
158158
InternalTransform.Host.CheckValue(ctx, nameof(ctx));
159159
ctx.CheckAtModel();
@@ -164,7 +164,7 @@ public override void Save(ModelSaveContext ctx)
164164
// *** Binary format ***
165165
// <base>
166166

167-
base.Save(ctx);
167+
base.SaveModel(ctx);
168168
}
169169

170170
// Factory method for SignatureLoadRowMapper.

src/Microsoft.ML.TimeSeries/SequentialAnomalyDetectionTransformBase.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ public Func<int, bool> GetDependencies(Func<int, bool> activeOutput)
348348
return col => false;
349349
}
350350

351-
public void Save(ModelSaveContext ctx) => _parent.SaveModel(ctx);
351+
void ICanSaveModel.Save(ModelSaveContext ctx) => _parent.SaveModel(ctx);
352352

353353
public Delegate[] CreateGetters(DataViewRow input, Func<int, bool> activeOutput, out Action disposer)
354354
{

src/Microsoft.ML.TimeSeries/SequentialTransformerBase.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ private protected virtual void CloneCore(TState state)
269269
internal readonly string OutputColumnName;
270270
private protected DataViewType OutputColumnType;
271271

272-
public bool IsRowToRowMapper => false;
272+
bool ITransformer.IsRowToRowMapper => false;
273273

274274
internal TState StateRef { get; set; }
275275

src/Microsoft.ML.TimeSeries/SsaAnomalyDetectionBase.cs

+11-7
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,10 @@ public static Func<Double, Double, Double> GetErrorFunction(ErrorFunction errorF
8787
public class SsaAnomalyDetectionBaseWrapper : IStatefulTransformer, ICanSaveModel
8888
{
8989
/// <summary>
90-
/// Whether a call to <see cref="GetRowToRowMapper(DataViewSchema)"/> should succeed, on an
90+
/// Whether a call to <see cref="ITransformer.GetRowToRowMapper(DataViewSchema)"/> should succeed, on an
9191
/// appropriate schema.
9292
/// </summary>
93-
public bool IsRowToRowMapper => InternalTransform.IsRowToRowMapper;
93+
bool ITransformer.IsRowToRowMapper => ((ITransformer)InternalTransform).IsRowToRowMapper;
9494

9595
/// <summary>
9696
/// Creates a clone of the transfomer. Used for taking the snapshot of the state.
@@ -105,20 +105,22 @@ public class SsaAnomalyDetectionBaseWrapper : IStatefulTransformer, ICanSaveMode
105105
public DataViewSchema GetOutputSchema(DataViewSchema inputSchema) => InternalTransform.GetOutputSchema(inputSchema);
106106

107107
/// <summary>
108-
/// Constructs a row-to-row mapper based on an input schema. If <see cref="IsRowToRowMapper"/>
108+
/// Constructs a row-to-row mapper based on an input schema. If <see cref="ITransformer.IsRowToRowMapper"/>
109109
/// is <c>false</c>, then an exception should be thrown. If the input schema is in any way
110110
/// unsuitable for constructing the mapper, an exception should likewise be thrown.
111111
/// </summary>
112112
/// <param name="inputSchema">The input schema for which we should get the mapper.</param>
113113
/// <returns>The row to row mapper.</returns>
114-
public IRowToRowMapper GetRowToRowMapper(DataViewSchema inputSchema) => InternalTransform.GetRowToRowMapper(inputSchema);
114+
IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema)
115+
=> ((ITransformer)InternalTransform).GetRowToRowMapper(inputSchema);
115116

116117
/// <summary>
117118
/// Same as <see cref="ITransformer.GetRowToRowMapper(DataViewSchema)"/> but also supports mechanism to save the state.
118119
/// </summary>
119120
/// <param name="inputSchema">The input schema for which we should get the mapper.</param>
120121
/// <returns>The row to row mapper.</returns>
121-
public IRowToRowMapper GetStatefulRowToRowMapper(DataViewSchema inputSchema) => ((IStatefulTransformer)InternalTransform).GetStatefulRowToRowMapper(inputSchema);
122+
public IRowToRowMapper GetStatefulRowToRowMapper(DataViewSchema inputSchema)
123+
=> ((IStatefulTransformer)InternalTransform).GetStatefulRowToRowMapper(inputSchema);
122124

123125
/// <summary>
124126
/// Take the data in, make transformations, output the data.
@@ -129,7 +131,9 @@ public class SsaAnomalyDetectionBaseWrapper : IStatefulTransformer, ICanSaveMode
129131
/// <summary>
130132
/// For saving a model into a repository.
131133
/// </summary>
132-
public virtual void Save(ModelSaveContext ctx) => InternalTransform.SaveThis(ctx);
134+
void ICanSaveModel.Save(ModelSaveContext ctx) => SaveModel(ctx);
135+
136+
private protected virtual void SaveModel(ModelSaveContext ctx) => InternalTransform.SaveThis(ctx);
133137

134138
/// <summary>
135139
/// Creates a row mapper from Schema.
@@ -255,7 +259,7 @@ public override DataViewSchema GetOutputSchema(DataViewSchema inputSchema)
255259

256260
private protected override void SaveModel(ModelSaveContext ctx)
257261
{
258-
Parent.Save(ctx);
262+
((ICanSaveModel)Parent).Save(ctx);
259263
}
260264

261265
internal void SaveThis(ModelSaveContext ctx)

src/Microsoft.ML.TimeSeries/SsaChangePointDetector.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ internal SsaChangePointDetector(IHostEnvironment env, ModelLoadContext ctx)
180180
InternalTransform.Host.CheckDecode(InternalTransform.IsAdaptive == false);
181181
}
182182

183-
public override void Save(ModelSaveContext ctx)
183+
private protected override void SaveModel(ModelSaveContext ctx)
184184
{
185185
InternalTransform.Host.CheckValue(ctx, nameof(ctx));
186186
ctx.CheckAtModel();
@@ -194,7 +194,7 @@ public override void Save(ModelSaveContext ctx)
194194
// *** Binary format ***
195195
// <base>
196196

197-
base.Save(ctx);
197+
base.SaveModel(ctx);
198198
}
199199

200200
// Factory method for SignatureLoadRowMapper.

src/Microsoft.ML.TimeSeries/SsaSpikeDetector.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ internal SsaSpikeDetector(IHostEnvironment env, ModelLoadContext ctx)
162162
InternalTransform.Host.CheckDecode(InternalTransform.IsAdaptive == false);
163163
}
164164

165-
public override void Save(ModelSaveContext ctx)
165+
private protected override void SaveModel(ModelSaveContext ctx)
166166
{
167167
InternalTransform.Host.CheckValue(ctx, nameof(ctx));
168168
ctx.CheckAtModel();
@@ -175,7 +175,7 @@ public override void Save(ModelSaveContext ctx)
175175
// *** Binary format ***
176176
// <base>
177177

178-
base.Save(ctx);
178+
base.SaveModel(ctx);
179179
}
180180

181181
// Factory method for SignatureLoadRowMapper.

src/Microsoft.ML.Transforms/CustomMappingTransformer.cs

+4-4
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,10 @@ public sealed class CustomMappingTransformer<TSrc, TDst> : ITransformer
3030
internal SchemaDefinition InputSchemaDefinition { get; }
3131

3232
/// <summary>
33-
/// Whether a call to <see cref="GetRowToRowMapper(DataViewSchema)"/> should succeed, on an
33+
/// Whether a call to <see cref="ITransformer.GetRowToRowMapper(DataViewSchema)"/> should succeed, on an
3434
/// appropriate schema.
3535
/// </summary>
36-
public bool IsRowToRowMapper => true;
36+
bool ITransformer.IsRowToRowMapper => true;
3737

3838
/// <summary>
3939
/// Create a custom mapping of input columns to output columns.
@@ -95,11 +95,11 @@ public IDataView Transform(IDataView input)
9595
}
9696

9797
/// <summary>
98-
/// Constructs a row-to-row mapper based on an input schema. If <see cref="IsRowToRowMapper"/>
98+
/// Constructs a row-to-row mapper based on an input schema. If <see cref="ITransformer.IsRowToRowMapper"/>
9999
/// is <c>false</c>, then an exception is thrown. If the <paramref name="inputSchema"/> is in any way
100100
/// unsuitable for constructing the mapper, an exception is likewise thrown.
101101
/// </summary>
102-
public IRowToRowMapper GetRowToRowMapper(DataViewSchema inputSchema)
102+
IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema)
103103
{
104104
_host.CheckValue(inputSchema, nameof(inputSchema));
105105
var simplerMapper = MakeRowMapper(inputSchema);

src/Microsoft.ML.Transforms/OneHotEncoding.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -166,9 +166,9 @@ internal OneHotEncodingTransformer(ValueToKeyMappingEstimator term, IEstimator<I
166166

167167
void ICanSaveModel.Save(ModelSaveContext ctx) => (_transformer as ICanSaveModel).Save(ctx);
168168

169-
public bool IsRowToRowMapper => _transformer.IsRowToRowMapper;
169+
bool ITransformer.IsRowToRowMapper => ((ITransformer)_transformer).IsRowToRowMapper;
170170

171-
public IRowToRowMapper GetRowToRowMapper(DataViewSchema inputSchema) => _transformer.GetRowToRowMapper(inputSchema);
171+
IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema) => ((ITransformer)_transformer).GetRowToRowMapper(inputSchema);
172172
}
173173
/// <summary>
174174
/// Estimator which takes set of columns and produce for each column indicator array.

0 commit comments

Comments
 (0)