diff --git a/src/napari_matplotlib/base.py b/src/napari_matplotlib/base.py
index f948d017..fda2d2d5 100644
--- a/src/napari_matplotlib/base.py
+++ b/src/napari_matplotlib/base.py
@@ -245,6 +245,7 @@ def _draw(self) -> None:
             isinstance(layer, self.input_layer_types) for layer in self.layers
         ):
             self.draw()
+        self.apply_napari_colorscheme(self.figure.gca())
         self.canvas.draw()
 
     def clear(self) -> None:
diff --git a/src/napari_matplotlib/tests/baseline/test_histogram_2D.png b/src/napari_matplotlib/tests/baseline/test_histogram_2D.png
index f3f53aea..b76d1e10 100644
Binary files a/src/napari_matplotlib/tests/baseline/test_histogram_2D.png and b/src/napari_matplotlib/tests/baseline/test_histogram_2D.png differ
diff --git a/src/napari_matplotlib/tests/baseline/test_histogram_3D.png b/src/napari_matplotlib/tests/baseline/test_histogram_3D.png
index 484092b1..2dffdcb2 100644
Binary files a/src/napari_matplotlib/tests/baseline/test_histogram_3D.png and b/src/napari_matplotlib/tests/baseline/test_histogram_3D.png differ
diff --git a/src/napari_matplotlib/tests/baseline/test_slice_2D.png b/src/napari_matplotlib/tests/baseline/test_slice_2D.png
index de2cbd42..c9e4d6f6 100644
Binary files a/src/napari_matplotlib/tests/baseline/test_slice_2D.png and b/src/napari_matplotlib/tests/baseline/test_slice_2D.png differ
diff --git a/src/napari_matplotlib/tests/baseline/test_slice_3D.png b/src/napari_matplotlib/tests/baseline/test_slice_3D.png
index 30b02e93..43c8c3b6 100644
Binary files a/src/napari_matplotlib/tests/baseline/test_slice_3D.png and b/src/napari_matplotlib/tests/baseline/test_slice_3D.png differ
diff --git a/src/napari_matplotlib/tests/scatter/baseline/test_features_scatter_widget_2D.png b/src/napari_matplotlib/tests/scatter/baseline/test_features_scatter_widget_2D.png
index db9940e9..269ebd01 100644
Binary files a/src/napari_matplotlib/tests/scatter/baseline/test_features_scatter_widget_2D.png and b/src/napari_matplotlib/tests/scatter/baseline/test_features_scatter_widget_2D.png differ
diff --git a/src/napari_matplotlib/tests/scatter/baseline/test_scatter_2D.png b/src/napari_matplotlib/tests/scatter/baseline/test_scatter_2D.png
index 1977d45f..3b550666 100644
Binary files a/src/napari_matplotlib/tests/scatter/baseline/test_scatter_2D.png and b/src/napari_matplotlib/tests/scatter/baseline/test_scatter_2D.png differ
diff --git a/src/napari_matplotlib/tests/scatter/baseline/test_scatter_3D.png b/src/napari_matplotlib/tests/scatter/baseline/test_scatter_3D.png
index 6238d89d..27e7d673 100644
Binary files a/src/napari_matplotlib/tests/scatter/baseline/test_scatter_3D.png and b/src/napari_matplotlib/tests/scatter/baseline/test_scatter_3D.png differ
diff --git a/src/napari_matplotlib/tests/scatter/test_scatter.py b/src/napari_matplotlib/tests/scatter/test_scatter.py
index 493e9ab8..930f4a47 100644
--- a/src/napari_matplotlib/tests/scatter/test_scatter.py
+++ b/src/napari_matplotlib/tests/scatter/test_scatter.py
@@ -8,6 +8,7 @@
 @pytest.mark.mpl_image_compare
 def test_scatter_2D(make_napari_viewer, astronaut_data):
     viewer = make_napari_viewer()
+    viewer.theme = "light"
     widget = ScatterWidget(viewer)
     fig = widget.figure
 
@@ -28,6 +29,7 @@ def test_scatter_2D(make_napari_viewer, astronaut_data):
 @pytest.mark.mpl_image_compare
 def test_scatter_3D(make_napari_viewer, brain_data):
     viewer = make_napari_viewer()
+    viewer.theme = "light"
     widget = ScatterWidget(viewer)
     fig = widget.figure
 
diff --git a/src/napari_matplotlib/tests/scatter/test_scatter_features.py b/src/napari_matplotlib/tests/scatter/test_scatter_features.py
index 8284a1e8..fca8a767 100644
--- a/src/napari_matplotlib/tests/scatter/test_scatter_features.py
+++ b/src/napari_matplotlib/tests/scatter/test_scatter_features.py
@@ -11,6 +11,7 @@
 @pytest.mark.mpl_image_compare
 def test_features_scatter_widget_2D(make_napari_viewer):
     viewer = make_napari_viewer()
+    viewer.theme = "light"
     widget = FeaturesScatterWidget(viewer)
 
     # make the points data
diff --git a/src/napari_matplotlib/tests/test_histogram.py b/src/napari_matplotlib/tests/test_histogram.py
index eb4e3a6c..4d170014 100644
--- a/src/napari_matplotlib/tests/test_histogram.py
+++ b/src/napari_matplotlib/tests/test_histogram.py
@@ -8,6 +8,7 @@
 @pytest.mark.mpl_image_compare
 def test_histogram_2D(make_napari_viewer, astronaut_data):
     viewer = make_napari_viewer()
+    viewer.theme = "light"
     viewer.add_image(astronaut_data[0], **astronaut_data[1])
     fig = HistogramWidget(viewer).figure
     # Need to return a copy, as original figure is too eagerley garbage
@@ -18,6 +19,7 @@ def test_histogram_2D(make_napari_viewer, astronaut_data):
 @pytest.mark.mpl_image_compare
 def test_histogram_3D(make_napari_viewer, brain_data):
     viewer = make_napari_viewer()
+    viewer.theme = "light"
     viewer.add_image(brain_data[0], **brain_data[1])
     axis = viewer.dims.last_used
     slice_no = brain_data[0].shape[0] - 1
diff --git a/src/napari_matplotlib/tests/test_slice.py b/src/napari_matplotlib/tests/test_slice.py
index b14d8e38..412e71c3 100644
--- a/src/napari_matplotlib/tests/test_slice.py
+++ b/src/napari_matplotlib/tests/test_slice.py
@@ -8,6 +8,7 @@
 @pytest.mark.mpl_image_compare
 def test_slice_3D(make_napari_viewer, brain_data):
     viewer = make_napari_viewer()
+    viewer.theme = "light"
     viewer.add_image(brain_data[0], **brain_data[1])
     axis = viewer.dims.last_used
     slice_no = brain_data[0].shape[0] - 1
@@ -21,6 +22,7 @@ def test_slice_3D(make_napari_viewer, brain_data):
 @pytest.mark.mpl_image_compare
 def test_slice_2D(make_napari_viewer, astronaut_data):
     viewer = make_napari_viewer()
+    viewer.theme = "light"
     viewer.add_image(astronaut_data[0], **astronaut_data[1])
     fig = SliceWidget(viewer).figure
     # Need to return a copy, as original figure is too eagerley garbage
diff --git a/src/napari_matplotlib/tests/test_theme.py b/src/napari_matplotlib/tests/test_theme.py
index 88b271d6..dfeb5b5f 100644
--- a/src/napari_matplotlib/tests/test_theme.py
+++ b/src/napari_matplotlib/tests/test_theme.py
@@ -1,6 +1,8 @@
 import napari
+import numpy as np
 import pytest
 
+from napari_matplotlib import ScatterWidget
 from napari_matplotlib.base import NapariMPLWidget
 
 
@@ -49,3 +51,40 @@ def test_theme_background_check(make_napari_viewer):
     _mock_up_theme()
     viewer.theme = "blue"
     assert widget._theme_has_light_bg() is True
+
+
+@pytest.mark.parametrize(
+    "theme_name, expected_text_colour",
+    [
+        ("dark", "#f0f1f2"),  # #f0f1f2 is a light grey (almost white)
+        ("light", "#3b3a39"),  # #3b3a39 is a brownish dark grey (almost black)
+    ],
+)
+def test_titles_respect_theme(
+    make_napari_viewer, theme_name, expected_text_colour
+):
+    """
+    Test that the axis labels and titles are the correct color for the napari theme.
+    """
+    viewer = make_napari_viewer()
+    widget = ScatterWidget(viewer)
+    viewer.theme = theme_name
+
+    # make a scatter plot of two random layers
+    viewer.add_image(np.random.random((10, 10)), name="first test image")
+    viewer.add_image(np.random.random((10, 10)), name="second test image")
+    viewer.layers.selection.clear()
+    viewer.layers.selection.add(viewer.layers[0])
+    viewer.layers.selection.add(viewer.layers[1])
+
+    ax = widget.figure.gca()
+
+    # sanity test to make sure we've got the correct image names
+    assert ax.xaxis.label.get_text() == "first test image"
+    assert ax.yaxis.label.get_text() == "second test image"
+
+    # print(dir(ax.yaxis.label))
+    # TODO: put checks of the axis tick labels here
+
+    assert ax.xaxis.label.get_color() == expected_text_colour
+    assert ax.yaxis.label.get_color() == expected_text_colour