diff --git a/ddtrace/appsec/_asm_request_context.py b/ddtrace/appsec/_asm_request_context.py index 4bc868b7a26..eda18517bde 100644 --- a/ddtrace/appsec/_asm_request_context.py +++ b/ddtrace/appsec/_asm_request_context.py @@ -436,6 +436,10 @@ def _on_wrapped_view(kwargs): if _is_iast_enabled() and kwargs: from ddtrace.appsec._iast._taint_tracking import OriginType from ddtrace.appsec._iast._taint_tracking import taint_pyobject + from ddtrace.appsec._iast.processor import AppSecIastSpanProcessor + + if not AppSecIastSpanProcessor.is_span_analyzed(): + return return_value _kwargs = {} for k, v in kwargs.items(): @@ -451,9 +455,14 @@ def _on_set_request_tags(request, span, flask_config): from ddtrace.appsec._iast._metrics import _set_metric_iast_instrumented_source from ddtrace.appsec._iast._taint_tracking import OriginType from ddtrace.appsec._iast._taint_utils import taint_structure + from ddtrace.appsec._iast.processor import AppSecIastSpanProcessor _set_metric_iast_instrumented_source(OriginType.COOKIE_NAME) _set_metric_iast_instrumented_source(OriginType.COOKIE) + + if not AppSecIastSpanProcessor.is_span_analyzed(span._local_root or span): + return + request.cookies = taint_structure( request.cookies, OriginType.COOKIE_NAME, diff --git a/ddtrace/appsec/_constants.py b/ddtrace/appsec/_constants.py index dd0512d822d..fc0f4a4abba 100644 --- a/ddtrace/appsec/_constants.py +++ b/ddtrace/appsec/_constants.py @@ -82,6 +82,7 @@ class IAST(metaclass=Constant_Class): PATCH_MODULES = "_DD_IAST_PATCH_MODULES" DENY_MODULES = "_DD_IAST_DENY_MODULES" SEP_MODULES = "," + REQUEST_IAST_ENABLED = "_dd.iast.request_enabled" class IAST_SPAN_TAGS(metaclass=Constant_Class): diff --git a/ddtrace/appsec/_handlers.py b/ddtrace/appsec/_handlers.py index 6439632969d..32efb83722f 100644 --- a/ddtrace/appsec/_handlers.py +++ b/ddtrace/appsec/_handlers.py @@ -190,6 +190,13 @@ def _on_request_init(wrapped, instance, args, kwargs): from ddtrace.appsec._iast._metrics import _set_metric_iast_instrumented_source from ddtrace.appsec._iast._taint_tracking import OriginType from ddtrace.appsec._iast._taint_tracking import taint_pyobject + from ddtrace.appsec._iast.processor import AppSecIastSpanProcessor + + _set_metric_iast_instrumented_source(OriginType.PATH) + _set_metric_iast_instrumented_source(OriginType.QUERY) + + if not AppSecIastSpanProcessor.is_span_analyzed(): + return # TODO: instance.query_string = ?? instance.query_string = taint_pyobject( @@ -204,8 +211,6 @@ def _on_request_init(wrapped, instance, args, kwargs): source_value=instance.path, source_origin=OriginType.PATH, ) - _set_metric_iast_instrumented_source(OriginType.PATH) - _set_metric_iast_instrumented_source(OriginType.QUERY) except Exception: log.debug("Unexpected exception while tainting pyobject", exc_info=True) @@ -269,6 +274,10 @@ def _on_django_func_wrapped(fn_args, fn_kwargs, first_arg_expected_type, *_): from ddtrace.appsec._iast._taint_tracking import is_pyobject_tainted from ddtrace.appsec._iast._taint_tracking import taint_pyobject from ddtrace.appsec._iast._taint_utils import taint_structure + from ddtrace.appsec._iast.processor import AppSecIastSpanProcessor + + if not AppSecIastSpanProcessor.is_span_analyzed(): + return http_req = fn_args[0] @@ -318,6 +327,7 @@ def _on_wsgi_environ(wrapped, _instance, args, kwargs): from ddtrace.appsec._iast._metrics import _set_metric_iast_instrumented_source from ddtrace.appsec._iast._taint_tracking import OriginType # noqa: F401 from ddtrace.appsec._iast._taint_utils import taint_structure + from ddtrace.appsec._iast.processor import AppSecIastSpanProcessor _set_metric_iast_instrumented_source(OriginType.HEADER_NAME) _set_metric_iast_instrumented_source(OriginType.HEADER) @@ -330,6 +340,9 @@ def _on_wsgi_environ(wrapped, _instance, args, kwargs): _set_metric_iast_instrumented_source(OriginType.PARAMETER_NAME) _set_metric_iast_instrumented_source(OriginType.BODY) + if not AppSecIastSpanProcessor.is_span_analyzed(): + return wrapped(*args, **kwargs) + return wrapped(*((taint_structure(args[0], OriginType.HEADER_NAME, OriginType.HEADER),) + args[1:]), **kwargs) return wrapped(*args, **kwargs) diff --git a/ddtrace/appsec/_iast/_overhead_control_engine.py b/ddtrace/appsec/_iast/_overhead_control_engine.py index 18e15553106..252d2398176 100644 --- a/ddtrace/appsec/_iast/_overhead_control_engine.py +++ b/ddtrace/appsec/_iast/_overhead_control_engine.py @@ -90,8 +90,8 @@ class OverheadControl(object): The goal is to do sampling at different levels of the IAST analysis (per process, per request, etc) """ + _lock = threading.Lock() _request_quota = MAX_REQUESTS - _enabled = False _vulnerabilities = set() # type: Set[Type[Operation]] _sampler = RateSampler(sample_rate=get_request_sampling_value() / 100.0) @@ -99,23 +99,26 @@ def reconfigure(self): self._sampler = RateSampler(sample_rate=get_request_sampling_value() / 100.0) def acquire_request(self, span): - # type: (Span) -> None + # type: (Span) -> bool """Decide whether if IAST analysis will be done for this request. - Block a request's quota at start of the request to limit simultaneous requests analyzed. - Use sample rating to analyze only a percentage of the total requests (30% by default). """ - if self._request_quota > 0 and self._sampler.sample(span): + if self._request_quota <= 0 or not self._sampler.sample(span): + return False + + with self._lock: + if self._request_quota <= 0: + return False + self._request_quota -= 1 - self._enabled = True - def release_request(self): - """increment request's quota at end of the request. + return True - TODO: figure out how to check maximum requests per thread - """ - if self._request_quota < MAX_REQUESTS: + def release_request(self): + """increment request's quota at end of the request.""" + with self._lock: self._request_quota += 1 - self._enabled = False self.vulnerabilities_reset_quota() def register(self, klass): @@ -124,11 +127,6 @@ def register(self, klass): self._vulnerabilities.add(klass) return klass - @property - def request_has_quota(self): - # type: () -> bool - return self._enabled - def vulnerabilities_reset_quota(self): # type: () -> None for k in self._vulnerabilities: diff --git a/ddtrace/appsec/_iast/_patch.py b/ddtrace/appsec/_iast/_patch.py index 9b97afef435..d6d8fefce08 100644 --- a/ddtrace/appsec/_iast/_patch.py +++ b/ddtrace/appsec/_iast/_patch.py @@ -145,6 +145,10 @@ def if_iast_taint_returned_object_for(origin, wrapped, instance, args, kwargs): try: from ._taint_tracking import is_pyobject_tainted from ._taint_tracking import taint_pyobject + from .processor import AppSecIastSpanProcessor + + if not AppSecIastSpanProcessor.is_span_analyzed(): + return value if not is_pyobject_tainted(value): name = str(args[0]) if len(args) else "http.request.body" @@ -157,6 +161,11 @@ def if_iast_taint_returned_object_for(origin, wrapped, instance, args, kwargs): def if_iast_taint_yield_tuple_for(origins, wrapped, instance, args, kwargs): if _is_iast_enabled(): from ._taint_tracking import taint_pyobject + from .processor import AppSecIastSpanProcessor + + if not AppSecIastSpanProcessor.is_span_analyzed(): + for key, value in wrapped(*args, **kwargs): + yield key, value for key, value in wrapped(*args, **kwargs): new_key = taint_pyobject(pyobject=key, source_name=key, source_value=key, source_origin=origins[0]) diff --git a/ddtrace/appsec/_iast/_patches/json_tainting.py b/ddtrace/appsec/_iast/_patches/json_tainting.py index b76564971be..0984b7a2572 100644 --- a/ddtrace/appsec/_iast/_patches/json_tainting.py +++ b/ddtrace/appsec/_iast/_patches/json_tainting.py @@ -47,6 +47,10 @@ def wrapped_loads(wrapped, instance, args, kwargs): from .._taint_tracking import get_tainted_ranges from .._taint_tracking import is_pyobject_tainted from .._taint_tracking import taint_pyobject + from ..processor import AppSecIastSpanProcessor + + if not AppSecIastSpanProcessor.is_span_analyzed(): + return obj if is_pyobject_tainted(args[0]) and obj: # tainting object diff --git a/ddtrace/appsec/_iast/_taint_tracking/__init__.py b/ddtrace/appsec/_iast/_taint_tracking/__init__.py index ba168d485c2..9506d7a4568 100644 --- a/ddtrace/appsec/_iast/_taint_tracking/__init__.py +++ b/ddtrace/appsec/_iast/_taint_tracking/__init__.py @@ -88,9 +88,7 @@ def taint_pyobject(pyobject, source_name, source_value, source_origin=None): # type: (Any, Any, Any, OriginType) -> Any - # Request is not analyzed - if not oce.request_has_quota: - return pyobject + # Pyobject must be Text with len > 1 if not pyobject or not isinstance(pyobject, (str, bytes, bytearray)): return pyobject diff --git a/ddtrace/appsec/_iast/processor.py b/ddtrace/appsec/_iast/processor.py index d93cfe14b2c..7f72709cb37 100644 --- a/ddtrace/appsec/_iast/processor.py +++ b/ddtrace/appsec/_iast/processor.py @@ -20,6 +20,8 @@ if TYPE_CHECKING: # pragma: no cover + from typing import Optional # noqa:F401 + from ddtrace._trace.span import Span # noqa:F401 log = get_logger(__name__) @@ -27,14 +29,34 @@ @attr.s(eq=False) class AppSecIastSpanProcessor(SpanProcessor): + @staticmethod + def is_span_analyzed(span=None): + # type: (Optional[Span]) -> bool + if span is None: + from ddtrace import tracer + + span = tracer.current_root_span() + + if span and span.span_type == SpanTypes.WEB and core.get_item(IAST.REQUEST_IAST_ENABLED, span=span): + return True + return False + def on_span_start(self, span): # type: (Span) -> None if span.span_type != SpanTypes.WEB: return - oce.acquire_request(span) - from ._taint_tracking import create_context - create_context() + if not _is_iast_enabled(): + return + + request_iast_enabled = False + if oce.acquire_request(span): + from ._taint_tracking import create_context + + request_iast_enabled = True + create_context() + + core.set_item(IAST.REQUEST_IAST_ENABLED, request_iast_enabled, span=span) def on_span_finish(self, span): # type: (Span) -> None @@ -48,7 +70,7 @@ def on_span_finish(self, span): if span.span_type != SpanTypes.WEB: return - if not oce._enabled or not _is_iast_enabled(): + if not core.get_item(IAST.REQUEST_IAST_ENABLED, span=span): span.set_metric(IAST.ENABLED, 0.0) return diff --git a/ddtrace/appsec/_iast/taint_sinks/_base.py b/ddtrace/appsec/_iast/taint_sinks/_base.py index 0613fbbb726..e934655f802 100644 --- a/ddtrace/appsec/_iast/taint_sinks/_base.py +++ b/ddtrace/appsec/_iast/taint_sinks/_base.py @@ -11,12 +11,12 @@ from ddtrace.settings.asm import config as asm_config from ..._deduplications import deduplication -from .. import oce from .._overhead_control_engine import Operation from .._stacktrace import get_info_frame from .._utils import _has_to_scrub from .._utils import _is_evidence_value_parts from .._utils import _scrub +from ..processor import AppSecIastSpanProcessor from ..reporter import Evidence from ..reporter import IastSpanReporter from ..reporter import Location @@ -83,7 +83,7 @@ def wrapper(wrapped, instance, args, kwargs): """Get the current root Span and attach it to the wrapped function. We need the span to report the vulnerability and update the context with the report information. """ - if oce.request_has_quota and cls.has_quota(): + if AppSecIastSpanProcessor.is_span_analyzed() and cls.has_quota(): return func(wrapped, instance, args, kwargs) else: log.debug("IAST: no vulnerability quota to analyze more sink points") diff --git a/ddtrace/appsec/_iast/taint_sinks/path_traversal.py b/ddtrace/appsec/_iast/taint_sinks/path_traversal.py index 43b026247d5..c7618000d05 100644 --- a/ddtrace/appsec/_iast/taint_sinks/path_traversal.py +++ b/ddtrace/appsec/_iast/taint_sinks/path_traversal.py @@ -10,6 +10,7 @@ from .._patch import set_module_unpatched from ..constants import EVIDENCE_PATH_TRAVERSAL from ..constants import VULN_PATH_TRAVERSAL +from ..processor import AppSecIastSpanProcessor from ._base import VulnerabilityBase @@ -49,7 +50,7 @@ def patch(): def check_and_report_path_traversal(*args: Any, **kwargs: Any) -> None: - if oce.request_has_quota and PathTraversal.has_quota(): + if AppSecIastSpanProcessor.is_span_analyzed() and PathTraversal.has_quota(): try: from .._metrics import _set_metric_iast_executed_sink from .._taint_tracking import is_pyobject_tainted diff --git a/ddtrace/appsec/_iast/taint_sinks/ssrf.py b/ddtrace/appsec/_iast/taint_sinks/ssrf.py index bee94a020c8..1deb8bb7c16 100644 --- a/ddtrace/appsec/_iast/taint_sinks/ssrf.py +++ b/ddtrace/appsec/_iast/taint_sinks/ssrf.py @@ -16,6 +16,7 @@ from ..constants import EVIDENCE_SSRF from ..constants import VULN_SSRF from ..constants import VULNERABILITY_TOKEN_TYPE +from ..processor import AppSecIastSpanProcessor from ..reporter import IastSpanReporter # noqa:F401 from ..reporter import Vulnerability from ._base import VulnerabilityBase @@ -165,7 +166,7 @@ def _iast_report_ssrf(func: Callable, *args, **kwargs): increment_iast_span_metric(IAST_SPAN_TAGS.TELEMETRY_EXECUTED_SINK, SSRF.vulnerability_type) _set_metric_iast_executed_sink(SSRF.vulnerability_type) if report_ssrf: - if oce.request_has_quota and SSRF.has_quota(): + if AppSecIastSpanProcessor.is_span_analyzed() and SSRF.has_quota(): try: from .._taint_tracking import is_pyobject_tainted diff --git a/releasenotes/notes/iast-fix-oce-logic-4369ebeed72759fc.yaml b/releasenotes/notes/iast-fix-oce-logic-4369ebeed72759fc.yaml new file mode 100644 index 00000000000..bdcc6b965a3 --- /dev/null +++ b/releasenotes/notes/iast-fix-oce-logic-4369ebeed72759fc.yaml @@ -0,0 +1,4 @@ +--- +fixes: + - | + Vulnerability Management for Code-level (IAST): Fixes an issue where requests stopped being analyzed after some time due. diff --git a/tests/appsec/iast/conftest.py b/tests/appsec/iast/conftest.py index a5a80cd5642..1025f672015 100644 --- a/tests/appsec/iast/conftest.py +++ b/tests/appsec/iast/conftest.py @@ -5,6 +5,7 @@ from ddtrace.appsec._iast._patches.json_tainting import unpatch_iast as json_unpatch from ddtrace.appsec._iast._taint_tracking import create_context from ddtrace.appsec._iast._taint_tracking import reset_context +from ddtrace.appsec._iast.processor import AppSecIastSpanProcessor from ddtrace.appsec._iast.taint_sinks._base import VulnerabilityBase from ddtrace.appsec._iast.taint_sinks.command_injection import patch as cmdi_patch from ddtrace.appsec._iast.taint_sinks.command_injection import unpatch as cmdi_unpatch @@ -40,10 +41,12 @@ def iast_span(tracer, env, request_sampling="100", deduplication="false"): psycopg_unpatch = lambda: True # noqa: E731 env.update({"DD_IAST_REQUEST_SAMPLING": request_sampling, "_DD_APPSEC_DEDUPLICATION_ENABLED": deduplication}) + iast_span_processor = AppSecIastSpanProcessor() VulnerabilityBase._reset_cache() with override_global_config(dict(_iast_enabled=True)), override_env(env): oce.reconfigure() with tracer.trace("test") as span: + span.span_type = "web" weak_hash_patch() weak_cipher_patch() path_traversal_patch() @@ -53,9 +56,9 @@ def iast_span(tracer, env, request_sampling="100", deduplication="false"): sqlalchemy_patch() cmdi_patch() langchain_patch() - oce.acquire_request(span) + iast_span_processor.on_span_start(span) yield span - oce.release_request() + iast_span_processor.on_span_finish(span) weak_hash_unpatch() weak_cipher_unpatch() sqli_sqlite_unpatch() diff --git a/tests/appsec/iast/test_overhead_control_engine.py b/tests/appsec/iast/test_overhead_control_engine.py index f37eb666c28..8f64ff8a5c6 100644 --- a/tests/appsec/iast/test_overhead_control_engine.py +++ b/tests/appsec/iast/test_overhead_control_engine.py @@ -70,53 +70,94 @@ def test_oce_reset_vulnerabilities_report(iast_span_defaults): assert len(span_report.vulnerabilities) == MAX_VULNERABILITIES_PER_REQUEST + 1 -def test_oce_max_requests(tracer, iast_span_defaults): +def test_oce_no_race_conditions(tracer, iast_span_defaults): + from ddtrace.appsec._iast._overhead_control_engine import OverheadControl + + oc = OverheadControl() + oc.reconfigure() + + assert oc._request_quota == MAX_REQUESTS + + # Request 1 tries to acquire the lock + assert oc.acquire_request(iast_span_defaults) is True + + # oce should have quota + assert oc._request_quota > 0 + + # Request 2 tries to acquire the lock + assert oc.acquire_request(iast_span_defaults) is True + + # oce should not have quota + assert oc._request_quota == 0 + + # Request 3 tries to acquire the lock and fails + assert oc.acquire_request(iast_span_defaults) is False + + # oce should have quota + assert oc._request_quota == 0 + + # Request 1 releases the lock + oc.release_request() + + assert oc._request_quota > 0 + + # Request 4 tries to acquire the lock + assert oc.acquire_request(iast_span_defaults) is True + + # oce should have quota + assert oc._request_quota == 0 + + # Request 4 releases the lock + oc.release_request() + + # oce should have quota again + assert oc._request_quota > 0 + + # Request 5 tries to acquire the lock + assert oc.acquire_request(iast_span_defaults) is True + + # oce should not have quota + assert oc._request_quota == 0 + + +def acquire_and_release_quota(oc, iast_span_defaults): + """ + Just acquires the request quota and releases it with some + random sleeps + """ + import random + import time + + random_int = random.randint(1, 10) + time.sleep(0.01 * random_int) + if oc.acquire_request(iast_span_defaults): + time.sleep(0.01 * random_int) + oc.release_request() + + +def test_oce_concurrent_requests(tracer, iast_span_defaults): + """ + Ensures quota is always within bounds after multithreading scenario + """ import threading + from ddtrace.appsec._iast._overhead_control_engine import MAX_REQUESTS + from ddtrace.appsec._iast._overhead_control_engine import OverheadControl + + oc = OverheadControl() + oc.reconfigure() + results = [] - num_requests = 5 - total_vulnerabilities = 0 + num_requests = 5000 - threads = [threading.Thread(target=function_with_vulnerabilities_1, args=(tracer,)) for _ in range(0, num_requests)] + threads = [ + threading.Thread(target=acquire_and_release_quota, args=(oc, iast_span_defaults)) + for _ in range(0, num_requests) + ] for thread in threads: thread.start() for thread in threads: results.append(thread.join()) - spans = tracer.pop() - for span in spans: - span_report = core.get_item(IAST.CONTEXT_KEY, span=span) - if span_report: - total_vulnerabilities += len(span_report.vulnerabilities) - - assert len(results) == num_requests - assert len(spans) == num_requests - assert total_vulnerabilities == 1 - - -def test_oce_max_requests_py3(tracer, iast_span_defaults): - import concurrent.futures - - results = [] - num_requests = 5 - total_vulnerabilities = 0 - - with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor: - futures = [] - for _ in range(0, num_requests): - futures.append(executor.submit(function_with_vulnerabilities_1, tracer)) - futures.append(executor.submit(function_with_vulnerabilities_2, tracer)) - futures.append(executor.submit(function_with_vulnerabilities_3, tracer)) - - for future in concurrent.futures.as_completed(futures): - results.append(future.result()) - - spans = tracer.pop() - for span in spans: - span_report = core.get_item(IAST.CONTEXT_KEY, span=span) - if span_report: - total_vulnerabilities += len(span_report.vulnerabilities) - - assert len(results) == num_requests * 3 - assert len(spans) == num_requests * 3 - assert total_vulnerabilities == MAX_REQUESTS + # Ensures quota is always within bounds after multithreading scenario + assert 0 <= oc._request_quota <= MAX_REQUESTS