diff --git a/superset/models/helpers.py b/superset/models/helpers.py index 83ec9ba37c4fa..f846c19ec8ddb 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -68,7 +68,12 @@ ) from superset.extensions import feature_flag_manager from superset.jinja_context import BaseTemplateProcessor -from superset.sql_parse import has_table_query, insert_rls, ParsedQuery, sanitize_clause +from superset.sql_parse import ( + has_table_query, + insert_rls_in_predicate, + ParsedQuery, + sanitize_clause, +) from superset.superset_typing import ( AdhocMetric, Column as ColumnTyping, @@ -128,7 +133,7 @@ def validate_adhoc_subquery( level=ErrorLevel.ERROR, ) ) - statement = insert_rls(statement, database_id, default_schema) + statement = insert_rls_in_predicate(statement, database_id, default_schema) statements.append(statement) return ";\n".join(str(statement) for statement in statements) diff --git a/superset/sql_lab.py b/superset/sql_lab.py index 4d71e23d88cee..efbef6560a366 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -48,7 +48,12 @@ from superset.models.core import Database from superset.models.sql_lab import Query from superset.result_set import SupersetResultSet -from superset.sql_parse import CtasMethod, insert_rls, ParsedQuery +from superset.sql_parse import ( + CtasMethod, + insert_rls_as_subquery, + insert_rls_in_predicate, + ParsedQuery, +) from superset.sqllab.limiting_factor import LimitingFactor from superset.sqllab.utils import write_ipc_buffer from superset.utils.celery import session_scope @@ -191,7 +196,7 @@ def get_sql_results( # pylint: disable=too-many-arguments return handle_query_error(ex, query, session) -def execute_sql_statement( # pylint: disable=too-many-arguments +def execute_sql_statement( # pylint: disable=too-many-arguments, too-many-locals sql_statement: str, query: Query, session: Session, @@ -205,6 +210,16 @@ def execute_sql_statement( # pylint: disable=too-many-arguments parsed_query = ParsedQuery(sql_statement) if is_feature_enabled("RLS_IN_SQLLAB"): + # There are two ways to insert RLS: either replacing the table with a subquery + # that has the RLS, or appending the RLS to the ``WHERE`` clause. The former is + # safer, but not supported in all databases. + insert_rls = ( + insert_rls_as_subquery + if database.db_engine_spec.allows_subqueries + and database.db_engine_spec.allows_alias_in_select + else insert_rls_in_predicate + ) + # Insert any applicable RLS predicates parsed_query = ParsedQuery( str( diff --git a/superset/sql_parse.py b/superset/sql_parse.py index d75551bef0767..cecd673276976 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -44,6 +44,7 @@ Punctuation, String, Whitespace, + Wildcard, ) from sqlparse.utils import imt @@ -660,18 +661,116 @@ def get_rls_for_table( return None rls = sqlparse.parse(predicate)[0] - add_table_name(rls, str(dataset)) + add_table_name(rls, table.table) return rls -def insert_rls( +def insert_rls_as_subquery( token_list: TokenList, database_id: int, default_schema: Optional[str], ) -> TokenList: """ Update a statement inplace applying any associated RLS predicates. + + The RLS predicate is applied as subquery replacing the original table: + + before: SELECT * FROM some_table WHERE 1=1 + after: SELECT * FROM ( + SELECT * FROM some_table WHERE some_table.id=42 + ) AS some_table + WHERE 1=1 + + This method is safer than ``insert_rls_in_predicate``, but doesn't work in all + databases. + """ + rls: Optional[TokenList] = None + state = InsertRLSState.SCANNING + for token in token_list.tokens: + # Recurse into child token list + if isinstance(token, TokenList): + i = token_list.tokens.index(token) + token_list.tokens[i] = insert_rls_as_subquery( + token, + database_id, + default_schema, + ) + + # Found a source keyword (FROM/JOIN) + if imt(token, m=[(Keyword, "FROM"), (Keyword, "JOIN")]): + state = InsertRLSState.SEEN_SOURCE + + # Found identifier/keyword after FROM/JOIN, test for table + elif state == InsertRLSState.SEEN_SOURCE and ( + isinstance(token, Identifier) or token.ttype == Keyword + ): + rls = get_rls_for_table(token, database_id, default_schema) + if rls: + # replace table with subquery + subquery_alias = ( + token.tokens[-1].value + if isinstance(token, Identifier) + else token.value + ) + i = token_list.tokens.index(token) + + # strip alias from table name + if isinstance(token, Identifier) and token.has_alias(): + whitespace_index = token.token_next_by(t=Whitespace)[0] + token.tokens = token.tokens[:whitespace_index] + + token_list.tokens[i] = Identifier( + [ + Parenthesis( + [ + Token(Punctuation, "("), + Token(DML, "SELECT"), + Token(Whitespace, " "), + Token(Wildcard, "*"), + Token(Whitespace, " "), + Token(Keyword, "FROM"), + Token(Whitespace, " "), + token, + Token(Whitespace, " "), + Where( + [ + Token(Keyword, "WHERE"), + Token(Whitespace, " "), + rls, + ] + ), + Token(Punctuation, ")"), + ] + ), + Token(Whitespace, " "), + Token(Keyword, "AS"), + Token(Whitespace, " "), + Identifier([Token(Name, subquery_alias)]), + ] + ) + state = InsertRLSState.SCANNING + + # Found nothing, leaving source + elif state == InsertRLSState.SEEN_SOURCE and token.ttype != Whitespace: + state = InsertRLSState.SCANNING + + return token_list + + +def insert_rls_in_predicate( + token_list: TokenList, + database_id: int, + default_schema: Optional[str], +) -> TokenList: + """ + Update a statement inplace applying any associated RLS predicates. + + The RLS predicate is ``AND``ed to any existing predicates: + + before: SELECT * FROM some_table WHERE 1=1 + after: SELECT * FROM some_table WHERE ( 1=1) AND some_table.id=42 + """ rls: Optional[TokenList] = None state = InsertRLSState.SCANNING @@ -679,7 +778,11 @@ def insert_rls( # Recurse into child token list if isinstance(token, TokenList): i = token_list.tokens.index(token) - token_list.tokens[i] = insert_rls(token, database_id, default_schema) + token_list.tokens[i] = insert_rls_in_predicate( + token, + database_id, + default_schema, + ) # Found a source keyword (FROM/JOIN) if imt(token, m=[(Keyword, "FROM"), (Keyword, "JOIN")]): diff --git a/tests/unit_tests/sql_lab_test.py b/tests/unit_tests/sql_lab_test.py index edc1fd2ec4a5d..200ee091ec558 100644 --- a/tests/unit_tests/sql_lab_test.py +++ b/tests/unit_tests/sql_lab_test.py @@ -87,7 +87,7 @@ def test_execute_sql_statement_with_rls( cursor = mocker.MagicMock() SupersetResultSet = mocker.patch("superset.sql_lab.SupersetResultSet") mocker.patch( - "superset.sql_lab.insert_rls", + "superset.sql_lab.insert_rls_as_subquery", return_value=sqlparse.parse("SELECT * FROM sales WHERE organization_id=42")[0], ) mocker.patch("superset.sql_lab.is_feature_enabled", return_value=True) @@ -112,12 +112,12 @@ def test_execute_sql_statement_with_rls( SupersetResultSet.assert_called_with([(42,)], cursor.description, db_engine_spec) -def test_sql_lab_insert_rls( +def test_sql_lab_insert_rls_as_subquery( mocker: MockerFixture, session: Session, ) -> None: """ - Integration test for `insert_rls`. + Integration test for `insert_rls_as_subquery`. """ from flask_appbuilder.security.sqla.models import Role, User @@ -213,4 +213,7 @@ def test_sql_lab_insert_rls( | 2 | 8 | | 3 | 9 |""".strip() ) - assert query.executed_sql == "SELECT c FROM t WHERE (t.c > 5)\nLIMIT 6" + assert ( + query.executed_sql + == "SELECT c FROM (SELECT * FROM t WHERE (t.c > 5)) AS t\nLIMIT 6" + ) diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py index 341ba9d789396..efd883810147e 100644 --- a/tests/unit_tests/sql_parse_tests.py +++ b/tests/unit_tests/sql_parse_tests.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name, redefined-outer-name, unused-argument, protected-access, too-many-lines +# pylint: disable=invalid-name, redefined-outer-name, too-many-lines from typing import Optional @@ -31,7 +31,8 @@ extract_table_references, get_rls_for_table, has_table_query, - insert_rls, + insert_rls_as_subquery, + insert_rls_in_predicate, ParsedQuery, sanitize_clause, strip_comments_from_sql, @@ -1318,6 +1319,184 @@ def test_has_table_query(sql: str, expected: bool) -> None: assert has_table_query(statement) == expected +@pytest.mark.parametrize( + "sql,table,rls,expected", + [ + # Basic test + ( + "SELECT * FROM some_table WHERE 1=1", + "some_table", + "id=42", + ( + "SELECT * FROM (SELECT * FROM some_table WHERE some_table.id=42) " + "AS some_table WHERE 1=1" + ), + ), + # Here "table" is a reserved word; since sqlparse is too aggressive when + # characterizing reserved words we need to support them even when not quoted. + ( + "SELECT * FROM table WHERE 1=1", + "table", + "id=42", + "SELECT * FROM (SELECT * FROM table WHERE table.id=42) AS table WHERE 1=1", + ), + # RLS is only applied to queries reading from the associated table + ( + "SELECT * FROM table WHERE 1=1", + "other_table", + "id=42", + "SELECT * FROM table WHERE 1=1", + ), + ( + "SELECT * FROM other_table WHERE 1=1", + "table", + "id=42", + "SELECT * FROM other_table WHERE 1=1", + ), + # JOINs are supported + ( + "SELECT * FROM table JOIN other_table ON table.id = other_table.id", + "other_table", + "id=42", + ( + "SELECT * FROM table JOIN " + "(SELECT * FROM other_table WHERE other_table.id=42) AS other_table " + "ON table.id = other_table.id" + ), + ), + # Subqueries + ( + "SELECT * FROM (SELECT * FROM other_table)", + "other_table", + "id=42", + ( + "SELECT * FROM (SELECT * FROM (" + "SELECT * FROM other_table WHERE other_table.id=42" + ") AS other_table)" + ), + ), + # UNION + ( + "SELECT * FROM table UNION ALL SELECT * FROM other_table", + "table", + "id=42", + ( + "SELECT * FROM (SELECT * FROM table WHERE table.id=42) AS table " + "UNION ALL SELECT * FROM other_table" + ), + ), + ( + "SELECT * FROM table UNION ALL SELECT * FROM other_table", + "other_table", + "id=42", + ( + "SELECT * FROM table UNION ALL SELECT * FROM (" + "SELECT * FROM other_table WHERE other_table.id=42) AS other_table" + ), + ), + # When comparing fully qualified table names (eg, schema.table) to simple names + # (eg, table) we are also conservative, assuming the schema is the same, since + # we don't have information on the default schema. + ( + "SELECT * FROM schema.table_name", + "table_name", + "id=42", + ( + "SELECT * FROM (SELECT * FROM schema.table_name " + "WHERE table_name.id=42) AS table_name" + ), + ), + ( + "SELECT * FROM schema.table_name", + "schema.table_name", + "id=42", + ( + "SELECT * FROM (SELECT * FROM schema.table_name " + "WHERE schema.table_name.id=42) AS table_name" + ), + ), + ( + "SELECT * FROM table_name", + "schema.table_name", + "id=42", + ( + "SELECT * FROM (SELECT * FROM table_name WHERE " + "schema.table_name.id=42) AS table_name" + ), + ), + # Aliases + ( + "SELECT a.*, b.* FROM tbl_a AS a INNER JOIN tbl_b AS b ON a.col = b.col", + "tbl_a", + "id=42", + ( + "SELECT a.*, b.* FROM " + "(SELECT * FROM tbl_a WHERE tbl_a.id=42) AS a " + "INNER JOIN tbl_b AS b " + "ON a.col = b.col" + ), + ), + ( + "SELECT a.*, b.* FROM tbl_a a INNER JOIN tbl_b b ON a.col = b.col", + "tbl_a", + "id=42", + ( + "SELECT a.*, b.* FROM " + "(SELECT * FROM tbl_a WHERE tbl_a.id=42) AS a " + "INNER JOIN tbl_b b ON a.col = b.col" + ), + ), + ], +) +def test_insert_rls_as_subquery( + mocker: MockerFixture, sql: str, table: str, rls: str, expected: str +) -> None: + """ + Insert into a statement a given RLS condition associated with a table. + """ + condition = sqlparse.parse(rls)[0] + add_table_name(condition, table) + + # pylint: disable=unused-argument + def get_rls_for_table( + candidate: Token, + database_id: int, + default_schema: str, + ) -> Optional[TokenList]: + """ + Return the RLS ``condition`` if ``candidate`` matches ``table``. + """ + if not isinstance(candidate, Identifier): + candidate = Identifier([Token(Name, candidate.value)]) + + candidate_table = ParsedQuery.get_table(candidate) + if not candidate_table: + return None + candidate_table_name = ( + f"{candidate_table.schema}.{candidate_table.table}" + if candidate_table.schema + else candidate_table.table + ) + for left, right in zip( + candidate_table_name.split(".")[::-1], table.split(".")[::-1] + ): + if left != right: + return None + return condition + + mocker.patch("superset.sql_parse.get_rls_for_table", new=get_rls_for_table) + + statement = sqlparse.parse(sql)[0] + assert ( + str( + insert_rls_as_subquery( + token_list=statement, database_id=1, default_schema="my_schema" + ) + ).strip() + == expected.strip() + ) + + @pytest.mark.parametrize( "sql,table,rls,expected", [ @@ -1492,7 +1671,7 @@ def test_has_table_query(sql: str, expected: bool) -> None: ), ], ) -def test_insert_rls( +def test_insert_rls_in_predicate( mocker: MockerFixture, sql: str, table: str, rls: str, expected: str ) -> None: """ @@ -1521,7 +1700,11 @@ def get_rls_for_table( statement = sqlparse.parse(sql)[0] assert ( str( - insert_rls(token_list=statement, database_id=1, default_schema="my_schema") + insert_rls_in_predicate( + token_list=statement, + database_id=1, + default_schema="my_schema", + ) ).strip() == expected.strip() )