Skip to content

Commit 985db07

Browse files
committed
Added shift.
1 parent fd59a64 commit 985db07

File tree

2 files changed

+34
-5
lines changed

2 files changed

+34
-5
lines changed

configs/spmd/kmeans_wide_weak.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
"SETS": [
2626
"synthetic data",
2727
"sklearnex spmd implementation",
28-
"large scale 2k parameters sample shift",
28+
"large scale 2k parameters",
2929
"spmd kmeans parameters"
3030
]
3131
}

sklbench/datasets/transformer.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
from ..utils.bench_case import get_bench_case_value
2525
from ..utils.logger import logger
26-
26+
from mpi4py import MPI
2727

2828
def convert_data(data, dformat: str, order: str, dtype: str, device: str = None):
2929
if isinstance(data, csr_matrix) and dformat != "csr_matrix":
@@ -113,8 +113,36 @@ def split_and_transform_data(bench_case, data, data_description):
113113
"KNeighbors" in get_bench_case_value(bench_case, "algorithm:estimator", "")
114114
and int(get_bench_case_value(bench_case, "bench:mpi_params:n", 1)) > 1
115115
)
116-
if distributed_split == "rank_based" or knn_split_train:
117-
from mpi4py import MPI
116+
117+
if distributed_split == "sample_shift":
118+
comm = MPI.COMM_WORLD
119+
rank = comm.Get_rank()
120+
size = comm.Get_size()
121+
122+
n_train = len(x_train)
123+
n_test = len(x_test)
124+
125+
train_start = 0
126+
train_end = n_train
127+
test_start = 0
128+
test_end = n_test
129+
130+
adjust_number = (math.sqrt(rank) * 0.003) + 1
131+
132+
if "y" in data:
133+
x_train, y_train = (
134+
x_train[train_start:train_end] * adjust_number,
135+
y_train[train_start:train_end],
136+
)
137+
138+
x_test, y_test = x_test[test_start:test_end] * adjust_number, y_test[test_start:test_end]
139+
else:
140+
x_train = x_train[train_start:train_end]
141+
142+
x_test = x_test[test_start:test_end] * adjust_number
143+
144+
elif distributed_split == "rank_based" or knn_split_train:
145+
118146

119147
comm = MPI.COMM_WORLD
120148
rank = comm.Get_rank()
@@ -127,6 +155,7 @@ def split_and_transform_data(bench_case, data, data_description):
127155
train_end = (1 + rank) * n_train // size
128156
test_start = rank * n_test // size
129157
test_end = (1 + rank) * n_test // size
158+
x_train_rank = x_train[train_start:train_end]
130159

131160
if "y" in data:
132161
x_train, y_train = (
@@ -138,7 +167,7 @@ def split_and_transform_data(bench_case, data, data_description):
138167
else:
139168
x_train = x_train[train_start:train_end]
140169
if distributed_split == "rank_based":
141-
x_test = x_test[test_start:test_end]
170+
x_test = x_test[test_start:test_end] * adjust_number
142171

143172
device = get_bench_case_value(bench_case, "algorithm:device", None)
144173
common_data_format = get_bench_case_value(bench_case, "data:format", "pandas")

0 commit comments

Comments
 (0)