Skip to content

Commit 5fb8959

Browse files
committed
WIP: ENH: enable complex dtype support unconditionally
This also gets rid of usage of the deprecated IF macro in Cython code, which resolves a bunch of build warnings.
1 parent d4f3854 commit 5fb8959

13 files changed

+344
-412
lines changed

pywt/_dwt.py

-24
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import numpy as np
44

5-
from ._c99_config import _have_c99_complex
65
from ._extensions._dwt import downcoef as _downcoef
76
from ._extensions._dwt import dwt_axis, dwt_single, idwt_axis, idwt_single
87
from ._extensions._dwt import dwt_coeff_len as _dwt_coeff_len
@@ -161,12 +160,6 @@ def dwt(data, wavelet, mode='symmetric', axis=-1):
161160
array([-0.70710678, -0.70710678, -0.70710678])
162161
163162
"""
164-
if not _have_c99_complex and np.iscomplexobj(data):
165-
data = np.asarray(data)
166-
cA_r, cD_r = dwt(data.real, wavelet, mode, axis)
167-
cA_i, cD_i = dwt(data.imag, wavelet, mode, axis)
168-
return (cA_r + 1j*cA_i, cD_r + 1j*cD_i)
169-
170163
# accept array_like input; make a copy to ensure a contiguous array
171164
dt = _check_dtype(data)
172165
data = np.asarray(data, dtype=dt, order='C')
@@ -241,17 +234,6 @@ def idwt(cA, cD, wavelet, mode='symmetric', axis=-1):
241234
raise ValueError("At least one coefficient parameter must be "
242235
"specified.")
243236

244-
# for complex inputs: compute real and imaginary separately then combine
245-
if not _have_c99_complex and (np.iscomplexobj(cA) or np.iscomplexobj(cD)):
246-
if cA is None:
247-
cD = np.asarray(cD)
248-
cA = np.zeros_like(cD)
249-
elif cD is None:
250-
cA = np.asarray(cA)
251-
cD = np.zeros_like(cA)
252-
return (idwt(cA.real, cD.real, wavelet, mode, axis) +
253-
1j*idwt(cA.imag, cD.imag, wavelet, mode, axis))
254-
255237
if cA is not None:
256238
dt = _check_dtype(cA)
257239
cA = np.asarray(cA, dtype=dt, order='C')
@@ -328,9 +310,6 @@ def downcoef(part, data, wavelet, mode='symmetric', level=1):
328310
upcoef
329311
330312
"""
331-
if not _have_c99_complex and np.iscomplexobj(data):
332-
return (downcoef(part, data.real, wavelet, mode, level) +
333-
1j*downcoef(part, data.imag, wavelet, mode, level))
334313
# accept array_like input; make a copy to ensure a contiguous array
335314
dt = _check_dtype(data)
336315
data = np.asarray(data, dtype=dt, order='C')
@@ -387,9 +366,6 @@ def upcoef(part, coeffs, wavelet, level=1, take=0):
387366
array([ 1., 2., 3., 4., 5., 6.])
388367
389368
"""
390-
if not _have_c99_complex and np.iscomplexobj(coeffs):
391-
return (upcoef(part, coeffs.real, wavelet, level, take) +
392-
1j*upcoef(part, coeffs.imag, wavelet, level, take))
393369
# accept array_like input; make a copy to ensure a contiguous array
394370
dt = _check_dtype(coeffs)
395371
coeffs = np.asarray(coeffs, dtype=dt, order='C')

pywt/_extensions/_dwt.pyx

+146-147
Large diffs are not rendered by default.

pywt/_extensions/_pywt.pxd

+5-13
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,17 @@ cimport numpy as np
33

44
np.import_array()
55

6-
include "config.pxi"
7-
86
ctypedef Py_ssize_t pywt_index_t
97

108
ctypedef fused data_t:
119
np.float32_t
1210
np.float64_t
1311

14-
cdef int have_c99_complex
15-
IF HAVE_C99_CPLX:
16-
ctypedef fused cdata_t:
17-
np.float32_t
18-
np.float64_t
19-
np.complex64_t
20-
np.complex128_t
21-
have_c99_complex = 1
22-
ELSE:
23-
ctypedef data_t cdata_t
24-
have_c99_complex = 0
12+
ctypedef fused cdata_t:
13+
np.float32_t
14+
np.float64_t
15+
np.complex64_t
16+
np.complex128_t
2517

2618
cdef public class Wavelet [type WaveletType, object WaveletObject]:
2719
cdef wavelet.DiscreteWavelet* w

pywt/_extensions/_swt.pyx

+78-83
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@ cimport numpy as np
1010
from .common cimport pywt_index_t
1111
from ._pywt cimport c_wavelet_from_object, cdata_t, Wavelet, _check_dtype
1212

13-
include "config.pxi"
14-
1513
np.import_array()
1614

1715

@@ -99,21 +97,20 @@ def swt(cdata_t[::1] data, Wavelet wavelet, size_t level, size_t start_level,
9997
&cD[0], output_len, i)
10098
if retval < 0:
10199
raise RuntimeError("C swt failed.")
102-
IF HAVE_C99_CPLX:
103-
if cdata_t is np.complex128_t:
104-
cD = np.zeros(output_len, dtype=np.complex128)
105-
with nogil:
106-
retval = c_wt.double_complex_swt_d(&data[0], data_size, wavelet.w,
107-
&cD[0], output_len, i)
108-
if retval < 0:
109-
raise RuntimeError("C swt failed.")
110-
elif cdata_t is np.complex64_t:
111-
cD = np.zeros(output_len, dtype=np.complex64)
112-
with nogil:
113-
retval = c_wt.float_complex_swt_d(&data[0], data_size, wavelet.w,
114-
&cD[0], output_len, i)
115-
if retval < 0:
116-
raise RuntimeError("C swt failed.")
100+
elif cdata_t is np.complex128_t:
101+
cD = np.zeros(output_len, dtype=np.complex128)
102+
with nogil:
103+
retval = c_wt.double_complex_swt_d(&data[0], data_size, wavelet.w,
104+
&cD[0], output_len, i)
105+
if retval < 0:
106+
raise RuntimeError("C swt failed.")
107+
elif cdata_t is np.complex64_t:
108+
cD = np.zeros(output_len, dtype=np.complex64)
109+
with nogil:
110+
retval = c_wt.float_complex_swt_d(&data[0], data_size, wavelet.w,
111+
&cD[0], output_len, i)
112+
if retval < 0:
113+
raise RuntimeError("C swt failed.")
117114

118115
# alloc memory, decompose A
119116
if cdata_t is np.float64_t:
@@ -130,21 +127,20 @@ def swt(cdata_t[::1] data, Wavelet wavelet, size_t level, size_t start_level,
130127
&cA[0], output_len, i)
131128
if retval < 0:
132129
raise RuntimeError("C swt failed.")
133-
IF HAVE_C99_CPLX:
134-
if cdata_t is np.complex128_t:
135-
cA = np.zeros(output_len, dtype=np.complex128)
136-
with nogil:
137-
retval = c_wt.double_complex_swt_a(&data[0], data_size, wavelet.w,
138-
&cA[0], output_len, i)
139-
if retval < 0:
140-
raise RuntimeError("C swt failed.")
141-
elif cdata_t is np.complex64_t:
142-
cA = np.zeros(output_len, dtype=np.complex64)
143-
with nogil:
144-
retval = c_wt.float_complex_swt_a(&data[0], data_size, wavelet.w,
145-
&cA[0], output_len, i)
146-
if retval < 0:
147-
raise RuntimeError("C swt failed.")
130+
elif cdata_t is np.complex128_t:
131+
cA = np.zeros(output_len, dtype=np.complex128)
132+
with nogil:
133+
retval = c_wt.double_complex_swt_a(&data[0], data_size, wavelet.w,
134+
&cA[0], output_len, i)
135+
if retval < 0:
136+
raise RuntimeError("C swt failed.")
137+
elif cdata_t is np.complex64_t:
138+
cA = np.zeros(output_len, dtype=np.complex64)
139+
with nogil:
140+
retval = c_wt.float_complex_swt_a(&data[0], data_size, wavelet.w,
141+
&cA[0], output_len, i)
142+
if retval < 0:
143+
raise RuntimeError("C swt failed.")
148144

149145
data = cA
150146
if not trim_approx:
@@ -253,58 +249,57 @@ cpdef swt_axis(np.ndarray data, Wavelet wavelet, size_t level,
253249
if retval:
254250
raise RuntimeError(
255251
"C wavelet transform failed with error code %d" % retval)
252+
elif data.dtype == np.complex128:
253+
cA = np.zeros(output_shape, dtype=np.complex128)
254+
with nogil:
255+
retval = c_wt.double_complex_downcoef_axis(
256+
<double complex *> data.data, data_info,
257+
<double complex *> cA.data, output_info,
258+
wavelet.w, axis,
259+
common.COEF_APPROX, common.MODE_PERIODIZATION,
260+
i, common.SWT_TRANSFORM)
261+
if retval:
262+
raise RuntimeError(
263+
"C wavelet transform failed with error code %d" %
264+
retval)
265+
cD = np.zeros(output_shape, dtype=np.complex128)
266+
with nogil:
267+
retval = c_wt.double_complex_downcoef_axis(
268+
<double complex *> data.data, data_info,
269+
<double complex *> cD.data, output_info,
270+
wavelet.w, axis,
271+
common.COEF_DETAIL, common.MODE_PERIODIZATION,
272+
i, common.SWT_TRANSFORM)
273+
if retval:
274+
raise RuntimeError(
275+
"C wavelet transform failed with error code %d" %
276+
retval)
277+
elif data.dtype == np.complex64:
278+
cA = np.zeros(output_shape, dtype=np.complex64)
279+
with nogil:
280+
retval = c_wt.float_complex_downcoef_axis(
281+
<float complex *> data.data, data_info,
282+
<float complex *> cA.data, output_info,
283+
wavelet.w, axis,
284+
common.COEF_APPROX, common.MODE_PERIODIZATION,
285+
i, common.SWT_TRANSFORM)
286+
if retval:
287+
raise RuntimeError(
288+
"C wavelet transform failed with error code %d" %
289+
retval)
290+
cD = np.zeros(output_shape, dtype=np.complex64)
291+
with nogil:
292+
retval = c_wt.float_complex_downcoef_axis(
293+
<float complex *> data.data, data_info,
294+
<float complex *> cD.data, output_info,
295+
wavelet.w, axis,
296+
common.COEF_DETAIL, common.MODE_PERIODIZATION,
297+
i, common.SWT_TRANSFORM)
298+
if retval:
299+
raise RuntimeError(
300+
"C wavelet transform failed with error code %d" %
301+
retval)
256302

257-
IF HAVE_C99_CPLX:
258-
if data.dtype == np.complex128:
259-
cA = np.zeros(output_shape, dtype=np.complex128)
260-
with nogil:
261-
retval = c_wt.double_complex_downcoef_axis(
262-
<double complex *> data.data, data_info,
263-
<double complex *> cA.data, output_info,
264-
wavelet.w, axis,
265-
common.COEF_APPROX, common.MODE_PERIODIZATION,
266-
i, common.SWT_TRANSFORM)
267-
if retval:
268-
raise RuntimeError(
269-
"C wavelet transform failed with error code %d" %
270-
retval)
271-
cD = np.zeros(output_shape, dtype=np.complex128)
272-
with nogil:
273-
retval = c_wt.double_complex_downcoef_axis(
274-
<double complex *> data.data, data_info,
275-
<double complex *> cD.data, output_info,
276-
wavelet.w, axis,
277-
common.COEF_DETAIL, common.MODE_PERIODIZATION,
278-
i, common.SWT_TRANSFORM)
279-
if retval:
280-
raise RuntimeError(
281-
"C wavelet transform failed with error code %d" %
282-
retval)
283-
elif data.dtype == np.complex64:
284-
cA = np.zeros(output_shape, dtype=np.complex64)
285-
with nogil:
286-
retval = c_wt.float_complex_downcoef_axis(
287-
<float complex *> data.data, data_info,
288-
<float complex *> cA.data, output_info,
289-
wavelet.w, axis,
290-
common.COEF_APPROX, common.MODE_PERIODIZATION,
291-
i, common.SWT_TRANSFORM)
292-
if retval:
293-
raise RuntimeError(
294-
"C wavelet transform failed with error code %d" %
295-
retval)
296-
cD = np.zeros(output_shape, dtype=np.complex64)
297-
with nogil:
298-
retval = c_wt.float_complex_downcoef_axis(
299-
<float complex *> data.data, data_info,
300-
<float complex *> cD.data, output_info,
301-
wavelet.w, axis,
302-
common.COEF_DETAIL, common.MODE_PERIODIZATION,
303-
i, common.SWT_TRANSFORM)
304-
if retval:
305-
raise RuntimeError(
306-
"C wavelet transform failed with error code %d" %
307-
retval)
308303
if retval == -5:
309304
raise TypeError("Array must be floating point, not {}"
310305
.format(data.dtype))

pywt/_extensions/c/common.h

+4-5
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,10 @@
88

99
#pragma once
1010

11-
#ifdef HAVE_C99_COMPLEX
12-
/* For templating, we need typedefs without spaces for complex types. */
13-
typedef float _Complex float_complex;
14-
typedef double _Complex double_complex;
15-
#endif
11+
/* For templating, we need typedefs without spaces for complex types. */
12+
/* FIXME: needs more portable complex types here */
13+
typedef float _Complex float_complex;
14+
typedef double _Complex double_complex;
1615

1716
/* ##### Typedefs ##### */
1817

pywt/_extensions/c/convolution.c

+10-12
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,17 @@
2020
#undef REAL_TYPE
2121
#undef TYPE
2222

23-
#ifdef HAVE_C99_COMPLEX
24-
#define TYPE float_complex
25-
#define REAL_TYPE float
26-
#include "convolution.template.c"
27-
#undef REAL_TYPE
28-
#undef TYPE
23+
#define TYPE float_complex
24+
#define REAL_TYPE float
25+
#include "convolution.template.c"
26+
#undef REAL_TYPE
27+
#undef TYPE
2928

30-
#define TYPE double_complex
31-
#define REAL_TYPE double
32-
#include "convolution.template.c"
33-
#undef REAL_TYPE
34-
#undef TYPE
35-
#endif
29+
#define TYPE double_complex
30+
#define REAL_TYPE double
31+
#include "convolution.template.c"
32+
#undef REAL_TYPE
33+
#undef TYPE
3634

3735
#endif /* REAL_TYPE */
3836
#endif /* TYPE */

pywt/_extensions/c/convolution.h

+11-13
Original file line numberDiff line numberDiff line change
@@ -22,19 +22,17 @@
2222
#undef REAL_TYPE
2323
#undef TYPE
2424

25-
#ifdef HAVE_C99_COMPLEX
26-
#define TYPE float_complex
27-
#define REAL_TYPE float
28-
#include "convolution.template.h"
29-
#undef REAL_TYPE
30-
#undef TYPE
31-
32-
#define TYPE double_complex
33-
#define REAL_TYPE double
34-
#include "convolution.template.h"
35-
#undef REAL_TYPE
36-
#undef TYPE
37-
#endif
25+
#define TYPE float_complex
26+
#define REAL_TYPE float
27+
#include "convolution.template.h"
28+
#undef REAL_TYPE
29+
#undef TYPE
30+
31+
#define TYPE double_complex
32+
#define REAL_TYPE double
33+
#include "convolution.template.h"
34+
#undef REAL_TYPE
35+
#undef TYPE
3836

3937
#endif /* REAL_TYPE */
4038
#endif /* TYPE */

pywt/_extensions/c/wt.c

+10-12
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,17 @@
2020
#undef REAL_TYPE
2121
#undef TYPE
2222

23-
#ifdef HAVE_C99_COMPLEX
24-
#define TYPE float_complex
25-
#define REAL_TYPE float
26-
#include "wt.template.c"
27-
#undef REAL_TYPE
28-
#undef TYPE
23+
#define TYPE float_complex
24+
#define REAL_TYPE float
25+
#include "wt.template.c"
26+
#undef REAL_TYPE
27+
#undef TYPE
2928

30-
#define TYPE double_complex
31-
#define REAL_TYPE double
32-
#include "wt.template.c"
33-
#undef REAL_TYPE
34-
#undef TYPE
35-
#endif
29+
#define TYPE double_complex
30+
#define REAL_TYPE double
31+
#include "wt.template.c"
32+
#undef REAL_TYPE
33+
#undef TYPE
3634

3735
#endif /* REAL_TYPE */
3836
#endif /* TYPE */

0 commit comments

Comments
 (0)