diff --git a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs b/src/Microsoft.ML.Api/DataViewConstructionUtils.cs index 840b866fe4..f185dc6c0b 100644 --- a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs +++ b/src/Microsoft.ML.Api/DataViewConstructionUtils.cs @@ -432,7 +432,7 @@ public override bool CanShuffle public override long? GetRowCount(bool lazy = true) { - return null; + return (_data as ICollection)?.Count; } public override IRowCursor GetRowCursor(Func predicate, IRandom rand = null) diff --git a/src/Microsoft.ML.Data/Scorers/RowToRowScorerBase.cs b/src/Microsoft.ML.Data/Scorers/RowToRowScorerBase.cs index 3594d921b8..8dc1138f55 100644 --- a/src/Microsoft.ML.Data/Scorers/RowToRowScorerBase.cs +++ b/src/Microsoft.ML.Data/Scorers/RowToRowScorerBase.cs @@ -143,7 +143,7 @@ public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolid var inputs = Source.GetRowCursorSet(out consolidator, predicateInput, n, rand); Contracts.AssertNonEmpty(inputs); - if (inputs.Length == 1 && n > 1 && WantParallelCursors(predicate)) + if (inputs.Length == 1 && n > 1 && WantParallelCursors(predicate) && (Source.GetRowCount() ?? int.MaxValue) > n) inputs = DataViewUtils.CreateSplitCursors(out consolidator, Host, inputs[0], n); Contracts.AssertNonEmpty(inputs); @@ -432,14 +432,14 @@ protected override void GetMetadataCore(string kind, int iinfo, ref TVal Contracts.Assert(0 <= iinfo && iinfo < InfoCount); switch (kind) { - case MetadataUtils.Kinds.ScoreColumnSetId: - _getScoreColumnSetId.Marshal(iinfo, ref value); - break; - default: - if (iinfo < DerivedColumnCount) - throw MetadataUtils.ExceptGetMetadata(); - Mapper.OutputSchema.GetMetadata(kind, iinfo - DerivedColumnCount, ref value); - break; + case MetadataUtils.Kinds.ScoreColumnSetId: + _getScoreColumnSetId.Marshal(iinfo, ref value); + break; + default: + if (iinfo < DerivedColumnCount) + throw MetadataUtils.ExceptGetMetadata(); + Mapper.OutputSchema.GetMetadata(kind, iinfo - DerivedColumnCount, ref value); + break; } }