Skip to content

Commit 0db264a

Browse files
authored
Update visualize for main changes (#179)
* Update visualize for main changes * Update tests/__init__.py
1 parent ccf578d commit 0db264a

File tree

1 file changed

+15
-12
lines changed

1 file changed

+15
-12
lines changed

flox/visualize.py

+15-12
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import numpy as np
77
import pandas as pd
88

9-
from .core import find_group_cohorts
9+
from .core import _unique, find_group_cohorts
1010

1111

1212
def draw_mesh(
@@ -131,14 +131,14 @@ def factorize_cohorts(by, cohorts):
131131
return factorized
132132

133133

134-
def visualize_cohorts_2d(by, array, method="cohorts"):
134+
def visualize_cohorts_2d(by, array):
135135
assert by.ndim == 2
136136
print("finding cohorts...")
137137
before_merged = find_group_cohorts(
138-
by, [array.chunks[ax] for ax in range(-by.ndim, 0)], merge=False, method=method
138+
by, [array.chunks[ax] for ax in range(-by.ndim, 0)], merge=False
139139
).values()
140140
merged = find_group_cohorts(
141-
by, [array.chunks[ax] for ax in range(-by.ndim, 0)], merge=True, method=method
141+
by, [array.chunks[ax] for ax in range(-by.ndim, 0)], merge=True
142142
).values()
143143
print("finished cohorts...")
144144

@@ -149,16 +149,12 @@ def visualize_cohorts_2d(by, array, method="cohorts"):
149149
ax = ax.ravel()
150150
ax[1].set_visible(False)
151151
ax = ax[[0, 2, 3]]
152-
flat = by.ravel()
153-
ngroups = len(np.unique(flat[~np.isnan(flat)]))
154152

153+
ngroups = len(_unique(by))
155154
h0 = ax[0].imshow(by, cmap=get_colormap(ngroups))
156-
h1 = ax[1].imshow(
157-
factorize_cohorts(by, before_merged),
158-
vmin=0,
159-
cmap=get_colormap(len(before_merged)),
160-
)
161-
h2 = ax[2].imshow(factorize_cohorts(by, merged), vmin=0, cmap=get_colormap(len(merged)))
155+
h1 = _visualize_cohorts(by, before_merged, ax=ax[1])
156+
h2 = _visualize_cohorts(by, merged, ax=ax[2])
157+
162158
for axx in ax:
163159
axx.grid(True, which="both")
164160
axx.set_xticks(xticks)
@@ -170,3 +166,10 @@ def visualize_cohorts_2d(by, array, method="cohorts"):
170166
ax[1].set_title(f"{len(before_merged)} cohorts")
171167
ax[2].set_title(f"{len(merged)} merged cohorts")
172168
f.set_size_inches((6, 6))
169+
170+
171+
def _visualize_cohorts(by, cohorts, ax=None):
172+
if ax is None:
173+
_, ax = plt.subplots(1, 1)
174+
175+
ax.imshow(factorize_cohorts(by, cohorts), vmin=0, cmap=get_colormap(len(cohorts)))

0 commit comments

Comments
 (0)