Skip to content

Commit b6e9e74

Browse files
authored
Don't create parallel cursor if we have only one element in dataview (#197)
* Don't create parallel cursor if we have only one element in dataview * update for comments
1 parent 35b8134 commit b6e9e74

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

src/Microsoft.ML.Api/DataViewConstructionUtils.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,7 @@ public override bool CanShuffle
432432

433433
public override long? GetRowCount(bool lazy = true)
434434
{
435-
return null;
435+
return (_data as ICollection<TRow>)?.Count;
436436
}
437437

438438
public override IRowCursor GetRowCursor(Func<int, bool> predicate, IRandom rand = null)

src/Microsoft.ML.Data/Scorers/RowToRowScorerBase.cs

+9-9
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolid
143143
var inputs = Source.GetRowCursorSet(out consolidator, predicateInput, n, rand);
144144
Contracts.AssertNonEmpty(inputs);
145145

146-
if (inputs.Length == 1 && n > 1 && WantParallelCursors(predicate))
146+
if (inputs.Length == 1 && n > 1 && WantParallelCursors(predicate) && (Source.GetRowCount() ?? int.MaxValue) > n)
147147
inputs = DataViewUtils.CreateSplitCursors(out consolidator, Host, inputs[0], n);
148148
Contracts.AssertNonEmpty(inputs);
149149

@@ -432,14 +432,14 @@ protected override void GetMetadataCore<TValue>(string kind, int iinfo, ref TVal
432432
Contracts.Assert(0 <= iinfo && iinfo < InfoCount);
433433
switch (kind)
434434
{
435-
case MetadataUtils.Kinds.ScoreColumnSetId:
436-
_getScoreColumnSetId.Marshal(iinfo, ref value);
437-
break;
438-
default:
439-
if (iinfo < DerivedColumnCount)
440-
throw MetadataUtils.ExceptGetMetadata();
441-
Mapper.OutputSchema.GetMetadata<TValue>(kind, iinfo - DerivedColumnCount, ref value);
442-
break;
435+
case MetadataUtils.Kinds.ScoreColumnSetId:
436+
_getScoreColumnSetId.Marshal(iinfo, ref value);
437+
break;
438+
default:
439+
if (iinfo < DerivedColumnCount)
440+
throw MetadataUtils.ExceptGetMetadata();
441+
Mapper.OutputSchema.GetMetadata<TValue>(kind, iinfo - DerivedColumnCount, ref value);
442+
break;
443443
}
444444
}
445445

0 commit comments

Comments
 (0)