Skip to content

Commit 9b6eb9a

Browse files
committed
Merge branch 'main' into docs/shift_origin
2 parents c325925 + b561e9d commit 9b6eb9a

File tree

5 files changed

+96
-7
lines changed

5 files changed

+96
-7
lines changed

pygmt/clib/session.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@
8787
np.datetime64: "GMT_DATETIME",
8888
}
8989

90+
# Load the GMT library outside the Session class to avoid repeated loading.
91+
_libgmt = load_libgmt()
92+
9093

9194
class Session:
9295
"""
@@ -308,7 +311,7 @@ def get_libgmt_func(self, name, argtypes=None, restype=None):
308311
<class 'ctypes.CDLL.__init__.<locals>._FuncPtr'>
309312
"""
310313
if not hasattr(self, "_libgmt"):
311-
self._libgmt = load_libgmt()
314+
self._libgmt = _libgmt
312315
function = getattr(self._libgmt, name)
313316
if argtypes is not None:
314317
function.argtypes = argtypes

pygmt/session_management.py

+8
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
"""
22
Modern mode session management modules.
33
"""
4+
import os
5+
import sys
6+
47
from pygmt.clib import Session
8+
from pygmt.helpers import unique_name
59

610

711
def begin():
@@ -12,6 +16,10 @@ def begin():
1216
1317
Only meant to be used once for creating the global session.
1418
"""
19+
# On Windows, need to set GMT_SESSION_NAME to a unique value
20+
if sys.platform == "win32":
21+
os.environ["GMT_SESSION_NAME"] = unique_name()
22+
1523
prefix = "pygmt-session"
1624
with Session() as lib:
1725
lib.call_module(module="begin", args=prefix)

pygmt/tests/test_clib_loading.py

+39
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import pytest
1313
from pygmt.clib.loading import check_libgmt, clib_full_names, clib_names, load_libgmt
14+
from pygmt.clib.session import Session
1415
from pygmt.exceptions import GMTCLibError, GMTCLibNotFoundError, GMTOSError
1516

1617

@@ -207,6 +208,44 @@ def test_brokenlib_brokenlib_workinglib(self):
207208
assert check_libgmt(load_libgmt(lib_fullnames=lib_fullnames)) is None
208209

209210

211+
class TestLibgmtCount:
212+
"""
213+
Test that the GMT library is not repeatedly loaded in every session.
214+
"""
215+
216+
loaded_libgmt = load_libgmt() # Load the GMT library and reuse it when necessary
217+
counter = 0 # Global counter for how many times ctypes.CDLL is called
218+
219+
def _mock_ctypes_cdll_return(self, libname): # noqa: ARG002
220+
"""
221+
Mock ctypes.CDLL to count how many times the function is called.
222+
223+
If ctypes.CDLL is called, the counter increases by one.
224+
"""
225+
self.counter += 1 # Increase the counter
226+
return self.loaded_libgmt
227+
228+
def test_libgmt_load_counter(self, monkeypatch):
229+
"""
230+
Make sure that the GMT library is not loaded in every session.
231+
"""
232+
# Monkeypatch the ctypes.CDLL function
233+
monkeypatch.setattr(ctypes, "CDLL", self._mock_ctypes_cdll_return)
234+
235+
# Create two sessions and check the global counter
236+
with Session() as lib:
237+
_ = lib
238+
with Session() as lib:
239+
_ = lib
240+
assert self.counter == 0 # ctypes.CDLL is not called after two sessions.
241+
242+
# Explicitly calling load_libgmt to make sure the mock function is correct
243+
load_libgmt()
244+
assert self.counter == 1
245+
load_libgmt()
246+
assert self.counter == 2
247+
248+
210249
###############################################################################
211250
# Test clib_full_names
212251
@pytest.fixture(scope="module", name="gmt_lib_names")

pygmt/tests/test_clib_virtualfiles.py

+16-6
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,19 @@ def fixture_dtypes():
3434
return "int8 int16 int32 int64 uint8 uint16 uint32 uint64 float32 float64".split()
3535

3636

37+
@pytest.fixture(scope="module", name="dtypes_pandas")
38+
def fixture_dtypes_pandas(dtypes):
39+
"""
40+
List of supported pandas dtypes.
41+
"""
42+
dtypes_pandas = dtypes.copy()
43+
44+
if find_spec("pyarrow") is not None:
45+
dtypes_pandas.extend([f"{dtype}[pyarrow]" for dtype in dtypes_pandas])
46+
47+
return tuple(dtypes_pandas)
48+
49+
3750
def test_virtual_file(dtypes):
3851
"""
3952
Test passing in data via a virtual file with a Dataset.
@@ -248,11 +261,10 @@ def test_virtualfile_from_vectors_two_string_or_object_columns(dtype):
248261
assert output == expected
249262

250263

251-
def test_virtualfile_from_vectors_transpose():
264+
def test_virtualfile_from_vectors_transpose(dtypes):
252265
"""
253266
Test transforming matrix columns to virtual file dataset.
254267
"""
255-
dtypes = "float32 float64 int32 int64 uint32 uint64".split()
256268
shape = (7, 5)
257269
for dtype in dtypes:
258270
data = np.arange(shape[0] * shape[1], dtype=dtype).reshape(shape)
@@ -315,16 +327,14 @@ def test_virtualfile_from_matrix_slice(dtypes):
315327
assert output == expected
316328

317329

318-
def test_virtualfile_from_vectors_pandas(dtypes):
330+
def test_virtualfile_from_vectors_pandas(dtypes_pandas):
319331
"""
320332
Pass vectors to a dataset using pandas.Series, checking both numpy and pyarrow
321333
dtypes.
322334
"""
323335
size = 13
324-
if find_spec("pyarrow") is not None:
325-
dtypes.extend([f"{dtype}[pyarrow]" for dtype in dtypes])
326336

327-
for dtype in dtypes:
337+
for dtype in dtypes_pandas:
328338
data = pd.DataFrame(
329339
data={
330340
"x": np.arange(size),

pygmt/tests/test_session_management.py

+29
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
"""
22
Test the session management modules.
33
"""
4+
import multiprocessing as mp
45
import os
6+
from importlib import reload
7+
from pathlib import Path
58

69
import pytest
710
from pygmt.clib import Session
@@ -57,3 +60,29 @@ def test_gmt_compat_6_is_applied(capsys):
5760
# Make sure no global "gmt.conf" in the current directory
5861
assert not os.path.exists("gmt.conf")
5962
begin() # Restart the global session
63+
64+
65+
def _gmt_func_wrapper(figname):
66+
"""
67+
A wrapper for running PyGMT scripts with multiprocessing.
68+
69+
Currently, we have to import pygmt and reload it in each process. Workaround from
70+
https://github.com/GenericMappingTools/pygmt/issues/217#issuecomment-754774875.
71+
"""
72+
import pygmt
73+
74+
reload(pygmt)
75+
fig = pygmt.Figure()
76+
fig.basemap(region=[10, 70, -3, 8], projection="X8c/6c", frame="afg")
77+
fig.savefig(figname)
78+
79+
80+
def test_session_multiprocessing():
81+
"""
82+
Make sure that multiprocessing is supported if pygmt is re-imported.
83+
"""
84+
prefix = "test_session_multiprocessing"
85+
with mp.Pool(2) as p:
86+
p.map(_gmt_func_wrapper, [f"{prefix}-1.png", f"{prefix}-2.png"])
87+
Path(f"{prefix}-1.png").unlink()
88+
Path(f"{prefix}-2.png").unlink()

0 commit comments

Comments
 (0)