Skip to content

Commit 0fc8006

Browse files
Update RTA common.py for py3 (#2287)
* add run-all argument and initial p2 conversion * remove unicode * format with black
1 parent 3ba777c commit 0fc8006

File tree

3 files changed

+143
-94
lines changed

3 files changed

+143
-94
lines changed

rta/__init__.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -3,38 +3,38 @@
33
# 2.0; you may not use this file except in compliance with the Elastic License
44
# 2.0.
55

6-
import glob
76
import importlib
8-
import os
7+
from pathlib import Path
8+
from typing import List, Optional
99

1010
from . import common
1111

12-
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
12+
CURRENT_DIR = Path(__file__).resolve().parent
1313

1414

15-
def get_ttp_list(os_types=None):
15+
def get_ttp_list(os_types: Optional[List[str]] = None) -> List[str]:
1616
scripts = []
1717
if os_types and not isinstance(os_types, (list, tuple)):
1818
os_types = [os_types]
1919

20-
for script in sorted(glob.glob(os.path.join(CURRENT_DIR, "*.py"))):
21-
base_name, _ = os.path.splitext(os.path.basename(script))
20+
for script in CURRENT_DIR.glob("*.py"):
21+
base_name = script.stem
2222
if base_name not in ("common", "main") and not base_name.startswith("_"):
2323
if os_types:
2424
# Import it and skip it if it's not supported
2525
importlib.import_module(__name__ + "." + base_name)
2626
if not any(base_name in common.OS_MAPPING[os_type] for os_type in os_types):
2727
continue
2828

29-
scripts.append(script)
29+
scripts.append(str(script))
3030

3131
return scripts
3232

3333

34-
def get_ttp_names(os_types=None):
34+
def get_ttp_names(os_types: Optional[List[str]] = None) -> List[str]:
3535
names = []
3636
for script in get_ttp_list(os_types):
37-
basename, ext = os.path.splitext(os.path.basename(script))
37+
basename = Path(script).stem
3838
names.append(basename)
3939
return names
4040

rta/__main__.py

+49-10
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,57 @@
55

66
import argparse
77
import importlib
8-
import os
8+
import subprocess
9+
import sys
10+
import time
11+
from pathlib import Path
912

10-
from . import get_ttp_names
13+
from . import get_ttp_list, get_ttp_names
14+
from .common import CURRENT_OS
1115

12-
parser = argparse.ArgumentParser("rta")
13-
parser.add_argument("ttp_name")
1416

15-
parsed_args, remaining = parser.parse_known_args()
16-
ttp_name, _ = os.path.splitext(os.path.basename(parsed_args.ttp_name))
17+
DELAY = 1
1718

18-
if ttp_name not in get_ttp_names():
19-
raise ValueError("Unknown RTA {}".format(ttp_name))
2019

21-
module = importlib.import_module("rta." + ttp_name)
22-
exit(module.main(*remaining))
20+
def run_all():
21+
"""Run a single RTA."""
22+
errors = []
23+
for ttp_file in get_ttp_list(CURRENT_OS):
24+
print(f"---- {Path(ttp_file).name} ----")
25+
p = subprocess.Popen([sys.executable, ttp_file])
26+
p.wait()
27+
code = p.returncode
28+
29+
if p.returncode:
30+
errors.append((ttp_file, code))
31+
32+
time.sleep(DELAY)
33+
print("")
34+
35+
return len(errors)
36+
37+
38+
def run(ttp_name: str, *args):
39+
"""Run all RTAs compatible with OS."""
40+
if ttp_name not in get_ttp_names():
41+
raise ValueError(f"Unknown RTA {ttp_name}")
42+
43+
module = importlib.import_module("rta." + ttp_name)
44+
return module.main(*args)
45+
46+
47+
if __name__ == '__main__':
48+
parser = argparse.ArgumentParser("rta")
49+
parser.add_argument("--ttp-name")
50+
parser.add_argument("--run-all", action="store_true")
51+
parser.add_argument("--delay", type=int, help="For run-all, the delay between executions")
52+
parsed_args, remaining = parser.parse_known_args()
53+
54+
if parsed_args.ttp_name and parsed_args.run_all:
55+
raise ValueError(f"Pass --ttp-name or --run-all, not both")
56+
57+
if parsed_args.run_all:
58+
exit(run_all())
59+
else:
60+
rta_name = Path(parsed_args.run).stem
61+
exit(run(rta_name, *remaining))

0 commit comments

Comments
 (0)