From 7e8aaacab0e80466294b59d5fced7d53cb3e7eda Mon Sep 17 00:00:00 2001 From: Federico Mon Date: Thu, 22 Feb 2024 14:15:46 +0100 Subject: [PATCH] fix(iast): improve overhead control logic (#8452) IAST: Improve overhead control logic so the decision to analyze a request is done at span start and is saved at the span level using the core API. This should fix issues where requests were analyzed when they shouldn't be and viceversa. ## Checklist - [x] Change(s) are motivated and described in the PR description - [x] Testing strategy is described if automated tests are not included in the PR - [x] Risks are described (performance impact, potential for breakage, maintainability) - [x] Change is maintainable (easy to change, telemetry, documentation) - [x] [Library release note guidelines](https://ddtrace.readthedocs.io/en/stable/releasenotes.html) are followed or label `changelog/no-changelog` is set - [x] Documentation is included (in-code, generated user docs, [public corp docs](https://github.com/DataDog/documentation/)) - [x] Backport labels are set (if [applicable](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting)) - [x] If this PR changes the public interface, I've notified `@DataDog/apm-tees`. - [x] If change touches code that signs or publishes builds or packages, or handles credentials of any kind, I've requested a review from `@DataDog/security-design-and-guidance`. ## Reviewer Checklist - [x] Title is accurate - [x] All changes are related to the pull request's stated goal - [x] Description motivates each change - [x] Avoids breaking [API](https://ddtrace.readthedocs.io/en/stable/versioning.html#interfaces) changes - [x] Testing strategy adequately addresses listed risks - [x] Change is maintainable (easy to change, telemetry, documentation) - [x] Release note makes sense to a user of the library - [x] Author has acknowledged and discussed the performance implications of this PR as reported in the benchmarks PR comment - [x] Backport labels are set in a manner that is consistent with the [release branch maintenance policy](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting) --- ddtrace/appsec/_asm_request_context.py | 9 ++ ddtrace/appsec/_constants.py | 1 + ddtrace/appsec/_handlers.py | 17 ++- .../appsec/_iast/_overhead_control_engine.py | 28 ++-- ddtrace/appsec/_iast/_patch.py | 9 ++ .../appsec/_iast/_patches/json_tainting.py | 4 + .../appsec/_iast/_taint_tracking/__init__.py | 4 +- ddtrace/appsec/_iast/processor.py | 30 ++++- ddtrace/appsec/_iast/taint_sinks/_base.py | 4 +- .../_iast/taint_sinks/path_traversal.py | 3 +- ddtrace/appsec/_iast/taint_sinks/ssrf.py | 3 +- .../iast-fix-oce-logic-4369ebeed72759fc.yaml | 4 + tests/appsec/iast/conftest.py | 7 +- .../iast/test_overhead_control_engine.py | 123 ++++++++++++------ 14 files changed, 175 insertions(+), 71 deletions(-) create mode 100644 releasenotes/notes/iast-fix-oce-logic-4369ebeed72759fc.yaml diff --git a/ddtrace/appsec/_asm_request_context.py b/ddtrace/appsec/_asm_request_context.py index 4bc868b7a26..eda18517bde 100644 --- a/ddtrace/appsec/_asm_request_context.py +++ b/ddtrace/appsec/_asm_request_context.py @@ -436,6 +436,10 @@ def _on_wrapped_view(kwargs): if _is_iast_enabled() and kwargs: from ddtrace.appsec._iast._taint_tracking import OriginType from ddtrace.appsec._iast._taint_tracking import taint_pyobject + from ddtrace.appsec._iast.processor import AppSecIastSpanProcessor + + if not AppSecIastSpanProcessor.is_span_analyzed(): + return return_value _kwargs = {} for k, v in kwargs.items(): @@ -451,9 +455,14 @@ def _on_set_request_tags(request, span, flask_config): from ddtrace.appsec._iast._metrics import _set_metric_iast_instrumented_source from ddtrace.appsec._iast._taint_tracking import OriginType from ddtrace.appsec._iast._taint_utils import taint_structure + from ddtrace.appsec._iast.processor import AppSecIastSpanProcessor _set_metric_iast_instrumented_source(OriginType.COOKIE_NAME) _set_metric_iast_instrumented_source(OriginType.COOKIE) + + if not AppSecIastSpanProcessor.is_span_analyzed(span._local_root or span): + return + request.cookies = taint_structure( request.cookies, OriginType.COOKIE_NAME, diff --git a/ddtrace/appsec/_constants.py b/ddtrace/appsec/_constants.py index dd0512d822d..fc0f4a4abba 100644 --- a/ddtrace/appsec/_constants.py +++ b/ddtrace/appsec/_constants.py @@ -82,6 +82,7 @@ class IAST(metaclass=Constant_Class): PATCH_MODULES = "_DD_IAST_PATCH_MODULES" DENY_MODULES = "_DD_IAST_DENY_MODULES" SEP_MODULES = "," + REQUEST_IAST_ENABLED = "_dd.iast.request_enabled" class IAST_SPAN_TAGS(metaclass=Constant_Class): diff --git a/ddtrace/appsec/_handlers.py b/ddtrace/appsec/_handlers.py index 6439632969d..32efb83722f 100644 --- a/ddtrace/appsec/_handlers.py +++ b/ddtrace/appsec/_handlers.py @@ -190,6 +190,13 @@ def _on_request_init(wrapped, instance, args, kwargs): from ddtrace.appsec._iast._metrics import _set_metric_iast_instrumented_source from ddtrace.appsec._iast._taint_tracking import OriginType from ddtrace.appsec._iast._taint_tracking import taint_pyobject + from ddtrace.appsec._iast.processor import AppSecIastSpanProcessor + + _set_metric_iast_instrumented_source(OriginType.PATH) + _set_metric_iast_instrumented_source(OriginType.QUERY) + + if not AppSecIastSpanProcessor.is_span_analyzed(): + return # TODO: instance.query_string = ?? instance.query_string = taint_pyobject( @@ -204,8 +211,6 @@ def _on_request_init(wrapped, instance, args, kwargs): source_value=instance.path, source_origin=OriginType.PATH, ) - _set_metric_iast_instrumented_source(OriginType.PATH) - _set_metric_iast_instrumented_source(OriginType.QUERY) except Exception: log.debug("Unexpected exception while tainting pyobject", exc_info=True) @@ -269,6 +274,10 @@ def _on_django_func_wrapped(fn_args, fn_kwargs, first_arg_expected_type, *_): from ddtrace.appsec._iast._taint_tracking import is_pyobject_tainted from ddtrace.appsec._iast._taint_tracking import taint_pyobject from ddtrace.appsec._iast._taint_utils import taint_structure + from ddtrace.appsec._iast.processor import AppSecIastSpanProcessor + + if not AppSecIastSpanProcessor.is_span_analyzed(): + return http_req = fn_args[0] @@ -318,6 +327,7 @@ def _on_wsgi_environ(wrapped, _instance, args, kwargs): from ddtrace.appsec._iast._metrics import _set_metric_iast_instrumented_source from ddtrace.appsec._iast._taint_tracking import OriginType # noqa: F401 from ddtrace.appsec._iast._taint_utils import taint_structure + from ddtrace.appsec._iast.processor import AppSecIastSpanProcessor _set_metric_iast_instrumented_source(OriginType.HEADER_NAME) _set_metric_iast_instrumented_source(OriginType.HEADER) @@ -330,6 +340,9 @@ def _on_wsgi_environ(wrapped, _instance, args, kwargs): _set_metric_iast_instrumented_source(OriginType.PARAMETER_NAME) _set_metric_iast_instrumented_source(OriginType.BODY) + if not AppSecIastSpanProcessor.is_span_analyzed(): + return wrapped(*args, **kwargs) + return wrapped(*((taint_structure(args[0], OriginType.HEADER_NAME, OriginType.HEADER),) + args[1:]), **kwargs) return wrapped(*args, **kwargs) diff --git a/ddtrace/appsec/_iast/_overhead_control_engine.py b/ddtrace/appsec/_iast/_overhead_control_engine.py index 18e15553106..252d2398176 100644 --- a/ddtrace/appsec/_iast/_overhead_control_engine.py +++ b/ddtrace/appsec/_iast/_overhead_control_engine.py @@ -90,8 +90,8 @@ class OverheadControl(object): The goal is to do sampling at different levels of the IAST analysis (per process, per request, etc) """ + _lock = threading.Lock() _request_quota = MAX_REQUESTS - _enabled = False _vulnerabilities = set() # type: Set[Type[Operation]] _sampler = RateSampler(sample_rate=get_request_sampling_value() / 100.0) @@ -99,23 +99,26 @@ def reconfigure(self): self._sampler = RateSampler(sample_rate=get_request_sampling_value() / 100.0) def acquire_request(self, span): - # type: (Span) -> None + # type: (Span) -> bool """Decide whether if IAST analysis will be done for this request. - Block a request's quota at start of the request to limit simultaneous requests analyzed. - Use sample rating to analyze only a percentage of the total requests (30% by default). """ - if self._request_quota > 0 and self._sampler.sample(span): + if self._request_quota <= 0 or not self._sampler.sample(span): + return False + + with self._lock: + if self._request_quota <= 0: + return False + self._request_quota -= 1 - self._enabled = True - def release_request(self): - """increment request's quota at end of the request. + return True - TODO: figure out how to check maximum requests per thread - """ - if self._request_quota < MAX_REQUESTS: + def release_request(self): + """increment request's quota at end of the request.""" + with self._lock: self._request_quota += 1 - self._enabled = False self.vulnerabilities_reset_quota() def register(self, klass): @@ -124,11 +127,6 @@ def register(self, klass): self._vulnerabilities.add(klass) return klass - @property - def request_has_quota(self): - # type: () -> bool - return self._enabled - def vulnerabilities_reset_quota(self): # type: () -> None for k in self._vulnerabilities: diff --git a/ddtrace/appsec/_iast/_patch.py b/ddtrace/appsec/_iast/_patch.py index 9b97afef435..d6d8fefce08 100644 --- a/ddtrace/appsec/_iast/_patch.py +++ b/ddtrace/appsec/_iast/_patch.py @@ -145,6 +145,10 @@ def if_iast_taint_returned_object_for(origin, wrapped, instance, args, kwargs): try: from ._taint_tracking import is_pyobject_tainted from ._taint_tracking import taint_pyobject + from .processor import AppSecIastSpanProcessor + + if not AppSecIastSpanProcessor.is_span_analyzed(): + return value if not is_pyobject_tainted(value): name = str(args[0]) if len(args) else "http.request.body" @@ -157,6 +161,11 @@ def if_iast_taint_returned_object_for(origin, wrapped, instance, args, kwargs): def if_iast_taint_yield_tuple_for(origins, wrapped, instance, args, kwargs): if _is_iast_enabled(): from ._taint_tracking import taint_pyobject + from .processor import AppSecIastSpanProcessor + + if not AppSecIastSpanProcessor.is_span_analyzed(): + for key, value in wrapped(*args, **kwargs): + yield key, value for key, value in wrapped(*args, **kwargs): new_key = taint_pyobject(pyobject=key, source_name=key, source_value=key, source_origin=origins[0]) diff --git a/ddtrace/appsec/_iast/_patches/json_tainting.py b/ddtrace/appsec/_iast/_patches/json_tainting.py index b76564971be..0984b7a2572 100644 --- a/ddtrace/appsec/_iast/_patches/json_tainting.py +++ b/ddtrace/appsec/_iast/_patches/json_tainting.py @@ -47,6 +47,10 @@ def wrapped_loads(wrapped, instance, args, kwargs): from .._taint_tracking import get_tainted_ranges from .._taint_tracking import is_pyobject_tainted from .._taint_tracking import taint_pyobject + from ..processor import AppSecIastSpanProcessor + + if not AppSecIastSpanProcessor.is_span_analyzed(): + return obj if is_pyobject_tainted(args[0]) and obj: # tainting object diff --git a/ddtrace/appsec/_iast/_taint_tracking/__init__.py b/ddtrace/appsec/_iast/_taint_tracking/__init__.py index ba168d485c2..9506d7a4568 100644 --- a/ddtrace/appsec/_iast/_taint_tracking/__init__.py +++ b/ddtrace/appsec/_iast/_taint_tracking/__init__.py @@ -88,9 +88,7 @@ def taint_pyobject(pyobject, source_name, source_value, source_origin=None): # type: (Any, Any, Any, OriginType) -> Any - # Request is not analyzed - if not oce.request_has_quota: - return pyobject + # Pyobject must be Text with len > 1 if not pyobject or not isinstance(pyobject, (str, bytes, bytearray)): return pyobject diff --git a/ddtrace/appsec/_iast/processor.py b/ddtrace/appsec/_iast/processor.py index d93cfe14b2c..7f72709cb37 100644 --- a/ddtrace/appsec/_iast/processor.py +++ b/ddtrace/appsec/_iast/processor.py @@ -20,6 +20,8 @@ if TYPE_CHECKING: # pragma: no cover + from typing import Optional # noqa:F401 + from ddtrace._trace.span import Span # noqa:F401 log = get_logger(__name__) @@ -27,14 +29,34 @@ @attr.s(eq=False) class AppSecIastSpanProcessor(SpanProcessor): + @staticmethod + def is_span_analyzed(span=None): + # type: (Optional[Span]) -> bool + if span is None: + from ddtrace import tracer + + span = tracer.current_root_span() + + if span and span.span_type == SpanTypes.WEB and core.get_item(IAST.REQUEST_IAST_ENABLED, span=span): + return True + return False + def on_span_start(self, span): # type: (Span) -> None if span.span_type != SpanTypes.WEB: return - oce.acquire_request(span) - from ._taint_tracking import create_context - create_context() + if not _is_iast_enabled(): + return + + request_iast_enabled = False + if oce.acquire_request(span): + from ._taint_tracking import create_context + + request_iast_enabled = True + create_context() + + core.set_item(IAST.REQUEST_IAST_ENABLED, request_iast_enabled, span=span) def on_span_finish(self, span): # type: (Span) -> None @@ -48,7 +70,7 @@ def on_span_finish(self, span): if span.span_type != SpanTypes.WEB: return - if not oce._enabled or not _is_iast_enabled(): + if not core.get_item(IAST.REQUEST_IAST_ENABLED, span=span): span.set_metric(IAST.ENABLED, 0.0) return diff --git a/ddtrace/appsec/_iast/taint_sinks/_base.py b/ddtrace/appsec/_iast/taint_sinks/_base.py index 0613fbbb726..e934655f802 100644 --- a/ddtrace/appsec/_iast/taint_sinks/_base.py +++ b/ddtrace/appsec/_iast/taint_sinks/_base.py @@ -11,12 +11,12 @@ from ddtrace.settings.asm import config as asm_config from ..._deduplications import deduplication -from .. import oce from .._overhead_control_engine import Operation from .._stacktrace import get_info_frame from .._utils import _has_to_scrub from .._utils import _is_evidence_value_parts from .._utils import _scrub +from ..processor import AppSecIastSpanProcessor from ..reporter import Evidence from ..reporter import IastSpanReporter from ..reporter import Location @@ -83,7 +83,7 @@ def wrapper(wrapped, instance, args, kwargs): """Get the current root Span and attach it to the wrapped function. We need the span to report the vulnerability and update the context with the report information. """ - if oce.request_has_quota and cls.has_quota(): + if AppSecIastSpanProcessor.is_span_analyzed() and cls.has_quota(): return func(wrapped, instance, args, kwargs) else: log.debug("IAST: no vulnerability quota to analyze more sink points") diff --git a/ddtrace/appsec/_iast/taint_sinks/path_traversal.py b/ddtrace/appsec/_iast/taint_sinks/path_traversal.py index 43b026247d5..c7618000d05 100644 --- a/ddtrace/appsec/_iast/taint_sinks/path_traversal.py +++ b/ddtrace/appsec/_iast/taint_sinks/path_traversal.py @@ -10,6 +10,7 @@ from .._patch import set_module_unpatched from ..constants import EVIDENCE_PATH_TRAVERSAL from ..constants import VULN_PATH_TRAVERSAL +from ..processor import AppSecIastSpanProcessor from ._base import VulnerabilityBase @@ -49,7 +50,7 @@ def patch(): def check_and_report_path_traversal(*args: Any, **kwargs: Any) -> None: - if oce.request_has_quota and PathTraversal.has_quota(): + if AppSecIastSpanProcessor.is_span_analyzed() and PathTraversal.has_quota(): try: from .._metrics import _set_metric_iast_executed_sink from .._taint_tracking import is_pyobject_tainted diff --git a/ddtrace/appsec/_iast/taint_sinks/ssrf.py b/ddtrace/appsec/_iast/taint_sinks/ssrf.py index bee94a020c8..1deb8bb7c16 100644 --- a/ddtrace/appsec/_iast/taint_sinks/ssrf.py +++ b/ddtrace/appsec/_iast/taint_sinks/ssrf.py @@ -16,6 +16,7 @@ from ..constants import EVIDENCE_SSRF from ..constants import VULN_SSRF from ..constants import VULNERABILITY_TOKEN_TYPE +from ..processor import AppSecIastSpanProcessor from ..reporter import IastSpanReporter # noqa:F401 from ..reporter import Vulnerability from ._base import VulnerabilityBase @@ -165,7 +166,7 @@ def _iast_report_ssrf(func: Callable, *args, **kwargs): increment_iast_span_metric(IAST_SPAN_TAGS.TELEMETRY_EXECUTED_SINK, SSRF.vulnerability_type) _set_metric_iast_executed_sink(SSRF.vulnerability_type) if report_ssrf: - if oce.request_has_quota and SSRF.has_quota(): + if AppSecIastSpanProcessor.is_span_analyzed() and SSRF.has_quota(): try: from .._taint_tracking import is_pyobject_tainted diff --git a/releasenotes/notes/iast-fix-oce-logic-4369ebeed72759fc.yaml b/releasenotes/notes/iast-fix-oce-logic-4369ebeed72759fc.yaml new file mode 100644 index 00000000000..bdcc6b965a3 --- /dev/null +++ b/releasenotes/notes/iast-fix-oce-logic-4369ebeed72759fc.yaml @@ -0,0 +1,4 @@ +--- +fixes: + - | + Vulnerability Management for Code-level (IAST): Fixes an issue where requests stopped being analyzed after some time due. diff --git a/tests/appsec/iast/conftest.py b/tests/appsec/iast/conftest.py index a5a80cd5642..1025f672015 100644 --- a/tests/appsec/iast/conftest.py +++ b/tests/appsec/iast/conftest.py @@ -5,6 +5,7 @@ from ddtrace.appsec._iast._patches.json_tainting import unpatch_iast as json_unpatch from ddtrace.appsec._iast._taint_tracking import create_context from ddtrace.appsec._iast._taint_tracking import reset_context +from ddtrace.appsec._iast.processor import AppSecIastSpanProcessor from ddtrace.appsec._iast.taint_sinks._base import VulnerabilityBase from ddtrace.appsec._iast.taint_sinks.command_injection import patch as cmdi_patch from ddtrace.appsec._iast.taint_sinks.command_injection import unpatch as cmdi_unpatch @@ -40,10 +41,12 @@ def iast_span(tracer, env, request_sampling="100", deduplication="false"): psycopg_unpatch = lambda: True # noqa: E731 env.update({"DD_IAST_REQUEST_SAMPLING": request_sampling, "_DD_APPSEC_DEDUPLICATION_ENABLED": deduplication}) + iast_span_processor = AppSecIastSpanProcessor() VulnerabilityBase._reset_cache() with override_global_config(dict(_iast_enabled=True)), override_env(env): oce.reconfigure() with tracer.trace("test") as span: + span.span_type = "web" weak_hash_patch() weak_cipher_patch() path_traversal_patch() @@ -53,9 +56,9 @@ def iast_span(tracer, env, request_sampling="100", deduplication="false"): sqlalchemy_patch() cmdi_patch() langchain_patch() - oce.acquire_request(span) + iast_span_processor.on_span_start(span) yield span - oce.release_request() + iast_span_processor.on_span_finish(span) weak_hash_unpatch() weak_cipher_unpatch() sqli_sqlite_unpatch() diff --git a/tests/appsec/iast/test_overhead_control_engine.py b/tests/appsec/iast/test_overhead_control_engine.py index f37eb666c28..8f64ff8a5c6 100644 --- a/tests/appsec/iast/test_overhead_control_engine.py +++ b/tests/appsec/iast/test_overhead_control_engine.py @@ -70,53 +70,94 @@ def test_oce_reset_vulnerabilities_report(iast_span_defaults): assert len(span_report.vulnerabilities) == MAX_VULNERABILITIES_PER_REQUEST + 1 -def test_oce_max_requests(tracer, iast_span_defaults): +def test_oce_no_race_conditions(tracer, iast_span_defaults): + from ddtrace.appsec._iast._overhead_control_engine import OverheadControl + + oc = OverheadControl() + oc.reconfigure() + + assert oc._request_quota == MAX_REQUESTS + + # Request 1 tries to acquire the lock + assert oc.acquire_request(iast_span_defaults) is True + + # oce should have quota + assert oc._request_quota > 0 + + # Request 2 tries to acquire the lock + assert oc.acquire_request(iast_span_defaults) is True + + # oce should not have quota + assert oc._request_quota == 0 + + # Request 3 tries to acquire the lock and fails + assert oc.acquire_request(iast_span_defaults) is False + + # oce should have quota + assert oc._request_quota == 0 + + # Request 1 releases the lock + oc.release_request() + + assert oc._request_quota > 0 + + # Request 4 tries to acquire the lock + assert oc.acquire_request(iast_span_defaults) is True + + # oce should have quota + assert oc._request_quota == 0 + + # Request 4 releases the lock + oc.release_request() + + # oce should have quota again + assert oc._request_quota > 0 + + # Request 5 tries to acquire the lock + assert oc.acquire_request(iast_span_defaults) is True + + # oce should not have quota + assert oc._request_quota == 0 + + +def acquire_and_release_quota(oc, iast_span_defaults): + """ + Just acquires the request quota and releases it with some + random sleeps + """ + import random + import time + + random_int = random.randint(1, 10) + time.sleep(0.01 * random_int) + if oc.acquire_request(iast_span_defaults): + time.sleep(0.01 * random_int) + oc.release_request() + + +def test_oce_concurrent_requests(tracer, iast_span_defaults): + """ + Ensures quota is always within bounds after multithreading scenario + """ import threading + from ddtrace.appsec._iast._overhead_control_engine import MAX_REQUESTS + from ddtrace.appsec._iast._overhead_control_engine import OverheadControl + + oc = OverheadControl() + oc.reconfigure() + results = [] - num_requests = 5 - total_vulnerabilities = 0 + num_requests = 5000 - threads = [threading.Thread(target=function_with_vulnerabilities_1, args=(tracer,)) for _ in range(0, num_requests)] + threads = [ + threading.Thread(target=acquire_and_release_quota, args=(oc, iast_span_defaults)) + for _ in range(0, num_requests) + ] for thread in threads: thread.start() for thread in threads: results.append(thread.join()) - spans = tracer.pop() - for span in spans: - span_report = core.get_item(IAST.CONTEXT_KEY, span=span) - if span_report: - total_vulnerabilities += len(span_report.vulnerabilities) - - assert len(results) == num_requests - assert len(spans) == num_requests - assert total_vulnerabilities == 1 - - -def test_oce_max_requests_py3(tracer, iast_span_defaults): - import concurrent.futures - - results = [] - num_requests = 5 - total_vulnerabilities = 0 - - with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor: - futures = [] - for _ in range(0, num_requests): - futures.append(executor.submit(function_with_vulnerabilities_1, tracer)) - futures.append(executor.submit(function_with_vulnerabilities_2, tracer)) - futures.append(executor.submit(function_with_vulnerabilities_3, tracer)) - - for future in concurrent.futures.as_completed(futures): - results.append(future.result()) - - spans = tracer.pop() - for span in spans: - span_report = core.get_item(IAST.CONTEXT_KEY, span=span) - if span_report: - total_vulnerabilities += len(span_report.vulnerabilities) - - assert len(results) == num_requests * 3 - assert len(spans) == num_requests * 3 - assert total_vulnerabilities == MAX_REQUESTS + # Ensures quota is always within bounds after multithreading scenario + assert 0 <= oc._request_quota <= MAX_REQUESTS