Skip to content

Commit 5e672da

Browse files
fix error when sklearnex has more parameters than sklearn
1 parent 88fd9c8 commit 5e672da

File tree

1 file changed

+14
-2
lines changed

1 file changed

+14
-2
lines changed

sklbench/report/implementation.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616

1717
import argparse
1818
import json
19+
from functools import reduce
1920
from typing import Dict, List
2021

22+
import numpy as np
2123
import openpyxl as xl
2224
import pandas as pd
2325
from openpyxl.formatting.rule import ColorScaleRule
@@ -165,7 +167,12 @@ def select_comparison(i, j, diffs_selection):
165167
df = input_df.set_index(index_columns)
166168
unique_indices = df.index.unique()
167169
splitted_dfs = split_df_by_columns(input_df, diff_columns)
168-
splitted_dfs = {key: df.set_index(index_columns) for key, df in splitted_dfs.items()}
170+
common_cols = reduce(np.intersect1d, [df.columns for df in splitted_dfs.values()])
171+
df_specific_cols = np.setdiff1d(index_columns, common_cols)
172+
splitted_dfs = {
173+
key: df.assign(**{col: None for col in df_specific_cols}).set_index(index_columns)
174+
for key, df in splitted_dfs.items()
175+
}
169176

170177
# drop results with duplicated indices (keep first entry only)
171178
for key, splitted_df in splitted_dfs.items():
@@ -181,6 +188,9 @@ def select_comparison(i, j, diffs_selection):
181188
# compared values
182189
for i, (key_ith, df_ith) in enumerate(splitted_dfs.items()):
183190
for j, (key_jth, df_jth) in enumerate(splitted_dfs.items()):
191+
common_indexes = np.intersect1d(df_ith.index, df_jth.index)
192+
df_ith = df_ith.loc[common_indexes]
193+
df_jth = df_jth.loc[common_indexes]
184194
if select_comparison(i, j, diffs_selection):
185195
comparison_name = f"{key_jth} vs {key_ith}"
186196
for column in df_ith.columns:
@@ -196,7 +206,9 @@ def select_comparison(i, j, diffs_selection):
196206
df[f"{comparison_name}\n{column} is equal"] = (
197207
df_ith[column] == df_jth[column]
198208
)
199-
df = df.reset_index()
209+
if len(df_specific_cols):
210+
df.index = df.index.droplevel(list(df_specific_cols))
211+
df = df.dropna(axis=0, how="all", ignore_index=False).reset_index()
200212
# move to multi-index
201213
df = df[reorder_columns(list(df.columns))]
202214
df.columns = [

0 commit comments

Comments
 (0)