@@ -59,11 +59,13 @@ def sq2(a):
59
59
return z
60
60
61
61
62
- @torch .compile
63
- def step (n , c , Z , N , horizon ):
64
- I = abs2 (Z ) < horizon ** 2
65
- N = np .where (I , n , N ) # N[I] = n
66
- Z = np .where (I [..., None ], sq2 (Z ) + c , Z ) # Z[I] = Z[I]**2 + C[I]
62
+ @torch .compile (dynamic = True )
63
+ def step (n0 , c , Z , N , horizon , chunksize ):
64
+ for j in range (chunksize ):
65
+ n = n0 + j
66
+ I = abs2 (Z ) < horizon ** 2
67
+ N = np .where (I , n , N ) # N[I] = n
68
+ Z = np .where (I [..., None ], sq2 (Z ) + c , Z ) # Z[I] = Z[I]**2 + C[I]
67
69
return Z , N
68
70
69
71
@@ -75,8 +77,12 @@ def mandelbrot_c(xmin, xmax, ymin, ymax, xn, yn, horizon=2**10, maxiter=5):
75
77
N = np .zeros (c .shape [:- 1 ], dtype = 'int' )
76
78
Z = np .zeros_like (c , dtype = 'float32' )
77
79
78
- for n in range (maxiter ):
79
- Z , N = step (n , c , Z , N , horizon )
80
+ chunksize = 10
81
+ n_chunks = maxiter // chunksize
82
+
83
+ for i_chunk in range (n_chunks ):
84
+ n0 = i_chunk * chunksize
85
+ Z , N = step (n0 , c , Z , N , horizon , chunksize )
80
86
81
87
N = np .where (N == maxiter - 1 , 0 , N ) # N[N == maxiter-1] = 0
82
88
return Z , N
0 commit comments