Skip to content

Commit c7f38f4

Browse files
committed
Rolled back the accidental changes to the ranked_based distributed_split.
1 parent af48e96 commit c7f38f4

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

sklbench/datasets/transformer.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,7 @@ def split_and_transform_data(bench_case, data, data_description):
146146

147147
x_test = x_test[test_start:test_end] * adjust_number
148148

149-
elif distributed_split == "rank_based" or knn_split_train:
150-
149+
if distributed_split == "rank_based" or knn_split_train:
151150
comm = MPI.COMM_WORLD
152151
rank = comm.Get_rank()
153152
size = comm.Get_size()
@@ -159,7 +158,6 @@ def split_and_transform_data(bench_case, data, data_description):
159158
train_end = (1 + rank) * n_train // size
160159
test_start = rank * n_test // size
161160
test_end = (1 + rank) * n_test // size
162-
x_train_rank = x_train[train_start:train_end]
163161

164162
if "y" in data:
165163
x_train, y_train = (
@@ -171,7 +169,8 @@ def split_and_transform_data(bench_case, data, data_description):
171169
else:
172170
x_train = x_train[train_start:train_end]
173171
if distributed_split == "rank_based":
174-
x_test = x_test[test_start:test_end] * adjust_number
172+
x_test = x_test[test_start:test_end]
173+
175174

176175
device = get_bench_case_value(bench_case, "algorithm:device", None)
177176
common_data_format = get_bench_case_value(bench_case, "data:format", "pandas")

0 commit comments

Comments
 (0)