Skip to content

Commit f97c3a4

Browse files
authored
Figure.plot & Figure.plot3d: Move common codes into _common.py (#3461)
1 parent d2fbb38 commit f97c3a4

File tree

3 files changed

+49
-32
lines changed

3 files changed

+49
-32
lines changed

pygmt/src/_common.py

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
"""
2+
Common functions used in multiple PyGMT functions/methods.
3+
"""
4+
5+
from pathlib import Path
6+
from typing import Any
7+
8+
from pygmt.src.which import which
9+
10+
11+
def _data_geometry_is_point(data: Any, kind: str) -> bool:
12+
"""
13+
Check if the geometry of the input data is Point or MultiPoint.
14+
15+
The inptu data can be a GeoJSON object or a OGR_GMT file.
16+
17+
This function is used in ``Figure.plot`` and ``Figure.plot3d``.
18+
19+
Parameters
20+
----------
21+
data
22+
The data being plotted.
23+
kind
24+
The data kind.
25+
26+
Returns
27+
-------
28+
bool
29+
``True`` if the geometry is Point/MultiPoint, ``False`` otherwise.
30+
"""
31+
if kind == "geojson" and data.geom_type.isin(["Point", "MultiPoint"]).all():
32+
return True
33+
if kind == "file" and str(data).endswith(".gmt"): # OGR_GMT file
34+
try:
35+
with Path(which(data)).open(encoding="utf-8") as file:
36+
line = file.readline()
37+
if "@GMULTIPOINT" in line or "@GPOINT" in line:
38+
return True
39+
except FileNotFoundError:
40+
pass
41+
return False

pygmt/src/plot.py

+4-17
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
plot - Plot in two dimensions.
33
"""
44

5-
from pathlib import Path
6-
75
from pygmt.clib import Session
86
from pygmt.exceptions import GMTInvalidInput
97
from pygmt.helpers import (
@@ -14,7 +12,7 @@
1412
kwargs_to_strings,
1513
use_alias,
1614
)
17-
from pygmt.src.which import which
15+
from pygmt.src._common import _data_geometry_is_point
1816

1917

2018
@fmt_docstring
@@ -50,9 +48,7 @@
5048
w="wrap",
5149
)
5250
@kwargs_to_strings(R="sequence", c="sequence_comma", i="sequence_comma", p="sequence")
53-
def plot( # noqa: PLR0912
54-
self, data=None, x=None, y=None, size=None, direction=None, **kwargs
55-
):
51+
def plot(self, data=None, x=None, y=None, size=None, direction=None, **kwargs):
5652
r"""
5753
Plot lines, polygons, and symbols in 2-D.
5854
@@ -242,17 +238,8 @@ def plot( # noqa: PLR0912
242238
raise GMTInvalidInput(f"'{name}' can't be 1-D array if 'data' is used.")
243239

244240
# Set the default style if data has a geometry of Point or MultiPoint
245-
if kwargs.get("S") is None:
246-
if kind == "geojson" and data.geom_type.isin(["Point", "MultiPoint"]).all():
247-
kwargs["S"] = "s0.2c"
248-
elif kind == "file" and str(data).endswith(".gmt"): # OGR_GMT file
249-
try:
250-
with Path(which(data)).open(encoding="utf-8") as file:
251-
line = file.readline()
252-
if "@GMULTIPOINT" in line or "@GPOINT" in line:
253-
kwargs["S"] = "s0.2c"
254-
except FileNotFoundError:
255-
pass
241+
if kwargs.get("S") is None and _data_geometry_is_point(data, kind):
242+
kwargs["S"] = "s0.2c"
256243

257244
with Session() as lib:
258245
with lib.virtualfile_in(

pygmt/src/plot3d.py

+4-15
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
plot3d - Plot in three dimensions.
33
"""
44

5-
from pathlib import Path
6-
75
from pygmt.clib import Session
86
from pygmt.exceptions import GMTInvalidInput
97
from pygmt.helpers import (
@@ -14,7 +12,7 @@
1412
kwargs_to_strings,
1513
use_alias,
1614
)
17-
from pygmt.src.which import which
15+
from pygmt.src._common import _data_geometry_is_point
1816

1917

2018
@fmt_docstring
@@ -51,7 +49,7 @@
5149
w="wrap",
5250
)
5351
@kwargs_to_strings(R="sequence", c="sequence_comma", i="sequence_comma", p="sequence")
54-
def plot3d( # noqa: PLR0912
52+
def plot3d(
5553
self, data=None, x=None, y=None, z=None, size=None, direction=None, **kwargs
5654
):
5755
r"""
@@ -218,17 +216,8 @@ def plot3d( # noqa: PLR0912
218216
raise GMTInvalidInput(f"'{name}' can't be 1-D array if 'data' is used.")
219217

220218
# Set the default style if data has a geometry of Point or MultiPoint
221-
if kwargs.get("S") is None:
222-
if kind == "geojson" and data.geom_type.isin(["Point", "MultiPoint"]).all():
223-
kwargs["S"] = "u0.2c"
224-
elif kind == "file" and str(data).endswith(".gmt"): # OGR_GMT file
225-
try:
226-
with Path(which(data)).open(encoding="utf-8") as file:
227-
line = file.readline()
228-
if "@GMULTIPOINT" in line or "@GPOINT" in line:
229-
kwargs["S"] = "u0.2c"
230-
except FileNotFoundError:
231-
pass
219+
if kwargs.get("S") is None and _data_geometry_is_point(data, kind):
220+
kwargs["S"] = "u0.2c"
232221

233222
with Session() as lib:
234223
with lib.virtualfile_in(

0 commit comments

Comments
 (0)