23
23
24
24
from ..utils .bench_case import get_bench_case_value
25
25
from ..utils .logger import logger
26
-
26
+ from mpi4py import MPI
27
27
28
28
def convert_data (data , dformat : str , order : str , dtype : str , device : str = None ):
29
29
if isinstance (data , csr_matrix ) and dformat != "csr_matrix" :
@@ -113,8 +113,36 @@ def split_and_transform_data(bench_case, data, data_description):
113
113
"KNeighbors" in get_bench_case_value (bench_case , "algorithm:estimator" , "" )
114
114
and int (get_bench_case_value (bench_case , "bench:mpi_params:n" , 1 )) > 1
115
115
)
116
- if distributed_split == "rank_based" or knn_split_train :
117
- from mpi4py import MPI
116
+
117
+ if distributed_split == "sample_shift" :
118
+ comm = MPI .COMM_WORLD
119
+ rank = comm .Get_rank ()
120
+ size = comm .Get_size ()
121
+
122
+ n_train = len (x_train )
123
+ n_test = len (x_test )
124
+
125
+ train_start = 0
126
+ train_end = n_train
127
+ test_start = 0
128
+ test_end = n_test
129
+
130
+ adjust_number = (math .sqrt (rank ) * 0.003 ) + 1
131
+
132
+ if "y" in data :
133
+ x_train , y_train = (
134
+ x_train [train_start :train_end ] * adjust_number ,
135
+ y_train [train_start :train_end ],
136
+ )
137
+
138
+ x_test , y_test = x_test [test_start :test_end ] * adjust_number , y_test [test_start :test_end ]
139
+ else :
140
+ x_train = x_train [train_start :train_end ]
141
+
142
+ x_test = x_test [test_start :test_end ] * adjust_number
143
+
144
+ elif distributed_split == "rank_based" or knn_split_train :
145
+
118
146
119
147
comm = MPI .COMM_WORLD
120
148
rank = comm .Get_rank ()
@@ -127,6 +155,7 @@ def split_and_transform_data(bench_case, data, data_description):
127
155
train_end = (1 + rank ) * n_train // size
128
156
test_start = rank * n_test // size
129
157
test_end = (1 + rank ) * n_test // size
158
+ x_train_rank = x_train [train_start :train_end ]
130
159
131
160
if "y" in data :
132
161
x_train , y_train = (
@@ -138,7 +167,7 @@ def split_and_transform_data(bench_case, data, data_description):
138
167
else :
139
168
x_train = x_train [train_start :train_end ]
140
169
if distributed_split == "rank_based" :
141
- x_test = x_test [test_start :test_end ]
170
+ x_test = x_test [test_start :test_end ] * adjust_number
142
171
143
172
device = get_bench_case_value (bench_case , "algorithm:device" , None )
144
173
common_data_format = get_bench_case_value (bench_case , "data:format" , "pandas" )
0 commit comments