Skip to content

Commit 5133797

Browse files
authored
Image transforms become Estimators (#753)
Converted the following transforms to Estimators: - ImageLoader - ImageResizer - ImagePixelExtractor - ImageGrayscale Fixes #707
1 parent 40aab06 commit 5133797

File tree

17 files changed

+1438
-666
lines changed

17 files changed

+1438
-666
lines changed

src/Microsoft.ML.Console/Microsoft.ML.Console.csproj

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
<ProjectReference Include="..\Microsoft.ML.Ensemble\Microsoft.ML.Ensemble.csproj" />
1717
<ProjectReference Include="..\Microsoft.ML.FastTree\Microsoft.ML.FastTree.csproj" />
1818
<ProjectReference Include="..\Microsoft.ML.HalLearners\Microsoft.ML.HalLearners.csproj" />
19+
<ProjectReference Include="..\Microsoft.ML.ImageAnalytics\Microsoft.ML.ImageAnalytics.csproj" />
1920
<ProjectReference Include="..\Microsoft.ML.KMeansClustering\Microsoft.ML.KMeansClustering.csproj" />
2021
<ProjectReference Include="..\Microsoft.ML.LightGBM\Microsoft.ML.LightGBM.csproj" />
2122
<ProjectReference Include="..\Microsoft.ML.Maml\Microsoft.ML.Maml.csproj" />

src/Microsoft.ML.Core/Data/IEstimator.cs

+32-8
Original file line numberDiff line numberDiff line change
@@ -28,20 +28,42 @@ public enum VectorKind
2828
VariableVector
2929
}
3030

31+
/// <summary>
32+
/// The column name.
33+
/// </summary>
3134
public readonly string Name;
35+
36+
/// <summary>
37+
/// The type of the column: scalar, fixed vector or variable vector.
38+
/// </summary>
3239
public readonly VectorKind Kind;
33-
public readonly DataKind ItemKind;
40+
41+
/// <summary>
42+
/// The 'raw' type of column item: must be a primitive type or a structured type.
43+
/// </summary>
44+
public readonly ColumnType ItemType;
45+
/// <summary>
46+
/// The flag whether the column is actually a key. If yes, <see cref="ItemType"/> is representing
47+
/// the underlying primitive type.
48+
/// </summary>
3449
public readonly bool IsKey;
50+
/// <summary>
51+
/// The metadata kinds that are present for this column.
52+
/// </summary>
3553
public readonly string[] MetadataKinds;
3654

37-
public Column(string name, VectorKind vecKind, DataKind itemKind, bool isKey, string[] metadataKinds = null)
55+
public Column(string name, VectorKind vecKind, ColumnType itemType, bool isKey, string[] metadataKinds = null)
3856
{
3957
Contracts.CheckNonEmpty(name, nameof(name));
4058
Contracts.CheckValueOrNull(metadataKinds);
59+
Contracts.CheckParam(!itemType.IsKey, nameof(itemType), "Item type cannot be a key");
60+
Contracts.CheckParam(!itemType.IsVector, nameof(itemType), "Item type cannot be a vector");
61+
62+
Contracts.CheckParam(!isKey || KeyType.IsValidDataKind(itemType.RawKind), nameof(itemType), "The item type must be valid for a key");
4163

4264
Name = name;
4365
Kind = vecKind;
44-
ItemKind = itemKind;
66+
ItemType = itemType;
4567
IsKey = isKey;
4668
MetadataKinds = metadataKinds ?? new string[0];
4769
}
@@ -51,7 +73,7 @@ public Column(string name, VectorKind vecKind, DataKind itemKind, bool isKey, st
5173
/// requirement.
5274
///
5375
/// Namely, it returns true iff:
54-
/// - The <see cref="Name"/>, <see cref="Kind"/>, <see cref="ItemKind"/>, <see cref="IsKey"/> fields match.
76+
/// - The <see cref="Name"/>, <see cref="Kind"/>, <see cref="ItemType"/>, <see cref="IsKey"/> fields match.
5577
/// - The <see cref="MetadataKinds"/> of <paramref name="inputColumn"/> is a superset of our <see cref="MetadataKinds"/>.
5678
/// </summary>
5779
public bool IsCompatibleWith(Column inputColumn)
@@ -61,7 +83,7 @@ public bool IsCompatibleWith(Column inputColumn)
6183
return false;
6284
if (Kind != inputColumn.Kind)
6385
return false;
64-
if (ItemKind != inputColumn.ItemKind)
86+
if (!ItemType.Equals(inputColumn.ItemType))
6587
return false;
6688
if (IsKey != inputColumn.IsKey)
6789
return false;
@@ -72,7 +94,7 @@ public bool IsCompatibleWith(Column inputColumn)
7294

7395
public string GetTypeString()
7496
{
75-
string result = ItemKind.ToString();
97+
string result = ItemType.ToString();
7698
if (IsKey)
7799
result = $"Key<{result}>";
78100
if (Kind == VectorKind.Vector)
@@ -110,13 +132,15 @@ public static SchemaShape Create(ISchema schema)
110132
else
111133
vecKind = Column.VectorKind.Scalar;
112134

113-
var kind = type.ItemType.RawKind;
135+
ColumnType itemType = type.ItemType;
136+
if (type.ItemType.IsKey)
137+
itemType = PrimitiveType.FromKind(type.ItemType.RawKind);
114138
var isKey = type.ItemType.IsKey;
115139

116140
var metadataNames = schema.GetMetadataTypes(iCol)
117141
.Select(kvp => kvp.Key)
118142
.ToArray();
119-
cols.Add(new Column(schema.GetColumnName(iCol), vecKind, kind, isKey, metadataNames));
143+
cols.Add(new Column(schema.GetColumnName(iCol), vecKind, itemType, isKey, metadataNames));
120144
}
121145
}
122146
return new SchemaShape(cols.ToArray());
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Core.Data;
6+
7+
namespace Microsoft.ML.Runtime.Data
8+
{
9+
/// <summary>
10+
/// The trivial implementation of <see cref="IEstimator{TTransformer}"/> that already has
11+
/// the transformer and returns it on every call to <see cref="Fit(IDataView)"/>.
12+
///
13+
/// Concrete implementations still have to provide the schema propagation mechanism, since
14+
/// there is no easy way to infer it from the transformer.
15+
/// </summary>
16+
public abstract class TrivialEstimator<TTransformer> : IEstimator<TTransformer>
17+
where TTransformer : class, ITransformer
18+
{
19+
protected readonly IHost Host;
20+
protected readonly TTransformer Transformer;
21+
22+
protected TrivialEstimator(IHost host, TTransformer transformer)
23+
{
24+
Contracts.AssertValue(host);
25+
26+
Host = host;
27+
Host.CheckValue(transformer, nameof(transformer));
28+
Transformer = transformer;
29+
}
30+
31+
public TTransformer Fit(IDataView input) => Transformer;
32+
33+
public abstract SchemaShape GetOutputSchema(SchemaShape inputSchema);
34+
}
35+
}

src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
7070
var originalColumn = inputSchema.FindColumn(column.Source);
7171
if (originalColumn != null)
7272
{
73-
var col = new SchemaShape.Column(column.Name, originalColumn.Kind, originalColumn.ItemKind, originalColumn.IsKey, originalColumn.MetadataKinds);
73+
var col = new SchemaShape.Column(column.Name, originalColumn.Kind, originalColumn.ItemType, originalColumn.IsKey, originalColumn.MetadataKinds);
7474
resultDic[column.Name] = col;
7575
}
7676
else
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using System.Collections.Generic;
7+
using System.Linq;
8+
using Microsoft.ML.Core.Data;
9+
using Microsoft.ML.Runtime.Model;
10+
11+
namespace Microsoft.ML.Runtime.Data
12+
{
13+
public abstract class OneToOneTransformerBase : ITransformer, ICanSaveModel
14+
{
15+
protected readonly IHost Host;
16+
protected readonly (string input, string output)[] ColumnPairs;
17+
18+
protected OneToOneTransformerBase(IHost host, (string input, string output)[] columns)
19+
{
20+
Contracts.AssertValue(host);
21+
host.CheckValue(columns, nameof(columns));
22+
23+
var newNames = new HashSet<string>();
24+
foreach (var column in columns)
25+
{
26+
host.CheckNonEmpty(column.input, nameof(columns));
27+
host.CheckNonEmpty(column.output, nameof(columns));
28+
29+
if (!newNames.Add(column.output))
30+
throw Contracts.ExceptParam(nameof(columns), $"Output column '{column.output}' specified multiple times");
31+
}
32+
33+
Host = host;
34+
ColumnPairs = columns;
35+
}
36+
37+
protected OneToOneTransformerBase(IHost host, ModelLoadContext ctx)
38+
{
39+
Host = host;
40+
// *** Binary format ***
41+
// int: number of added columns
42+
// for each added column
43+
// int: id of output column name
44+
// int: id of input column name
45+
46+
int n = ctx.Reader.ReadInt32();
47+
ColumnPairs = new (string input, string output)[n];
48+
for (int i = 0; i < n; i++)
49+
{
50+
string output = ctx.LoadNonEmptyString();
51+
string input = ctx.LoadNonEmptyString();
52+
ColumnPairs[i] = (input, output);
53+
}
54+
}
55+
56+
public abstract void Save(ModelSaveContext ctx);
57+
58+
protected void SaveColumns(ModelSaveContext ctx)
59+
{
60+
Host.CheckValue(ctx, nameof(ctx));
61+
62+
// *** Binary format ***
63+
// int: number of added columns
64+
// for each added column
65+
// int: id of output column name
66+
// int: id of input column name
67+
68+
ctx.Writer.Write(ColumnPairs.Length);
69+
for (int i = 0; i < ColumnPairs.Length; i++)
70+
{
71+
ctx.SaveNonEmptyString(ColumnPairs[i].output);
72+
ctx.SaveNonEmptyString(ColumnPairs[i].input);
73+
}
74+
}
75+
76+
private void CheckInput(ISchema inputSchema, int col, out int srcCol)
77+
{
78+
Contracts.AssertValue(inputSchema);
79+
Contracts.Assert(0 <= col && col < ColumnPairs.Length);
80+
81+
if (!inputSchema.TryGetColumnIndex(ColumnPairs[col].input, out srcCol))
82+
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].input);
83+
CheckInputColumn(inputSchema, col, srcCol);
84+
}
85+
86+
protected virtual void CheckInputColumn(ISchema inputSchema, int col, int srcCol)
87+
{
88+
// By default, there are no extra checks.
89+
}
90+
91+
protected abstract IRowMapper MakeRowMapper(ISchema schema);
92+
93+
public ISchema GetOutputSchema(ISchema inputSchema)
94+
{
95+
Host.CheckValue(inputSchema, nameof(inputSchema));
96+
97+
// Check that all the input columns are present and correct.
98+
for (int i = 0; i < ColumnPairs.Length; i++)
99+
CheckInput(inputSchema, i, out int col);
100+
101+
return Transform(new EmptyDataView(Host, inputSchema)).Schema;
102+
}
103+
104+
public IDataView Transform(IDataView input) => MakeDataTransform(input);
105+
106+
protected RowToRowMapperTransform MakeDataTransform(IDataView input)
107+
{
108+
Host.CheckValue(input, nameof(input));
109+
return new RowToRowMapperTransform(Host, input, MakeRowMapper(input.Schema));
110+
}
111+
112+
protected abstract class MapperBase : IRowMapper
113+
{
114+
protected readonly IHost Host;
115+
protected readonly Dictionary<int, int> ColMapNewToOld;
116+
protected readonly ISchema InputSchema;
117+
private readonly OneToOneTransformerBase _parent;
118+
119+
protected MapperBase(IHost host, OneToOneTransformerBase parent, ISchema inputSchema)
120+
{
121+
Contracts.AssertValue(host);
122+
Contracts.AssertValue(parent);
123+
Contracts.AssertValue(inputSchema);
124+
125+
Host = host;
126+
_parent = parent;
127+
128+
ColMapNewToOld = new Dictionary<int, int>();
129+
for (int i = 0; i < _parent.ColumnPairs.Length; i++)
130+
{
131+
_parent.CheckInput(inputSchema, i, out int srcCol);
132+
ColMapNewToOld.Add(i, srcCol);
133+
}
134+
InputSchema = inputSchema;
135+
}
136+
public Func<int, bool> GetDependencies(Func<int, bool> activeOutput)
137+
{
138+
var active = new bool[InputSchema.ColumnCount];
139+
foreach (var pair in ColMapNewToOld)
140+
if (activeOutput(pair.Key))
141+
active[pair.Value] = true;
142+
return col => active[col];
143+
}
144+
145+
public abstract RowMapperColumnInfo[] GetOutputColumns();
146+
147+
public void Save(ModelSaveContext ctx) => _parent.Save(ctx);
148+
149+
public Delegate[] CreateGetters(IRow input, Func<int, bool> activeOutput, out Action disposer)
150+
{
151+
Contracts.Assert(input.Schema == InputSchema);
152+
var result = new Delegate[_parent.ColumnPairs.Length];
153+
var disposers = new Action[_parent.ColumnPairs.Length];
154+
for (int i = 0; i < _parent.ColumnPairs.Length; i++)
155+
{
156+
if (!activeOutput(i))
157+
continue;
158+
int srcCol = ColMapNewToOld[i];
159+
result[i] = MakeGetter(input, i, out disposers[i]);
160+
}
161+
if (disposers.Any(x => x != null))
162+
{
163+
disposer = () =>
164+
{
165+
foreach (var act in disposers)
166+
act();
167+
};
168+
}
169+
else
170+
disposer = null;
171+
return result;
172+
}
173+
174+
protected abstract Delegate MakeGetter(IRow input, int iinfo, out Action disposer);
175+
}
176+
}
177+
}

src/Microsoft.ML.ImageAnalytics/EntryPoints/ImageAnalytics.cs

+4-4
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ public static class ImageAnalytics
1616
public static CommonOutputs.TransformOutput ImageLoader(IHostEnvironment env, ImageLoaderTransform.Arguments input)
1717
{
1818
var h = EntryPointUtils.CheckArgsAndCreateHost(env, "ImageLoaderTransform", input);
19-
var xf = new ImageLoaderTransform(h, input, input.Data);
19+
var xf = ImageLoaderTransform.Create(h, input, input.Data);
2020
return new CommonOutputs.TransformOutput()
2121
{
2222
Model = new TransformModel(h, xf, input.Data),
@@ -29,7 +29,7 @@ public static CommonOutputs.TransformOutput ImageLoader(IHostEnvironment env, Im
2929
public static CommonOutputs.TransformOutput ImageResizer(IHostEnvironment env, ImageResizerTransform.Arguments input)
3030
{
3131
var h = EntryPointUtils.CheckArgsAndCreateHost(env, "ImageResizerTransform", input);
32-
var xf = new ImageResizerTransform(h, input, input.Data);
32+
var xf = ImageResizerTransform.Create(h, input, input.Data);
3333
return new CommonOutputs.TransformOutput()
3434
{
3535
Model = new TransformModel(h, xf, input.Data),
@@ -42,7 +42,7 @@ public static CommonOutputs.TransformOutput ImageResizer(IHostEnvironment env, I
4242
public static CommonOutputs.TransformOutput ImagePixelExtractor(IHostEnvironment env, ImagePixelExtractorTransform.Arguments input)
4343
{
4444
var h = EntryPointUtils.CheckArgsAndCreateHost(env, "ImagePixelExtractorTransform", input);
45-
var xf = new ImagePixelExtractorTransform(h, input, input.Data);
45+
var xf = ImagePixelExtractorTransform.Create(h, input, input.Data);
4646
return new CommonOutputs.TransformOutput()
4747
{
4848
Model = new TransformModel(h, xf, input.Data),
@@ -55,7 +55,7 @@ public static CommonOutputs.TransformOutput ImagePixelExtractor(IHostEnvironment
5555
public static CommonOutputs.TransformOutput ImageGrayscale(IHostEnvironment env, ImageGrayscaleTransform.Arguments input)
5656
{
5757
var h = EntryPointUtils.CheckArgsAndCreateHost(env, "ImageGrayscaleTransform", input);
58-
var xf = new ImageGrayscaleTransform(h, input, input.Data);
58+
var xf = ImageGrayscaleTransform.Create(h, input, input.Data);
5959
return new CommonOutputs.TransformOutput()
6060
{
6161
Model = new TransformModel(h, xf, input.Data),

0 commit comments

Comments
 (0)