Skip to content

Commit 8e3b885

Browse files
committed
ci: fix type
1 parent 4ab5d7f commit 8e3b885

File tree

10 files changed

+66
-55
lines changed

10 files changed

+66
-55
lines changed

ddtrace/appsec/_iast/_ast/visitor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def _merge_dicts(*args_functions: Set[str]) -> Set[str]:
227227

228228
@staticmethod
229229
def _is_string_node(node: Any) -> bool:
230-
if PY3 and (isinstance(node, ast.Constant) and isinstance(node.value, (str, bytes, bytearray))):
230+
if PY3 and (isinstance(node, ast.Constant) and isinstance(node.value, IAST.TEXT_TYPES)):
231231
return True
232232

233233
return False

ddtrace/appsec/_iast/_handlers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from wrapt import when_imported
55
from wrapt import wrap_function_wrapper as _w
66

7+
from ddtrace.appsec._constants import IAST
78
from ddtrace.appsec._iast._iast_request_context import get_iast_stacktrace_reported
89
from ddtrace.appsec._iast._iast_request_context import set_iast_stacktrace_reported
910
from ddtrace.appsec._iast._logs import iast_instrumentation_wrapt_debug_log
@@ -249,7 +250,7 @@ def _on_django_func_wrapped(fn_args, fn_kwargs, first_arg_expected_type, *_):
249250

250251
def _custom_protobuf_getattribute(self, name):
251252
ret = type(self).__saved_getattr(self, name)
252-
if isinstance(ret, (str, bytes, bytearray)):
253+
if isinstance(ret, IAST.TEXT_TYPES):
253254
ret = taint_pyobject(
254255
pyobject=ret,
255256
source_name=OriginType.GRPC_BODY,

ddtrace/appsec/_iast/_patches/json_tainting.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from ddtrace.internal.logger import get_logger
55
from ddtrace.settings.asm import config as asm_config
66

7+
from ..._constants import IAST
78
from .._patch import set_and_check_module_is_patched
89
from .._patch import set_module_unpatched
910
from .._patch import try_wrap_function_wrapper
@@ -55,7 +56,7 @@ def wrapped_loads(wrapped, instance, args, kwargs):
5556
obj = taint_structure(obj, source.origin, source.origin)
5657
elif isinstance(obj, list):
5758
obj = taint_structure(obj, source.origin, source.origin)
58-
elif isinstance(obj, (str, bytes, bytearray)):
59+
elif isinstance(obj, IAST.TEXT_TYPES):
5960
obj = taint_pyobject(obj, source.name, source.value, source.origin)
6061
except Exception:
6162
log.debug("Unexpected exception while reporting vulnerability", exc_info=True)

ddtrace/appsec/_iast/_taint_tracking/aspects.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,7 @@ def format_value_aspect(
435435
element: Any,
436436
options: int = 0,
437437
format_spec: Optional[str] = None,
438-
) -> Union[str, bytes, bytearray]:
438+
) -> TEXT_TYPES:
439439
if options == 115:
440440
new_text = str_aspect(str, 0, element)
441441
elif options == 114:

ddtrace/appsec/_iast/_taint_utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Optional
66
from typing import Union
77

8+
from ddtrace.appsec._constants import IAST
89
from ddtrace.appsec._iast._taint_tracking._taint_objects import is_pyobject_tainted
910
from ddtrace.appsec._iast._taint_tracking._taint_objects import taint_pyobject
1011
from ddtrace.internal.logger import get_logger
@@ -101,7 +102,7 @@ def taint_structure(main_obj, source_key, source_value, override_pyobject_tainte
101102
if command.pre: # first processing of the object
102103
if not command.obj:
103104
command.store(command.obj)
104-
elif isinstance(command.obj, (str, bytes, bytearray)):
105+
elif isinstance(command.obj, IAST.TEXT_TYPES):
105106
if override_pyobject_tainted or not is_pyobject_tainted(command.obj):
106107
new_obj = taint_pyobject(
107108
pyobject=command.obj,
@@ -161,7 +162,7 @@ def __init__(self, original_list, origins=(0, 0), override_pyobject_tainted=Fals
161162

162163
def _taint(self, value):
163164
if value:
164-
if isinstance(value, (str, bytes, bytearray)):
165+
if isinstance(value, IAST.TEXT_TYPES):
165166
if not is_pyobject_tainted(value) or self._override_pyobject_tainted:
166167
try:
167168
# TODO: migrate this part to shift ranges instead of creating a new one
@@ -342,7 +343,7 @@ def _taint(self, value, key, origin=None):
342343
if origin is None:
343344
origin = self._origin_value
344345
if value:
345-
if isinstance(value, (str, bytes, bytearray)):
346+
if isinstance(value, IAST.TEXT_TYPES):
346347
if not is_pyobject_tainted(value) or self._override_pyobject_tainted:
347348
try:
348349
# TODO: migrate this part to shift ranges instead of creating a new one

ddtrace/appsec/_iast/taint_sinks/_base.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
from typing import Any
33
from typing import Callable
44
from typing import Optional
5-
from typing import Text
65
from typing import Tuple
6+
from typing import Union
77

88
from ddtrace.appsec._deduplications import deduplication
99
from ddtrace.appsec._iast._taint_tracking import get_ranges
@@ -12,6 +12,7 @@
1212
from ddtrace.settings.asm import config as asm_config
1313
from ddtrace.trace import tracer
1414

15+
from ..._constants import IAST
1516
from .._iast_request_context import get_iast_reporter
1617
from .._iast_request_context import set_iast_reporter
1718
from .._overhead_control_engine import Operation
@@ -27,6 +28,8 @@
2728

2829
CWD = os.path.abspath(os.getcwd())
2930

31+
TEXT_TYPES = Union[str, bytes, bytearray]
32+
3033

3134
class taint_sink_deduplication(deduplication):
3235
def _check_deduplication(self):
@@ -77,12 +80,12 @@ def wrapper(wrapped: Callable, instance: Any, args: Any, kwargs: Any) -> Any:
7780
@taint_sink_deduplication
7881
def _prepare_report(
7982
cls,
80-
vulnerability_type: Text,
83+
vulnerability_type: str,
8184
evidence: Evidence,
82-
file_name: Optional[Text],
85+
file_name: Optional[str],
8386
line_number: int,
84-
function_name: Text = "",
85-
class_name: Text = "",
87+
function_name: str = "",
88+
class_name: str = "",
8689
*args,
8790
**kwargs,
8891
) -> bool:
@@ -126,7 +129,7 @@ def _prepare_report(
126129
return True
127130

128131
@classmethod
129-
def _compute_file_line(cls) -> Tuple[Optional[Text], Optional[int], Optional[Text], Optional[Text]]:
132+
def _compute_file_line(cls) -> Tuple[Optional[str], Optional[int], Optional[str], Optional[str]]:
130133
file_name = line_number = function_name = class_name = None
131134

132135
frame_info = get_info_frame(CWD)
@@ -146,18 +149,18 @@ def _compute_file_line(cls) -> Tuple[Optional[Text], Optional[int], Optional[Tex
146149
@classmethod
147150
def _create_evidence_and_report(
148151
cls,
149-
vulnerability_type: Text,
150-
evidence_value: Text = "",
151-
dialect: Optional[Text] = None,
152-
file_name: Optional[Text] = None,
152+
vulnerability_type: str,
153+
evidence_value: TEXT_TYPES = "",
154+
dialect: Optional[str] = None,
155+
file_name: Optional[str] = None,
153156
line_number: Optional[int] = None,
154-
function_name: Optional[Text] = None,
155-
class_name: Optional[Text] = None,
157+
function_name: Optional[str] = None,
158+
class_name: Optional[str] = None,
156159
*args,
157160
**kwargs,
158161
):
159-
if isinstance(evidence_value, (str, bytes, bytearray)):
160-
if isinstance(evidence_value, bytearray):
162+
if isinstance(evidence_value, IAST.TEXT_TYPES):
163+
if isinstance(evidence_value, (bytes, bytearray)):
161164
evidence_value = evidence_value.decode("utf-8")
162165
evidence = Evidence(value=evidence_value, dialect=dialect)
163166
else:
@@ -168,7 +171,7 @@ def _create_evidence_and_report(
168171
)
169172

170173
@classmethod
171-
def report(cls, evidence_value: Text = "", dialect: Optional[Text] = None) -> None:
174+
def report(cls, evidence_value: TEXT_TYPES = "", dialect: Optional[str] = None) -> None:
172175
"""Build a IastSpanReporter instance to report it in the `AppSecIastSpanProcessor` as a string JSON"""
173176
if cls.acquire_quota():
174177
file_name = line_number = function_name = class_name = None
@@ -192,7 +195,7 @@ def report(cls, evidence_value: Text = "", dialect: Optional[Text] = None) -> No
192195
cls.increment_quota()
193196

194197
@classmethod
195-
def is_tainted_pyobject(cls, string_to_check: Text) -> bool:
198+
def is_tainted_pyobject(cls, string_to_check: TEXT_TYPES) -> bool:
196199
"""Check if a string contains tainted ranges that are not marked as secure.
197200
198201
A string is considered tainted when:

tests/appsec/iast/aspects/aspect_utils.py

Lines changed: 1 addition & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,46 +3,17 @@
33
from typing import List
44
from typing import NamedTuple
55
from typing import Optional
6-
from typing import Union
7-
8-
from hypothesis.strategies import binary
9-
from hypothesis.strategies import builds
10-
from hypothesis.strategies import text
116

127
from ddtrace.appsec._iast._taint_tracking import OriginType
138
from ddtrace.appsec._iast._taint_tracking import Source
149
from ddtrace.appsec._iast._taint_tracking import TaintRange
1510
from ddtrace.appsec._iast._taint_tracking import as_formatted_evidence
1611
from ddtrace.appsec._iast._taint_tracking import set_ranges
1712
from ddtrace.appsec._iast._taint_tracking._taint_objects import taint_pyobject_with_ranges
13+
from tests.appsec.iast.iast_utils import TEXT_TYPE
1814
from tests.appsec.iast.iast_utils import _iast_patched_module
1915

2016

21-
TEXT_TYPE = Union[str, bytes, bytearray]
22-
23-
24-
class CustomStr(str):
25-
pass
26-
27-
28-
class CustomBytes(bytes):
29-
pass
30-
31-
32-
class CustomBytearray(bytearray):
33-
pass
34-
35-
36-
string_strategies: List[Any] = [
37-
text(), # regular str
38-
binary(), # regular bytes
39-
builds(bytearray, binary()), # regular bytearray
40-
builds(CustomStr, text()), # custom str subclass
41-
builds(CustomBytes, binary()), # custom bytes subclass
42-
builds(CustomBytearray, binary()), # custom bytearray subclass
43-
]
44-
45-
4617
mod = _iast_patched_module("benchmarks.bm.iast_fixtures.str_methods")
4718

4819
EscapeContext = NamedTuple("EscapeContext", [("id", Any), ("position", int)])

tests/appsec/iast/aspects/test_add_aspect.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
from ddtrace.appsec._iast._taint_tracking._taint_objects import taint_pyobject
1717
import ddtrace.appsec._iast._taint_tracking.aspects as ddtrace_aspects
1818
from ddtrace.appsec._iast._taint_tracking.aspects import add_aspect
19-
from tests.appsec.iast.aspects.aspect_utils import string_strategies
2019
from tests.appsec.iast.conftest import _end_iast_context_and_oce
2120
from tests.appsec.iast.conftest import _start_iast_context_and_oce
21+
from tests.appsec.iast.iast_utils import string_strategies
2222
from tests.utils import override_env
2323
from tests.utils import override_global_config
2424

tests/appsec/iast/aspects/test_common_aspects.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
from hypothesis.strategies import one_of
99
import pytest
1010

11-
from tests.appsec.iast.aspects.aspect_utils import string_strategies
1211
import tests.appsec.iast.fixtures.aspects.callees
1312
from tests.appsec.iast.iast_utils import _iast_patched_module
13+
from tests.appsec.iast.iast_utils import string_strategies
1414

1515

1616
def generate_callers_from_callees(callees_module, callers_file="", callees_module_str=""):

tests/appsec/iast/iast_utils.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
11
import importlib
22
import re
33
import types
4+
from typing import Any
5+
from typing import List
46
from typing import Optional
57
from typing import Text
8+
from typing import Union
69
import zlib
710

11+
from hypothesis.strategies import binary
12+
from hypothesis.strategies import builds
13+
from hypothesis.strategies import text
14+
815
from ddtrace.appsec._constants import IAST
916
from ddtrace.appsec._iast._ast.ast_patching import astpatch_module
1017
from ddtrace.appsec._iast._ast.ast_patching import iastpatch
@@ -60,3 +67,30 @@ def _iast_patched_module(module_name, new_module_object=False):
6067
else:
6168
raise IastTestException(f"IAST Test Error: module {module_name} was excluded: {res}")
6269
return module
70+
71+
72+
TEXT_TYPE = Union[str, bytes, bytearray]
73+
74+
75+
class CustomStr(str):
76+
pass
77+
78+
79+
class CustomBytes(bytes):
80+
pass
81+
82+
83+
class CustomBytearray(bytearray):
84+
pass
85+
86+
87+
non_empty_text = text().filter(lambda x: x not in ("", "0", "1", "2"))
88+
89+
string_strategies: List[Any] = [
90+
text(), # regular str
91+
binary(), # regular bytes
92+
builds(bytearray, binary()), # regular bytearray
93+
builds(CustomStr, text()), # custom str subclass
94+
builds(CustomBytes, binary()), # custom bytes subclass
95+
builds(CustomBytearray, binary()), # custom bytearray subclass
96+
]

0 commit comments

Comments
 (0)