Skip to content

Commit 12be7fb

Browse files
committed
chunk the mandelbrot loop
1 parent 781ad00 commit 12be7fb

File tree

2 files changed

+13
-7
lines changed

2 files changed

+13
-7
lines changed

e2e/mandelbrot/mandelbrot.png

71 Bytes
Loading

e2e/mandelbrot/mandelbrot.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,13 @@ def sq2(a):
5959
return z
6060

6161

62-
@torch.compile
63-
def step(n, c, Z, N, horizon):
64-
I = abs2(Z) < horizon**2
65-
N = np.where(I, n, N) # N[I] = n
66-
Z = np.where(I[..., None], sq2(Z) + c, Z) # Z[I] = Z[I]**2 + C[I]
62+
@torch.compile(dynamic=True)
63+
def step(n0, c, Z, N, horizon, chunksize):
64+
for j in range(chunksize):
65+
n = n0 + j
66+
I = abs2(Z) < horizon**2
67+
N = np.where(I, n, N) # N[I] = n
68+
Z = np.where(I[..., None], sq2(Z) + c, Z) # Z[I] = Z[I]**2 + C[I]
6769
return Z, N
6870

6971

@@ -75,8 +77,12 @@ def mandelbrot_c(xmin, xmax, ymin, ymax, xn, yn, horizon=2**10, maxiter=5):
7577
N = np.zeros(c.shape[:-1], dtype='int')
7678
Z = np.zeros_like(c, dtype='float32')
7779

78-
for n in range(maxiter):
79-
Z, N = step(n, c, Z, N, horizon)
80+
chunksize=10
81+
n_chunks = maxiter // chunksize
82+
83+
for i_chunk in range(n_chunks):
84+
n0 = i_chunk*chunksize
85+
Z, N = step(n0, c, Z, N, horizon, chunksize)
8086

8187
N = np.where(N == maxiter-1, 0, N) # N[N == maxiter-1] = 0
8288
return Z, N

0 commit comments

Comments
 (0)