Skip to content

Commit 5c04a35

Browse files
committed
Updated format.
1 parent 4ba3fe4 commit 5c04a35

File tree

3 files changed

+26
-21
lines changed

3 files changed

+26
-21
lines changed

sklbench/datasets/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,11 @@ def load_data(bench_case: BenchCase) -> Tuple[Dict, Dict]:
6767
generation_kwargs = get_bench_case_value(
6868
bench_case, "data:generation_kwargs", dict()
6969
)
70-
if 'center_box' in generation_kwargs:
71-
generation_kwargs['center_box'] = (-1 * generation_kwargs['center_box'], generation_kwargs['center_box'])
70+
if "center_box" in generation_kwargs:
71+
generation_kwargs["center_box"] = (
72+
-1 * generation_kwargs["center_box"],
73+
generation_kwargs["center_box"],
74+
)
7275
return load_sklearn_synthetic_data(
7376
function_name=source,
7477
input_kwargs=generation_kwargs,

sklbench/datasets/transformer.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -116,34 +116,36 @@ def split_and_transform_data(bench_case, data, data_description):
116116
)
117117

118118
if distributed_split == "sample_shift":
119-
comm = MPI.COMM_WORLD
120-
rank = comm.Get_rank()
121-
size = comm.Get_size()
119+
comm = MPI.COMM_WORLD
120+
rank = comm.Get_rank()
121+
size = comm.Get_size()
122122

123-
n_train = len(x_train)
124-
n_test = len(x_test)
123+
n_train = len(x_train)
124+
n_test = len(x_test)
125125

126-
train_start = 0
127-
train_end = n_train
128-
test_start = 0
129-
test_end = n_test
126+
train_start = 0
127+
train_end = n_train
128+
test_start = 0
129+
test_end = n_test
130130

131-
adjust_number = (math.sqrt(rank) * 0.003) + 1
131+
adjust_number = (math.sqrt(rank) * 0.003) + 1
132132

133-
if "y" in data:
133+
if "y" in data:
134134
x_train, y_train = (
135-
x_train[train_start:train_end] * adjust_number,
135+
x_train[train_start:train_end] * adjust_number,
136136
y_train[train_start:train_end],
137137
)
138-
139-
x_test, y_test = x_test[test_start:test_end] * adjust_number, y_test[test_start:test_end]
140-
else:
138+
139+
x_test, y_test = (
140+
x_test[test_start:test_end] * adjust_number,
141+
y_test[test_start:test_end],
142+
)
143+
else:
141144
x_train = x_train[train_start:train_end]
142-
145+
143146
x_test = x_test[test_start:test_end] * adjust_number
144147

145148
elif distributed_split == "rank_based" or knn_split_train:
146-
147149

148150
comm = MPI.COMM_WORLD
149151
rank = comm.Get_rank()
@@ -156,7 +158,7 @@ def split_and_transform_data(bench_case, data, data_description):
156158
train_end = (1 + rank) * n_train // size
157159
test_start = rank * n_test // size
158160
test_end = (1 + rank) * n_test // size
159-
x_train_rank = x_train[train_start:train_end]
161+
x_train_rank = x_train[train_start:train_end]
160162

161163
if "y" in data:
162164
x_train, y_train = (

sklbench/utils/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def flatten_list(input_list: List, ensure_type_homogeneity: bool = False) -> Lis
120120

121121

122122
def get_module_members(
123-
module_names_chain: Union[List, str]
123+
module_names_chain: Union[List, str],
124124
) -> Tuple[ModuleContentMap, ModuleContentMap]:
125125
def get_module_name(module_names_chain: List[str]) -> str:
126126
name = module_names_chain[0]

0 commit comments

Comments
 (0)