49
49
# string used by netCDF4.
50
50
_endian_lookup = {"=" : "native" , ">" : "big" , "<" : "little" , "|" : "native" }
51
51
52
-
53
52
NETCDF4_PYTHON_LOCK = combine_locks ([NETCDFC_LOCK , HDF5_LOCK ])
54
53
55
54
@@ -141,7 +140,9 @@ def _check_encoding_dtype_is_vlen_string(dtype):
141
140
)
142
141
143
142
144
- def _get_datatype (var , nc_format = "NETCDF4" , raise_on_invalid_encoding = False ):
143
+ def _get_datatype (
144
+ var , nc_format = "NETCDF4" , raise_on_invalid_encoding = False
145
+ ) -> np .dtype :
145
146
if nc_format == "NETCDF4" :
146
147
return _nc4_dtype (var )
147
148
if "dtype" in var .encoding :
@@ -234,13 +235,13 @@ def _force_native_endianness(var):
234
235
235
236
236
237
def _extract_nc4_variable_encoding (
237
- variable ,
238
+ variable : Variable ,
238
239
raise_on_invalid = False ,
239
240
lsd_okay = True ,
240
241
h5py_okay = False ,
241
242
backend = "netCDF4" ,
242
243
unlimited_dims = None ,
243
- ):
244
+ ) -> dict [ str , Any ] :
244
245
if unlimited_dims is None :
245
246
unlimited_dims = ()
246
247
@@ -308,7 +309,7 @@ def _extract_nc4_variable_encoding(
308
309
return encoding
309
310
310
311
311
- def _is_list_of_strings (value ):
312
+ def _is_list_of_strings (value ) -> bool :
312
313
arr = np .asarray (value )
313
314
return arr .dtype .kind in ["U" , "S" ] and arr .size > 1
314
315
@@ -414,13 +415,25 @@ def _acquire(self, needs_lock=True):
414
415
def ds (self ):
415
416
return self ._acquire ()
416
417
417
- def open_store_variable (self , name , var ):
418
+ def open_store_variable (self , name : str , var ):
419
+ import netCDF4
420
+
418
421
dimensions = var .dimensions
419
- data = indexing .LazilyIndexedArray (NetCDF4ArrayWrapper (name , self ))
420
422
attributes = {k : var .getncattr (k ) for k in var .ncattrs ()}
423
+ data = indexing .LazilyIndexedArray (NetCDF4ArrayWrapper (name , self ))
424
+ encoding : dict [str , Any ] = {}
425
+ if isinstance (var .datatype , netCDF4 .EnumType ):
426
+ encoding ["dtype" ] = np .dtype (
427
+ data .dtype ,
428
+ metadata = {
429
+ "enum" : var .datatype .enum_dict ,
430
+ "enum_name" : var .datatype .name ,
431
+ },
432
+ )
433
+ else :
434
+ encoding ["dtype" ] = var .dtype
421
435
_ensure_fill_value_valid (data , attributes )
422
436
# netCDF4 specific encoding; save _FillValue for later
423
- encoding = {}
424
437
filters = var .filters ()
425
438
if filters is not None :
426
439
encoding .update (filters )
@@ -440,7 +453,6 @@ def open_store_variable(self, name, var):
440
453
# save source so __repr__ can detect if it's local or not
441
454
encoding ["source" ] = self ._filename
442
455
encoding ["original_shape" ] = var .shape
443
- encoding ["dtype" ] = var .dtype
444
456
445
457
return Variable (dimensions , data , attributes , encoding )
446
458
@@ -485,21 +497,24 @@ def encode_variable(self, variable):
485
497
return variable
486
498
487
499
def prepare_variable (
488
- self , name , variable , check_encoding = False , unlimited_dims = None
500
+ self , name , variable : Variable , check_encoding = False , unlimited_dims = None
489
501
):
490
502
_ensure_no_forward_slash_in_name (name )
491
-
503
+ attrs = variable .attrs .copy ()
504
+ fill_value = attrs .pop ("_FillValue" , None )
492
505
datatype = _get_datatype (
493
506
variable , self .format , raise_on_invalid_encoding = check_encoding
494
507
)
495
- attrs = variable .attrs .copy ()
496
-
497
- fill_value = attrs .pop ("_FillValue" , None )
498
-
508
+ # check enum metadata and use netCDF4.EnumType
509
+ if (
510
+ (meta := np .dtype (datatype ).metadata )
511
+ and (e_name := meta .get ("enum_name" ))
512
+ and (e_dict := meta .get ("enum" ))
513
+ ):
514
+ datatype = self ._build_and_get_enum (name , datatype , e_name , e_dict )
499
515
encoding = _extract_nc4_variable_encoding (
500
516
variable , raise_on_invalid = check_encoding , unlimited_dims = unlimited_dims
501
517
)
502
-
503
518
if name in self .ds .variables :
504
519
nc4_var = self .ds .variables [name ]
505
520
else :
@@ -527,6 +542,33 @@ def prepare_variable(
527
542
528
543
return target , variable .data
529
544
545
+ def _build_and_get_enum (
546
+ self , var_name : str , dtype : np .dtype , enum_name : str , enum_dict : dict [str , int ]
547
+ ) -> Any :
548
+ """
549
+ Add or get the netCDF4 Enum based on the dtype in encoding.
550
+ The return type should be ``netCDF4.EnumType``,
551
+ but we avoid importing netCDF4 globally for performances.
552
+ """
553
+ if enum_name not in self .ds .enumtypes :
554
+ return self .ds .createEnumType (
555
+ dtype ,
556
+ enum_name ,
557
+ enum_dict ,
558
+ )
559
+ datatype = self .ds .enumtypes [enum_name ]
560
+ if datatype .enum_dict != enum_dict :
561
+ error_msg = (
562
+ f"Cannot save variable `{ var_name } ` because an enum"
563
+ f" `{ enum_name } ` already exists in the Dataset but have"
564
+ " a different definition. To fix this error, make sure"
565
+ " each variable have a uniquely named enum in their"
566
+ " `encoding['dtype'].metadata` or, if they should share"
567
+ " the same enum type, make sure the enums are identical."
568
+ )
569
+ raise ValueError (error_msg )
570
+ return datatype
571
+
530
572
def sync (self ):
531
573
self .ds .sync ()
532
574
0 commit comments