Skip to content

De-duplicate single axis logic #168

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion src/napari_matplotlib/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from .util import Interval, from_napari_css_get_size_of

__all__ = ["BaseNapariMPLWidget", "NapariMPLWidget"]
__all__ = ["BaseNapariMPLWidget", "NapariMPLWidget", "SingleAxesWidget"]


class BaseNapariMPLWidget(QWidget):
Expand Down Expand Up @@ -270,6 +270,27 @@ def on_update_layers(self) -> None:
"""


class SingleAxesWidget(NapariMPLWidget):
"""
In addition to `NapariMPLWidget`, this sets up a single axes and
the callback to clear it.
"""

def __init__(
self,
napari_viewer: napari.viewer.Viewer,
parent: Optional[QWidget] = None,
):
super().__init__(napari_viewer=napari_viewer, parent=parent)
self.add_single_axes()

def clear(self) -> None:
"""
Clear the axes.
"""
self.axes.clear()


class NapariNavigationToolbar(NavigationToolbar2QT):
"""Custom Toolbar style for Napari."""

Expand Down
11 changes: 2 additions & 9 deletions src/napari_matplotlib/histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
import numpy as np
from qtpy.QtWidgets import QWidget

from .base import NapariMPLWidget
from .base import SingleAxesWidget
from .util import Interval

__all__ = ["HistogramWidget"]

_COLORS = {"r": "tab:red", "g": "tab:green", "b": "tab:blue"}


class HistogramWidget(NapariMPLWidget):
class HistogramWidget(SingleAxesWidget):
"""
Display a histogram of the currently selected layer.
"""
Expand All @@ -26,15 +26,8 @@ def __init__(
parent: Optional[QWidget] = None,
):
super().__init__(napari_viewer, parent=parent)
self.add_single_axes()
self._update_layers(None)

def clear(self) -> None:
"""
Clear the axes.
"""
self.axes.clear()

def draw(self) -> None:
"""
Clear the axes and histogram the currently selected layer/slice.
Expand Down
18 changes: 2 additions & 16 deletions src/napari_matplotlib/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
import numpy.typing as npt
from qtpy.QtWidgets import QComboBox, QLabel, QVBoxLayout, QWidget

from .base import NapariMPLWidget
from .base import SingleAxesWidget
from .util import Interval

__all__ = ["ScatterBaseWidget", "ScatterWidget", "FeaturesScatterWidget"]


class ScatterBaseWidget(NapariMPLWidget):
class ScatterBaseWidget(SingleAxesWidget):
"""
Base class for widgets that scatter two datasets against each other.
"""
Expand All @@ -19,20 +19,6 @@ class ScatterBaseWidget(NapariMPLWidget):
# the scatter is plotted as a 2D histogram
_threshold_to_switch_to_histogram = 500

def __init__(
self,
napari_viewer: napari.viewer.Viewer,
parent: Optional[QWidget] = None,
):
super().__init__(napari_viewer, parent=parent)
self.add_single_axes()

def clear(self) -> None:
"""
Clear the axes.
"""
self.axes.clear()

def draw(self) -> None:
"""
Scatter the currently selected layers.
Expand Down
11 changes: 2 additions & 9 deletions src/napari_matplotlib/slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy.typing as npt
from qtpy.QtWidgets import QComboBox, QHBoxLayout, QLabel, QSpinBox, QWidget

from .base import NapariMPLWidget
from .base import SingleAxesWidget
from .util import Interval

__all__ = ["SliceWidget"]
Expand All @@ -14,7 +14,7 @@
_dims = ["x", "y", "z"]


class SliceWidget(NapariMPLWidget):
class SliceWidget(SingleAxesWidget):
"""
Plot a 1D slice along a given dimension.
"""
Expand All @@ -29,7 +29,6 @@ def __init__(
):
# Setup figure/axes
super().__init__(napari_viewer, parent=parent)
self.add_single_axes()

button_layout = QHBoxLayout()
self.layout().addLayout(button_layout)
Expand Down Expand Up @@ -108,12 +107,6 @@ def _get_xy(self) -> Tuple[npt.NDArray[Any], npt.NDArray[Any]]:

return x, y

def clear(self) -> None:
"""
Clear the axes.
"""
self.axes.cla()

def draw(self) -> None:
"""
Clear axes and draw a 1D plot.
Expand Down