Skip to content

Commit e7372be

Browse files
authored
Figure.meca: Refactor to improve the processing of the input data (#3831)
1 parent eb0ebe7 commit e7372be

File tree

1 file changed

+101
-72
lines changed

1 file changed

+101
-72
lines changed

pygmt/src/meca.py

+101-72
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,13 @@
66
import pandas as pd
77
from pygmt.clib import Session
88
from pygmt.exceptions import GMTInvalidInput
9-
from pygmt.helpers import build_arg_list, fmt_docstring, kwargs_to_strings, use_alias
9+
from pygmt.helpers import (
10+
build_arg_list,
11+
data_kind,
12+
fmt_docstring,
13+
kwargs_to_strings,
14+
use_alias,
15+
)
1016
from pygmt.src._common import _FocalMechanismConvention
1117

1218

@@ -25,6 +31,82 @@ def _get_focal_convention(spec, convention, component) -> _FocalMechanismConvent
2531
return _FocalMechanismConvention(convention=convention, component=component)
2632

2733

34+
def _preprocess_spec(spec, colnames, override_cols):
35+
"""
36+
Preprocess the input data.
37+
38+
Parameters
39+
----------
40+
spec
41+
The input data to be preprocessed.
42+
colnames
43+
The minimum required column names of the input data.
44+
override_cols
45+
Dictionary of column names and values to override in the input data. Only makes
46+
sense if ``spec`` is a dict or :class:`pandas.DataFrame`.
47+
"""
48+
kind = data_kind(spec) # Determine the kind of the input data.
49+
50+
# Convert pandas.DataFrame and numpy.ndarray to dict.
51+
if isinstance(spec, pd.DataFrame):
52+
spec = {k: v.to_numpy() for k, v in spec.items()}
53+
elif isinstance(spec, np.ndarray):
54+
spec = np.atleast_2d(spec)
55+
# Optional columns that are not required by the convention. The key is the
56+
# number of extra columns, and the value is a list of optional column names.
57+
extra_cols = {
58+
0: [],
59+
1: ["event_name"],
60+
2: ["plot_longitude", "plot_latitude"],
61+
3: ["plot_longitude", "plot_latitude", "event_name"],
62+
}
63+
ndiff = spec.shape[1] - len(colnames)
64+
if ndiff not in extra_cols:
65+
msg = f"Input array must have {len(colnames)} or two/three more columns."
66+
raise GMTInvalidInput(msg)
67+
spec = dict(zip([*colnames, *extra_cols[ndiff]], spec.T, strict=False))
68+
69+
# Now, the input data is a dict or an ASCII file.
70+
if isinstance(spec, dict):
71+
# The columns can be overridden by the parameters given in the function
72+
# arguments. Only makes sense for dict/pandas.DataFrame input.
73+
if kind != "matrix" and override_cols is not None:
74+
spec.update({k: v for k, v in override_cols.items() if v is not None})
75+
# Due to the internal implementation of the meca module, we need to convert the
76+
# ``plot_longitude``, ``plot_latitude``, and ``event_name`` columns into strings
77+
# if they exist.
78+
for key in ["plot_longitude", "plot_latitude", "event_name"]:
79+
if key in spec:
80+
spec[key] = np.array(spec[key], dtype=str)
81+
82+
# Reorder columns to match convention if necessary. The expected columns are:
83+
# longitude, latitude, depth, focal_parameters, [plot_longitude, plot_latitude],
84+
# [event_name].
85+
extra_cols = []
86+
if "plot_longitude" in spec and "plot_latitude" in spec:
87+
extra_cols.extend(["plot_longitude", "plot_latitude"])
88+
if "event_name" in spec:
89+
extra_cols.append("event_name")
90+
cols = [*colnames, *extra_cols]
91+
if list(spec.keys()) != cols:
92+
spec = {k: spec[k] for k in cols}
93+
return spec
94+
95+
96+
def _auto_offset(spec) -> bool:
97+
"""
98+
Determine if offset should be set based on the input data.
99+
100+
If the input data contains ``plot_longitude`` and ``plot_latitude``, then we set the
101+
``offset`` parameter to ``True`` automatically.
102+
"""
103+
return (
104+
isinstance(spec, dict | pd.DataFrame)
105+
and "plot_longitude" in spec
106+
and "plot_latitude" in spec
107+
)
108+
109+
28110
@fmt_docstring
29111
@use_alias(
30112
A="offset",
@@ -45,7 +127,7 @@ def _get_focal_convention(spec, convention, component) -> _FocalMechanismConvent
45127
t="transparency",
46128
)
47129
@kwargs_to_strings(R="sequence", c="sequence_comma", p="sequence")
48-
def meca( # noqa: PLR0912, PLR0913
130+
def meca( # noqa: PLR0913
49131
self,
50132
spec,
51133
scale,
@@ -248,78 +330,25 @@ def meca( # noqa: PLR0912, PLR0913
248330
{transparency}
249331
"""
250332
kwargs = self._preprocess(**kwargs)
251-
252333
# Determine the focal mechanism convention from the input data or parameters.
253334
_convention = _get_focal_convention(spec, convention, component)
254-
255-
# Convert spec to pandas.DataFrame unless it's a file
256-
if isinstance(spec, dict | pd.DataFrame): # spec is a dict or pd.DataFrame
257-
# convert dict to pd.DataFrame so columns can be reordered
258-
if isinstance(spec, dict):
259-
# convert values to ndarray so pandas doesn't complain about "all
260-
# scalar values". See
261-
# https://github.com/GenericMappingTools/pygmt/pull/2174
262-
spec = pd.DataFrame(
263-
{key: np.atleast_1d(value) for key, value in spec.items()}
264-
)
265-
elif isinstance(spec, np.ndarray): # spec is a numpy array
266-
# Convert array to pd.DataFrame and assign column names
267-
spec = pd.DataFrame(np.atleast_2d(spec))
268-
colnames = ["longitude", "latitude", "depth", *_convention.params]
269-
# check if spec has the expected number of columns
270-
ncolsdiff = len(spec.columns) - len(colnames)
271-
if ncolsdiff == 0:
272-
pass
273-
elif ncolsdiff == 1:
274-
colnames += ["event_name"]
275-
elif ncolsdiff == 2:
276-
colnames += ["plot_longitude", "plot_latitude"]
277-
elif ncolsdiff == 3:
278-
colnames += ["plot_longitude", "plot_latitude", "event_name"]
279-
else:
280-
msg = (
281-
f"Input array must have {len(colnames)} to {len(colnames) + 3} columns."
282-
)
283-
raise GMTInvalidInput(msg)
284-
spec.columns = colnames
285-
286-
# Now spec is a pd.DataFrame or a file
287-
if isinstance(spec, pd.DataFrame):
288-
# override the values in pd.DataFrame if parameters are given
289-
for arg, name in [
290-
(longitude, "longitude"),
291-
(latitude, "latitude"),
292-
(depth, "depth"),
293-
(plot_longitude, "plot_longitude"),
294-
(plot_latitude, "plot_latitude"),
295-
(event_name, "event_name"),
296-
]:
297-
if arg is not None:
298-
spec[name] = np.atleast_1d(arg)
299-
300-
# Due to the internal implementation of the meca module, we need to
301-
# convert the following columns to strings if they exist
302-
if "plot_longitude" in spec.columns and "plot_latitude" in spec.columns:
303-
spec["plot_longitude"] = spec["plot_longitude"].astype(str)
304-
spec["plot_latitude"] = spec["plot_latitude"].astype(str)
305-
if "event_name" in spec.columns:
306-
spec["event_name"] = spec["event_name"].astype(str)
307-
308-
# Reorder columns in DataFrame to match convention if necessary
309-
# expected columns are:
310-
# longitude, latitude, depth, focal_parameters,
311-
# [plot_longitude, plot_latitude] [event_name]
312-
newcols = ["longitude", "latitude", "depth", *_convention.params]
313-
if "plot_longitude" in spec.columns and "plot_latitude" in spec.columns:
314-
newcols += ["plot_longitude", "plot_latitude"]
315-
if kwargs.get("A") is None:
316-
kwargs["A"] = True
317-
if "event_name" in spec.columns:
318-
newcols += ["event_name"]
319-
# reorder columns in DataFrame
320-
if spec.columns.tolist() != newcols:
321-
spec = spec.reindex(newcols, axis=1)
322-
335+
# Preprocess the input data.
336+
spec = _preprocess_spec(
337+
spec,
338+
# The minimum expected columns for the input data.
339+
colnames=["longitude", "latitude", "depth", *_convention.params],
340+
override_cols={
341+
"longitude": longitude,
342+
"latitude": latitude,
343+
"depth": depth,
344+
"plot_longitude": plot_longitude,
345+
"plot_latitude": plot_latitude,
346+
"event_name": event_name,
347+
},
348+
)
349+
# Determine the offset parameter if not provided.
350+
if kwargs.get("A") is None:
351+
kwargs["A"] = _auto_offset(spec)
323352
kwargs["S"] = f"{_convention.code}{scale}"
324353
with Session() as lib:
325354
with lib.virtualfile_in(check_kind="vector", data=spec) as vintbl:

0 commit comments

Comments
 (0)