1
1
""" General Data Source output pydantic class. """
2
2
from __future__ import annotations
3
- import os
4
- from nowcasting_dataset .filesystem .utils import make_folder
5
- from nowcasting_dataset .utils import get_netcdf_filename
6
3
4
+ import logging
5
+ import os
7
6
from pathlib import Path
8
- from pydantic import BaseModel , Field
9
- import pandas as pd
10
- import xarray as xr
7
+ from typing import List
8
+
11
9
import numpy as np
12
- from typing import List , Union
13
- import logging
14
- from datetime import datetime
10
+ from pydantic import BaseModel , Field
15
11
16
- from nowcasting_dataset .utils import to_numpy
12
+ from nowcasting_dataset .dataset .xr_utils import PydanticXArrayDataSet
13
+ from nowcasting_dataset .filesystem .utils import make_folder
14
+ from nowcasting_dataset .utils import get_netcdf_filename
17
15
18
16
logger = logging .getLogger (__name__ )
19
17
20
18
21
- class DataSourceOutput (BaseModel ):
19
+ class DataSourceOutput (PydanticXArrayDataSet ):
22
20
"""General Data Source output pydantic class.
23
21
24
22
Data source output classes should inherit from this class
25
23
"""
26
24
27
- class Config :
28
- """ Allowed classes e.g. tensor.Tensor"""
29
-
30
- # TODO maybe there is a better way to do this
31
- arbitrary_types_allowed = True
32
-
33
- batch_size : int = Field (
34
- 0 ,
35
- ge = 0 ,
36
- description = "The size of this batch. If the batch size is 0, "
37
- "then this item stores one data item i.e Example" ,
38
- )
25
+ __slots__ = []
39
26
40
27
def get_name (self ) -> str :
41
- """ Get the name of the class """
28
+ """Get the name of the class"""
42
29
return self .__class__ .__name__ .lower ()
43
30
44
- def to_numpy (self ):
45
- """Change to numpy"""
46
- for k , v in self .dict ().items ():
47
- self .__setattr__ (k , to_numpy (v ))
48
-
49
- def to_xr_data_array (self ):
50
- """ Change to xr DataArray"""
51
- raise NotImplementedError ()
52
-
53
- @staticmethod
54
- def create_batch_from_examples (data ):
55
- """
56
- Join a list of data source items to a batch.
57
-
58
- Note that this only works for numpy objects, so objects are changed into numpy
59
- """
60
- _ = [d .to_numpy () for d in data ]
61
-
62
- # use the first item in the list, and then update each item
63
- batch = data [0 ]
64
- for k in batch .dict ().keys ():
65
-
66
- # set batch size to the list of the items
67
- if k == "batch_size" :
68
- batch .batch_size = len (data )
69
- else :
70
-
71
- # get list of one variable from the list of data items.
72
- one_variable_list = [d .__getattribute__ (k ) for d in data ]
73
- batch .__setattr__ (k , np .stack (one_variable_list , axis = 0 ))
74
-
75
- return batch
76
-
77
- def split (self ) -> List [DataSourceOutput ]:
78
- """
79
- Split the datasource from a batch to a list of items
80
-
81
- Returns: List of single data source items
82
- """
83
- cls = self .__class__
84
-
85
- items = []
86
- for batch_idx in range (self .batch_size ):
87
- d = {k : v [batch_idx ] for k , v in self .dict ().items () if k != "batch_size" }
88
- d ["batch_size" ] = 0
89
- items .append (cls (** d ))
90
-
91
- return items
92
-
93
- def to_xr_dataset (self , ** kwargs ):
94
- """ Make a xr dataset. Each data source needs to define this """
95
- raise NotImplementedError
96
-
97
- def from_xr_dataset (self ):
98
- """ Load from xr dataset. Each data source needs to define this """
99
- raise NotImplementedError
100
-
101
- def get_datetime_index (self ):
102
- """ Datetime index for the data """
103
- pass
104
-
105
- def save_netcdf (self , batch_i : int , path : Path , xr_dataset : xr .Dataset ):
31
+ def save_netcdf (self , batch_i : int , path : Path ):
106
32
"""
107
33
Save batch to netcdf file
108
34
109
35
Args:
110
36
batch_i: the batch id, used to make the filename
111
37
path: the path where it will be saved. This can be local or in the cloud.
112
- xr_dataset: xr dataset that has batch information in it
113
38
"""
114
39
filename = get_netcdf_filename (batch_i )
115
40
@@ -124,77 +49,46 @@ def save_netcdf(self, batch_i: int, path: Path, xr_dataset: xr.Dataset):
124
49
# make file
125
50
local_filename = os .path .join (folder , filename )
126
51
127
- encoding = {name : {"compression" : "lzf" } for name in xr_dataset .data_vars }
128
- xr_dataset .to_netcdf (local_filename , engine = "h5netcdf" , mode = "w" , encoding = encoding )
129
-
130
- def select_time_period (
131
- self ,
132
- keys : List [str ],
133
- history_minutes : int ,
134
- forecast_minutes : int ,
135
- t0_dt_of_first_example : Union [datetime , pd .Timestamp ],
136
- ):
137
- """
138
- Selects a subset of data between the indicies of [start, end] for each key in keys
139
-
140
- Note that class is edited so nothing is returned.
141
-
142
- Args:
143
- keys: Keys in batch to use
144
- t0_dt_of_first_example: datetime of the current time (t0) in the first example of the batch
145
- history_minutes: How many minutes of history to use
146
- forecast_minutes: How many minutes of future data to use for forecasting
147
-
148
- """
149
- logger .debug (
150
- f"Taking a sub-selection of the batch data based on a history minutes of { history_minutes } "
151
- f"and forecast minutes of { forecast_minutes } "
152
- )
52
+ encoding = {name : {"compression" : "lzf" } for name in self .data_vars }
53
+ self .to_netcdf (local_filename , engine = "h5netcdf" , mode = "w" , encoding = encoding )
153
54
154
- start_time_of_first_batch = t0_dt_of_first_example - pd .to_timedelta (
155
- f"{ history_minutes } minute 30 second"
156
- )
157
- end_time_of_first_example = t0_dt_of_first_example + pd .to_timedelta (
158
- f"{ forecast_minutes } minute 30 second"
159
- )
160
55
161
- logger . debug ( f"New start time for first batch is { start_time_of_first_batch } " )
162
- logger . debug ( f"New end time for first batch is { end_time_of_first_example } " )
56
+ class DataSourceOutputML ( BaseModel ):
57
+ """General Data Source output pydantic class.
163
58
164
- start_time_of_first_example = to_numpy ( start_time_of_first_batch )
165
- end_time_of_first_example = to_numpy ( end_time_of_first_example )
59
+ Data source output classes should inherit from this class
60
+ """
166
61
167
- if self .get_datetime_index () is not None :
62
+ class Config :
63
+ """Allowed classes e.g. tensor.Tensor"""
168
64
169
- time_of_first_example = to_numpy (pd .to_datetime (self .get_datetime_index ()[0 ]))
65
+ # TODO maybe there is a better way to do this
66
+ arbitrary_types_allowed = True
170
67
171
- # find the start and end index, that we will then use to slice the data
172
- start_i , end_i = np .searchsorted (
173
- time_of_first_example , [start_time_of_first_example , end_time_of_first_example ]
174
- )
68
+ batch_size : int = Field (
69
+ 0 ,
70
+ ge = 0 ,
71
+ description = "The size of this batch. If the batch size is 0, "
72
+ "then this item stores one data item i.e Example" ,
73
+ )
175
74
176
- # slice all the data
177
- for key in keys :
178
- if "time" in self .__getattribute__ (key ).dims :
179
- self .__setattr__ (
180
- key , self .__getattribute__ (key ).isel (time = slice (start_i , end_i ))
181
- )
182
- elif "time_30" in self .__getattribute__ (key ).dims :
183
- self .__setattr__ (
184
- key , self .__getattribute__ (key ).isel (time_30 = slice (start_i , end_i ))
185
- )
75
+ def get_name (self ) -> str :
76
+ """Get the name of the class"""
77
+ return self .__class__ .__name__ .lower ()
186
78
187
- logger .debug (f"{ self .__class__ .__name__ } { key } : { self .__getattribute__ (key ).shape } " )
79
+ def get_datetime_index (self ):
80
+ """Datetime index for the data"""
81
+ pass
188
82
189
83
190
84
def pad_nans (array , pad_width ) -> np .ndarray :
191
- """ Pad nans with nans"""
85
+ """Pad nans with nans"""
192
86
array = array .astype (np .float32 )
193
87
return np .pad (array , pad_width , constant_values = np .NaN )
194
88
195
89
196
90
def pad_data (
197
- data : DataSourceOutput ,
91
+ data : DataSourceOutputML ,
198
92
pad_size : int ,
199
93
one_dimensional_arrays : List [str ],
200
94
two_dimensional_arrays : List [str ],
0 commit comments