Skip to content
This repository was archived by the owner on Sep 11, 2023. It is now read-only.

Commit cbcac82

Browse files
Merge pull request #229 from openclimatefix/issue/209-xrdatarray-b
Issue/209 xrdatarray b
2 parents 3a63937 + 7bf0b70 commit cbcac82

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

63 files changed

+1596
-1390
lines changed

conftest.py

+8
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,17 @@
1010
from nowcasting_dataset.data_sources import SatelliteDataSource
1111
from nowcasting_dataset.data_sources.gsp.gsp_data_source import GSPDataSource
1212
from nowcasting_dataset.data_sources.metadata.metadata_data_source import MetadataDataSource
13+
from nowcasting_dataset.dataset.xr_utils import (
14+
register_xr_data_array_to_tensor,
15+
register_xr_data_set_to_tensor,
16+
)
1317

1418
pytest.IMAGE_SIZE_PIXELS = 128
1519

20+
# need to run these to ensure that xarray DataArray and Dataset have torch functions
21+
register_xr_data_array_to_tensor()
22+
register_xr_data_set_to_tensor()
23+
1624

1725
def pytest_addoption(parser):
1826
parser.addoption(

nowcasting_dataset/config/gcp.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ input_data:
88
solar_pv_metadata_filename: gs://solar-pv-nowcasting-data/PV/PVOutput.org/UK_PV_metadata.csv
99
gsp_zarr_path: gs://solar-pv-nowcasting-data/PV/GSP/v1/pv_gsp.zarr
1010
topographic_filename: gs://solar-pv-nowcasting-data/Topographic/europe_dem_1km_osgb.tif
11+
sun_zarr_path: gs://solar-pv-nowcasting-data/Sun/v0/sun.zarr
1112
output_data:
1213
filepath: gs://solar-pv-nowcasting-data/prepared_ML_training_data/v7/
1314
process:

nowcasting_dataset/data_sources/README.md

+8-1
Original file line numberDiff line numberDiff line change
@@ -38,5 +38,12 @@ General pydantic model of output of the data source. Contains the following meth
3838

3939
Roughly each of the data source folders follows this pattern
4040
- A class which defines how to load the data source, how to select for batches etc. This inherits from 'data_source.DataSource',
41-
- A class which contains the output model of the data source. This is the information used in the batches.
41+
- A class which contains the output model of the data source, built from an xarray Dataset. This is the information used in the batches.
4242
This inherits from 'datasource_output.DataSourceOutput'.
43+
- A second class (pydantic) which moves the xarray Dataset to tensor fields. This will be used for training in ML models
44+
45+
46+
# fake
47+
48+
`fake.py` has several function to create fake `Batch` data. This is useful for testing,
49+
and hopefully useful outside this module too.

nowcasting_dataset/data_sources/data_source.py

+13-9
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import nowcasting_dataset.time as nd_time
1313
from nowcasting_dataset import square
1414
from nowcasting_dataset.data_sources.datasource_output import DataSourceOutput
15+
from nowcasting_dataset.dataset.xr_utils import join_dataset_to_batch_dataset
1516

1617
logger = logging.getLogger(__name__)
1718

@@ -122,16 +123,19 @@ def get_batch(
122123
examples = []
123124
zipped = zip(t0_datetimes, x_locations, y_locations)
124125
for t0_datetime, x_location, y_location in zipped:
125-
output: DataSourceOutput = self.get_example(t0_datetime, x_location, y_location)
126+
output: xr.Dataset = self.get_example(t0_datetime, x_location, y_location)
126127

127-
if self.convert_to_numpy:
128-
output.to_numpy()
129128
examples.append(output)
130129

131130
# could add option here, to save each data source using
132131
# 1. # DataSourceOutput.to_xr_dataset() to make it a dataset
133132
# 2. DataSourceOutput.save_netcdf(), save to netcdf
134-
return DataSourceOutput.create_batch_from_examples(examples)
133+
134+
# get the name of the cls, this could be one of the data sources like Sun
135+
cls = examples[0].__class__
136+
137+
# join the examples together, and cast them to the cls, so that validation can occur
138+
return cls(join_dataset_to_batch_dataset(examples))
135139

136140
def datetime_index(self) -> pd.DatetimeIndex:
137141
"""Returns a complete list of all available datetimes."""
@@ -203,7 +207,7 @@ def get_example(
203207
t0_dt: pd.Timestamp, #: Datetime of "now": The most recent obs.
204208
x_meters_center: Number, #: Centre, in OSGB coordinates.
205209
y_meters_center: Number, #: Centre, in OSGB coordinates.
206-
) -> DataSourceOutput:
210+
) -> xr.Dataset:
207211
"""Must be overridden by child classes."""
208212
raise NotImplementedError()
209213

@@ -305,7 +309,10 @@ def get_example(
305309
f"actual shape {selected_data.shape}"
306310
)
307311

308-
return self._put_data_into_example(selected_data)
312+
# rename 'variable' to 'channels'
313+
selected_data = selected_data.rename({"variable": "channels"})
314+
315+
return selected_data
309316

310317
def geospatial_border(self) -> List[Tuple[Number, Number]]:
311318
"""
@@ -342,6 +349,3 @@ def open(self) -> None:
342349

343350
def _open_data(self) -> xr.DataArray:
344351
raise NotImplementedError()
345-
346-
def _put_data_into_example(self, selected_data: xr.DataArray) -> DataSourceOutput:
347-
raise NotImplementedError()

nowcasting_dataset/data_sources/datasource_output.py

+36-142
Original file line numberDiff line numberDiff line change
@@ -1,115 +1,40 @@
11
""" General Data Source output pydantic class. """
22
from __future__ import annotations
3-
import os
4-
from nowcasting_dataset.filesystem.utils import make_folder
5-
from nowcasting_dataset.utils import get_netcdf_filename
63

4+
import logging
5+
import os
76
from pathlib import Path
8-
from pydantic import BaseModel, Field
9-
import pandas as pd
10-
import xarray as xr
7+
from typing import List
8+
119
import numpy as np
12-
from typing import List, Union
13-
import logging
14-
from datetime import datetime
10+
from pydantic import BaseModel, Field
1511

16-
from nowcasting_dataset.utils import to_numpy
12+
from nowcasting_dataset.dataset.xr_utils import PydanticXArrayDataSet
13+
from nowcasting_dataset.filesystem.utils import make_folder
14+
from nowcasting_dataset.utils import get_netcdf_filename
1715

1816
logger = logging.getLogger(__name__)
1917

2018

21-
class DataSourceOutput(BaseModel):
19+
class DataSourceOutput(PydanticXArrayDataSet):
2220
"""General Data Source output pydantic class.
2321
2422
Data source output classes should inherit from this class
2523
"""
2624

27-
class Config:
28-
""" Allowed classes e.g. tensor.Tensor"""
29-
30-
# TODO maybe there is a better way to do this
31-
arbitrary_types_allowed = True
32-
33-
batch_size: int = Field(
34-
0,
35-
ge=0,
36-
description="The size of this batch. If the batch size is 0, "
37-
"then this item stores one data item i.e Example",
38-
)
25+
__slots__ = []
3926

4027
def get_name(self) -> str:
41-
""" Get the name of the class """
28+
"""Get the name of the class"""
4229
return self.__class__.__name__.lower()
4330

44-
def to_numpy(self):
45-
"""Change to numpy"""
46-
for k, v in self.dict().items():
47-
self.__setattr__(k, to_numpy(v))
48-
49-
def to_xr_data_array(self):
50-
""" Change to xr DataArray"""
51-
raise NotImplementedError()
52-
53-
@staticmethod
54-
def create_batch_from_examples(data):
55-
"""
56-
Join a list of data source items to a batch.
57-
58-
Note that this only works for numpy objects, so objects are changed into numpy
59-
"""
60-
_ = [d.to_numpy() for d in data]
61-
62-
# use the first item in the list, and then update each item
63-
batch = data[0]
64-
for k in batch.dict().keys():
65-
66-
# set batch size to the list of the items
67-
if k == "batch_size":
68-
batch.batch_size = len(data)
69-
else:
70-
71-
# get list of one variable from the list of data items.
72-
one_variable_list = [d.__getattribute__(k) for d in data]
73-
batch.__setattr__(k, np.stack(one_variable_list, axis=0))
74-
75-
return batch
76-
77-
def split(self) -> List[DataSourceOutput]:
78-
"""
79-
Split the datasource from a batch to a list of items
80-
81-
Returns: List of single data source items
82-
"""
83-
cls = self.__class__
84-
85-
items = []
86-
for batch_idx in range(self.batch_size):
87-
d = {k: v[batch_idx] for k, v in self.dict().items() if k != "batch_size"}
88-
d["batch_size"] = 0
89-
items.append(cls(**d))
90-
91-
return items
92-
93-
def to_xr_dataset(self, **kwargs):
94-
""" Make a xr dataset. Each data source needs to define this """
95-
raise NotImplementedError
96-
97-
def from_xr_dataset(self):
98-
""" Load from xr dataset. Each data source needs to define this """
99-
raise NotImplementedError
100-
101-
def get_datetime_index(self):
102-
""" Datetime index for the data """
103-
pass
104-
105-
def save_netcdf(self, batch_i: int, path: Path, xr_dataset: xr.Dataset):
31+
def save_netcdf(self, batch_i: int, path: Path):
10632
"""
10733
Save batch to netcdf file
10834
10935
Args:
11036
batch_i: the batch id, used to make the filename
11137
path: the path where it will be saved. This can be local or in the cloud.
112-
xr_dataset: xr dataset that has batch information in it
11338
"""
11439
filename = get_netcdf_filename(batch_i)
11540

@@ -124,77 +49,46 @@ def save_netcdf(self, batch_i: int, path: Path, xr_dataset: xr.Dataset):
12449
# make file
12550
local_filename = os.path.join(folder, filename)
12651

127-
encoding = {name: {"compression": "lzf"} for name in xr_dataset.data_vars}
128-
xr_dataset.to_netcdf(local_filename, engine="h5netcdf", mode="w", encoding=encoding)
129-
130-
def select_time_period(
131-
self,
132-
keys: List[str],
133-
history_minutes: int,
134-
forecast_minutes: int,
135-
t0_dt_of_first_example: Union[datetime, pd.Timestamp],
136-
):
137-
"""
138-
Selects a subset of data between the indicies of [start, end] for each key in keys
139-
140-
Note that class is edited so nothing is returned.
141-
142-
Args:
143-
keys: Keys in batch to use
144-
t0_dt_of_first_example: datetime of the current time (t0) in the first example of the batch
145-
history_minutes: How many minutes of history to use
146-
forecast_minutes: How many minutes of future data to use for forecasting
147-
148-
"""
149-
logger.debug(
150-
f"Taking a sub-selection of the batch data based on a history minutes of {history_minutes} "
151-
f"and forecast minutes of {forecast_minutes}"
152-
)
52+
encoding = {name: {"compression": "lzf"} for name in self.data_vars}
53+
self.to_netcdf(local_filename, engine="h5netcdf", mode="w", encoding=encoding)
15354

154-
start_time_of_first_batch = t0_dt_of_first_example - pd.to_timedelta(
155-
f"{history_minutes} minute 30 second"
156-
)
157-
end_time_of_first_example = t0_dt_of_first_example + pd.to_timedelta(
158-
f"{forecast_minutes} minute 30 second"
159-
)
16055

161-
logger.debug(f"New start time for first batch is {start_time_of_first_batch}")
162-
logger.debug(f"New end time for first batch is {end_time_of_first_example}")
56+
class DataSourceOutputML(BaseModel):
57+
"""General Data Source output pydantic class.
16358
164-
start_time_of_first_example = to_numpy(start_time_of_first_batch)
165-
end_time_of_first_example = to_numpy(end_time_of_first_example)
59+
Data source output classes should inherit from this class
60+
"""
16661

167-
if self.get_datetime_index() is not None:
62+
class Config:
63+
"""Allowed classes e.g. tensor.Tensor"""
16864

169-
time_of_first_example = to_numpy(pd.to_datetime(self.get_datetime_index()[0]))
65+
# TODO maybe there is a better way to do this
66+
arbitrary_types_allowed = True
17067

171-
# find the start and end index, that we will then use to slice the data
172-
start_i, end_i = np.searchsorted(
173-
time_of_first_example, [start_time_of_first_example, end_time_of_first_example]
174-
)
68+
batch_size: int = Field(
69+
0,
70+
ge=0,
71+
description="The size of this batch. If the batch size is 0, "
72+
"then this item stores one data item i.e Example",
73+
)
17574

176-
# slice all the data
177-
for key in keys:
178-
if "time" in self.__getattribute__(key).dims:
179-
self.__setattr__(
180-
key, self.__getattribute__(key).isel(time=slice(start_i, end_i))
181-
)
182-
elif "time_30" in self.__getattribute__(key).dims:
183-
self.__setattr__(
184-
key, self.__getattribute__(key).isel(time_30=slice(start_i, end_i))
185-
)
75+
def get_name(self) -> str:
76+
"""Get the name of the class"""
77+
return self.__class__.__name__.lower()
18678

187-
logger.debug(f"{self.__class__.__name__} {key}: {self.__getattribute__(key).shape}")
79+
def get_datetime_index(self):
80+
"""Datetime index for the data"""
81+
pass
18882

18983

19084
def pad_nans(array, pad_width) -> np.ndarray:
191-
""" Pad nans with nans"""
85+
"""Pad nans with nans"""
19286
array = array.astype(np.float32)
19387
return np.pad(array, pad_width, constant_values=np.NaN)
19488

19589

19690
def pad_data(
197-
data: DataSourceOutput,
91+
data: DataSourceOutputML,
19892
pad_size: int,
19993
one_dimensional_arrays: List[str],
20094
two_dimensional_arrays: List[str],

nowcasting_dataset/data_sources/datetime/datetime_data_source.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from nowcasting_dataset import time as nd_time
99
from nowcasting_dataset.data_sources.data_source import DataSource
1010
from nowcasting_dataset.data_sources.datetime.datetime_model import Datetime
11+
from nowcasting_dataset.dataset.xr_utils import make_dim_index
1112

1213

1314
@dataclass
@@ -36,7 +37,13 @@ def get_example(
3637
start_dt = self._get_start_dt(t0_dt)
3738
end_dt = self._get_end_dt(t0_dt)
3839
index = pd.date_range(start_dt, end_dt, freq="5T")
39-
return nd_time.datetime_features_in_example(index)
40+
41+
datetime_xr_dataset = nd_time.datetime_features_in_example(index).rename({"index": "time"})
42+
43+
# make sure time is indexes in the correct way
44+
datetime_xr_dataset = make_dim_index(datetime_xr_dataset)
45+
46+
return Datetime(datetime_xr_dataset)
4047

4148
def get_locations_for_batch(
4249
self, t0_datetimes: pd.DatetimeIndex

0 commit comments

Comments
 (0)