Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: helper functions for RLS #19055

Merged
merged 9 commits into from
Mar 11, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 181 additions & 0 deletions superset/sql_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
remove_quotes,
Token,
TokenList,
Where,
)
from sqlparse.tokens import (
CTE,
Expand Down Expand Up @@ -458,3 +459,183 @@ def validate_filter_clause(clause: str) -> None:
)
if open_parens > 0:
raise QueryClauseValidationException("Unclosed parenthesis in filter clause")


class InsertRLSState(str, Enum):
"""
State machine that scans for WHERE and ON clauses referencing tables.
"""

SCANNING = "SCANNING"
SEEN_SOURCE = "SEEN_SOURCE"
FOUND_TABLE = "FOUND_TABLE"


def has_table_query(token_list: TokenList) -> bool:
"""
Return if a stament has a query reading from a table.

>>> has_table_query(sqlparse.parse("COUNT(*)")[0])
False
>>> has_table_query(sqlparse.parse("SELECT * FROM table")[0])
True

Note that queries reading from constant values return false:

>>> has_table_query(sqlparse.parse("SELECT * FROM (SELECT 1)")[0])
False

"""
state = InsertRLSState.SCANNING
for token in token_list.tokens:

# # Recurse into child token list
if isinstance(token, TokenList) and has_table_query(token):
return True

# Found a source keyword (FROM/JOIN)
if imt(token, m=[(Keyword, "FROM"), (Keyword, "JOIN")]):
state = InsertRLSState.SEEN_SOURCE

# Found identifier/keyword after FROM/JOIN
elif state == InsertRLSState.SEEN_SOURCE and (
isinstance(token, sqlparse.sql.Identifier) or token.ttype == Keyword
):
return True

# Found nothing, leaving source
elif state == InsertRLSState.SEEN_SOURCE and token.ttype != Whitespace:
state = InsertRLSState.SCANNING

return False


def add_table_name(rls: TokenList, table: str) -> None:
betodealmeida marked this conversation as resolved.
Show resolved Hide resolved
"""
Modify a RLS expression ensuring columns are fully qualified.
"""
tokens = rls.tokens[:]
while tokens:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You likely could use flatten here. It uses a generator so likely a copy should be made given you're mutating the tokens, i.e.,

for token in list(rls.flatten()):
    if imt(token, i=Identifier) and token.get_parent_name() is None:
        ...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue, if we call .flatten() we would never get an Identifier.

token = tokens.pop(0)

if isinstance(token, Identifier) and token.get_parent_name() is None:
token.tokens = [
Token(Name, table),
Token(Punctuation, "."),
Token(Name, token.get_name()),
]
elif isinstance(token, TokenList):
tokens.extend(token.tokens)


def matches_table_name(token: Token, table: str) -> bool:
"""
Return the name of a table.
betodealmeida marked this conversation as resolved.
Show resolved Hide resolved

A table should be represented as an identifier, but due to sqlparse's aggressive list
of keywords (spanning multiple dialects) often it gets classified as a keyword.
"""
candidate = token.value

# match from right to left, splitting on the period, eg, schema.table == table
for left, right in zip(candidate.split(".")[::-1], table.split(".")[::-1]):
betodealmeida marked this conversation as resolved.
Show resolved Hide resolved
if left != right:
return False

return True


def insert_rls(token_list: TokenList, table: str, rls: TokenList) -> TokenList:
"""
Update a statement inpalce applying an RLS associated with a given table.
betodealmeida marked this conversation as resolved.
Show resolved Hide resolved
"""
# make sure the identifier has the table name
add_table_name(rls, table)

state = InsertRLSState.SCANNING
for token in token_list.tokens:

# Recurse into child token list
if isinstance(token, TokenList):
i = token_list.tokens.index(token)
token_list.tokens[i] = insert_rls(token, table, rls)

# Found a source keyword (FROM/JOIN)
if imt(token, m=[(Keyword, "FROM"), (Keyword, "JOIN")]):
state = InsertRLSState.SEEN_SOURCE

# Found identifier/keyword after FROM/JOIN, test for table
elif state == InsertRLSState.SEEN_SOURCE and (
isinstance(token, Identifier) or token.ttype == Keyword
):
if matches_table_name(token, table):
state = InsertRLSState.FOUND_TABLE

# found table at the end of the statement; append a WHERE clause
betodealmeida marked this conversation as resolved.
Show resolved Hide resolved
if token == token_list[-1]:
token_list.tokens.extend(
[
Token(Whitespace, " "),
Where(
[Token(Keyword, "WHERE"), Token(Whitespace, " "), rls]
),
]
)
return token_list

# Found WHERE clause, insert RLS if not present
elif state == InsertRLSState.FOUND_TABLE and isinstance(token, Where):
if str(rls) not in {str(t) for t in token.tokens}:
token.tokens.extend(
[
Token(Whitespace, " "),
Token(Keyword, "AND"),
Token(Whitespace, " "),
]
+ rls.tokens
)
state = InsertRLSState.SCANNING

# Found ON clause, insert RLS if not present
betodealmeida marked this conversation as resolved.
Show resolved Hide resolved
elif (
state == InsertRLSState.FOUND_TABLE
and token.ttype == Keyword
and token.value.upper() == "ON"
):
i = token_list.tokens.index(token)
token.parent.tokens[i + 1 : i + 1] = [
Token(Whitespace, " "),
rls,
Token(Whitespace, " "),
Token(Keyword, "AND"),
]
state = InsertRLSState.SCANNING

# Found table but no WHERE clause found, insert one
elif state == InsertRLSState.FOUND_TABLE and token.ttype != Whitespace:
i = token_list.tokens.index(token)

# Left pad with space, if needed
if i > 0 and token_list.tokens[i - 1].ttype != Whitespace:
token_list.tokens.insert(i, Token(Whitespace, " "))
i += 1

# Insert predicate
token_list.tokens.insert(
i, Where([Token(Keyword, "WHERE"), Token(Whitespace, " "), rls]),
)

# Right pad with space, if needed
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does sqlparse even tokenize whitespace?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's because it makes it easier to convert the parse tree back to a string. Not sure.

if (
i < len(token_list.tokens) - 2
and token_list.tokens[i + 2] != Whitespace
):
token_list.tokens.insert(i + 1, Token(Whitespace, " "))

state = InsertRLSState.SCANNING

# Found nothing, leaving source
elif state == InsertRLSState.SEEN_SOURCE and token.ttype != Whitespace:
state = InsertRLSState.SCANNING

return token_list
174 changes: 172 additions & 2 deletions tests/unit_tests/sql_parse_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.

# pylint: disable=invalid-name
# pylint: disable=invalid-name, too-many-lines

import unittest
from typing import Set
Expand All @@ -25,6 +25,9 @@

from superset.exceptions import QueryClauseValidationException
from superset.sql_parse import (
add_table_name,
has_table_query,
insert_rls,
ParsedQuery,
strip_comments_from_sql,
Table,
Expand Down Expand Up @@ -1111,7 +1114,8 @@ def test_sqlparse_formatting():

"""
assert sqlparse.format(
"SELECT extract(HOUR from from_unixtime(hour_ts) AT TIME ZONE 'America/Los_Angeles') from table",
"SELECT extract(HOUR from from_unixtime(hour_ts) "
"AT TIME ZONE 'America/Los_Angeles') from table",
reindent=True,
) == (
"SELECT extract(HOUR\n from from_unixtime(hour_ts) "
Expand Down Expand Up @@ -1189,3 +1193,169 @@ def test_sqlparse_issue_652():
stmt = sqlparse.parse(r"foo = '\' AND bar = 'baz'")[0]
assert len(stmt.tokens) == 5
assert str(stmt.tokens[0]) == "foo = '\\'"


@pytest.mark.parametrize(
"sql,expected",
[
("SELECT * FROM table", True),
("SELECT a FROM (SELECT 1 AS a) JOIN (SELECT * FROM table)", True),
("(SELECT COUNT(DISTINCT name) AS foo FROM birth_names)", True),
("COUNT(*)", False),
("SELECT a FROM (SELECT 1 AS a)", False),
("SELECT a FROM (SELECT 1 AS a) JOIN table", True),
("SELECT * FROM (SELECT 1 AS foo, 2 AS bar) ORDER BY foo ASC, bar", False),
("SELECT * FROM other_table", True),
],
)
betodealmeida marked this conversation as resolved.
Show resolved Hide resolved
def test_has_table_query(sql: str, expected: bool) -> None:
"""
Test if a given statement queries a table.

This is used to prevent ad-hoc metrics from querying unauthorized tables, bypassing
row-level security.
"""
statement = sqlparse.parse(sql)[0]
assert has_table_query(statement) == expected


@pytest.mark.parametrize(
"sql,table,rls,expected",
[
# append RLS to an existing WHERE clause
(
"SELECT * FROM other_table WHERE 1=1",
"other_table",
"id=42",
"SELECT * FROM other_table WHERE 1=1 AND other_table.id=42",
),
# "table" is a reserved word; since sqlparse is too aggressive when characterizing
# reserved words we need to support them even when not quoted
betodealmeida marked this conversation as resolved.
Show resolved Hide resolved
(
"SELECT * FROM table WHERE 1=1",
"table",
"id=42",
"SELECT * FROM table WHERE 1=1 AND table.id=42",
),
# RLS applies to a different table
(
"SELECT * FROM table WHERE 1=1",
"other_table",
"id=42",
"SELECT * FROM table WHERE 1=1",
),
(
"SELECT * FROM other_table WHERE 1=1",
"table",
"id=42",
"SELECT * FROM other_table WHERE 1=1",
),
# insert the WHERE clause if there isn't one
(
"SELECT * FROM table",
"table",
"id=42",
"SELECT * FROM table WHERE table.id=42",
),
(
"SELECT * FROM other_table",
"other_table",
"id=42",
"SELECT * FROM other_table WHERE other_table.id=42",
),
(
"SELECT * FROM table ORDER BY id",
"table",
"id=42",
"SELECT * FROM table WHERE table.id=42 ORDER BY id",
),
# do not add RLS if already present...
(
"SELECT * FROM table WHERE 1=1 AND table.id=42",
"table",
"id=42",
"SELECT * FROM table WHERE 1=1 AND table.id=42",
),
# ...but when in doubt add it
(
"SELECT * FROM table WHERE 1=1 AND id=42",
"table",
"id=42",
"SELECT * FROM table WHERE 1=1 AND id=42 AND table.id=42",
),
# test with joins
(
"SELECT * FROM table JOIN other_table ON table.id = other_table.id",
"other_table",
"id=42",
(
"SELECT * FROM table JOIN other_table ON other_table.id=42 "
"AND table.id = other_table.id"
),
),
# test with inner selects
(
"SELECT * FROM (SELECT * FROM other_table)",
"other_table",
"id=42",
"SELECT * FROM (SELECT * FROM other_table WHERE other_table.id=42)",
),
# union
(
"SELECT * FROM table UNION ALL SELECT * FROM other_table",
"table",
"id=42",
"SELECT * FROM table WHERE table.id=42 UNION ALL SELECT * FROM other_table",
),
(
"SELECT * FROM table UNION ALL SELECT * FROM other_table",
"other_table",
"id=42",
(
"SELECT * FROM table UNION ALL "
"SELECT * FROM other_table WHERE other_table.id=42"
),
),
# fully qualified table names
(
"SELECT * FROM schema.table_name",
"table_name",
"id=42",
"SELECT * FROM schema.table_name WHERE table_name.id=42",
),
(
"SELECT * FROM schema.table_name",
"schema.table_name",
"id=42",
"SELECT * FROM schema.table_name WHERE schema.table_name.id=42",
),
(
"SELECT * FROM table_name",
"schema.table_name",
"id=42",
"SELECT * FROM table_name WHERE schema.table_name.id=42",
),
],
)
def test_insert_rls(sql, table, rls, expected) -> None:
"""
Insert into a statement a given RLS condition associated with a table.
"""
statement = sqlparse.parse(sql)[0]
condition = sqlparse.parse(rls)[0]
assert str(insert_rls(statement, table, condition)).strip() == expected.strip()


@pytest.mark.parametrize(
"rls,table,expected",
[
("id=42", "users", "users.id=42"),
("users.id=42", "users", "users.id=42"),
("schema.users.id=42", "users", "schema.users.id=42"),
("false", "users", "false"),
],
)
def test_add_table_name(rls, table, expected) -> None:
condition = sqlparse.parse(rls)[0]
add_table_name(condition, table)
assert str(condition) == expected