diff --git a/src/napari_matplotlib/base.py b/src/napari_matplotlib/base.py index fda2d2d5..8c717d6a 100644 --- a/src/napari_matplotlib/base.py +++ b/src/napari_matplotlib/base.py @@ -14,7 +14,7 @@ from .util import Interval, from_napari_css_get_size_of -__all__ = ["BaseNapariMPLWidget", "NapariMPLWidget"] +__all__ = ["BaseNapariMPLWidget", "NapariMPLWidget", "SingleAxesWidget"] class BaseNapariMPLWidget(QWidget): @@ -270,6 +270,27 @@ def on_update_layers(self) -> None: """ +class SingleAxesWidget(NapariMPLWidget): + """ + In addition to `NapariMPLWidget`, this sets up a single axes and + the callback to clear it. + """ + + def __init__( + self, + napari_viewer: napari.viewer.Viewer, + parent: Optional[QWidget] = None, + ): + super().__init__(napari_viewer=napari_viewer, parent=parent) + self.add_single_axes() + + def clear(self) -> None: + """ + Clear the axes. + """ + self.axes.clear() + + class NapariNavigationToolbar(NavigationToolbar2QT): """Custom Toolbar style for Napari.""" diff --git a/src/napari_matplotlib/histogram.py b/src/napari_matplotlib/histogram.py index 1e273e70..39ad41a3 100644 --- a/src/napari_matplotlib/histogram.py +++ b/src/napari_matplotlib/histogram.py @@ -4,7 +4,7 @@ import numpy as np from qtpy.QtWidgets import QWidget -from .base import NapariMPLWidget +from .base import SingleAxesWidget from .util import Interval __all__ = ["HistogramWidget"] @@ -12,7 +12,7 @@ _COLORS = {"r": "tab:red", "g": "tab:green", "b": "tab:blue"} -class HistogramWidget(NapariMPLWidget): +class HistogramWidget(SingleAxesWidget): """ Display a histogram of the currently selected layer. """ @@ -26,15 +26,8 @@ def __init__( parent: Optional[QWidget] = None, ): super().__init__(napari_viewer, parent=parent) - self.add_single_axes() self._update_layers(None) - def clear(self) -> None: - """ - Clear the axes. - """ - self.axes.clear() - def draw(self) -> None: """ Clear the axes and histogram the currently selected layer/slice. diff --git a/src/napari_matplotlib/scatter.py b/src/napari_matplotlib/scatter.py index c439677b..334f941c 100644 --- a/src/napari_matplotlib/scatter.py +++ b/src/napari_matplotlib/scatter.py @@ -4,13 +4,13 @@ import numpy.typing as npt from qtpy.QtWidgets import QComboBox, QLabel, QVBoxLayout, QWidget -from .base import NapariMPLWidget +from .base import SingleAxesWidget from .util import Interval __all__ = ["ScatterBaseWidget", "ScatterWidget", "FeaturesScatterWidget"] -class ScatterBaseWidget(NapariMPLWidget): +class ScatterBaseWidget(SingleAxesWidget): """ Base class for widgets that scatter two datasets against each other. """ @@ -19,20 +19,6 @@ class ScatterBaseWidget(NapariMPLWidget): # the scatter is plotted as a 2D histogram _threshold_to_switch_to_histogram = 500 - def __init__( - self, - napari_viewer: napari.viewer.Viewer, - parent: Optional[QWidget] = None, - ): - super().__init__(napari_viewer, parent=parent) - self.add_single_axes() - - def clear(self) -> None: - """ - Clear the axes. - """ - self.axes.clear() - def draw(self) -> None: """ Scatter the currently selected layers. diff --git a/src/napari_matplotlib/slice.py b/src/napari_matplotlib/slice.py index 54a13ea2..8cbef453 100644 --- a/src/napari_matplotlib/slice.py +++ b/src/napari_matplotlib/slice.py @@ -5,7 +5,7 @@ import numpy.typing as npt from qtpy.QtWidgets import QComboBox, QHBoxLayout, QLabel, QSpinBox, QWidget -from .base import NapariMPLWidget +from .base import SingleAxesWidget from .util import Interval __all__ = ["SliceWidget"] @@ -14,7 +14,7 @@ _dims = ["x", "y", "z"] -class SliceWidget(NapariMPLWidget): +class SliceWidget(SingleAxesWidget): """ Plot a 1D slice along a given dimension. """ @@ -29,7 +29,6 @@ def __init__( ): # Setup figure/axes super().__init__(napari_viewer, parent=parent) - self.add_single_axes() button_layout = QHBoxLayout() self.layout().addLayout(button_layout) @@ -108,12 +107,6 @@ def _get_xy(self) -> Tuple[npt.NDArray[Any], npt.NDArray[Any]]: return x, y - def clear(self) -> None: - """ - Clear the axes. - """ - self.axes.cla() - def draw(self) -> None: """ Clear axes and draw a 1D plot.