@@ -233,7 +233,7 @@ def make_trace_kwargs(args, trace_spec, g, mapping_labels, sizeref):
233
233
result ["y" ] = trendline [:, 1 ]
234
234
hover_header = "<b>LOWESS trendline</b><br><br>"
235
235
elif v == "ols" :
236
- fit_results = sm .OLS (y , sm .add_constant (x )).fit ()
236
+ fit_results = sm .OLS (y . values , sm .add_constant (x . values )).fit ()
237
237
result ["y" ] = fit_results .predict ()
238
238
hover_header = "<b>OLS trendline</b><br>"
239
239
hover_header += "%s = %f * %s + %f<br>" % (
@@ -747,10 +747,10 @@ def apply_default_cascade(args):
747
747
]
748
748
749
749
# If both marginals and faceting are specified, faceting wins
750
- if args .get ("facet_col" , None ) and args .get ("marginal_y" , None ):
750
+ if args .get ("facet_col" , None ) is not None and args .get ("marginal_y" , None ):
751
751
args ["marginal_y" ] = None
752
752
753
- if args .get ("facet_row" , None ) and args .get ("marginal_x" , None ):
753
+ if args .get ("facet_row" , None ) is not None and args .get ("marginal_x" , None ):
754
754
args ["marginal_x" ] = None
755
755
756
756
@@ -874,7 +874,7 @@ def build_dataframe(args, attrables, array_attrables):
874
874
"pandas MultiIndex is not supported by plotly express "
875
875
"at the moment." % field
876
876
)
877
- ## ----------------- argument is a col name ----------------------
877
+ # ----------------- argument is a col name ----------------------
878
878
if isinstance (argument , str ) or isinstance (
879
879
argument , int
880
880
): # just a column name given as str or int
@@ -1042,6 +1042,13 @@ def infer_config(args, constructor, trace_patch):
1042
1042
args [position ] = args ["marginal" ]
1043
1043
args [other_position ] = None
1044
1044
1045
+ if (
1046
+ args .get ("marginal_x" , None ) is not None
1047
+ or args .get ("marginal_y" , None ) is not None
1048
+ or args .get ("facet_row" , None ) is not None
1049
+ ):
1050
+ args ["facet_col_wrap" ] = 0
1051
+
1045
1052
# Compute applicable grouping attributes
1046
1053
for k in group_attrables :
1047
1054
if k in args :
@@ -1098,15 +1105,14 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
1098
1105
1099
1106
orders , sorted_group_names = get_orderings (args , grouper , grouped )
1100
1107
1101
- has_marginal_x = bool (args .get ("marginal_x" , False ))
1102
- has_marginal_y = bool (args .get ("marginal_y" , False ))
1103
-
1104
1108
subplot_type = _subplot_type_for_trace_type (constructor ().type )
1105
1109
1106
1110
trace_names_by_frame = {}
1107
1111
frames = OrderedDict ()
1108
1112
trendline_rows = []
1109
1113
nrows = ncols = 1
1114
+ col_labels = []
1115
+ row_labels = []
1110
1116
for group_name in sorted_group_names :
1111
1117
group = grouped .get_group (group_name if len (group_name ) > 1 else group_name [0 ])
1112
1118
mapping_labels = OrderedDict ()
@@ -1188,27 +1194,36 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
1188
1194
# Find row for trace, handling facet_row and marginal_x
1189
1195
if m .facet == "row" :
1190
1196
row = m .val_map [val ]
1191
- trace ._subplot_row_val = val
1197
+ if args ["facet_row" ] and len (row_labels ) < row :
1198
+ row_labels .append (args ["facet_row" ] + "=" + str (val ))
1192
1199
else :
1193
- if has_marginal_x and trace_spec .marginal != "x" :
1200
+ if (
1201
+ bool (args .get ("marginal_x" , False ))
1202
+ and trace_spec .marginal != "x"
1203
+ ):
1194
1204
row = 2
1195
1205
else :
1196
1206
row = 1
1197
1207
1198
- nrows = max (nrows , row )
1199
- if row > 1 :
1200
- trace ._subplot_row = row
1201
-
1208
+ facet_col_wrap = args .get ("facet_col_wrap" , 0 )
1202
1209
# Find col for trace, handling facet_col and marginal_y
1203
1210
if m .facet == "col" :
1204
1211
col = m .val_map [val ]
1205
- trace ._subplot_col_val = val
1212
+ if args ["facet_col" ] and len (col_labels ) < col :
1213
+ col_labels .append (args ["facet_col" ] + "=" + str (val ))
1214
+ if facet_col_wrap : # assumes no facet_row, no marginals
1215
+ row = 1 + ((col - 1 ) // facet_col_wrap )
1216
+ col = 1 + ((col - 1 ) % facet_col_wrap )
1206
1217
else :
1207
1218
if trace_spec .marginal == "y" :
1208
1219
col = 2
1209
1220
else :
1210
1221
col = 1
1211
1222
1223
+ nrows = max (nrows , row )
1224
+ if row > 1 :
1225
+ trace ._subplot_row = row
1226
+
1212
1227
ncols = max (ncols , col )
1213
1228
if col > 1 :
1214
1229
trace ._subplot_col = col
@@ -1238,7 +1253,6 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
1238
1253
if show_colorbar :
1239
1254
colorvar = "z" if constructor == go .Histogram2d else "color"
1240
1255
range_color = args ["range_color" ] or [None , None ]
1241
- d = len (args ["color_continuous_scale" ]) - 1
1242
1256
1243
1257
colorscale_validator = ColorscaleValidator ("colorscale" , "make_figure" )
1244
1258
layout_patch ["coloraxis1" ] = dict (
@@ -1260,7 +1274,7 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
1260
1274
layout_patch ["legend" ]["itemsizing" ] = "constant"
1261
1275
1262
1276
fig = init_figure (
1263
- args , subplot_type , frame_list , ncols , nrows , has_marginal_x , has_marginal_y
1277
+ args , subplot_type , frame_list , nrows , ncols , col_labels , row_labels
1264
1278
)
1265
1279
1266
1280
# Position traces in subplots
@@ -1290,49 +1304,39 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
1290
1304
return fig
1291
1305
1292
1306
1293
- def init_figure (
1294
- args , subplot_type , frame_list , ncols , nrows , has_marginal_x , has_marginal_y
1295
- ):
1307
+ def init_figure (args , subplot_type , frame_list , nrows , ncols , col_labels , row_labels ):
1296
1308
# Build subplot specs
1297
1309
specs = [[{}] * ncols for _ in range (nrows )]
1298
- column_titles = [None ] * ncols
1299
- row_titles = [None ] * nrows
1300
1310
for frame in frame_list :
1301
1311
for trace in frame ["data" ]:
1302
1312
row0 = trace ._subplot_row - 1
1303
1313
col0 = trace ._subplot_col - 1
1304
-
1305
1314
if isinstance (trace , go .Splom ):
1306
1315
# Splom not compatible with make_subplots, treat as domain
1307
1316
specs [row0 ][col0 ] = {"type" : "domain" }
1308
1317
else :
1309
1318
specs [row0 ][col0 ] = {"type" : trace .type }
1310
- if args .get ("facet_row" , None ) and hasattr (trace , "_subplot_row_val" ):
1311
- row_titles [row0 ] = args ["facet_row" ] + "=" + str (trace ._subplot_row_val )
1312
-
1313
- if args .get ("facet_col" , None ) and hasattr (trace , "_subplot_col_val" ):
1314
- column_titles [col0 ] = (
1315
- args ["facet_col" ] + "=" + str (trace ._subplot_col_val )
1316
- )
1317
1319
1318
1320
# Default row/column widths uniform
1319
1321
column_widths = [1.0 ] * ncols
1320
1322
row_heights = [1.0 ] * nrows
1321
1323
1322
1324
# Build column_widths/row_heights
1323
1325
if subplot_type == "xy" :
1324
- if has_marginal_x :
1326
+ if bool ( args . get ( "marginal_x" , False )) :
1325
1327
if args ["marginal_x" ] == "histogram" or ("color" in args and args ["color" ]):
1326
1328
main_size = 0.74
1327
1329
else :
1328
1330
main_size = 0.84
1329
1331
1330
1332
row_heights = [main_size ] * (nrows - 1 ) + [1 - main_size ]
1331
1333
vertical_spacing = 0.01
1334
+ elif args .get ("facet_col_wrap" , 0 ):
1335
+ vertical_spacing = 0.07
1332
1336
else :
1333
1337
vertical_spacing = 0.03
1334
1338
1335
- if has_marginal_y :
1339
+ if bool ( args . get ( "marginal_y" , False )) :
1336
1340
if args ["marginal_y" ] == "histogram" or ("color" in args and args ["color" ]):
1337
1341
main_size = 0.74
1338
1342
else :
@@ -1351,15 +1355,25 @@ def init_figure(
1351
1355
vertical_spacing = 0.1
1352
1356
horizontal_spacing = 0.1
1353
1357
1358
+ facet_col_wrap = args .get ("facet_col_wrap" , 0 )
1359
+ if facet_col_wrap :
1360
+ subplot_labels = [None ] * nrows * ncols
1361
+ while len (col_labels ) < nrows * ncols :
1362
+ col_labels .append (None )
1363
+ for i in range (nrows ):
1364
+ for j in range (ncols ):
1365
+ subplot_labels [i * ncols + j ] = col_labels [(nrows - 1 - i ) * ncols + j ]
1366
+
1354
1367
# Create figure with subplots
1355
1368
fig = make_subplots (
1356
1369
rows = nrows ,
1357
1370
cols = ncols ,
1358
1371
specs = specs ,
1359
1372
shared_xaxes = "all" ,
1360
1373
shared_yaxes = "all" ,
1361
- row_titles = list (reversed (row_titles )),
1362
- column_titles = column_titles ,
1374
+ row_titles = [] if facet_col_wrap else list (reversed (row_labels )),
1375
+ column_titles = [] if facet_col_wrap else col_labels ,
1376
+ subplot_titles = subplot_labels if facet_col_wrap else [],
1363
1377
horizontal_spacing = horizontal_spacing ,
1364
1378
vertical_spacing = vertical_spacing ,
1365
1379
row_heights = row_heights ,
0 commit comments