Skip to content

Commit 62477b6

Browse files
committed
Merge branch 'issue573-stac-item-assets-support'
2 parents 14854ce + 84a9f5b commit 62477b6

File tree

7 files changed

+247
-3
lines changed

7 files changed

+247
-3
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1414
- Document PEP 723 based Python UDF dependency declarations ([Open-EO/openeo-geopyspark-driver#237](https://github.com/Open-EO/openeo-geopyspark-driver/issues/237))
1515
- Added more `openeo.api.process.Parameter` helpers to easily create "bounding_box", "date", "datetime", "geojson" and "temporal_interval" parameters for UDP construction.
1616
- Added convenience method `Connection.load_stac_from_job(job)` to easily load the results of a batch job with the `load_stac` process ([#566](https://github.com/Open-EO/openeo-python-client/issues/566))
17+
- `load_stac`/`metadata_from_stac`: add support for extracting band info from "item_assets" in collection metadata ([#573](https://github.com/Open-EO/openeo-python-client/issues/573))
18+
- Added initial `openeo.testing` submodule for reusable test utilities
1719

1820
### Changed
1921

openeo/metadata.py

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
from __future__ import annotations
22

3+
import functools
34
import logging
45
import warnings
5-
from typing import Any, Callable, List, NamedTuple, Optional, Tuple, Union
6+
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple, Union
67

78
import pystac
9+
import pystac.extensions.eo
10+
import pystac.extensions.item_assets
811

912
from openeo.internal.jupyter import render_component
1013
from openeo.util import deep_get
@@ -107,6 +110,7 @@ class Band(NamedTuple):
107110

108111

109112
class BandDimension(Dimension):
113+
# TODO #575 support unordered bands and avoid assumption that band order is known.
110114
def __init__(self, name: str, bands: List[Band]):
111115
super().__init__(type="bands", name=name)
112116
self.bands = bands
@@ -534,6 +538,8 @@ def metadata_from_stac(url: str) -> CubeMetadata:
534538
:return: A :py:class:`CubeMetadata` containing the DataCube band metadata from the url.
535539
"""
536540

541+
# TODO move these nested functions and other logic to _StacMetadataParser
542+
537543
def get_band_metadata(eo_bands_location: dict) -> List[Band]:
538544
# TODO: return None iso empty list when no metadata?
539545
return [
@@ -573,6 +579,10 @@ def is_band_asset(asset: pystac.Asset) -> bool:
573579
for asset_band in asset_bands:
574580
if asset_band.name not in get_band_names(bands):
575581
bands.append(asset_band)
582+
if _PYSTAC_1_9_EXTENSION_INTERFACE and collection.ext.has("item_assets"):
583+
# TODO #575 support unordered band names and avoid conversion to a list.
584+
bands = list(_StacMetadataParser().get_bands_from_item_assets(collection.ext.item_assets))
585+
576586
elif isinstance(stac_object, pystac.Catalog):
577587
catalog = stac_object
578588
bands = get_band_metadata(catalog.extra_fields.get("summaries", {}))
@@ -586,3 +596,77 @@ def is_band_asset(asset: pystac.Asset) -> bool:
586596
temporal_dimension = TemporalDimension(name="t", extent=[None, None])
587597
metadata = CubeMetadata(dimensions=[band_dimension, temporal_dimension])
588598
return metadata
599+
600+
601+
# Sniff for PySTAC extension API since version 1.9.0 (which is not available below Python 3.9)
602+
# TODO: remove this once support for Python 3.7 and 3.8 is dropped
603+
_PYSTAC_1_9_EXTENSION_INTERFACE = hasattr(pystac.Item, "ext")
604+
605+
606+
class _StacMetadataParser:
607+
"""
608+
Helper to extract openEO metadata from STAC metadata resource
609+
"""
610+
611+
def __init__(self):
612+
# TODO: toggles for how to handle strictness, warnings, logging, etc
613+
pass
614+
615+
def _get_band_from_eo_bands_item(self, eo_band: Union[dict, pystac.extensions.eo.Band]) -> Band:
616+
if isinstance(eo_band, pystac.extensions.eo.Band):
617+
return Band(
618+
name=eo_band.name,
619+
common_name=eo_band.common_name,
620+
wavelength_um=eo_band.center_wavelength,
621+
)
622+
elif isinstance(eo_band, dict) and "name" in eo_band:
623+
return Band(
624+
name=eo_band["name"],
625+
common_name=eo_band.get("common_name"),
626+
wavelength_um=eo_band.get("center_wavelength"),
627+
)
628+
else:
629+
raise ValueError(eo_band)
630+
631+
def get_bands_from_eo_bands(self, eo_bands: List[Union[dict, pystac.extensions.eo.Band]]) -> List[Band]:
632+
"""
633+
Extract bands from STAC `eo:bands` array
634+
635+
:param eo_bands: List of band objects, as dict or `pystac.extensions.eo.Band` instances
636+
"""
637+
# TODO: option to skip bands that failed to parse in some way?
638+
return [self._get_band_from_eo_bands_item(band) for band in eo_bands]
639+
640+
def _get_bands_from_item_asset(
641+
self, item_asset: pystac.extensions.item_assets.AssetDefinition, *, _warn: Callable[[str], None] = _log.warning
642+
) -> Union[List[Band], None]:
643+
"""Get bands from a STAC 'item_assets' asset definition."""
644+
if _PYSTAC_1_9_EXTENSION_INTERFACE and item_asset.ext.has("eo"):
645+
if item_asset.ext.eo.bands is not None:
646+
return self.get_bands_from_eo_bands(item_asset.ext.eo.bands)
647+
elif "eo:bands" in item_asset.properties:
648+
# TODO: skip this in strict mode?
649+
if _PYSTAC_1_9_EXTENSION_INTERFACE:
650+
_warn("Extracting band info from 'eo:bands' metadata, but 'eo' STAC extension was not declared.")
651+
return self.get_bands_from_eo_bands(item_asset.properties["eo:bands"])
652+
653+
def get_bands_from_item_assets(
654+
self, item_assets: Dict[str, pystac.extensions.item_assets.AssetDefinition]
655+
) -> Set[Band]:
656+
"""
657+
Get bands extracted from "item_assets" objects (defined by "item-assets" extension,
658+
in combination with "eo" extension) at STAC Collection top-level,
659+
660+
Note that "item_assets" in STAC is a mapping, so the band order is undefined,
661+
which is why we return a set of bands here.
662+
663+
:param item_assets: a STAC `item_assets` mapping
664+
"""
665+
bands = set()
666+
# Trick to just warn once per collection
667+
_warn = functools.lru_cache()(_log.warning)
668+
for item_asset in item_assets.values():
669+
asset_bands = self._get_bands_from_item_asset(item_asset, _warn=_warn)
670+
if asset_bands:
671+
bands.update(asset_bands)
672+
return bands

openeo/testing.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
"""
2+
Utilities for testing of openEO client workflows.
3+
"""
4+
5+
import json
6+
from pathlib import Path
7+
from typing import Callable, Optional, Union
8+
9+
10+
class TestDataLoader:
11+
"""
12+
Helper to resolve paths to test data files, load them as JSON, optionally preprocess them, etc.
13+
14+
It's intended to be used as a pytest fixture, e.g. from conftest.py:
15+
16+
@pytest.fixture
17+
def test_data() -> TestDataLoader:
18+
return TestDataLoader(root=Path(__file__).parent / "data")
19+
20+
.. versionadded:: 0.30.0
21+
22+
"""
23+
24+
def __init__(self, root: Union[str, Path]):
25+
self.data_root = Path(root)
26+
27+
def get_path(self, filename: Union[str, Path]) -> Path:
28+
"""Get absolute path to a test data file"""
29+
return self.data_root / filename
30+
31+
def load_json(self, filename: Union[str, Path], preprocess: Optional[Callable[[str], str]] = None) -> dict:
32+
"""Parse data from a test JSON file"""
33+
data = self.get_path(filename).read_text(encoding="utf8")
34+
if preprocess:
35+
data = preprocess(data)
36+
return json.loads(data)

tests/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55

66

77
def get_test_resource(relative_path: str) -> Path:
8+
# TODO: migrate to TestDataLoader
89
dir = Path(os.path.dirname(os.path.realpath(__file__)))
910
return dir / relative_path
1011

1112

1213
def load_json_resource(relative_path, preprocess: Callable = None) -> dict:
14+
# TODO: migrate to TestDataLoader
1315
with get_test_resource(relative_path).open("r+") as f:
1416
data = f.read()
1517
if preprocess:

tests/conftest.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import pytest
66

7+
from openeo.testing import TestDataLoader
78
from openeo.util import ensure_dir
89

910
pytest_plugins = "pytester"
@@ -25,3 +26,8 @@ def tmp_openeo_config_home(tmp_path):
2526
path = ensure_dir(Path(str(tmp_path)) / "openeo-conf")
2627
with mock.patch.dict("os.environ", {"OPENEO_CONFIG_HOME": str(path)}):
2728
yield path
29+
30+
31+
@pytest.fixture
32+
def test_data() -> TestDataLoader:
33+
return TestDataLoader(root=Path(__file__).parent / "data")
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
{
2+
"type": "Collection",
3+
"id": "agera5_daily",
4+
"stac_version": "1.0.0",
5+
"description": "ERA5",
6+
"links": [
7+
{
8+
"rel": "self",
9+
"type": "application/json",
10+
"href": "https://stac.test/collections/agera5_daily"
11+
}
12+
],
13+
"stac_extensions": [
14+
"https://stac-extensions.github.io/item-assets/v1.0.0/schema.json",
15+
"https://stac-extensions.github.io/eo/v1.1.0/schema.json"
16+
],
17+
"item_assets": {
18+
"2m_temperature_min": {
19+
"type": "image/tiff; application=geotiff",
20+
"title": "2m temperature min 24h",
21+
"eo:bands": [
22+
{
23+
"name": "2m_temperature_min",
24+
"description": "temperature 2m above ground (Kelvin)"
25+
}
26+
]
27+
},
28+
"2m_temperature_max": {
29+
"type": "image/tiff; application=geotiff",
30+
"eo:bands": [
31+
{
32+
"name": "2m_temperature_max",
33+
"description": "temperature 2m above ground (Kelvin)"
34+
}
35+
]
36+
},
37+
"dewpoint_temperature_mean": {
38+
"type": "image/tiff; application=geotiff",
39+
"title": "2m dewpoint temperature",
40+
"eo:bands": [
41+
{
42+
"name": "dewpoint_temperature_mean",
43+
"description": "dewpoint temperature 2m above ground (Kelvin)"
44+
}
45+
]
46+
},
47+
"vapour_pressure": {
48+
"eo:bands": [
49+
{
50+
"name": "vapour_pressure"
51+
}
52+
]
53+
}
54+
},
55+
"title": "agERA5 data",
56+
"extent": {
57+
"spatial": {
58+
"bbox": [
59+
[
60+
-180,
61+
-90,
62+
180,
63+
90
64+
]
65+
]
66+
},
67+
"temporal": {
68+
"interval": [
69+
[
70+
"2010-01-01T00:00:00Z",
71+
"2024-06-12T00:00:00Z"
72+
]
73+
]
74+
}
75+
},
76+
"keywords": [
77+
"ERA5"
78+
],
79+
"summaries": {},
80+
"assets": {},
81+
"license": "proprietary"
82+
}

tests/test_metadata.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from __future__ import annotations
22

33
import json
4-
from typing import List
4+
import re
5+
from typing import List, Union
56

67
import pytest
78

89
from openeo.metadata import (
10+
_PYSTAC_1_9_EXTENSION_INTERFACE,
911
Band,
1012
BandDimension,
1113
CollectionMetadata,
@@ -835,8 +837,38 @@ def filter_bbox(self, bbox):
835837
],
836838
)
837839
def test_metadata_from_stac(tmp_path, test_stac, expected):
838-
839840
path = tmp_path / "stac.json"
840841
path.write_text(json.dumps(test_stac))
841842
metadata = metadata_from_stac(path)
842843
assert metadata.band_names == expected
844+
845+
846+
@pytest.mark.skipif(not _PYSTAC_1_9_EXTENSION_INTERFACE, reason="Requires PySTAC 1.9+ extension interface")
847+
@pytest.mark.parametrize("eo_extension_is_declared", [False, True])
848+
def test_metadata_from_stac_collection_bands_from_item_assets(test_data, tmp_path, eo_extension_is_declared, caplog):
849+
stac_data = test_data.load_json("stac/collections/agera5_daily01.json")
850+
stac_data["stac_extensions"] = [
851+
ext
852+
for ext in stac_data["stac_extensions"]
853+
if (not ext.startswith("https://stac-extensions.github.io/eo/") or eo_extension_is_declared)
854+
]
855+
assert (
856+
any(ext.startswith("https://stac-extensions.github.io/eo/") for ext in stac_data["stac_extensions"])
857+
== eo_extension_is_declared
858+
)
859+
path = tmp_path / "stac.json"
860+
path.write_text(json.dumps(stac_data))
861+
862+
metadata = metadata_from_stac(path)
863+
assert sorted(metadata.band_names) == [
864+
"2m_temperature_max",
865+
"2m_temperature_min",
866+
"dewpoint_temperature_mean",
867+
"vapour_pressure",
868+
]
869+
870+
warn_count = sum(
871+
"Extracting band info from 'eo:bands' metadata, but 'eo' STAC extension was not declared." in m
872+
for m in caplog.messages
873+
)
874+
assert warn_count == (0 if eo_extension_is_declared else 1)

0 commit comments

Comments
 (0)