From 955a4fe7925c76b2ffdd838c68099952f9595ed7 Mon Sep 17 00:00:00 2001 From: John Bodley <4567245+john-bodley@users.noreply.github.com> Date: Fri, 24 Apr 2020 08:18:22 -0700 Subject: [PATCH] [fix] Automatically add relevant Jinja methods to cache key if present (#9572) * [fix] Adding URL params to cache key if present * [cache] Wrapping Jinja methods Co-authored-by: John Bodley --- UPDATING.md | 2 + docs/sqllab.rst | 8 +- superset/connectors/sqla/models.py | 25 +++-- superset/jinja_context.py | 161 +++++++++++++++++------------ superset/views/tags.py | 7 +- tests/core_tests.py | 1 - tests/sqla_models_tests.py | 62 ++++++----- 7 files changed, 156 insertions(+), 110 deletions(-) diff --git a/UPDATING.md b/UPDATING.md index ef9405e1fce0..046585214ef9 100644 --- a/UPDATING.md +++ b/UPDATING.md @@ -23,6 +23,8 @@ assists people when migrating to a new version. ## Next +* [9572](https://github.com/apache/incubator-superset/pull/9572): a change which by defau;t means that the Jinja `current_user_id`, `current_username`, and `url_param` context calls no longer need to be wrapped via `cache_key_wrapper` in order to be included in the cache key. The `cache_key_wrapper` function should only be required for Jinja add-ons. + * [8867](https://github.com/apache/incubator-superset/pull/8867): a change which adds the `tmp_schema_name` column to the `query` table which requires locking the table. Given the `query` table is heavily used performance may be degraded during the migration. Scheduled downtime may be advised. * [9238](https://github.com/apache/incubator-superset/pull/9238): the config option `TIME_GRAIN_FUNCTIONS` has been renamed to `TIME_GRAIN_EXPRESSIONS` to better reflect the content of the dictionary. diff --git a/docs/sqllab.rst b/docs/sqllab.rst index aace28f119e9..c74ce487a12e 100644 --- a/docs/sqllab.rst +++ b/docs/sqllab.rst @@ -79,15 +79,15 @@ Superset's Jinja context: `Jinja's builtin filters `_ can be also be applied where needed. -.. autofunction:: superset.jinja_context.current_user_id +.. autofunction:: superset.jinja_context.ExtraCache.current_user_id -.. autofunction:: superset.jinja_context.current_username +.. autofunction:: superset.jinja_context.ExtraCache.current_username -.. autofunction:: superset.jinja_context.url_param +.. autofunction:: superset.jinja_context.ExtraCache.url_param .. autofunction:: superset.jinja_context.filter_values -.. autofunction:: superset.jinja_context.CacheKeyWrapper.cache_key_wrapper +.. autofunction:: superset.jinja_context.ExtraCache.cache_key_wrapper .. autoclass:: superset.jinja_context.PrestoTemplateProcessor :members: diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index bb3cf59c9fc2..7b3ec04b00dd 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -54,7 +54,7 @@ from superset.constants import NULL_STRING from superset.db_engine_specs.base import TimestampExpression from superset.exceptions import DatabaseNotFound -from superset.jinja_context import get_template_processor +from superset.jinja_context import ExtraCache, get_template_processor from superset.models.annotations import Annotation from superset.models.core import Database from superset.models.helpers import AuditMixinNullable, QueryResult @@ -1216,18 +1216,17 @@ def query_datasources_by_name( def default_query(qry) -> Query: return qry.filter_by(is_sqllab_view=False) - def has_calls_to_cache_key_wrapper(self, query_obj: Dict[str, Any]) -> bool: + def has_extra_cache_key_calls(self, query_obj: Dict[str, Any]) -> bool: """ - Detects the presence of calls to `cache_key_wrapper` in items in query_obj that + Detects the presence of calls to `ExtraCache` methods in items in query_obj that can be templated. If any are present, the query must be evaluated to extract - additional keys for the cache key. This method is needed to avoid executing - the template code unnecessarily, as it may contain expensive calls, e.g. to - extract the latest partition of a database. + additional keys for the cache key. This method is needed to avoid executing the + template code unnecessarily, as it may contain expensive calls, e.g. to extract + the latest partition of a database. :param query_obj: query object to analyze - :return: True if at least one item calls `cache_key_wrapper`, otherwise False + :return: True if there are call(s) to an `ExtraCache` method, False otherwise """ - regex = re.compile(r"\{\{.*cache_key_wrapper\(.*\).*\}\}") templatable_statements: List[str] = [] if self.sql: templatable_statements.append(self.sql) @@ -1239,20 +1238,20 @@ def has_calls_to_cache_key_wrapper(self, query_obj: Dict[str, Any]) -> bool: if "having" in extras: templatable_statements.append(extras["having"]) for statement in templatable_statements: - if regex.search(statement): + if ExtraCache.regex.search(statement): return True return False def get_extra_cache_keys(self, query_obj: Dict[str, Any]) -> List[Hashable]: """ - The cache key of a SqlaTable needs to consider any keys added by the parent class - and any keys added via `cache_key_wrapper`. + The cache key of a SqlaTable needs to consider any keys added by the parent + class and any keys added via `ExtraCache`. :param query_obj: query object to analyze - :return: True if at least one item calls `cache_key_wrapper`, otherwise False + :return: The extra cache keys """ extra_cache_keys = super().get_extra_cache_keys(query_obj) - if self.has_calls_to_cache_key_wrapper(query_obj): + if self.has_extra_cache_key_calls(query_obj): sqla_query = self.get_sqla_query(**query_obj) extra_cache_keys += sqla_query.extra_cache_keys return extra_cache_keys diff --git a/superset/jinja_context.py b/superset/jinja_context.py index b4f915917588..5438f0f3b5a3 100644 --- a/superset/jinja_context.py +++ b/superset/jinja_context.py @@ -17,6 +17,7 @@ """Defines the templating context for SQL Lab""" import inspect import json +import re from typing import Any, List, Optional, Tuple from flask import g, request @@ -27,51 +28,6 @@ from superset.utils.core import convert_legacy_filters_into_adhoc, merge_extra_filters -def url_param(param: str, default: Optional[str] = None) -> Optional[Any]: - """Read a url or post parameter and use it in your SQL Lab query - - When in SQL Lab, it's possible to add arbitrary URL "query string" - parameters, and use those in your SQL code. For instance you can - alter your url and add `?foo=bar`, as in - `{domain}/superset/sqllab?foo=bar`. Then if your query is something like - SELECT * FROM foo = '{{ url_param('foo') }}', it will be parsed at - runtime and replaced by the value in the URL. - - As you create a visualization form this SQL Lab query, you can pass - parameters in the explore view as well as from the dashboard, and - it should carry through to your queries. - - Default values for URL parameters can be defined in chart metdata by - adding the key-value pair `url_params: {'foo': 'bar'}` - - :param param: the parameter to lookup - :param default: the value to return in the absence of the parameter - """ - if request.args.get(param): - return request.args.get(param, default) - # Supporting POST as well as get - form_data = request.form.get("form_data") - if isinstance(form_data, str): - form_data = json.loads(form_data) - url_params = form_data.get("url_params") or {} - return url_params.get(param, default) - return default - - -def current_user_id() -> Optional[int]: - """The id of the user who is currently logged in""" - if hasattr(g, "user") and g.user: - return g.user.id - return None - - -def current_username() -> Optional[str]: - """The username of the user who is currently logged in""" - if g.user: - return g.user.username - return None - - def filter_values(column: str, default: Optional[str] = None) -> List[str]: """ Gets a values for a particular filter as a list @@ -122,33 +78,63 @@ def filter_values(column: str, default: Optional[str] = None) -> List[str]: return [] -class CacheKeyWrapper: # pylint: disable=too-few-public-methods - """ Dummy class that exposes a method used to store additional values used in - calculation of query object cache keys""" +class ExtraCache: + """ + Dummy class that exposes a method used to store additional values used in + calculation of query object cache keys. + """ + + # Regular expression for detecting the presence of templated methods which could + # be added to the cache key. + regex = re.compile( + r"\{\{.*(" + r"current_user_id\(.*\)|" + r"current_username\(.*\)|" + r"cache_key_wrapper\(.*\)|" + r"url_param\(.*\)" + r").*\}\}" + ) def __init__(self, extra_cache_keys: Optional[List[Any]] = None): self.extra_cache_keys = extra_cache_keys + def current_user_id(self, add_to_cache_keys: bool = True) -> Optional[int]: + """ + Return the user ID of the user who is currently logged in. + + :param add_to_cache_keys: Whether the value should be included in the cache key + :returns: The user ID + """ + + if hasattr(g, "user") and g.user: + if add_to_cache_keys: + self.cache_key_wrapper(g.user.id) + return g.user.id + return None + + def current_username(self, add_to_cache_keys: bool = True) -> Optional[str]: + """ + Return the username of the user who is currently logged in. + + :param add_to_cache_keys: Whether the value should be included in the cache key + :returns: The username + """ + + if g.user: + if add_to_cache_keys: + self.cache_key_wrapper(g.user.username) + return g.user.username + return None + def cache_key_wrapper(self, key: Any) -> Any: - """ Adds values to a list that is added to the query object used for calculating - a cache key. + """ + Adds values to a list that is added to the query object used for calculating a + cache key. This is needed if the following applies: - Caching is enabled - The query is dynamically generated using a jinja template - - A username or similar is used as a filter in the query - - Example when using a SQL query as a data source :: - - SELECT action, count(*) as times - FROM logs - WHERE logged_in_user = '{{ cache_key_wrapper(current_username()) }}' - GROUP BY action - - This will ensure that the query results that were cached by `user_1` will - **not** be seen by `user_2`, as the `cache_key` for the query will be - different. ``cache_key_wrapper`` can be used similarly for regular table data - sources by adding a `Custom SQL` filter. + - A `JINJA_CONTEXT_ADDONS` or similar is used as a filter in the query :param key: Any value that should be considered when calculating the cache key :return: the original value ``key`` passed to the function @@ -157,6 +143,44 @@ def cache_key_wrapper(self, key: Any) -> Any: self.extra_cache_keys.append(key) return key + def url_param( + self, param: str, default: Optional[str] = None, add_to_cache_keys: bool = True + ) -> Optional[Any]: + """ + Read a url or post parameter and use it in your SQL Lab query. + + When in SQL Lab, it's possible to add arbitrary URL "query string" parameters, + and use those in your SQL code. For instance you can alter your url and add + `?foo=bar`, as in `{domain}/superset/sqllab?foo=bar`. Then if your query is + something like SELECT * FROM foo = '{{ url_param('foo') }}', it will be parsed + at runtime and replaced by the value in the URL. + + As you create a visualization form this SQL Lab query, you can pass parameters + in the explore view as well as from the dashboard, and it should carry through + to your queries. + + Default values for URL parameters can be defined in chart metadata by adding the + key-value pair `url_params: {'foo': 'bar'}` + + :param param: the parameter to lookup + :param default: the value to return in the absence of the parameter + :param add_to_cache_keys: Whether the value should be included in the cache key + :returns: The URL parameters + """ + + if request.args.get(param): + return request.args.get(param, default) + # Supporting POST as well as get + form_data = request.form.get("form_data") + if isinstance(form_data, str): + form_data = json.loads(form_data) + url_params = form_data.get("url_params") or {} + result = url_params.get(param, default) + if add_to_cache_keys: + self.cache_key_wrapper(result) + return result + return default + class BaseTemplateProcessor: # pylint: disable=too-few-public-methods """Base class for database-specific jinja context @@ -190,11 +214,14 @@ def __init__( self.schema = query.schema elif table: self.schema = table.schema + + extra_cache = ExtraCache(extra_cache_keys) + self.context = { - "url_param": url_param, - "current_user_id": current_user_id, - "current_username": current_username, - "cache_key_wrapper": CacheKeyWrapper(extra_cache_keys).cache_key_wrapper, + "url_param": extra_cache.url_param, + "current_user_id": extra_cache.current_user_id, + "current_username": extra_cache.current_username, + "cache_key_wrapper": extra_cache.cache_key_wrapper, "filter_values": filter_values, "form_data": {}, } diff --git a/superset/views/tags.py b/superset/views/tags.py index decbdc2f6ad2..e12df2a7c904 100644 --- a/superset/views/tags.py +++ b/superset/views/tags.py @@ -24,7 +24,7 @@ from sqlalchemy import and_, func from superset import db, utils -from superset.jinja_context import current_user_id, current_username +from superset.jinja_context import ExtraCache from superset.models.dashboard import Dashboard from superset.models.slice import Slice from superset.models.sql_lab import SavedQuery @@ -36,7 +36,10 @@ def process_template(content): env = SandboxedEnvironment() template = env.from_string(content) - context = {"current_user_id": current_user_id, "current_username": current_username} + context = { + "current_user_id": ExtraCache.current_user_id, + "current_username": ExtraCache.current_username, + } return template.render(context) diff --git a/tests/core_tests.py b/tests/core_tests.py index ecbfa0588133..b3d8cad9aa98 100644 --- a/tests/core_tests.py +++ b/tests/core_tests.py @@ -115,7 +115,6 @@ def test_viz_cache_key(self): viz = slc.viz qobj = viz.query_obj() cache_key = viz.cache_key(qobj) - self.assertEqual(cache_key, viz.cache_key(qobj)) qobj["groupby"] = [] self.assertNotEqual(cache_key, viz.cache_key(qobj)) diff --git a/tests/sqla_models_tests.py b/tests/sqla_models_tests.py index 2cabe92aafa5..aa2daf9148e0 100644 --- a/tests/sqla_models_tests.py +++ b/tests/sqla_models_tests.py @@ -16,6 +16,7 @@ # under the License. # isort:skip_file from typing import Any, Dict, NamedTuple, List, Tuple, Union +from unittest.mock import patch import tests.test_app from superset.connectors.sqla.models import SqlaTable, TableColumn @@ -69,14 +70,10 @@ def test_db_column_types(self): self.assertEqual(col.is_numeric, db_col_type == DbColumnType.NUMERIC) self.assertEqual(col.is_string, db_col_type == DbColumnType.STRING) - def test_has_extra_cache_keys(self): - query = "SELECT '{{ cache_key_wrapper('user_1') }}' as user" - table = SqlaTable( - table_name="test_has_extra_cache_keys_table", - sql=query, - database=get_example_database(), - ) - query_obj = { + @patch("superset.jinja_context.g") + def test_extra_cache_keys(self, flask_g): + flask_g.user.username = "abc" + base_query_obj = { "granularity": None, "from_dttm": None, "to_dttm": None, @@ -84,33 +81,52 @@ def test_has_extra_cache_keys(self): "metrics": [], "is_timeseries": False, "filter": [], - "extras": {"where": "(user != '{{ cache_key_wrapper('user_2') }}')"}, } + + # Table with Jinja callable. + table = SqlaTable( + table_name="test_has_extra_cache_keys_table", + sql="SELECT '{{ current_username() }}' as user", + database=get_example_database(), + ) + + query_obj = dict(**base_query_obj, extras={}) extra_cache_keys = table.get_extra_cache_keys(query_obj) - self.assertTrue(table.has_calls_to_cache_key_wrapper(query_obj)) - self.assertListEqual(extra_cache_keys, ["user_1", "user_2"]) + self.assertTrue(table.has_extra_cache_key_calls(query_obj)) + self.assertListEqual(extra_cache_keys, ["abc"]) - def test_has_no_extra_cache_keys(self): + # Table with Jinja callable disabled. + table = SqlaTable( + table_name="test_has_extra_cache_keys_disabled_table", + sql="SELECT '{{ current_username(False) }}' as user", + database=get_example_database(), + ) + query_obj = dict(**base_query_obj, extras={}) + extra_cache_keys = table.get_extra_cache_keys(query_obj) + self.assertTrue(table.has_extra_cache_key_calls(query_obj)) + self.assertListEqual(extra_cache_keys, []) + + # Table with no Jinja callable. query = "SELECT 'abc' as user" table = SqlaTable( table_name="test_has_no_extra_cache_keys_table", sql=query, database=get_example_database(), ) - query_obj = { - "granularity": None, - "from_dttm": None, - "to_dttm": None, - "groupby": ["user"], - "metrics": [], - "is_timeseries": False, - "filter": [], - "extras": {"where": "(user != 'abc')"}, - } + + query_obj = dict(**base_query_obj, extras={"where": "(user != 'abc')"}) extra_cache_keys = table.get_extra_cache_keys(query_obj) - self.assertFalse(table.has_calls_to_cache_key_wrapper(query_obj)) + self.assertFalse(table.has_extra_cache_key_calls(query_obj)) self.assertListEqual(extra_cache_keys, []) + # With Jinja callable in SQL expression. + query_obj = dict( + **base_query_obj, extras={"where": "(user != '{{ current_username() }}')"} + ) + extra_cache_keys = table.get_extra_cache_keys(query_obj) + self.assertTrue(table.has_extra_cache_key_calls(query_obj)) + self.assertListEqual(extra_cache_keys, ["abc"]) + def test_where_operators(self): class FilterTestCase(NamedTuple): operator: str