@@ -75,20 +75,29 @@ def remove_unfinished(self):
75
75
self .neighbors_combined = deepcopy (self .neighbors )
76
76
77
77
def plot (self , * , with_sem = True ):
78
- scatter = super ().plot ()
78
+ hv = ensure_holoviews ()
79
+
80
+ if not self ._data :
81
+ return hv .Scatter ([])
82
+
83
+ xs , ys = zip (* sorted (self .data .items ()))
84
+ get = lambda attr : [getattr (self ._data [x ], attr ) for x in xs ]
85
+ sems = get ('standard_error' )
86
+ stds = get ('std' )
87
+ Ns = get ('n' )
88
+
89
+ scatter = hv .Scatter (
90
+ (xs , ys , stds , sems , Ns ),
91
+ vdims = ['mean' , 'std' , 'standard_error' , 'n' ]
92
+ )
93
+
79
94
if not with_sem :
80
- return scatter
81
-
82
- if self ._data :
83
- hv = ensure_holoviews ()
84
- xs , ys = zip (* sorted (self .data .items ()))
85
- sem = self .data_sem
86
- err = [sem [x ] if sem [x ] < sys .float_info .max
87
- else np .nan for x in xs ]
88
- spread = hv .Spread ((xs , ys , err ))
89
- return scatter * spread
95
+ plot = scatter .opts (plot = dict (tools = ['hover' ]))
90
96
else :
91
- return scatter
97
+ err = [x if x < sys .float_info .max else np .nan for x in sems ]
98
+ spread = hv .Spread ((xs , ys , err ))
99
+ plot = (scatter * spread )
100
+ return plot .opts (hv .opts .Scatter (tools = ['hover' ]))
92
101
93
102
def _set_data (self , data ):
94
103
# change dict -> DataPoint, because the points are saved using dicts
0 commit comments