Skip to content

Commit 5d36b83

Browse files
committed
Address comments
1 parent dbfb543 commit 5d36b83

File tree

3 files changed

+44
-21
lines changed

3 files changed

+44
-21
lines changed

src/Microsoft.ML.Data/DataView/CompositeSchema.cs renamed to src/Microsoft.ML.Data/DataView/ZipBinding.cs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System;
6-
using System.Collections.Generic;
7-
using System.Linq;
86
using Microsoft.ML.Internal.Utilities;
97

108
namespace Microsoft.ML.Data

src/Microsoft.ML.Data/DataView/ZipDataView.cs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ public sealed class ZipDataView : IDataView
2525

2626
private readonly IHost _host;
2727
private readonly IDataView[] _sources;
28-
private readonly ZipBinding _compositeSchema;
28+
private readonly ZipBinding _zipBinding;
2929

3030
public static IDataView Create(IHostEnvironment env, IEnumerable<IDataView> sources)
3131
{
@@ -47,12 +47,12 @@ private ZipDataView(IHost host, IDataView[] sources)
4747

4848
_host.Assert(Utils.Size(sources) > 1);
4949
_sources = sources;
50-
_compositeSchema = new ZipBinding(_sources.Select(x => x.Schema).ToArray());
50+
_zipBinding = new ZipBinding(_sources.Select(x => x.Schema).ToArray());
5151
}
5252

5353
public bool CanShuffle { get { return false; } }
5454

55-
public Schema Schema => _compositeSchema.OutputSchema;
55+
public Schema Schema => _zipBinding.OutputSchema;
5656

5757
public long? GetRowCount()
5858
{
@@ -75,7 +75,7 @@ public RowCursor GetRowCursor(Func<int, bool> predicate, Random rand = null)
7575
_host.CheckValue(predicate, nameof(predicate));
7676
_host.CheckValueOrNull(rand);
7777

78-
var srcPredicates = _compositeSchema.GetInputPredicates(predicate);
78+
var srcPredicates = _zipBinding.GetInputPredicates(predicate);
7979

8080
// REVIEW: if we know the row counts, we could only open cursor if it has needed columns, and have the
8181
// outer cursor handle the early stopping. If we don't know row counts, we need to open all the cursors because
@@ -106,7 +106,7 @@ public RowCursor[] GetRowCursorSet(Func<int, bool> predicate, int n, Random rand
106106
private sealed class Cursor : RootCursorBase
107107
{
108108
private readonly RowCursor[] _cursors;
109-
private readonly ZipBinding _compositeSchema;
109+
private readonly ZipBinding _zipBinding;
110110
private readonly bool[] _isColumnActive;
111111
private bool _disposed;
112112

@@ -119,8 +119,8 @@ public Cursor(ZipDataView parent, RowCursor[] srcCursors, Func<int, bool> predic
119119
Ch.AssertValue(predicate);
120120

121121
_cursors = srcCursors;
122-
_compositeSchema = parent._compositeSchema;
123-
_isColumnActive = Utils.BuildArray(_compositeSchema.ColumnCount, predicate);
122+
_zipBinding = parent._zipBinding;
123+
_isColumnActive = Utils.BuildArray(_zipBinding.ColumnCount, predicate);
124124
}
125125

126126
protected override void Dispose(bool disposing)
@@ -172,19 +172,19 @@ protected override bool MoveManyCore(long count)
172172
return true;
173173
}
174174

175-
public override Schema Schema => _compositeSchema.OutputSchema;
175+
public override Schema Schema => _zipBinding.OutputSchema;
176176

177177
public override bool IsColumnActive(int col)
178178
{
179-
_compositeSchema.CheckColumnInRange(col);
179+
_zipBinding.CheckColumnInRange(col);
180180
return _isColumnActive[col];
181181
}
182182

183183
public override ValueGetter<TValue> GetGetter<TValue>(int col)
184184
{
185185
int dv;
186186
int srcCol;
187-
_compositeSchema.GetColumnSource(col, out dv, out srcCol);
187+
_zipBinding.GetColumnSource(col, out dv, out srcCol);
188188
return _cursors[dv].GetGetter<TValue>(srcCol);
189189
}
190190
}

src/Microsoft.ML.Transforms/GroupTransform.cs

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -195,14 +195,18 @@ private sealed class GroupSchema : ISchema
195195
private readonly IExceptionContext _ectx;
196196
private readonly Schema _input;
197197

198+
// Column names in source schema used to group rows.
198199
private readonly string[] _groupColumns;
200+
// Column names in source schema grouped into rows.
199201
private readonly string[] _keepColumns;
200202

203+
// GroupIds[i] is the i-th group-key column's column index in the source schema.
201204
public readonly int[] GroupIds;
205+
// GroupIds[i] is the i-th grouped column's column index in the source schema.
202206
public readonly int[] KeepIds;
203207

204208
private readonly int _groupCount;
205-
private readonly ColumnType[] _columnTypes;
209+
private readonly ColumnType[] _keepColumnTypes;
206210

207211
private readonly Dictionary<string, int> _columnNameMap;
208212

@@ -224,10 +228,20 @@ public GroupSchema(IExceptionContext ectx, Schema inputSchema, string[] groupCol
224228
_keepColumns = keepColumns;
225229
KeepIds = GetColumnIds(inputSchema, keepColumns, x => _ectx.ExceptUserArg(nameof(Arguments.Column), x));
226230

227-
_columnTypes = BuildColumnTypes(_input, KeepIds);
231+
_keepColumnTypes = BuildColumnTypes(_input, KeepIds);
228232
_columnNameMap = BuildColumnNameMap();
229233

230-
AsSchema = Schema.Create(this);
234+
var schemaBuilder = new SchemaBuilder();
235+
foreach (var groupKeyColumnName in groupColumns)
236+
schemaBuilder.AddColumn(groupKeyColumnName, inputSchema[groupKeyColumnName].Type, inputSchema[groupKeyColumnName].Metadata);
237+
foreach (var groupValueColumnName in keepColumns)
238+
{
239+
var metadataBuilder = new MetadataBuilder();
240+
metadataBuilder.Add(inputSchema[groupValueColumnName].Metadata,
241+
s => s == MetadataUtils.Kinds.IsNormalized || s == MetadataUtils.Kinds.KeyValues);
242+
schemaBuilder.AddColumn(groupValueColumnName, inputSchema[groupValueColumnName].Type, metadataBuilder.GetMetadata());
243+
}
244+
AsSchema = schemaBuilder.GetSchema();
231245
}
232246

233247
public GroupSchema(Schema inputSchema, IHostEnvironment env, ModelLoadContext ctx)
@@ -261,7 +275,7 @@ public GroupSchema(Schema inputSchema, IHostEnvironment env, ModelLoadContext ct
261275

262276
KeepIds = GetColumnIds(inputSchema, _keepColumns, _ectx.Except);
263277

264-
_columnTypes = BuildColumnTypes(_input, KeepIds);
278+
_keepColumnTypes = BuildColumnTypes(_input, KeepIds);
265279
_columnNameMap = BuildColumnNameMap();
266280

267281
AsSchema = Schema.Create(this);
@@ -319,23 +333,34 @@ public void Save(ModelSaveContext ctx)
319333
}
320334
}
321335

336+
/// <summary>
337+
/// Given column names, extract and return column indexes from source schema.
338+
/// </summary>
339+
/// <param name="schema">Source schema</param>
340+
/// <param name="names">Column names</param>
341+
/// <param name="except">Marked exception function</param>
342+
/// <returns>column indexes</returns>
322343
private int[] GetColumnIds(Schema schema, string[] names, Func<string, Exception> except)
323344
{
324345
Contracts.AssertValue(schema);
325346
Contracts.AssertValue(names);
326347

327348
var ids = new int[names.Length];
349+
328350
for (int i = 0; i < names.Length; i++)
329351
{
330-
int col;
331-
if (!schema.TryGetColumnIndex(names[i], out col))
352+
// Find column called names[i] from input schema.
353+
var retrievedColumn = schema.GetColumnOrNull(names[i]);
354+
355+
// Throw if no such a schema.
356+
if (!retrievedColumn.HasValue)
332357
throw except(string.Format("Could not find column '{0}'", names[i]));
333358

334-
var colType = schema[col].Type;
359+
var colType = retrievedColumn.Value.Type;
335360
if (!colType.IsPrimitive)
336361
throw except(string.Format("Column '{0}' has type '{1}', but must have a primitive type", names[i], colType));
337362

338-
ids[i] = col;
363+
ids[i] = retrievedColumn.Value.Index;
339364
}
340365

341366
return ids;
@@ -366,7 +391,7 @@ public ColumnType GetColumnType(int col)
366391
CheckColumnInRange(col);
367392
if (col < _groupCount)
368393
return _input[GroupIds[col]].Type;
369-
return _columnTypes[col - _groupCount];
394+
return _keepColumnTypes[col - _groupCount];
370395
}
371396

372397
public IEnumerable<KeyValuePair<string, ColumnType>> GetMetadataTypes(int col)

0 commit comments

Comments
 (0)