|
1 | 1 | import asyncio
|
2 |
| - |
3 | 2 | import logging
|
4 | 3 | import os
|
5 | 4 | import shutil
|
6 |
| -import time |
| 5 | +import sys |
| 6 | +import tarfile |
| 7 | +import tempfile |
7 | 8 | import zipfile
|
8 | 9 | from pathlib import Path
|
9 | 10 |
|
10 | 11 | from simcore_sdk import node_ports
|
11 | 12 |
|
12 |
| -log = logging.getLogger(__name__) |
| 13 | +logger = logging.getLogger(__name__) |
13 | 14 |
|
14 |
| -_INPUT_PATH = Path(os.environ.get("RAWGRAPHS_INPUT_PATH")) |
| 15 | +_INPUTS_FOLDER = Path(os.environ.get("RAWGRAPHS_INPUT_PATH")) |
| 16 | +_OUTPUTS_FOLDER = Path(os.environ.get("RAWGRAPHS_OUTPUT_PATH")) |
| 17 | +_FILE_TYPE_PREFIX = "data:" |
| 18 | +_KEY_VALUE_FILE_NAME = "key_values.json" |
15 | 19 |
|
16 | 20 | # clean the directory
|
17 |
| -shutil.rmtree(str(_INPUT_PATH), ignore_errors=True) |
| 21 | +shutil.rmtree(str(_INPUTS_FOLDER), ignore_errors=True) |
| 22 | + |
| 23 | +if not _INPUTS_FOLDER.exists(): |
| 24 | + _INPUTS_FOLDER.mkdir() |
| 25 | + logger.debug("Created input folder at %s", _INPUTS_FOLDER) |
| 26 | + |
| 27 | +if not _OUTPUTS_FOLDER.exists(): |
| 28 | + _OUTPUTS_FOLDER.mkdir() |
| 29 | + logger.debug("Created output folder at %s", _OUTPUTS_FOLDER) |
| 30 | + |
| 31 | +def _no_relative_path_tar(members: tarfile.TarFile): |
| 32 | + for tarinfo in members: |
| 33 | + path = Path(tarinfo.name) |
| 34 | + if path.is_absolute(): |
| 35 | + # absolute path are not allowed |
| 36 | + continue |
| 37 | + if path.match("/../"): |
| 38 | + # relative paths are not allowed |
| 39 | + continue |
| 40 | + yield tarinfo |
18 | 41 |
|
19 |
| -if not _INPUT_PATH.exists(): |
20 |
| - _INPUT_PATH.mkdir() |
21 |
| - log.debug("Created input folder at %s", _INPUT_PATH) |
| 42 | +def _no_relative_path_zip(members: zipfile.ZipFile): |
| 43 | + for zipinfo in members.infolist(): |
| 44 | + path = Path(zipinfo.filename) |
| 45 | + if path.is_absolute(): |
| 46 | + # absolute path are not allowed |
| 47 | + continue |
| 48 | + if path.match("/../"): |
| 49 | + # relative paths are not allowed |
| 50 | + continue |
| 51 | + yield zipinfo |
22 | 52 |
|
23 |
| -async def retrieve_data(): |
24 |
| - log.debug("retrieving data...") |
25 |
| - print("retrieving data...") |
| 53 | +async def download_data(): |
| 54 | + logger.info("retrieving data from simcore...") |
| 55 | + print("retrieving data from simcore...") |
26 | 56 |
|
27 | 57 | # get all files in the local system and copy them to the input folder
|
28 |
| - start_time = time.time() |
29 | 58 | PORTS = node_ports.ports()
|
30 |
| - download_tasks = [] |
31 |
| - for node_input in PORTS.inputs: |
32 |
| - if not node_input or node_input.value is None: |
| 59 | + for port in PORTS.inputs: |
| 60 | + if not port or port.value is None: |
33 | 61 | continue
|
34 |
| - |
35 |
| - # collect coroutines |
36 |
| - download_tasks.append(node_input.get()) |
37 |
| - if download_tasks: |
38 |
| - downloaded_files = await asyncio.gather(*download_tasks) |
39 |
| - print("downloaded {} files /tmp <br>".format(len(download_tasks))) |
40 |
| - for local_path in downloaded_files: |
41 |
| - if local_path is None: |
42 |
| - continue |
43 |
| - # log.debug("Completed download of %s in local path %s", node_input.value, local_path) |
44 |
| - if local_path.exists(): |
45 |
| - if zipfile.is_zipfile(str(local_path)): |
46 |
| - zip_ref = zipfile.ZipFile(str(local_path), 'r') |
47 |
| - zip_ref.extractall(str(_INPUT_PATH)) |
48 |
| - zip_ref.close() |
49 |
| - log.debug("Unzipped") |
50 |
| - print("unzipped {file} to {path}<br>".format(file=str(local_path), path=str(_INPUT_PATH))) |
51 |
| - else: |
52 |
| - log.debug("Start moving %s to input path %s", local_path, _INPUT_PATH) |
53 |
| - shutil.move(str(local_path), str(_INPUT_PATH / local_path.name)) |
54 |
| - log.debug("Move completed") |
55 |
| - print("moved {file} to {path}<br>".format(file=str(local_path), path=str(_INPUT_PATH))) |
56 |
| - end_time = time.time() |
57 |
| - print("time to download: {} seconds".format(end_time - start_time)) |
58 |
| - |
59 |
| -asyncio.get_event_loop().run_until_complete(retrieve_data()) |
| 62 | + |
| 63 | + local_path = await port.get() |
| 64 | + dest_path = _INPUTS_FOLDER / port.key |
| 65 | + dest_path.mkdir(exist_ok=True, parents=True) |
| 66 | + |
| 67 | + # clean up destination directory |
| 68 | + for path in dest_path.iterdir(): |
| 69 | + if path.is_file(): |
| 70 | + path.unlink() |
| 71 | + elif path.is_dir(): |
| 72 | + shutil.rmtree(path) |
| 73 | + # check if local_path is a compressed file |
| 74 | + if tarfile.is_tarfile(local_path): |
| 75 | + with tarfile.open(local_path) as tar_file: |
| 76 | + tar_file.extractall(dest_path, members=_no_relative_path_tar(tar_file)) |
| 77 | + elif zipfile.is_zipfile(local_path): |
| 78 | + with zipfile.ZipFile(local_path) as zip_file: |
| 79 | + zip_file.extractall(dest_path, members=_no_relative_path_zip(zip_file)) |
| 80 | + else: |
| 81 | + dest_path_name = _INPUTS_FOLDER / (port.key + ":" + Path(local_path).name) |
| 82 | + shutil.move(local_path, dest_path_name) |
| 83 | + shutil.rmtree(Path(local_path).parents[0]) |
| 84 | + |
| 85 | +async def upload_data(): |
| 86 | + logger.info("uploading data to simcore...") |
| 87 | + PORTS = node_ports.ports() |
| 88 | + outputs_path = Path(_OUTPUTS_FOLDER).expanduser() |
| 89 | + for port in PORTS.outputs: |
| 90 | + logger.debug("uploading data to port '%s' with value '%s'...", port.key, port.value) |
| 91 | + src_folder = outputs_path / port.key |
| 92 | + list_files = list(src_folder.glob("*")) |
| 93 | + if len(list_files) == 1: |
| 94 | + # special case, direct upload |
| 95 | + await port.set(list_files[0]) |
| 96 | + continue |
| 97 | + # generic case let's create an archive |
| 98 | + if len(list_files) > 1: |
| 99 | + temp_file = tempfile.NamedTemporaryFile(suffix=".tgz") |
| 100 | + temp_file.close() |
| 101 | + for _file in list_files: |
| 102 | + with tarfile.open(temp_file.name, mode='w:gz') as tar_ptr: |
| 103 | + for file_path in list_files: |
| 104 | + tar_ptr.add(file_path, arcname=file_path.name, recursive=False) |
| 105 | + try: |
| 106 | + await port.set(temp_file.name) |
| 107 | + finally: |
| 108 | + #clean up |
| 109 | + Path(temp_file.name).unlink() |
| 110 | + |
| 111 | + logger.info("all data uploaded to simcore") |
| 112 | + |
| 113 | +async def sync_data(): |
| 114 | + try: |
| 115 | + await download_data() |
| 116 | + await upload_data() |
| 117 | + # self.set_status(200) |
| 118 | + except node_ports.exceptions.NodeportsException as exc: |
| 119 | + # self.set_status(500, reason=str(exc)) |
| 120 | + logger.error("error when syncing '%s'", str(exc)) |
| 121 | + sys.exit(1) |
| 122 | + finally: |
| 123 | + # self.finish('completed retrieve!') |
| 124 | + logger.info("download and upload finished") |
| 125 | + |
| 126 | +asyncio.get_event_loop().run_until_complete(sync_data()) |
0 commit comments