6
6
import numpy as np
7
7
import pandas as pd
8
8
9
- from .core import find_group_cohorts
9
+ from .core import _unique , find_group_cohorts
10
10
11
11
12
12
def draw_mesh (
@@ -131,14 +131,14 @@ def factorize_cohorts(by, cohorts):
131
131
return factorized
132
132
133
133
134
- def visualize_cohorts_2d (by , array , method = "cohorts" ):
134
+ def visualize_cohorts_2d (by , array ):
135
135
assert by .ndim == 2
136
136
print ("finding cohorts..." )
137
137
before_merged = find_group_cohorts (
138
- by , [array .chunks [ax ] for ax in range (- by .ndim , 0 )], merge = False , method = method
138
+ by , [array .chunks [ax ] for ax in range (- by .ndim , 0 )], merge = False
139
139
).values ()
140
140
merged = find_group_cohorts (
141
- by , [array .chunks [ax ] for ax in range (- by .ndim , 0 )], merge = True , method = method
141
+ by , [array .chunks [ax ] for ax in range (- by .ndim , 0 )], merge = True
142
142
).values ()
143
143
print ("finished cohorts..." )
144
144
@@ -149,16 +149,12 @@ def visualize_cohorts_2d(by, array, method="cohorts"):
149
149
ax = ax .ravel ()
150
150
ax [1 ].set_visible (False )
151
151
ax = ax [[0 , 2 , 3 ]]
152
- flat = by .ravel ()
153
- ngroups = len (np .unique (flat [~ np .isnan (flat )]))
154
152
153
+ ngroups = len (_unique (by ))
155
154
h0 = ax [0 ].imshow (by , cmap = get_colormap (ngroups ))
156
- h1 = ax [1 ].imshow (
157
- factorize_cohorts (by , before_merged ),
158
- vmin = 0 ,
159
- cmap = get_colormap (len (before_merged )),
160
- )
161
- h2 = ax [2 ].imshow (factorize_cohorts (by , merged ), vmin = 0 , cmap = get_colormap (len (merged )))
155
+ h1 = _visualize_cohorts (by , before_merged , ax = ax [1 ])
156
+ h2 = _visualize_cohorts (by , merged , ax = ax [2 ])
157
+
162
158
for axx in ax :
163
159
axx .grid (True , which = "both" )
164
160
axx .set_xticks (xticks )
@@ -170,3 +166,10 @@ def visualize_cohorts_2d(by, array, method="cohorts"):
170
166
ax [1 ].set_title (f"{ len (before_merged )} cohorts" )
171
167
ax [2 ].set_title (f"{ len (merged )} merged cohorts" )
172
168
f .set_size_inches ((6 , 6 ))
169
+
170
+
171
+ def _visualize_cohorts (by , cohorts , ax = None ):
172
+ if ax is None :
173
+ _ , ax = plt .subplots (1 , 1 )
174
+
175
+ ax .imshow (factorize_cohorts (by , cohorts ), vmin = 0 , cmap = get_colormap (len (cohorts )))
0 commit comments