Skip to content

Commit 405fb4d

Browse files
author
Thomas
committed
Cosmetic changes
1 parent c3c20ff commit 405fb4d

File tree

4 files changed

+34
-25
lines changed

4 files changed

+34
-25
lines changed

mpi4py_fft/distarray.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def local_slice(self):
246246
(slice(0, 16, None), slice(7, 14, None), slice(6, 12, None))
247247
"""
248248
v = [slice(start, start+shape) for start, shape in zip(self._p0.substart,
249-
self._p0.subshape)]
249+
self._p0.subshape)]
250250
return tuple([slice(0, s) for s in self.shape[:self.rank]] + v)
251251

252252
def redistribute(self, axis=None, out=None):
@@ -298,10 +298,10 @@ def redistribute(self, axis=None, out=None):
298298
p1, transfer = self.get_pencil_and_transfer(axis)
299299
if out is None:
300300
out = type(self)(self.global_shape,
301-
subcomm=p1.subcomm,
302-
dtype=self.dtype,
303-
alignment=axis,
304-
rank=self.rank)
301+
subcomm=p1.subcomm,
302+
dtype=self.dtype,
303+
alignment=axis,
304+
rank=self.rank)
305305

306306
if self.rank == 0:
307307
transfer.forward(self, out)

mpi4py_fft/libfft.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def _Xfftn_plan_mkl(shape, axes, dtype, transforms, options): #pragma: no cover
154154

155155
def _Xfftn_plan_cupyx_scipy(shape, axes, dtype, transforms, options):
156156
import cupy as cp
157-
import cupyx.scipy.fftpack as cufft
157+
import cupyx.scipy.fft as cufft
158158

159159
transforms = {} if transforms is None else transforms
160160
if tuple(axes) in transforms:
@@ -168,8 +168,8 @@ def _Xfftn_plan_cupyx_scipy(shape, axes, dtype, transforms, options):
168168
V = plan_fwd(U, s=s, axes=axes)
169169
V = cp.array(V)
170170
M = np.prod(s)
171-
return (_Yfftn_wrap(plan_fwd, U, V, 1, {'shape': s, 'axes': axes, 'overwrite_x': True}),
172-
_Yfftn_wrap(plan_bck, V, U, M, {'shape': s, 'axes': axes, 'overwrite_x': True}))
171+
return (_Yfftn_wrap(plan_fwd, U, V, 1, {'s': s, 'axes': axes, 'overwrite_x': True}),
172+
_Yfftn_wrap(plan_bck, V, U, M, {'s': s, 'axes': axes, 'overwrite_x': True}))
173173

174174
def _Xfftn_plan_scipy(shape, axes, dtype, transforms, options):
175175

mpi4py_fft/pencil.py

-2
Original file line numberDiff line numberDiff line change
@@ -471,8 +471,6 @@ def transfer(self, pencil, dtype):
471471
transfer_class = Transfer
472472
elif self.backend == 'NCCL':
473473
transfer_class = NCCLTransfer
474-
elif self.backend == 'CUDAMemCpy':
475-
transfer_class = CUDAMemCpy
476474
elif self.backend == 'customMPI':
477475
transfer_class = CustomMPITransfer
478476
else:

tests/test_transfer_classes.py

+26-15
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,17 @@
22
from mpi4py_fft.pencil import Transfer, CustomMPITransfer, Pencil, Subcomm
33
import numpy as np
44

5+
transfer_classes = [CustomMPITransfer]
6+
xps = {CustomMPITransfer: np}
7+
8+
try:
9+
import cupy as cp
10+
from mpi4py_fft.pencil import NCCLTransfer
11+
transfer_classes += [NCCLTransfer]
12+
xps[NCCLTransfer] = cp
13+
except ModuleNotFoundError:
14+
pass
15+
516

617
def get_args(comm, shape, dtype):
718
subcomm = Subcomm(comm=comm, dims=None)
@@ -20,43 +31,43 @@ def get_args(comm, shape, dtype):
2031
return kwargs
2132

2233

23-
def get_arrays(kwargs):
24-
arrayA = np.zeros(shape=kwargs['subshapeA'], dtype=kwargs['dtype'])
25-
arrayB = np.zeros(shape=kwargs['subshapeB'], dtype=kwargs['dtype'])
34+
def get_arrays(kwargs, xp):
35+
arrayA = xp.zeros(shape=kwargs['subshapeA'], dtype=kwargs['dtype'])
36+
arrayB = xp.zeros(shape=kwargs['subshapeB'], dtype=kwargs['dtype'])
2637

27-
arrayA[:] = np.random.random(arrayA.shape).astype(arrayA.dtype)
38+
arrayA[:] = xp.random.random(arrayA.shape).astype(arrayA.dtype)
2839
return arrayA, arrayB
2940

3041

31-
def single_test_all_to_allw(transfer_class, shape, dtype, comm=None):
42+
def single_test_all_to_allw(transfer_class, shape, dtype, comm=None, xp=None):
3243
comm = comm if comm else MPI.COMM_WORLD
3344
kwargs = get_args(comm, shape, dtype)
34-
arrayA, arrayB = get_arrays(kwargs)
45+
arrayA, arrayB = get_arrays(kwargs, xp)
3546
arrayB_ref = arrayB.copy()
3647

3748
transfer = transfer_class(**kwargs)
3849
reference_transfer = Transfer(**kwargs)
3950

4051
transfer.Alltoallw(arrayA, transfer._subtypesA, arrayB, transfer._subtypesB)
4152
reference_transfer.Alltoallw(arrayA, transfer._subtypesA, arrayB_ref, transfer._subtypesB)
42-
assert np.allclose(arrayB, arrayB_ref), f'Did not get the same result from `alltoallw` with {transfer_class.__name__} transfer class as MPI implementation on rank {comm.rank}!'
53+
assert xp.allclose(arrayB, arrayB_ref), f'Did not get the same result from `alltoallw` with {transfer_class.__name__} transfer class as MPI implementation on rank {comm.rank}!'
4354

4455
comm.Barrier()
4556
if comm.rank == 0:
4657
print(f'{transfer_class.__name__} passed alltoallw test with shape {shape} and dtype {dtype}')
4758

4859

49-
def single_test_forward_backward(transfer_class, shape, dtype, comm=None):
60+
def single_test_forward_backward(transfer_class, shape, dtype, comm=None, xp=None):
5061
comm = comm if comm else MPI.COMM_WORLD
5162
kwargs = get_args(comm, shape, dtype)
52-
arrayA, arrayB = get_arrays(kwargs)
63+
arrayA, arrayB = get_arrays(kwargs, xp)
5364
arrayA_ref = arrayA.copy()
5465

5566
transfer = transfer_class(**kwargs)
5667

5768
transfer.forward(arrayA, arrayB)
5869
transfer.backward(arrayB, arrayA)
59-
assert np.allclose(arrayA, arrayA_ref), f'Did not get the same result when transferring back and forth with {transfer_class.__name__} transfer class on rank {comm.rank}!'
70+
assert xp.allclose(arrayA, arrayA_ref), f'Did not get the same result when transferring back and forth with {transfer_class.__name__} transfer class on rank {comm.rank}!'
6071

6172
comm.Barrier()
6273
if comm.rank == 0:
@@ -67,14 +78,14 @@ def test_transfer_class():
6778
dims = (2, 3)
6879
sizes = (7, 8, 9, 128)
6980
dtypes = 'fFdD'
70-
transfer_class = CustomMPITransfer
7181

7282
shapes = [[size] * dim for size in sizes for dim in dims] + [[32, 256, 129]]
7383

74-
for shape in shapes:
75-
for dtype in dtypes:
76-
single_test_all_to_allw(transfer_class, shape, dtype)
77-
single_test_forward_backward(transfer_class, shape, dtype)
84+
for transfer_class in transfer_classes:
85+
for shape in shapes:
86+
for dtype in dtypes:
87+
single_test_all_to_allw(transfer_class, shape, dtype, xp=xps[transfer_class])
88+
single_test_forward_backward(transfer_class, shape, dtype, xp=xps[transfer_class])
7889

7990

8091
if __name__ == '__main__':

0 commit comments

Comments
 (0)