-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
Copy pathdataset.py
10879 lines (9467 loc) · 409 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
from __future__ import annotations
import copy
import datetime
import inspect
import itertools
import math
import sys
import warnings
from collections import defaultdict
from collections.abc import (
Callable,
Collection,
Hashable,
Iterable,
Iterator,
Mapping,
MutableMapping,
Sequence,
)
from functools import partial
from html import escape
from numbers import Number
from operator import methodcaller
from os import PathLike
from types import EllipsisType
from typing import IO, TYPE_CHECKING, Any, Generic, Literal, cast, overload
import numpy as np
from pandas.api.types import is_extension_array_dtype
# remove once numpy 2.0 is the oldest supported version
try:
from numpy.exceptions import RankWarning
except ImportError:
from numpy import RankWarning # type: ignore[no-redef,attr-defined,unused-ignore]
import pandas as pd
from xarray.coding.calendar_ops import convert_calendar, interp_calendar
from xarray.coding.cftimeindex import CFTimeIndex, _parse_array_of_cftime_strings
from xarray.core import (
alignment,
duck_array_ops,
formatting,
formatting_html,
ops,
utils,
)
from xarray.core import dtypes as xrdtypes
from xarray.core._aggregations import DatasetAggregations
from xarray.core.alignment import (
_broadcast_helper,
_get_broadcast_dims_map_common_coords,
align,
)
from xarray.core.arithmetic import DatasetArithmetic
from xarray.core.common import (
DataWithCoords,
_contains_datetime_like_objects,
get_chunksizes,
)
from xarray.core.computation import unify_chunks
from xarray.core.coordinates import (
Coordinates,
DatasetCoordinates,
assert_coordinate_consistent,
create_coords_with_default_indexes,
)
from xarray.core.duck_array_ops import datetime_to_numeric
from xarray.core.indexes import (
Index,
Indexes,
PandasIndex,
PandasMultiIndex,
assert_no_index_corrupted,
create_default_index_implicit,
filter_indexes_from_coords,
isel_indexes,
remove_unused_levels_categories,
roll_indexes,
)
from xarray.core.indexing import is_fancy_indexer, map_index_queries
from xarray.core.merge import (
dataset_merge_method,
dataset_update_method,
merge_coordinates_without_align,
merge_core,
)
from xarray.core.missing import _floatize_x
from xarray.core.options import OPTIONS, _get_keep_attrs
from xarray.core.types import (
Bins,
NetcdfWriteModes,
QuantileMethods,
Self,
T_ChunkDim,
T_ChunksFreq,
T_DataArray,
T_DataArrayOrSet,
T_Dataset,
ZarrWriteModes,
)
from xarray.core.utils import (
Default,
FilteredMapping,
Frozen,
FrozenMappingWarningOnValuesAccess,
OrderedSet,
_default,
decode_numpy_dict_values,
drop_dims_from_indexers,
either_dict_or_kwargs,
emit_user_level_warning,
infix_dims,
is_dict_like,
is_duck_array,
is_duck_dask_array,
is_scalar,
maybe_wrap_array,
parse_dims_as_set,
)
from xarray.core.variable import (
IndexVariable,
Variable,
as_variable,
broadcast_variables,
calculate_dimensions,
)
from xarray.namedarray.parallelcompat import get_chunked_array_type, guess_chunkmanager
from xarray.namedarray.pycompat import array_type, is_chunked_array
from xarray.plot.accessor import DatasetPlotAccessor
from xarray.util.deprecation_helpers import _deprecate_positional_args, deprecate_dims
if TYPE_CHECKING:
from dask.dataframe import DataFrame as DaskDataFrame
from dask.delayed import Delayed
from numpy.typing import ArrayLike
from xarray.backends import AbstractDataStore, ZarrStore
from xarray.backends.api import T_NetcdfEngine, T_NetcdfTypes
from xarray.core.dataarray import DataArray
from xarray.core.groupby import DatasetGroupBy
from xarray.core.merge import CoercibleMapping, CoercibleValue, _MergeResult
from xarray.core.resample import DatasetResample
from xarray.core.rolling import DatasetCoarsen, DatasetRolling
from xarray.core.types import (
CFCalendar,
CoarsenBoundaryOptions,
CombineAttrsOptions,
CompatOptions,
DataVars,
DatetimeLike,
DatetimeUnitOptions,
Dims,
DsCompatible,
ErrorOptions,
ErrorOptionsWithWarn,
GroupInput,
InterpOptions,
JoinOptions,
PadModeOptions,
PadReflectOptions,
QueryEngineOptions,
QueryParserOptions,
ReindexMethodOptions,
ResampleCompatible,
SideOptions,
T_ChunkDimFreq,
T_DatasetPadConstantValues,
T_Xarray,
)
from xarray.core.weighted import DatasetWeighted
from xarray.groupers import Grouper, Resampler
from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint
# list of attributes of pd.DatetimeIndex that are ndarrays of time info
_DATETIMEINDEX_COMPONENTS = [
"year",
"month",
"day",
"hour",
"minute",
"second",
"microsecond",
"nanosecond",
"date",
"time",
"dayofyear",
"weekofyear",
"dayofweek",
"quarter",
]
def _get_virtual_variable(
variables, key: Hashable, dim_sizes: Mapping | None = None
) -> tuple[Hashable, Hashable, Variable]:
"""Get a virtual variable (e.g., 'time.year') from a dict of xarray.Variable
objects (if possible)
"""
from xarray.core.dataarray import DataArray
if dim_sizes is None:
dim_sizes = {}
if key in dim_sizes:
data = pd.Index(range(dim_sizes[key]), name=key)
variable = IndexVariable((key,), data)
return key, key, variable
if not isinstance(key, str):
raise KeyError(key)
split_key = key.split(".", 1)
if len(split_key) != 2:
raise KeyError(key)
ref_name, var_name = split_key
ref_var = variables[ref_name]
if _contains_datetime_like_objects(ref_var):
ref_var = DataArray(ref_var)
data = getattr(ref_var.dt, var_name).data
else:
data = getattr(ref_var, var_name).data
virtual_var = Variable(ref_var.dims, data)
return ref_name, var_name, virtual_var
def _get_chunk(var: Variable, chunks, chunkmanager: ChunkManagerEntrypoint):
"""
Return map from each dim to chunk sizes, accounting for backend's preferred chunks.
"""
if isinstance(var, IndexVariable):
return {}
dims = var.dims
shape = var.shape
# Determine the explicit requested chunks.
preferred_chunks = var.encoding.get("preferred_chunks", {})
preferred_chunk_shape = tuple(
preferred_chunks.get(dim, size) for dim, size in zip(dims, shape, strict=True)
)
if isinstance(chunks, Number) or (chunks == "auto"):
chunks = dict.fromkeys(dims, chunks)
chunk_shape = tuple(
chunks.get(dim, None) or preferred_chunk_sizes
for dim, preferred_chunk_sizes in zip(dims, preferred_chunk_shape, strict=True)
)
chunk_shape = chunkmanager.normalize_chunks(
chunk_shape, shape=shape, dtype=var.dtype, previous_chunks=preferred_chunk_shape
)
# Warn where requested chunks break preferred chunks, provided that the variable
# contains data.
if var.size:
for dim, size, chunk_sizes in zip(dims, shape, chunk_shape, strict=True):
try:
preferred_chunk_sizes = preferred_chunks[dim]
except KeyError:
continue
# Determine the stop indices of the preferred chunks, but omit the last stop
# (equal to the dim size). In particular, assume that when a sequence
# expresses the preferred chunks, the sequence sums to the size.
preferred_stops = (
range(preferred_chunk_sizes, size, preferred_chunk_sizes)
if isinstance(preferred_chunk_sizes, int)
else itertools.accumulate(preferred_chunk_sizes[:-1])
)
# Gather any stop indices of the specified chunks that are not a stop index
# of a preferred chunk. Again, omit the last stop, assuming that it equals
# the dim size.
breaks = set(itertools.accumulate(chunk_sizes[:-1])).difference(
preferred_stops
)
if breaks:
warnings.warn(
"The specified chunks separate the stored chunks along "
f'dimension "{dim}" starting at index {min(breaks)}. This could '
"degrade performance. Instead, consider rechunking after loading.",
stacklevel=2,
)
return dict(zip(dims, chunk_shape, strict=True))
def _maybe_chunk(
name: Hashable,
var: Variable,
chunks: Mapping[Any, T_ChunkDim] | None,
token=None,
lock=None,
name_prefix: str = "xarray-",
overwrite_encoded_chunks: bool = False,
inline_array: bool = False,
chunked_array_type: str | ChunkManagerEntrypoint | None = None,
from_array_kwargs=None,
) -> Variable:
from xarray.namedarray.daskmanager import DaskManager
if chunks is not None:
chunks = {dim: chunks[dim] for dim in var.dims if dim in chunks}
if var.ndim:
chunked_array_type = guess_chunkmanager(
chunked_array_type
) # coerce string to ChunkManagerEntrypoint type
if isinstance(chunked_array_type, DaskManager):
from dask.base import tokenize
# when rechunking by different amounts, make sure dask names change
# by providing chunks as an input to tokenize.
# subtle bugs result otherwise. see GH3350
# we use str() for speed, and use the name for the final array name on the next line
token2 = tokenize(token if token else var._data, str(chunks))
name2 = f"{name_prefix}{name}-{token2}"
from_array_kwargs = utils.consolidate_dask_from_array_kwargs(
from_array_kwargs,
name=name2,
lock=lock,
inline_array=inline_array,
)
var = var.chunk(
chunks,
chunked_array_type=chunked_array_type,
from_array_kwargs=from_array_kwargs,
)
if overwrite_encoded_chunks and var.chunks is not None:
var.encoding["chunks"] = tuple(x[0] for x in var.chunks)
return var
else:
return var
def as_dataset(obj: Any) -> Dataset:
"""Cast the given object to a Dataset.
Handles Datasets, DataArrays and dictionaries of variables. A new Dataset
object is only created if the provided object is not already one.
"""
if hasattr(obj, "to_dataset"):
obj = obj.to_dataset()
if not isinstance(obj, Dataset):
obj = Dataset(obj)
return obj
def _get_func_args(func, param_names):
"""Use `inspect.signature` to try accessing `func` args. Otherwise, ensure
they are provided by user.
"""
try:
func_args = inspect.signature(func).parameters
except ValueError as err:
func_args = {}
if not param_names:
raise ValueError(
"Unable to inspect `func` signature, and `param_names` was not provided."
) from err
if param_names:
params = param_names
else:
params = list(func_args)[1:]
if any(
[(p.kind in [p.VAR_POSITIONAL, p.VAR_KEYWORD]) for p in func_args.values()]
):
raise ValueError(
"`param_names` must be provided because `func` takes variable length arguments."
)
return params, func_args
def _initialize_curvefit_params(params, p0, bounds, func_args):
"""Set initial guess and bounds for curvefit.
Priority: 1) passed args 2) func signature 3) scipy defaults
"""
from xarray.core.computation import where
def _initialize_feasible(lb, ub):
# Mimics functionality of scipy.optimize.minpack._initialize_feasible
lb_finite = np.isfinite(lb)
ub_finite = np.isfinite(ub)
p0 = where(
lb_finite,
where(
ub_finite,
0.5 * (lb + ub), # both bounds finite
lb + 1, # lower bound finite, upper infinite
),
where(
ub_finite,
ub - 1, # lower bound infinite, upper finite
0, # both bounds infinite
),
)
return p0
param_defaults = {p: 1 for p in params}
bounds_defaults = {p: (-np.inf, np.inf) for p in params}
for p in params:
if p in func_args and func_args[p].default is not func_args[p].empty:
param_defaults[p] = func_args[p].default
if p in bounds:
lb, ub = bounds[p]
bounds_defaults[p] = (lb, ub)
param_defaults[p] = where(
(param_defaults[p] < lb) | (param_defaults[p] > ub),
_initialize_feasible(lb, ub),
param_defaults[p],
)
if p in p0:
param_defaults[p] = p0[p]
return param_defaults, bounds_defaults
def merge_data_and_coords(data_vars: DataVars, coords) -> _MergeResult:
"""Used in Dataset.__init__."""
if isinstance(coords, Coordinates):
coords = coords.copy()
else:
coords = create_coords_with_default_indexes(coords, data_vars)
# exclude coords from alignment (all variables in a Coordinates object should
# already be aligned together) and use coordinates' indexes to align data_vars
return merge_core(
[data_vars, coords],
compat="broadcast_equals",
join="outer",
explicit_coords=tuple(coords),
indexes=coords.xindexes,
priority_arg=1,
skip_align_args=[1],
)
class DataVariables(Mapping[Any, "DataArray"]):
__slots__ = ("_dataset",)
def __init__(self, dataset: Dataset):
self._dataset = dataset
def __iter__(self) -> Iterator[Hashable]:
return (
key
for key in self._dataset._variables
if key not in self._dataset._coord_names
)
def __len__(self) -> int:
length = len(self._dataset._variables) - len(self._dataset._coord_names)
assert length >= 0, "something is wrong with Dataset._coord_names"
return length
def __contains__(self, key: Hashable) -> bool:
return key in self._dataset._variables and key not in self._dataset._coord_names
def __getitem__(self, key: Hashable) -> DataArray:
if key not in self._dataset._coord_names:
return self._dataset[key]
raise KeyError(key)
def __repr__(self) -> str:
return formatting.data_vars_repr(self)
@property
def variables(self) -> Mapping[Hashable, Variable]:
all_variables = self._dataset.variables
return Frozen({k: all_variables[k] for k in self})
@property
def dtypes(self) -> Frozen[Hashable, np.dtype]:
"""Mapping from data variable names to dtypes.
Cannot be modified directly, but is updated when adding new variables.
See Also
--------
Dataset.dtype
"""
return self._dataset.dtypes
def _ipython_key_completions_(self):
"""Provide method for the key-autocompletions in IPython."""
return [
key
for key in self._dataset._ipython_key_completions_()
if key not in self._dataset._coord_names
]
class _LocIndexer(Generic[T_Dataset]):
__slots__ = ("dataset",)
def __init__(self, dataset: T_Dataset):
self.dataset = dataset
def __getitem__(self, key: Mapping[Any, Any]) -> T_Dataset:
if not utils.is_dict_like(key):
raise TypeError("can only lookup dictionaries from Dataset.loc")
return self.dataset.sel(key)
def __setitem__(self, key, value) -> None:
if not utils.is_dict_like(key):
raise TypeError(
"can only set locations defined by dictionaries from Dataset.loc."
f" Got: {key}"
)
# set new values
dim_indexers = map_index_queries(self.dataset, key).dim_indexers
self.dataset[dim_indexers] = value
class Dataset(
DataWithCoords,
DatasetAggregations,
DatasetArithmetic,
Mapping[Hashable, "DataArray"],
):
"""A multi-dimensional, in memory, array database.
A dataset resembles an in-memory representation of a NetCDF file,
and consists of variables, coordinates and attributes which
together form a self describing dataset.
Dataset implements the mapping interface with keys given by variable
names and values given by DataArray objects for each variable name.
By default, pandas indexes are created for one dimensional variables with
name equal to their dimension (i.e., :term:`Dimension coordinate`) so those
variables can be readily used as coordinates for label based indexing. When a
:py:class:`~xarray.Coordinates` object is passed to ``coords``, any existing
index(es) built from those coordinates will be added to the Dataset.
To load data from a file or file-like object, use the `open_dataset`
function.
Parameters
----------
data_vars : dict-like, optional
A mapping from variable names to :py:class:`~xarray.DataArray`
objects, :py:class:`~xarray.Variable` objects or to tuples of
the form ``(dims, data[, attrs])`` which can be used as
arguments to create a new ``Variable``. Each dimension must
have the same length in all variables in which it appears.
The following notations are accepted:
- mapping {var name: DataArray}
- mapping {var name: Variable}
- mapping {var name: (dimension name, array-like)}
- mapping {var name: (tuple of dimension names, array-like)}
- mapping {dimension name: array-like}
(if array-like is not a scalar it will be automatically moved to coords,
see below)
Each dimension must have the same length in all variables in
which it appears.
coords : :py:class:`~xarray.Coordinates` or dict-like, optional
A :py:class:`~xarray.Coordinates` object or another mapping in
similar form as the `data_vars` argument, except that each item
is saved on the dataset as a "coordinate".
These variables have an associated meaning: they describe
constant/fixed/independent quantities, unlike the
varying/measured/dependent quantities that belong in
`variables`.
The following notations are accepted for arbitrary mappings:
- mapping {coord name: DataArray}
- mapping {coord name: Variable}
- mapping {coord name: (dimension name, array-like)}
- mapping {coord name: (tuple of dimension names, array-like)}
- mapping {dimension name: array-like}
(the dimension name is implicitly set to be the same as the
coord name)
The last notation implies either that the coordinate value is a scalar
or that it is a 1-dimensional array and the coord name is the same as
the dimension name (i.e., a :term:`Dimension coordinate`). In the latter
case, the 1-dimensional array will be assumed to give index values
along the dimension with the same name.
Alternatively, a :py:class:`~xarray.Coordinates` object may be used in
order to explicitly pass indexes (e.g., a multi-index or any custom
Xarray index) or to bypass the creation of a default index for any
:term:`Dimension coordinate` included in that object.
attrs : dict-like, optional
Global attributes to save on this dataset.
Examples
--------
In this example dataset, we will represent measurements of the temperature
and pressure that were made under various conditions:
* the measurements were made on four different days;
* they were made at two separate locations, which we will represent using
their latitude and longitude; and
* they were made using three instrument developed by three different
manufacturers, which we will refer to using the strings `'manufac1'`,
`'manufac2'`, and `'manufac3'`.
>>> np.random.seed(0)
>>> temperature = 15 + 8 * np.random.randn(2, 3, 4)
>>> precipitation = 10 * np.random.rand(2, 3, 4)
>>> lon = [-99.83, -99.32]
>>> lat = [42.25, 42.21]
>>> instruments = ["manufac1", "manufac2", "manufac3"]
>>> time = pd.date_range("2014-09-06", periods=4)
>>> reference_time = pd.Timestamp("2014-09-05")
Here, we initialize the dataset with multiple dimensions. We use the string
`"loc"` to represent the location dimension of the data, the string
`"instrument"` to represent the instrument manufacturer dimension, and the
string `"time"` for the time dimension.
>>> ds = xr.Dataset(
... data_vars=dict(
... temperature=(["loc", "instrument", "time"], temperature),
... precipitation=(["loc", "instrument", "time"], precipitation),
... ),
... coords=dict(
... lon=("loc", lon),
... lat=("loc", lat),
... instrument=instruments,
... time=time,
... reference_time=reference_time,
... ),
... attrs=dict(description="Weather related data."),
... )
>>> ds
<xarray.Dataset> Size: 552B
Dimensions: (loc: 2, instrument: 3, time: 4)
Coordinates:
lon (loc) float64 16B -99.83 -99.32
lat (loc) float64 16B 42.25 42.21
* instrument (instrument) <U8 96B 'manufac1' 'manufac2' 'manufac3'
* time (time) datetime64[ns] 32B 2014-09-06 ... 2014-09-09
reference_time datetime64[ns] 8B 2014-09-05
Dimensions without coordinates: loc
Data variables:
temperature (loc, instrument, time) float64 192B 29.11 18.2 ... 9.063
precipitation (loc, instrument, time) float64 192B 4.562 5.684 ... 1.613
Attributes:
description: Weather related data.
Find out where the coldest temperature was and what values the
other variables had:
>>> ds.isel(ds.temperature.argmin(...))
<xarray.Dataset> Size: 80B
Dimensions: ()
Coordinates:
lon float64 8B -99.32
lat float64 8B 42.21
instrument <U8 32B 'manufac3'
time datetime64[ns] 8B 2014-09-06
reference_time datetime64[ns] 8B 2014-09-05
Data variables:
temperature float64 8B -5.424
precipitation float64 8B 9.884
Attributes:
description: Weather related data.
"""
_attrs: dict[Hashable, Any] | None
_cache: dict[str, Any]
_coord_names: set[Hashable]
_dims: dict[Hashable, int]
_encoding: dict[Hashable, Any] | None
_close: Callable[[], None] | None
_indexes: dict[Hashable, Index]
_variables: dict[Hashable, Variable]
__slots__ = (
"_attrs",
"_cache",
"_coord_names",
"_dims",
"_encoding",
"_close",
"_indexes",
"_variables",
"__weakref__",
)
def __init__(
self,
# could make a VariableArgs to use more generally, and refine these
# categories
data_vars: DataVars | None = None,
coords: Mapping[Any, Any] | None = None,
attrs: Mapping[Any, Any] | None = None,
) -> None:
if data_vars is None:
data_vars = {}
if coords is None:
coords = {}
both_data_and_coords = set(data_vars) & set(coords)
if both_data_and_coords:
raise ValueError(
f"variables {both_data_and_coords!r} are found in both data_vars and coords"
)
if isinstance(coords, Dataset):
coords = coords._variables
variables, coord_names, dims, indexes, _ = merge_data_and_coords(
data_vars, coords
)
self._attrs = dict(attrs) if attrs else None
self._close = None
self._encoding = None
self._variables = variables
self._coord_names = coord_names
self._dims = dims
self._indexes = indexes
# TODO: dirty workaround for mypy 1.5 error with inherited DatasetOpsMixin vs. Mapping
# related to https://github.com/python/mypy/issues/9319?
def __eq__(self, other: DsCompatible) -> Self: # type: ignore[override]
return super().__eq__(other)
@classmethod
def load_store(cls, store, decoder=None) -> Self:
"""Create a new dataset from the contents of a backends.*DataStore
object
"""
variables, attributes = store.load()
if decoder:
variables, attributes = decoder(variables, attributes)
obj = cls(variables, attrs=attributes)
obj.set_close(store.close)
return obj
@property
def variables(self) -> Frozen[Hashable, Variable]:
"""Low level interface to Dataset contents as dict of Variable objects.
This ordered dictionary is frozen to prevent mutation that could
violate Dataset invariants. It contains all variable objects
constituting the Dataset, including both data variables and
coordinates.
"""
return Frozen(self._variables)
@property
def attrs(self) -> dict[Any, Any]:
"""Dictionary of global attributes on this dataset"""
if self._attrs is None:
self._attrs = {}
return self._attrs
@attrs.setter
def attrs(self, value: Mapping[Any, Any]) -> None:
self._attrs = dict(value) if value else None
@property
def encoding(self) -> dict[Any, Any]:
"""Dictionary of global encoding attributes on this dataset"""
if self._encoding is None:
self._encoding = {}
return self._encoding
@encoding.setter
def encoding(self, value: Mapping[Any, Any]) -> None:
self._encoding = dict(value)
def reset_encoding(self) -> Self:
warnings.warn(
"reset_encoding is deprecated since 2023.11, use `drop_encoding` instead",
stacklevel=2,
)
return self.drop_encoding()
def drop_encoding(self) -> Self:
"""Return a new Dataset without encoding on the dataset or any of its
variables/coords."""
variables = {k: v.drop_encoding() for k, v in self.variables.items()}
return self._replace(variables=variables, encoding={})
@property
def dims(self) -> Frozen[Hashable, int]:
"""Mapping from dimension names to lengths.
Cannot be modified directly, but is updated when adding new variables.
Note that type of this object differs from `DataArray.dims`.
See `Dataset.sizes` and `DataArray.sizes` for consistently named
properties. This property will be changed to return a type more consistent with
`DataArray.dims` in the future, i.e. a set of dimension names.
See Also
--------
Dataset.sizes
DataArray.dims
"""
return FrozenMappingWarningOnValuesAccess(self._dims)
@property
def sizes(self) -> Frozen[Hashable, int]:
"""Mapping from dimension names to lengths.
Cannot be modified directly, but is updated when adding new variables.
This is an alias for `Dataset.dims` provided for the benefit of
consistency with `DataArray.sizes`.
See Also
--------
DataArray.sizes
"""
return Frozen(self._dims)
@property
def dtypes(self) -> Frozen[Hashable, np.dtype]:
"""Mapping from data variable names to dtypes.
Cannot be modified directly, but is updated when adding new variables.
See Also
--------
DataArray.dtype
"""
return Frozen(
{
n: v.dtype
for n, v in self._variables.items()
if n not in self._coord_names
}
)
def load(self, **kwargs) -> Self:
"""Manually trigger loading and/or computation of this dataset's data
from disk or a remote source into memory and return this dataset.
Unlike compute, the original dataset is modified and returned.
Normally, it should not be necessary to call this method in user code,
because all xarray functions should either work on deferred data or
load data automatically. However, this method can be necessary when
working with many file objects on disk.
Parameters
----------
**kwargs : dict
Additional keyword arguments passed on to ``dask.compute``.
See Also
--------
dask.compute
"""
# access .data to coerce everything to numpy or dask arrays
lazy_data = {
k: v._data for k, v in self.variables.items() if is_chunked_array(v._data)
}
if lazy_data:
chunkmanager = get_chunked_array_type(*lazy_data.values())
# evaluate all the chunked arrays simultaneously
evaluated_data: tuple[np.ndarray[Any, Any], ...] = chunkmanager.compute(
*lazy_data.values(), **kwargs
)
for k, data in zip(lazy_data, evaluated_data, strict=False):
self.variables[k].data = data
# load everything else sequentially
for k, v in self.variables.items():
if k not in lazy_data:
v.load()
return self
def __dask_tokenize__(self) -> object:
from dask.base import normalize_token
return normalize_token(
(type(self), self._variables, self._coord_names, self._attrs or None)
)
def __dask_graph__(self):
graphs = {k: v.__dask_graph__() for k, v in self.variables.items()}
graphs = {k: v for k, v in graphs.items() if v is not None}
if not graphs:
return None
else:
try:
from dask.highlevelgraph import HighLevelGraph
return HighLevelGraph.merge(*graphs.values())
except ImportError:
from dask import sharedict
return sharedict.merge(*graphs.values())
def __dask_keys__(self):
import dask
return [
v.__dask_keys__()
for v in self.variables.values()
if dask.is_dask_collection(v)
]
def __dask_layers__(self):
import dask
return sum(
(
v.__dask_layers__()
for v in self.variables.values()
if dask.is_dask_collection(v)
),
(),
)
@property
def __dask_optimize__(self):
import dask.array as da
return da.Array.__dask_optimize__
@property
def __dask_scheduler__(self):
import dask.array as da
return da.Array.__dask_scheduler__
def __dask_postcompute__(self):
return self._dask_postcompute, ()
def __dask_postpersist__(self):
return self._dask_postpersist, ()
def _dask_postcompute(self, results: Iterable[Variable]) -> Self:
import dask
variables = {}
results_iter = iter(results)
for k, v in self._variables.items():
if dask.is_dask_collection(v):
rebuild, args = v.__dask_postcompute__()
v = rebuild(next(results_iter), *args)
variables[k] = v
return type(self)._construct_direct(
variables,
self._coord_names,
self._dims,
self._attrs,
self._indexes,
self._encoding,
self._close,
)
def _dask_postpersist(
self, dsk: Mapping, *, rename: Mapping[str, str] | None = None
) -> Self:
from dask import is_dask_collection
from dask.highlevelgraph import HighLevelGraph
from dask.optimization import cull
variables = {}
for k, v in self._variables.items():
if not is_dask_collection(v):
variables[k] = v
continue
if isinstance(dsk, HighLevelGraph):
# dask >= 2021.3
# __dask_postpersist__() was called by dask.highlevelgraph.
# Don't use dsk.cull(), as we need to prevent partial layers:
# https://github.com/dask/dask/issues/7137
layers = v.__dask_layers__()
if rename:
layers = [rename.get(k, k) for k in layers]
dsk2 = dsk.cull_layers(layers)
elif rename: # pragma: nocover
# At the moment of writing, this is only for forward compatibility.
# replace_name_in_key requires dask >= 2021.3.
from dask.base import flatten, replace_name_in_key
keys = [
replace_name_in_key(k, rename) for k in flatten(v.__dask_keys__())
]