-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy path_wrapper.py
782 lines (578 loc) · 21.3 KB
/
_wrapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
"""A thin pytorch / numpy compat layer.
Things imported from here have numpy-compatible signatures but operate on
pytorch tensors.
"""
#import numpy as np
import torch
from . import _util
from . import _dtypes
from . import _helpers
from ._ndarray import ndarray, asarray, array, asarray_replacer, newaxis
from ._ndarray import can_cast, result_type
# Things to decide on (punt for now)
#
# 1. Q: What are the return types of wrapper functions: plain torch.Tensors or
# wrapper ndarrays.
# A: Wrapper ndarrays.
#
# 2. Q: Default dtypes: numpy defaults to float64, pytorch defaults to float32
# A: Stick to pytorch defaults?
# NB: numpy recommends `dtype=float`?
#
# 3. Q: Masked arrays. Record, structured arrays.
# A: Ignore for now
#
# 4. Q: What are the defaults for pytorch-specific args not present in numpy signatures?
# device=..., requires_grad=... etc
# A: ignore, keep whatever they are from inputs; test w/various combinations
#
# 5. Q: What is the useful action for numpy-specific arguments? e.g. like=...
# A: like=... and subok=True both raise ValueErrors.
# initial=... can be useful though, punt on
# where=... punt on for now
# TODO
# 1. Mapping of the numpy layout ('C', 'K' etc) to torch layout / memory_format.
# 2. np.dtype <-> torch.dtype
# 3. numpy type casting rules (to be cleaned up in numpy: follow old or new)
#
# 4. wrap/unwrap/wrap patterns:
# - inputs are scalars, output is an array
# - two-arg functions (second may be None)
# - first arg is a sequence/tuple (_stack familty, concatenate, atleast_Nd etc)
# - optional out arg
#
# 5. handle the out= arg: verify dimensions, handle dtype (blocked on dtype decision)
NoValue = None
###### array creation routines
def copy(a, order='K', subok=False):
a = asarray(a)
_util.subok_not_ok(subok=subok)
if order != 'K':
raise NotImplementedError
# XXX: ndarray.copy only accepts order='C'
return a.copy(order='C')
def atleast_1d(*arys):
res = torch.atleast_1d([asarray(a).get() for a in arys])
if len(res) == 1:
return asarray(res[0])
else:
return list(asarray(_) for _ in res)
def atleast_2d(*arys):
res = torch.atleast_2d([asarray(a).get() for a in arys])
if len(res) == 1:
return asarray(res[0])
else:
return list(asarray(_) for _ in res)
def atleast_3d(*arys):
res = torch.atleast_3d([asarray(a).get() for a in arys])
if len(res) == 1:
return asarray(res[0])
else:
return list(asarray(_) for _ in res)
def vstack(tup, *, dtype=None, casting='same_kind'):
arrs = atleast_2d(*tup)
if not isinstance(arrs, list):
arrs = [arrs]
return concatenate(arrs, 0, dtype=dtype, casting=casting)
row_stack = vstack
def hstack(tup, *, dtype=None, casting='same_kind'):
arrs = atleast_1d(*tup)
if not isinstance(arrs, list):
arrs = [arrs]
# As a special case, dimension 0 of 1-dimensional arrays is "horizontal"
if arrs and arrs[0].ndim == 1:
return concatenate(arrs, 0, dtype=dtype, casting=casting)
else:
return concatenate(arrs, 1, dtype=dtype, casting=casting)
def dstack(tup, *, dtype=None, casting='same_kind'):
# XXX: in numpy 1.24 dstack does not have dtype and casting keywords
# but {h,v}stack do. Hence add them here for consistency.
arrs = atleast_3d(*tup)
if not isinstance(arrs, list):
arrs = [arrs]
return concatenate(arrs, 2, dtype=dtype, casting=casting)
def column_stack(tup, *, dtype=None, casting='same_kind'):
# XXX: in numpy 1.24 column_stack does not have dtype and casting keywords
# but row_stack does. (because row_stack is an alias for vstack, really).
# Hence add these keywords here for consistency.
arrays = []
for v in tup:
arr = asarray(v)
if arr.ndim < 2:
arr = array(arr, copy=False, ndmin=2).T
arrays.append(arr)
return concatenate(arrays, 1, dtype=dtype, casting=casting)
def stack(arrays, axis=0, out=None, *, dtype=None, casting='same_kind'):
arrays = [asarray(arr) for arr in arrays]
if not arrays:
raise ValueError('need at least one array to stack')
shapes = {arr.shape for arr in arrays}
if len(shapes) != 1:
raise ValueError('all input arrays must have the same shape')
result_ndim = arrays[0].ndim + 1
axis = _util.normalize_axis_index(axis, result_ndim)
sl = (slice(None),) * axis + (newaxis,)
expanded_arrays = [arr[sl] for arr in arrays]
return concatenate(expanded_arrays, axis=axis, out=out,
dtype=dtype, casting=casting)
def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None,
axis=0):
if axis != 0 or retstep or not endpoint:
raise NotImplementedError
# XXX: raises TypeError if start or stop are not scalars
return asarray(torch.linspace(start, stop, num, dtype=dtype))
def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0):
if axis != 0 or not endpoint:
raise NotImplementedError
tstart, tstop = torch.as_tensor([start, stop])
base = torch.pow(tstop / tstart, 1./(num-1))
result = torch.logspace(torch.log(tstart)/torch.log(base),
torch.log(tstop)/torch.log(base), num, base=base)
return asarray(result)
def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, axis=0):
if axis != 0 or not endpoint:
raise NotImplementedError
return asarray(torch.logspace(start, stop, num, base=base, dtype=dtype))
def arange(start=None, stop=None, step=1, dtype=None, *, like=None):
_util.subok_not_ok(like)
if step == 0:
raise ZeroDivisionError
if stop is None and start is None:
raise TypeError
if stop is None:
# XXX: this breaks if start is passed as a kwarg:
# arange(start=4) should raise (no stop) but doesn't
start, stop = 0, start
if start is None:
start = 0
if dtype is None:
dtype = _dtypes.default_int_type()
dtype = result_type(start, stop, step, dtype)
torch_dtype = _dtypes.torch_dtype_from(dtype)
start, stop, step = _helpers.to_tensors(start, stop, step)
try:
return asarray(torch.arange(start, stop, step, dtype=torch_dtype))
except RuntimeError:
raise ValueError("Maximum allowed size exceeded")
def empty(shape, dtype=float, order='C', *, like=None):
_util.subok_not_ok(like)
if order != 'C':
raise NotImplementedError
torch_dtype = _dtypes.torch_dtype_from(dtype)
return asarray(torch.empty(shape, dtype=torch_dtype))
# NB: *_like function deliberately deviate from numpy: it has subok=True
# as the default; we set subok=False and raise on anything else.
@asarray_replacer()
def empty_like(prototype, dtype=None, order='K', subok=False, shape=None):
_util.subok_not_ok(subok=subok)
if order != 'K':
raise NotImplementedError
torch_dtype = _dtypes.torch_dtype_from(dtype)
result = torch.empty_like(prototype, dtype=torch_dtype)
if shape is not None:
result = result.reshape(shape)
return result
def full(shape, fill_value, dtype=None, order='C', *, like=None):
_util.subok_not_ok(like)
if order != 'C':
raise NotImplementedError
if isinstance(fill_value, ndarray):
fill_value = fill_value.get()
torch_dtype = _dtypes.torch_dtype_from(dtype)
return asarray(torch.full(shape, fill_value, dtype=torch_dtype))
@asarray_replacer()
def full_like(a, fill_value, dtype=None, order='K', subok=False, shape=None):
_util.subok_not_ok(subok=subok)
if order != 'K':
raise NotImplementedError
torch_dtype = _dtypes.torch_dtype_from(dtype)
result = torch.full_like(a, fill_value, dtype=torch_dtype)
if shape is not None:
result = result.reshape(shape)
return result
def ones(shape, dtype=None, order='C', *, like=None):
_util.subok_not_ok(like)
if order != 'C':
raise NotImplementedError
torch_dtype = _dtypes.torch_dtype_from(dtype)
return asarray(torch.ones(shape, dtype=torch_dtype))
@asarray_replacer()
def ones_like(a, dtype=None, order='K', subok=False, shape=None):
_util.subok_not_ok(subok=subok)
if order != 'K':
raise NotImplementedError
torch_dtype = _dtypes.torch_dtype_from(dtype)
result = torch.ones_like(a, dtype=torch_dtype)
if shape is not None:
result = result.reshape(shape)
return result
# XXX: dtype=float
def zeros(shape, dtype=float, order='C', *, like=None):
_util.subok_not_ok(like)
if order != 'C':
raise NotImplementedError
torch_dtype = _dtypes.torch_dtype_from(dtype)
return asarray(torch.zeros(shape, dtype=torch_dtype))
@asarray_replacer()
def zeros_like(a, dtype=None, order='K', subok=False, shape=None):
_util.subok_not_ok(subok=subok)
if order != 'K':
raise NotImplementedError
torch_dtype = _dtypes.torch_dtype_from(dtype)
result = torch.zeros_like(a, dtype=torch_dtype)
if shape is not None:
result = result.reshape(shape)
return result
# XXX: dtype=float
def eye(N, M=None, k=0, dtype=float, order='C', *, like=None):
_util.subok_not_ok(like)
if order != 'C':
raise NotImplementedError
if M is None:
M = N
z = torch.zeros(N, M, dtype=dtype)
z.diagonal(k).fill_(1)
return asarray(z)
def identity(n, dtype=None, *, like=None):
_util.subok_not_ok(like)
return asarray(torch.eye(n, dtype=dtype))
###### misc/unordered
#YYY: pattern: initial=...
@asarray_replacer()
def prod(a, axis=None, dtype=None, out=None, keepdims=NoValue,
initial=NoValue, where=NoValue):
if initial is not None or where is not None:
raise NotImplementedError
if axis is None:
if keepdims is not None:
raise NotImplementedError
return torch.prod(a, dtype=dtype)
elif _util.is_sequence(axis):
raise NotImplementedError
return torch.prod(a, dim=axis, dtype=dtype, keepdim=bool(keepdims), out=out)
@asarray_replacer()
def corrcoef(x, y=None, rowvar=True, bias=NoValue, ddof=NoValue, *, dtype=None):
if bias is not None or ddof is not None:
# deprecated in NumPy
raise NotImplementedError
if y is not None:
# go figure what it means, XXX
raise NotImplementedError
if rowvar is False:
x = x.T
if dtype is not None:
x = x.to(dtype)
return torch.corrcoef(x)
def concatenate(ar_tuple, axis=0, out=None, dtype=None, casting="same_kind"):
if out is not None:
if dtype is not None:
# mimic numpy
raise TypeError("concatenate() only takes `out` or `dtype` as an "
"argument, but both were provided.")
if not isinstance(out, ndarray):
raise ValueError("'out' must be an array")
if ar_tuple == ():
# XXX: RuntimeError in torch, ValueError in numpy
raise ValueError("need at least one array to concatenate")
# make sure inputs are arrays
arrays = tuple(asarray(ar) for ar in ar_tuple)
# np.concatenate ravels if axis=None
arrays, axis = _helpers.axis_none_ravel(*arrays, axis=axis)
# figure out the type of the inputs and outputs
if out is None and dtype is None:
out_dtype = None
tensors = tuple(ar.get() for ar in arrays)
else:
out_dtype = _dtypes.dtype(dtype) if dtype is not None else out.dtype
# cast input arrays if necessary; do not broadcast them agains `out`
tensors = _helpers.cast_dont_broadcast(arrays, out_dtype, casting)
try:
result = torch.cat(tensors, axis)
except (IndexError, RuntimeError):
raise _util.AxisError
return _helpers.result_or_out(result, out)
@asarray_replacer()
def bincount(x, /, weights=None, minlength=0):
return torch.bincount(x, weights, minlength)
# YYY: pattern: sequence of arrays
def where(condition, x=None, y=None, /):
selector = (x is None) == (y is None)
if not selector:
raise ValueError("either both or neither of x and y should be given")
condition = asarray(condition).get()
if x is None and y is None:
return tuple(asarray(_) for _ in torch.where(condition))
x = asarray(condition).get()
y = asarray(condition).get()
return asarray(torch.where(condition, x, y))
###### module-level queries of object properties
def ndim(a):
a = asarray(a).get()
return a.ndim
def shape(a):
a = asarray(a).get()
return tuple(a.shape)
def size(a, axis=None):
a = asarray(a).get()
if axis is None:
return a.numel()
else:
return a.shape[axis]
###### shape manipulations and indexing
def transpose(a, axes=None):
arr = asarray(a)
return arr.transpose(axes)
def reshape(a, newshape, order='C'):
arr = asarray(a)
return arr.reshape(*newshape, order=order)
def ravel(a, order='C'):
arr = asarray(a)
return arr.ravel(order=order)
def squeeze(a, axis=None):
arr = asarray(a)
return arr.squeeze(axis)
def expand_dims(a, axis):
a = asarray(a)
shape = _util.expand_shape(a.shape, axis)
tensor = a.get().view(shape) # never copies
return ndarray._from_tensor_and_base(tensor, a)
@asarray_replacer()
def flip(m, axis=None):
# XXX: semantic difference: np.flip returns a view, torch.flip copies
if axis is None:
axis = tuple(range(m.ndim))
else:
axis = _util.normalize_axis_tuple(axis, m.ndim)
return torch.flip(m, axis)
@asarray_replacer()
def broadcast_to(array, shape, subok=False):
_util.subok_not_ok(subok=subok)
return torch.broadcast_to(array, size=shape)
from torch import broadcast_shapes
# YYY: pattern: tuple of arrays as input, tuple of arrays as output; cf nonzero
def broadcast_arrays(*args, subok=False):
_util.subok_not_ok(subok=subok)
res = torch.broadcast_tensors(*[asarray(a).get() for a in args])
return tuple(asarray(_) for _ in res)
@asarray_replacer()
def moveaxis(a, source, destination):
return asarray(torch.moveaxis(a, source, destination))
def unravel_index(indices, shape, order='C'):
# cf https://github.com/pytorch/pytorch/pull/66687
# this version is from
# https://discuss.pytorch.org/t/how-to-do-a-unravel-index-in-pytorch-just-like-in-numpy/12987/3
if order != 'C':
raise NotImplementedError
result = []
for index in indices:
out = []
for dim in reversed(shape):
out.append(index % dim)
index = index // dim
result.append(tuple(reversed(out)))
return result
def ravel_multi_index(multi_index, dims, mode='raise', order='C'):
# XXX: not available in pytorch, implement
return sum(idx*dim for idx, dim in zip(multi_index, dims))
def nonzero(a):
arr = asarray(a)
return arr.nonzero()
def flatnonzero(a):
arr = asarray(a)
return nonzero(arr.ravel())[0]
def argwhere(a):
arr = asarray(a)
tensor = arr.get()
return asarray(torch.argwhere(tensor))
def abs(a):
# FIXME: should go the other way, together with other ufuncs
arr = asarray(a)
return a.__abs__()
from ._ndarray import axis_out_keepdims_wrapper
@axis_out_keepdims_wrapper
def count_nonzero(a, axis=None, *, keepdims=False):
# XXX: this all should probably be generalized to a sum(a != 0, dtype=bool)
try:
tensor = a.get().count_nonzero(axis)
except RuntimeError:
raise ValueError
return tensor
@asarray_replacer()
def roll(a, shift, axis=None):
return a.roll(shift, axis)
@asarray_replacer()
def round_(a, decimals=0, out=None):
if torch.is_floating_point(a):
return torch.round(a, decimals=decimals, out=out)
else:
return a
around = round_
###### tri{l, u} and related
@asarray_replacer()
def tril(m, k=0):
return m.tril(k)
@asarray_replacer()
def triu(m, k=0):
return m.triu(k)
# YYY: pattern: return sequence
def tril_indices(n, k=0, m=None):
if m is None:
m = n
tensor_2 = torch.tril_indices(n, m, offset=k)
return tuple(asarray(_) for _ in tensor_2)
def triu_indices(n, k=0, m=None):
if m is None:
m = n
tensor_2 = torch.tril_indices(n, m, offset=k)
return tuple(asarray(_) for _ in tensor_2)
# YYY: pattern: array in, sequence of arrays out
def tril_indices_from(arr, k=0):
arr = asarray(arr).get()
if arr.ndim != 2:
raise ValueError("input array must be 2-d")
tensor_2 = torch.tril_indices(arr.shape[0], arr.shape[1], offset=k)
return tuple(asarray(_) for _ in tensor_2)
def triu_indices_from(arr, k=0):
arr = asarray(arr).get()
if arr.ndim != 2:
raise ValueError("input array must be 2-d")
tensor_2 = torch.tril_indices(arr.shape[0], arr.shape[1], offset=k)
return tuple(asarray(_) for _ in tensor_2)
def tri(N, M=None, k=0, dtype=float, *, like=None):
_util.subok_not_ok(like)
tensor = torch.tril(torch.ones((N, M), dtype=dtype), diagonal=k)
return asarray(tensor)
###### reductions
# YYY: pattern : argmax, argmin
def argmax(a, axis=None, out=None, *, keepdims=NoValue):
arr = asarray(a)
return arr.argmax(axis=axis, out=out, keepdims=keepdims)
def argmin(a, axis=None, out=None, *, keepdims=NoValue):
arr = asarray(a)
return arr.argmin(axis=axis, out=out, keepdims=keepdims)
def amax(a, axis=None, out=None, keepdims=NoValue, initial=NoValue,
where=NoValue):
arr = asarray(a)
return arr.max(axis=axis, out=out, keepdims=keepdims, initial=initial,
where=where)
max = amax
def amin(a, axis=None, out=None, keepdims=NoValue, initial=NoValue,
where=NoValue):
arr = asarray(a)
return arr.min(axis=axis, out=out, keepdims=keepdims, initial=initial,
where=where)
min = amin
def all(a, axis=None, out=None, keepdims=NoValue, *, where=NoValue):
arr = asarray(a)
return arr.all(axis=axis, out=out, keepdims=keepdims, where=where)
def any(a, axis=None, out=None, keepdims=NoValue, *, where=NoValue):
arr = asarray(a)
return arr.any(axis=axis, out=out, keepdims=keepdims, where=where)
# YYY: pattern: dtype kwarg, None not accepted
def mean(a, axis=None, dtype=None, out=None, keepdims=NoValue, *, where=NoValue):
arr = asarray(a)
return arr.mean(axis=axis, dtype=dtype, out=out, keepdims=keepdims, where=where)
def sum(a, axis=None, dtype=None, out=None, keepdims=NoValue,
initial=NoValue, where=NoValue):
arr = asarray(a)
return arr.sum(axis=axis, dtype=dtype, out=out, keepdims=keepdims,
initial=initial, where=where)
@asarray_replacer()
def nanmean(a, axis=None, dtype=None, out=None, keepdims=NoValue, *, where=NoValue):
if where is not None:
raise NotImplementedError
if dtype is None:
dtype = a.dtype
if axis is None:
result = a.nanmean(dtype=dtype)
if keepdims:
result = torch.full(a.shape, result, dtype=result.dtype)
else:
result = a.nanmean(dtype=dtype, dim=axis, keepdim=bool(keepdims))
if out is not None:
out.copy_(result)
return result
# YYY: pattern : std, var
@asarray_replacer()
def std(a, axis=None, dtype=None, out=None, ddof=0, keepdims=NoValue, *, where=NoValue):
if where is not None:
raise NotImplementedError
if dtype is not None:
raise NotImplementedError
if not torch.is_floating_point(a):
a = a * 1.0
return torch.std(a, axis, correction=ddof, keepdim=bool(keepdims), out=out)
@asarray_replacer()
def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=NoValue, *, where=NoValue):
if where is not None:
raise NotImplementedError
if dtype is not None:
raise NotImplementedError
if not torch.is_floating_point(a):
a = a * 1.0
return torch.var(a, axis, correction=ddof, keepdim=bool(keepdims), out=out)
@asarray_replacer()
def argsort(a, axis=-1, kind=None, order=None):
if order is not None:
raise NotImplementedError
stable = True if kind == 'stable' else False
if axis is None:
axis = -1
return torch.argsort(a, stable=stable, dim=axis, descending=False)
##### math functions
@asarray_replacer()
def angle(z, deg=False):
result = torch.angle(z)
if deg:
result *= 180 / torch.pi
return result
def real(a):
arr = asarray(a)
return arr.real
def imag(a):
arr = asarray(a)
return arr.imag
@asarray_replacer()
def real_if_close(a, tol=100):
if not torch.is_complex(a):
return a
if torch.abs(torch.imag) < tol * torch.finfo(a.dtype).eps:
return torch.real(a)
else:
return a
@asarray_replacer()
def iscomplex(x):
if torch.is_complex(x):
return torch.as_tensor(x).imag != 0
result = torch.zeros_like(x, dtype=torch.bool)
return result[()]
@asarray_replacer()
def isreal(x):
if torch.is_complex(x):
return torch.as_tensor(x).imag == 0
result = torch.zeros_like(x, dtype=torch.bool)
return result[()]
@asarray_replacer()
def iscomplexobj(x):
return torch.is_complex(x)
@asarray_replacer()
def isrealobj(x):
return not torch.is_complex(x)
@asarray_replacer()
def isneginf(x, out=None):
return torch.isneginf(x, out=out)
@asarray_replacer()
def isposinf(x, out=None):
return torch.isposinf(x, out=out)
@asarray_replacer()
def i0(x):
return torch.special.i0(x)
def isscalar(a):
# XXX: this is a stub
try:
arr = asarray(a)
return arr.size == 1
except Exception:
return False
###### mapping from numpy API objects to wrappers from this module ######
# All is in the mapping dict in _mapping.py