Skip to content

Commit 93c5bc5

Browse files
authored
Merge pull request #850 from deepmodeling/zjgemi
fix: add HDF5Datasets type Artifact
2 parents fa3201f + 28e855a commit 93c5bc5

File tree

4 files changed

+138
-24
lines changed

4 files changed

+138
-24
lines changed

src/dflow/python/__init__.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from .op import OP
2-
from .opio import OPIO, Artifact, BigParameter, OPIOSign, Parameter, NestedDict
2+
from .opio import (OPIO, Artifact, BigParameter, HDF5Datasets, NestedDict,
3+
OPIOSign, Parameter)
34
from .python_op_template import (FatalError, PythonOPTemplate, Slices,
45
TransientError, upload_packages)
56

67
__all__ = ["OP", "OPIO", "Artifact", "BigParameter", "OPIOSign", "Parameter",
78
"FatalError", "PythonOPTemplate", "Slices", "TransientError",
8-
"upload_packages", "NestedDict"]
9+
"upload_packages", "NestedDict", "HDF5Datasets"]

src/dflow/python/op.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
from ..io import (InputArtifact, InputParameter, OutputArtifact,
2020
OutputParameter, type_to_str)
2121
from ..utils import dict2list, get_key, randstr, s3_config
22-
from .vendor.typeguard import check_type
2322
from .opio import OPIO, Artifact, BigParameter, OPIOSign, Parameter
23+
from .vendor.typeguard import check_type
2424

2525
iwd = os.getcwd()
2626

@@ -190,6 +190,8 @@ def _check_signature(
190190
ss = Set[Union[str, None]]
191191
elif ss == Set[Path]:
192192
ss = Set[Union[Path, None]]
193+
else:
194+
continue
193195
if isinstance(ss, Parameter):
194196
ss = ss.type
195197
# skip type checking if the variable is None

src/dflow/python/opio.py

+57-11
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
import tarfile
23
from collections.abc import MutableMapping
34
from pathlib import Path
45
from typing import Any, Dict, List, Optional, Set, Union
@@ -8,27 +9,72 @@
89
from ..io import PVC, type_to_str
910

1011

11-
class nested_dict:
12-
def __init__(self, type):
13-
self.type = type
12+
class NestedDict:
13+
pass
14+
15+
16+
class NestedDictStr(NestedDict):
17+
pass
18+
19+
20+
class NestedDictPath(NestedDict):
21+
pass
22+
23+
24+
class HDF5Dataset:
25+
def __init__(self, dataset):
26+
self.dataset = dataset
27+
28+
def get_data(self):
29+
data = self.dataset[()]
30+
if self.dataset.attrs.get("dtype") == "utf-8":
31+
data = data.decode("utf-8")
32+
elif self.dataset.attrs.get("dtype") == "binary":
33+
data = data.tobytes()
34+
return data
35+
36+
def recover(self):
37+
if self.dataset.attrs["type"] == "file":
38+
path = Path(self.dataset.attrs["path"])
39+
if path.is_absolute():
40+
path = path.relative_to(path.root)
41+
path.parent.mkdir(parents=True, exist_ok=True)
42+
data = self.get_data()
43+
if isinstance(data, str):
44+
path.write_text(data)
45+
elif isinstance(data, bytes):
46+
path.write_bytes(data)
47+
return path
48+
elif self.dataset.attrs["type"] == "dir":
49+
path = Path(self.dataset.attrs["path"])
50+
if path.is_absolute():
51+
path = path.relative_to(path.root)
52+
path.parent.mkdir(parents=True, exist_ok=True)
53+
tgz_path = path.parent / (path.name + ".tgz")
54+
tgz_path.write_bytes(self.get_data())
55+
tf = tarfile.open(tgz_path, "r:gz")
56+
tf.extractall(".")
57+
tf.close()
58+
return path
59+
else:
60+
return self.get_data()
1461

15-
def __repr__(self):
16-
return "dflow.python.NestedDict[%s]" % type_to_str(self.type)
1762

18-
def __eq__(self, other):
19-
if not isinstance(other, nested_dict):
20-
return False
21-
return self.type == other.type
63+
class HDF5Datasets:
64+
pass
2265

2366

2467
NestedDict = {
25-
str: nested_dict(str),
26-
Path: nested_dict(Path),
68+
str: NestedDictStr,
69+
Path: NestedDictPath,
2770
}
2871

2972
ArtifactAllowedTypes = [str, Path, Set[str], Set[Path], List[str], List[Path],
3073
Dict[str, str], Dict[str, Path], NestedDict[str],
3174
NestedDict[Path]]
75+
for t in ArtifactAllowedTypes.copy():
76+
ArtifactAllowedTypes.append(Union[t, HDF5Datasets])
77+
ArtifactAllowedTypes.append(HDF5Datasets)
3278

3379

3480
@CustomHandler.handles

src/dflow/python/utils.py

+75-10
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
import os
22
import shutil
33
import signal
4+
import tarfile
45
import traceback
56
import uuid
67
from pathlib import Path
7-
from typing import Dict, List, Set
8+
from typing import Dict, List, Set, Union
89

910
from ..common import jsonpickle
1011
from ..config import config
1112
from ..utils import (artifact_classes, assemble_path_object,
1213
catalog_of_local_artifact, convert_dflow_list, copy_file,
1314
expand, flatten, randstr, remove_empty_dir_tag)
14-
from .opio import Artifact, BigParameter, NestedDict, Parameter
15+
from .opio import (Artifact, BigParameter, HDF5Dataset, HDF5Datasets,
16+
NestedDict, Parameter)
1517

1618

1719
def get_slices(path_object, slices):
@@ -78,7 +80,35 @@ def handle_input_artifact(name, sign, slices=None, data_root="/tmp",
7880

7981
path_object = get_slices(path_object, slices)
8082

81-
if sign.type in [str, Path]:
83+
sign_type = sign.type
84+
if getattr(sign_type, "__origin__", None) == Union:
85+
args = sign_type.__args__
86+
if HDF5Datasets in args:
87+
if isinstance(path_object, list) and all([isinstance(
88+
p, str) and p.endswith(".h5") for p in path_object]):
89+
sign_type = HDF5Datasets
90+
elif args[0] == HDF5Datasets:
91+
sign_type = args[1]
92+
elif args[1] == HDF5Datasets:
93+
sign_type = args[0]
94+
95+
if sign_type == HDF5Datasets:
96+
import h5py
97+
assert isinstance(path_object, list)
98+
res = None
99+
for path in path_object:
100+
f = h5py.File(path, "r")
101+
datasets = {k: HDF5Dataset(f[k]) for k in f.keys()}
102+
datasets = expand(datasets)
103+
if isinstance(datasets, list):
104+
if res is None:
105+
res = []
106+
res += datasets
107+
elif isinstance(datasets, dict):
108+
if res is None:
109+
res = {}
110+
res.update(datasets)
111+
if sign_type in [str, Path]:
82112
if path_object is None or isinstance(path_object, str):
83113
res = path_object
84114
elif isinstance(path_object, list) and len(path_object) == 1 and (
@@ -87,8 +117,8 @@ def handle_input_artifact(name, sign, slices=None, data_root="/tmp",
87117
res = path_object[0]
88118
else:
89119
res = art_path
90-
res = path_or_none(res) if sign.type == Path else res
91-
elif sign.type in [List[str], List[Path], Set[str], Set[Path]]:
120+
res = path_or_none(res) if sign_type == Path else res
121+
elif sign_type in [List[str], List[Path], Set[str], Set[Path]]:
92122
if path_object is None:
93123
return None
94124
elif isinstance(path_object, str):
@@ -99,17 +129,17 @@ def handle_input_artifact(name, sign, slices=None, data_root="/tmp",
99129
else:
100130
res = list(flatten(path_object).values())
101131

102-
if sign.type == List[str]:
132+
if sign_type == List[str]:
103133
pass
104-
elif sign.type == List[Path]:
134+
elif sign_type == List[Path]:
105135
res = path_or_none(res)
106-
elif sign.type == Set[str]:
136+
elif sign_type == Set[str]:
107137
res = set(res)
108138
else:
109139
res = set(path_or_none(res))
110-
elif sign.type in [Dict[str, str], NestedDict[str]]:
140+
elif sign_type in [Dict[str, str], NestedDict[str]]:
111141
res = path_object
112-
elif sign.type in [Dict[str, Path], NestedDict[Path]]:
142+
elif sign_type in [Dict[str, Path], NestedDict[Path]]:
113143
res = path_or_none(path_object)
114144

115145
if res is None:
@@ -169,6 +199,41 @@ def slice_to_dir(slice):
169199
def handle_output_artifact(name, value, sign, slices=None, data_root="/tmp",
170200
create_dir=False):
171201
path_list = []
202+
if sign.type == HDF5Datasets:
203+
import h5py
204+
os.makedirs(data_root + '/outputs/artifacts/' + name, exist_ok=True)
205+
h5_name = "%s.h5" % uuid.uuid4()
206+
h5_path = '%s/outputs/artifacts/%s/%s' % (data_root, name, h5_name)
207+
with h5py.File(h5_path, "w") as f:
208+
for s, v in flatten(value).items():
209+
if isinstance(v, Path):
210+
if v.is_file():
211+
try:
212+
data = v.read_text(encoding="utf-8")
213+
dtype = "utf-8"
214+
except Exception:
215+
import numpy as np
216+
data = np.void(v.read_bytes())
217+
dtype = "binary"
218+
d = f.create_dataset(s, data=data)
219+
d.attrs["type"] = "file"
220+
d.attrs["path"] = str(v)
221+
d.attrs["dtype"] = dtype
222+
elif v.is_dir():
223+
tgz_path = Path("%s.tgz" % v)
224+
tf = tarfile.open(tgz_path, "w:gz", dereference=True)
225+
tf.add(v)
226+
tf.close()
227+
import numpy as np
228+
d = f.create_dataset(s, data=np.void(
229+
tgz_path.read_bytes()))
230+
d.attrs["type"] = "dir"
231+
d.attrs["path"] = str(v)
232+
d.attrs["dtype"] = "binary"
233+
else:
234+
d = f.create_dataset(s, data=v)
235+
d.attrs["type"] = "data"
236+
path_list.append({"dflow_list_item": h5_name, "order": slices or 0})
172237
if sign.type in [str, Path]:
173238
os.makedirs(data_root + '/outputs/artifacts/' + name, exist_ok=True)
174239
if slices is None:

0 commit comments

Comments
 (0)