From f5f074b5d4eec89b8e499de6c3839723fc025527 Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Mon, 21 May 2018 18:09:20 -0700 Subject: [PATCH 1/2] Don't create parallel cursor if we have only one element in dataview --- .../DataViewConstructionUtils.cs | 2 ++ .../Scorers/RowToRowScorerBase.cs | 18 +++++++++--------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs b/src/Microsoft.ML.Api/DataViewConstructionUtils.cs index 840b866fe4..33c046eb87 100644 --- a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs +++ b/src/Microsoft.ML.Api/DataViewConstructionUtils.cs @@ -432,6 +432,8 @@ public override bool CanShuffle public override long? GetRowCount(bool lazy = true) { + if (_data is ICollection collection) + return collection.Count; return null; } diff --git a/src/Microsoft.ML.Data/Scorers/RowToRowScorerBase.cs b/src/Microsoft.ML.Data/Scorers/RowToRowScorerBase.cs index 3594d921b8..e182f16ccd 100644 --- a/src/Microsoft.ML.Data/Scorers/RowToRowScorerBase.cs +++ b/src/Microsoft.ML.Data/Scorers/RowToRowScorerBase.cs @@ -143,7 +143,7 @@ public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolid var inputs = Source.GetRowCursorSet(out consolidator, predicateInput, n, rand); Contracts.AssertNonEmpty(inputs); - if (inputs.Length == 1 && n > 1 && WantParallelCursors(predicate)) + if (inputs.Length == 1 && n > 1 && WantParallelCursors(predicate) && Source.GetRowCount() != 1) inputs = DataViewUtils.CreateSplitCursors(out consolidator, Host, inputs[0], n); Contracts.AssertNonEmpty(inputs); @@ -432,14 +432,14 @@ protected override void GetMetadataCore(string kind, int iinfo, ref TVal Contracts.Assert(0 <= iinfo && iinfo < InfoCount); switch (kind) { - case MetadataUtils.Kinds.ScoreColumnSetId: - _getScoreColumnSetId.Marshal(iinfo, ref value); - break; - default: - if (iinfo < DerivedColumnCount) - throw MetadataUtils.ExceptGetMetadata(); - Mapper.OutputSchema.GetMetadata(kind, iinfo - DerivedColumnCount, ref value); - break; + case MetadataUtils.Kinds.ScoreColumnSetId: + _getScoreColumnSetId.Marshal(iinfo, ref value); + break; + default: + if (iinfo < DerivedColumnCount) + throw MetadataUtils.ExceptGetMetadata(); + Mapper.OutputSchema.GetMetadata(kind, iinfo - DerivedColumnCount, ref value); + break; } } From 749b4e9b8d9f2a08914007db1c08b1bc162273bc Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Tue, 22 May 2018 16:38:49 -0700 Subject: [PATCH 2/2] update for comments --- src/Microsoft.ML.Api/DataViewConstructionUtils.cs | 4 +--- src/Microsoft.ML.Data/Scorers/RowToRowScorerBase.cs | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs b/src/Microsoft.ML.Api/DataViewConstructionUtils.cs index 33c046eb87..f185dc6c0b 100644 --- a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs +++ b/src/Microsoft.ML.Api/DataViewConstructionUtils.cs @@ -432,9 +432,7 @@ public override bool CanShuffle public override long? GetRowCount(bool lazy = true) { - if (_data is ICollection collection) - return collection.Count; - return null; + return (_data as ICollection)?.Count; } public override IRowCursor GetRowCursor(Func predicate, IRandom rand = null) diff --git a/src/Microsoft.ML.Data/Scorers/RowToRowScorerBase.cs b/src/Microsoft.ML.Data/Scorers/RowToRowScorerBase.cs index e182f16ccd..8dc1138f55 100644 --- a/src/Microsoft.ML.Data/Scorers/RowToRowScorerBase.cs +++ b/src/Microsoft.ML.Data/Scorers/RowToRowScorerBase.cs @@ -143,7 +143,7 @@ public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolid var inputs = Source.GetRowCursorSet(out consolidator, predicateInput, n, rand); Contracts.AssertNonEmpty(inputs); - if (inputs.Length == 1 && n > 1 && WantParallelCursors(predicate) && Source.GetRowCount() != 1) + if (inputs.Length == 1 && n > 1 && WantParallelCursors(predicate) && (Source.GetRowCount() ?? int.MaxValue) > n) inputs = DataViewUtils.CreateSplitCursors(out consolidator, Host, inputs[0], n); Contracts.AssertNonEmpty(inputs);