Skip to content

Simplify scatter code #111

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
0.0.2
0.4.0
=====

New features
------------
- `HistogramWidget` now shows individual histograms for RGB channels when
present.


Bug fixes
---------
- `HistogramWidget` now works properly with 2D images.
Changes
-------
- The scatter widgets no longer use a LogNorm() for 2D histogram scaling.
This is to move the widget in line with the philosophy of using Matplotlib default
settings throughout ``napari-matplotlib``. This still leaves open the option of
adding the option to change the normalization in the future. If this is something
you would be interested in please open an issue at https://github.com/matplotlib/napari-matplotlib.
- Labels plotting with the features scatter widget no longer have underscores
replaced with spaces.
100 changes: 48 additions & 52 deletions src/napari_matplotlib/scatter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Any, List, Optional, Tuple

import matplotlib.colors as mcolor
import napari
import numpy.typing as npt
from magicgui import magicgui
Expand All @@ -17,15 +16,8 @@ class ScatterBaseWidget(NapariMPLWidget):
Base class for widgets that scatter two datasets against each other.
"""

# opacity value for the markers
_marker_alpha = 0.5

# flag set to True if histogram should be used
# for plotting large points
_histogram_for_large_data = True

# if the number of points is greater than this value,
# the scatter is plotted as a 2dhist
# the scatter is plotted as a 2D histogram
_threshold_to_switch_to_histogram = 500

def __init__(self, napari_viewer: napari.viewer.Viewer):
Expand All @@ -44,40 +36,32 @@ def draw(self) -> None:
"""
Scatter the currently selected layers.
"""
data, x_axis_name, y_axis_name = self._get_data()

if len(data) == 0:
# don't plot if there isn't data
return
x, y, x_axis_name, y_axis_name = self._get_data()

if self._histogram_for_large_data and (
data[0].size > self._threshold_to_switch_to_histogram
):
if x.size > self._threshold_to_switch_to_histogram:
self.axes.hist2d(
data[0].ravel(),
data[1].ravel(),
x.ravel(),
y.ravel(),
bins=100,
norm=mcolor.LogNorm(),
)
else:
self.axes.scatter(data[0], data[1], alpha=self._marker_alpha)
self.axes.scatter(x, y, alpha=0.5)

self.axes.set_xlabel(x_axis_name)
self.axes.set_ylabel(y_axis_name)

def _get_data(self) -> Tuple[List[npt.NDArray[Any]], str, str]:
"""Get the plot data.
def _get_data(self) -> Tuple[npt.NDArray[Any], npt.NDArray[Any], str, str]:
"""
Get the plot data.

This must be implemented on the subclass.

Returns
-------
data : np.ndarray
The list containing the scatter plot data.
x_axis_name : str
The label to display on the x axis
y_axis_name: str
The label to display on the y axis
x, y : np.ndarray
x and y values of plot data.
x_axis_name, y_axis_name : str
Label to display on the x/y axis
"""
raise NotImplementedError

Expand All @@ -93,7 +77,7 @@ class ScatterWidget(ScatterBaseWidget):
n_layers_input = Interval(2, 2)
input_layer_types = (napari.layers.Image,)

def _get_data(self) -> Tuple[List[npt.NDArray[Any]], str, str]:
def _get_data(self) -> Tuple[npt.NDArray[Any], npt.NDArray[Any], str, str]:
"""
Get the plot data.

Expand All @@ -106,11 +90,12 @@ def _get_data(self) -> Tuple[List[npt.NDArray[Any]], str, str]:
y_axis_name: str
The title to display on the y axis
"""
data = [layer.data[self.current_z] for layer in self.layers]
x = self.layers[0].data[self.current_z]
y = self.layers[1].data[self.current_z]
x_axis_name = self.layers[0].name
y_axis_name = self.layers[1].name

return data, x_axis_name, y_axis_name
return x, y, x_axis_name, y_axis_name


class FeaturesScatterWidget(ScatterBaseWidget):
Expand Down Expand Up @@ -191,9 +176,33 @@ def _get_valid_axis_keys(
else:
return self.layers[0].features.keys()

def _get_data(self) -> Tuple[List[npt.NDArray[Any]], str, str]:
def _ready_to_scatter(self) -> bool:
"""
Get the plot data.
Return True if selected layer has a feature table we can scatter with,
and the two columns to be scatterd have been selected.
"""
if not hasattr(self.layers[0], "features"):
return False

feature_table = self.layers[0].features
return (
feature_table is not None
and len(feature_table) > 0
and self.x_axis_key is not None
and self.y_axis_key is not None
)

def draw(self) -> None:
"""
Scatter two features from the currently selected layer.
"""
if self._ready_to_scatter():
super().draw()

def _get_data(self) -> Tuple[npt.NDArray[Any], npt.NDArray[Any], str, str]:
"""
Get the plot data from the ``features`` attribute of the first
selected layer.

Returns
-------
Expand All @@ -207,28 +216,15 @@ def _get_data(self) -> Tuple[List[npt.NDArray[Any]], str, str]:
The title to display on the y axis. Returns
an empty string if nothing to plot.
"""
if not hasattr(self.layers[0], "features"):
# if the selected layer doesn't have a featuretable,
# skip draw
return [], "", ""

feature_table = self.layers[0].features

if (
(len(feature_table) == 0)
or (self.x_axis_key is None)
or (self.y_axis_key is None)
):
return [], "", ""

data_x = feature_table[self.x_axis_key]
data_y = feature_table[self.y_axis_key]
data = [data_x, data_y]
x = feature_table[self.x_axis_key]
y = feature_table[self.y_axis_key]

x_axis_name = self.x_axis_key.replace("_", " ")
y_axis_name = self.y_axis_key.replace("_", " ")
x_axis_name = str(self.x_axis_key)
y_axis_name = str(self.y_axis_key)

return data, x_axis_name, y_axis_name
return x, y, x_axis_name, y_axis_name

def _on_update_layers(self) -> None:
"""
Expand Down
19 changes: 10 additions & 9 deletions src/napari_matplotlib/tests/test_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ def make_labels_layer_with_features() -> (


def test_features_scatter_get_data(make_napari_viewer):
"""Test the get data method"""
"""
Test the get data method.
"""
# make the label image
label_image, feature_table = make_labels_layer_with_features()

Expand All @@ -55,17 +57,16 @@ def test_features_scatter_get_data(make_napari_viewer):
y_column = "feature_2"
scatter_widget.y_axis_key = y_column

data, x_axis_name, y_axis_name = scatter_widget._get_data()
np.testing.assert_allclose(
data, np.stack((feature_table[x_column], feature_table[y_column]))
)
assert x_axis_name == x_column.replace("_", " ")
assert y_axis_name == y_column.replace("_", " ")
x, y, x_axis_name, y_axis_name = scatter_widget._get_data()
np.testing.assert_allclose(x, feature_table[x_column])
np.testing.assert_allclose(y, np.stack(feature_table[y_column]))
assert x_axis_name == x_column
assert y_axis_name == y_column


def test_get_valid_axis_keys(make_napari_viewer):
"""Test the values returned from
FeaturesScatterWidget._get_valid_keys() when there
"""
Test the values returned from _get_valid_keys() when there
are valid keys.
"""
# make the label image
Expand Down