1
- from typing import Tuple
1
+ from typing import Tuple , Union
2
2
3
3
import numpy as np
4
4
from skimage .draw import line , line_aa
5
5
6
6
_TAU = np .pi * 2
7
7
8
+ Color = Union [int , float , Tuple [int , ...], Tuple [float , ...]]
9
+
8
10
9
11
class Turtle :
10
- def __init__ (self , array , deg : bool = False , aa : bool = False ):
12
+ def __init__ (self , array : np . ndarray , deg : bool = False , aa : bool = False ):
11
13
"""Draw on a NumPy array using turtle graphics.
12
14
13
15
Starts at (0, 0) (top-left corner) with a direction of 0 (pointing
@@ -16,18 +18,23 @@ def __init__(self, array, deg: bool=False, aa: bool=False):
16
18
Parameters
17
19
----------
18
20
array: np.ndarray
19
- The 2D array to write to
21
+ The 2D array to write to. Can be either of shape h x w (grayscale),
22
+ h x w x c (e.g. rgb for c=3 channels).
23
+ The dtype is used to determine the color depth of each channel:
24
+
25
+ * `bool` for 2 colors.
26
+ * All `np.integer` subtypes for discrete depth, ranging from 0 to
27
+ its max value (e.g. `np.uint8` for values in 0 - 255).
28
+ * All `np.floating` subtypes for continuous depth, ranging from 0
29
+ to 1.
30
+
20
31
deg : :obj:`bool`, optional
21
32
Use degrees instead of radians.
22
33
aa : :obj:`bool`, optional
23
34
Enable anti-aliasing.
24
35
"""
25
- assert type (array ) is np .ndarray , 'Array should be a NumPy ndarray'
26
- assert array .ndim == 2 , 'Only 2D arrays are supported'
27
- assert (
28
- np .issubdtype (array .dtype , np .integer ) or
29
- np .issubdtype (array .dtype , np .floating )
30
- ), '{} is unsupported' .format (array .dtype )
36
+ if type (array ) is not np .ndarray :
37
+ raise TypeError ('Array should be a NumPy ndarray' )
31
38
32
39
self .array = array
33
40
self .aa = aa
@@ -37,10 +44,32 @@ def __init__(self, array, deg: bool=False, aa: bool=False):
37
44
self .__r , self .__c = 0 , 0
38
45
self .__stack = []
39
46
40
- if np .issubdtype (array .dtype , np .integer ):
41
- self .__color = np .iinfo (array .dtype ).max
47
+ if array .ndim == 2 :
48
+ self .__channels = 1
49
+ elif array .ndim == 3 :
50
+ self .__channels = array .shape [2 ]
51
+ else :
52
+ raise TypeError ('Array does not have 2 or 3 dimensions' )
53
+
54
+ if array .dtype == np .dtype (bool ):
55
+ self .__depth = 1
56
+ self .__dtype = bool
57
+ elif np .issubdtype (array .dtype , np .integer ):
58
+ self .__depth = np .iinfo (array .dtype ).max
59
+ self .__dtype = int
42
60
elif np .issubdtype (array .dtype , np .floating ):
43
- self .__color = np .finfo (array .dtype ).max
61
+ self .__depth = 1.0
62
+ self .__dtype = float
63
+ else :
64
+ raise TypeError (
65
+ 'Array should have a bool, int-like, or float-like dtype'
66
+ )
67
+
68
+ # color initially the max depth (white).
69
+ if self .__channels == 1 :
70
+ self .__color = self .__depth
71
+ else :
72
+ self .__color = np .full (self .__channels , self .__depth , self .__dtype )
44
73
45
74
def __in_array (self , r = None , c = None ):
46
75
r = self .__r if r is None else r
@@ -58,10 +87,15 @@ def __draw_line(self, new_c, new_r):
58
87
59
88
if self .aa :
60
89
rr , cc , val = line_aa (r0 , c0 , r1 , c1 )
61
- self .array [rr , cc ] = (val / 255 * self .__color ).astype (self .array .dtype )
62
90
else :
63
91
rr , cc = line (r0 , c0 , r1 , c1 )
64
- self .array [rr , cc ] = self .__color
92
+ val = 1
93
+
94
+ if self .__channels == 1 :
95
+ self .array [rr , cc ] = val * self .__color
96
+ else :
97
+ for c in range (self .__channels ):
98
+ self .array [rr , cc , c ] = val * self .__color [c ]
65
99
66
100
def forward (self , distance : float ):
67
101
"""Move in the current direction and draw a line with Euclidian
@@ -123,10 +157,19 @@ def position(self, rc: Tuple[int, int]):
123
157
self .__r , self .__c = rc
124
158
125
159
@property
126
- def color (self ) -> float :
127
- """float: Grayscale color"""
128
- return self .__color
160
+ def color (self ) -> Color :
161
+ """int, float, tuple of int or tuple of float: Grayscale color"""
162
+ if self .__channels == 1 :
163
+ return self .__color
164
+ else :
165
+ return tuple (self .__color )
129
166
130
167
@color .setter
131
- def color (self , c : float ):
132
- self .__color = c
168
+ def color (self , c : Color ):
169
+ if not np .isscalar (c ) and len (c ) != self .__channels :
170
+ raise TypeError ('Invalid amount of color values' )
171
+ for _c in [c ] if np .isscalar (c ) else c :
172
+ if _c < 0 or _c > self .__depth :
173
+ raise ValueError ('Color value out of range' )
174
+
175
+ self .__color = np .array (c , dtype = self .__dtype )
0 commit comments