Skip to content

Commit 11cb76b

Browse files
authored
Merge branch 'main' into clib/load-libgmt
2 parents bf6083f + 88ab1ca commit 11cb76b

File tree

4 files changed

+55
-8
lines changed

4 files changed

+55
-8
lines changed

pygmt/session_management.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
"""
22
Modern mode session management modules.
33
"""
4+
import os
5+
import sys
6+
47
from pygmt.clib import Session
8+
from pygmt.helpers import unique_name
59

610

711
def begin():
@@ -12,6 +16,10 @@ def begin():
1216
1317
Only meant to be used once for creating the global session.
1418
"""
19+
# On Windows, need to set GMT_SESSION_NAME to a unique value
20+
if sys.platform == "win32":
21+
os.environ["GMT_SESSION_NAME"] = unique_name()
22+
1523
prefix = "pygmt-session"
1624
with Session() as lib:
1725
lib.call_module(module="begin", args=prefix)

pygmt/src/meca.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def convention_params(convention):
168168
],
169169
"mt": ["mrr", "mtt", "mff", "mrt", "mrf", "mtf", "exponent"],
170170
"partial": ["strike1", "dip1", "strike2", "fault_type", "magnitude"],
171-
"pricipal_axis": [
171+
"principal_axis": [
172172
"t_value",
173173
"t_azimuth",
174174
"t_plunge",
@@ -401,7 +401,7 @@ def meca( # noqa: PLR0912, PLR0913, PLR0915
401401
# Convert spec to pandas.DataFrame unless it's a file
402402
if isinstance(spec, (dict, pd.DataFrame)): # spec is a dict or pd.DataFrame
403403
# determine convention from dict keys or pd.DataFrame column names
404-
for conv in ["aki", "gcmt", "mt", "partial", "pricipal_axis"]:
404+
for conv in ["aki", "gcmt", "mt", "partial", "principal_axis"]:
405405
if set(convention_params(conv)).issubset(set(spec.keys())):
406406
convention = conv
407407
break

pygmt/tests/test_clib_virtualfiles.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,19 @@ def fixture_dtypes():
3434
return "int8 int16 int32 int64 uint8 uint16 uint32 uint64 float32 float64".split()
3535

3636

37+
@pytest.fixture(scope="module", name="dtypes_pandas")
38+
def fixture_dtypes_pandas(dtypes):
39+
"""
40+
List of supported pandas dtypes.
41+
"""
42+
dtypes_pandas = dtypes.copy()
43+
44+
if find_spec("pyarrow") is not None:
45+
dtypes_pandas.extend([f"{dtype}[pyarrow]" for dtype in dtypes_pandas])
46+
47+
return tuple(dtypes_pandas)
48+
49+
3750
def test_virtual_file(dtypes):
3851
"""
3952
Test passing in data via a virtual file with a Dataset.
@@ -248,11 +261,10 @@ def test_virtualfile_from_vectors_two_string_or_object_columns(dtype):
248261
assert output == expected
249262

250263

251-
def test_virtualfile_from_vectors_transpose():
264+
def test_virtualfile_from_vectors_transpose(dtypes):
252265
"""
253266
Test transforming matrix columns to virtual file dataset.
254267
"""
255-
dtypes = "float32 float64 int32 int64 uint32 uint64".split()
256268
shape = (7, 5)
257269
for dtype in dtypes:
258270
data = np.arange(shape[0] * shape[1], dtype=dtype).reshape(shape)
@@ -315,16 +327,14 @@ def test_virtualfile_from_matrix_slice(dtypes):
315327
assert output == expected
316328

317329

318-
def test_virtualfile_from_vectors_pandas(dtypes):
330+
def test_virtualfile_from_vectors_pandas(dtypes_pandas):
319331
"""
320332
Pass vectors to a dataset using pandas.Series, checking both numpy and pyarrow
321333
dtypes.
322334
"""
323335
size = 13
324-
if find_spec("pyarrow") is not None:
325-
dtypes.extend([f"{dtype}[pyarrow]" for dtype in dtypes])
326336

327-
for dtype in dtypes:
337+
for dtype in dtypes_pandas:
328338
data = pd.DataFrame(
329339
data={
330340
"x": np.arange(size),

pygmt/tests/test_session_management.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
"""
22
Test the session management modules.
33
"""
4+
import multiprocessing as mp
45
import os
6+
from importlib import reload
7+
from pathlib import Path
58

69
import pytest
710
from pygmt.clib import Session
@@ -57,3 +60,29 @@ def test_gmt_compat_6_is_applied(capsys):
5760
# Make sure no global "gmt.conf" in the current directory
5861
assert not os.path.exists("gmt.conf")
5962
begin() # Restart the global session
63+
64+
65+
def _gmt_func_wrapper(figname):
66+
"""
67+
A wrapper for running PyGMT scripts with multiprocessing.
68+
69+
Currently, we have to import pygmt and reload it in each process. Workaround from
70+
https://github.com/GenericMappingTools/pygmt/issues/217#issuecomment-754774875.
71+
"""
72+
import pygmt
73+
74+
reload(pygmt)
75+
fig = pygmt.Figure()
76+
fig.basemap(region=[10, 70, -3, 8], projection="X8c/6c", frame="afg")
77+
fig.savefig(figname)
78+
79+
80+
def test_session_multiprocessing():
81+
"""
82+
Make sure that multiprocessing is supported if pygmt is re-imported.
83+
"""
84+
prefix = "test_session_multiprocessing"
85+
with mp.Pool(2) as p:
86+
p.map(_gmt_func_wrapper, [f"{prefix}-1.png", f"{prefix}-2.png"])
87+
Path(f"{prefix}-1.png").unlink()
88+
Path(f"{prefix}-2.png").unlink()

0 commit comments

Comments
 (0)