2
2
import warnings
3
3
from collections import OrderedDict
4
4
from distutils .version import LooseVersion
5
-
6
5
import numpy as np
7
6
8
7
from .. import DataArray
24
23
class RasterioArrayWrapper (BackendArray ):
25
24
"""A wrapper around rasterio dataset objects"""
26
25
27
- def __init__ (self , manager , lock ):
26
+ def __init__ (self , manager , lock , vrt_params = None ):
27
+ from rasterio .vrt import WarpedVRT
28
28
self .manager = manager
29
29
self .lock = lock
30
30
31
31
# cannot save riods as an attribute: this would break pickleability
32
32
riods = manager .acquire ()
33
-
33
+ if vrt_params is not None :
34
+ riods = WarpedVRT (riods , ** vrt_params )
35
+ self .vrt_params = vrt_params
34
36
self ._shape = (riods .count , riods .height , riods .width )
35
37
36
38
dtypes = riods .dtypes
@@ -104,6 +106,7 @@ def _get_indexer(self, key):
104
106
return band_key , tuple (window ), tuple (squeeze_axis ), tuple (np_inds )
105
107
106
108
def _getitem (self , key ):
109
+ from rasterio .vrt import WarpedVRT
107
110
band_key , window , squeeze_axis , np_inds = self ._get_indexer (key )
108
111
109
112
if not band_key or any (start == stop for (start , stop ) in window ):
@@ -114,6 +117,8 @@ def _getitem(self, key):
114
117
else :
115
118
with self .lock :
116
119
riods = self .manager .acquire (needs_lock = False )
120
+ if self .vrt_params is not None :
121
+ riods = WarpedVRT (riods , ** self .vrt_params )
117
122
out = riods .read (band_key , window = window )
118
123
119
124
if squeeze_axis :
@@ -178,8 +183,8 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None,
178
183
179
184
Parameters
180
185
----------
181
- filename : str
182
- Path to the file to open.
186
+ filename : str, rasterio.DatasetReader, or rasterio.WarpedVRT
187
+ Path to the file to open. Or already open rasterio dataset.
183
188
parse_coordinates : bool, optional
184
189
Whether to parse the x and y coordinates out of the file's
185
190
``transform`` attribute or not. The default is to automatically
@@ -206,14 +211,28 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None,
206
211
data : DataArray
207
212
The newly created DataArray.
208
213
"""
209
-
210
214
import rasterio
215
+ from rasterio .vrt import WarpedVRT
216
+ vrt_params = None
217
+ if isinstance (filename , rasterio .io .DatasetReader ):
218
+ filename = filename .name
219
+ elif isinstance (filename , rasterio .vrt .WarpedVRT ):
220
+ vrt = filename
221
+ filename = vrt .src_dataset .name
222
+ vrt_params = dict (crs = vrt .crs .to_string (),
223
+ resampling = vrt .resampling ,
224
+ src_nodata = vrt .src_nodata ,
225
+ dst_nodata = vrt .dst_nodata ,
226
+ tolerance = vrt .tolerance ,
227
+ warp_extras = vrt .warp_extras )
211
228
212
229
if lock is None :
213
230
lock = RASTERIO_LOCK
214
231
215
232
manager = CachingFileManager (rasterio .open , filename , lock = lock , mode = 'r' )
216
233
riods = manager .acquire ()
234
+ if vrt_params is not None :
235
+ riods = WarpedVRT (riods , ** vrt_params )
217
236
218
237
if cache is None :
219
238
cache = chunks is None
@@ -287,14 +306,14 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None,
287
306
for k , v in meta .items ():
288
307
# Add values as coordinates if they match the band count,
289
308
# as attributes otherwise
290
- if (isinstance (v , (list , np .ndarray )) and
291
- len (v ) == riods .count ):
309
+ if (isinstance (v , (list , np .ndarray ))
310
+ and len (v ) == riods .count ):
292
311
coords [k ] = ('band' , np .asarray (v ))
293
312
else :
294
313
attrs [k ] = v
295
314
296
315
data = indexing .LazilyOuterIndexedArray (
297
- RasterioArrayWrapper (manager , lock ))
316
+ RasterioArrayWrapper (manager , lock , vrt_params ))
298
317
299
318
# this lets you write arrays loaded with rasterio
300
319
data = indexing .CopyOnWriteArray (data )
0 commit comments