Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: allow subquery in ad-hoc SQL #19242

Merged
merged 14 commits into from
Mar 18, 2022
1 change: 1 addition & 0 deletions superset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,7 @@ def _try_json_readsha(filepath: str, length: int) -> Optional[str]:
"ALLOW_FULL_CSV_EXPORT": False,
"UX_BETA": False,
"GENERIC_CHART_AXES": False,
"ALLOW_ADHOC_SUBQUERY": False,
}

# Feature flags may also be set via 'SUPERSET_FEATURE_' prefixed environment vars.
Expand Down
7 changes: 7 additions & 0 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
from superset.connectors.sqla.utils import (
get_physical_table_metadata,
get_virtual_table_metadata,
validate_adhoc_subquery,
)
from superset.datasets.models import Dataset as NewDataset
from superset.db_engine_specs.base import BaseEngineSpec, CTE_ALIAS, TimestampExpression
Expand Down Expand Up @@ -885,6 +886,7 @@ def adhoc_metric_to_sqla(
elif expression_type == utils.AdhocMetricExpressionType.SQL:
tp = self.get_template_processor()
expression = tp.process_template(cast(str, metric["sqlExpression"]))
validate_adhoc_subquery(expression)
sqla_metric = literal_column(expression)
else:
raise QueryObjectValidationError("Adhoc metric expressionType is invalid")
Expand All @@ -908,6 +910,8 @@ def adhoc_column_to_sqla(
expression = col["sqlExpression"]
if template_processor and expression:
expression = template_processor.process_template(expression)
if expression:
validate_adhoc_subquery(expression)
sqla_metric = literal_column(expression)
return self.make_sqla_column_compatible(sqla_metric, label)

Expand Down Expand Up @@ -1166,6 +1170,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
elif selected in columns_by_name:
outer = columns_by_name[selected].get_sqla_col()
else:
validate_adhoc_subquery(selected)
outer = literal_column(f"({selected})")
outer = self.make_sqla_column_compatible(outer, selected)
else:
Expand All @@ -1178,6 +1183,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
select_exprs.append(outer)
elif columns:
for selected in columns:
validate_adhoc_subquery(selected)
select_exprs.append(
columns_by_name[selected].get_sqla_col()
if selected in columns_by_name
Expand Down Expand Up @@ -1389,6 +1395,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
and db_engine_spec.allows_hidden_cc_in_orderby
and col.name in [select_col.name for select_col in select_exprs]
):
validate_adhoc_subquery(str(col.expression))
col = literal_column(col.name)
direction = asc if ascending else desc
qry = qry.order_by(direction(col))
Expand Down
28 changes: 27 additions & 1 deletion superset/connectors/sqla/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from contextlib import closing
from typing import Dict, List, Optional, TYPE_CHECKING

import sqlparse
from flask_babel import lazy_gettext as _
from sqlalchemy.exc import NoSuchTableError
from sqlalchemy.sql.type_api import TypeEngine
Expand All @@ -28,7 +29,7 @@
)
from superset.models.core import Database
from superset.result_set import SupersetResultSet
from superset.sql_parse import ParsedQuery
from superset.sql_parse import has_table_query, ParsedQuery

if TYPE_CHECKING:
from superset.connectors.sqla.models import SqlaTable
Expand Down Expand Up @@ -119,3 +120,28 @@ def get_virtual_table_metadata(dataset: "SqlaTable") -> List[Dict[str, str]]:
except Exception as ex:
raise SupersetGenericDBErrorException(message=str(ex)) from ex
return cols


def validate_adhoc_subquery(raw_sql: str) -> None:
"""
Check if adhoc SQL contains sub-queries or nested sub-queries with table
:param raw_sql: adhoc sql expression
:raise SupersetSecurityException if sql contains sub-queries or
nested sub-queries with table
"""
# pylint: disable=import-outside-toplevel
from superset import is_feature_enabled

if is_feature_enabled("ALLOW_ADHOC_SUBQUERY"):
return

for statement in sqlparse.parse(raw_sql):
if has_table_query(statement):
raise SupersetSecurityException(
SupersetError(
error_type=SupersetErrorType.ADHOC_SUBQUERY_NOT_ALLOWED_ERROR,
message=_("Custom SQL fields cannot contain sub-queries."),
level=ErrorLevel.ERROR,
)
)
return
3 changes: 3 additions & 0 deletions superset/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class SupersetErrorType(str, Enum):
SQLLAB_TIMEOUT_ERROR = "SQLLAB_TIMEOUT_ERROR"
RESULTS_BACKEND_ERROR = "RESULTS_BACKEND_ERROR"
ASYNC_WORKERS_ERROR = "ASYNC_WORKERS_ERROR"
ADHOC_SUBQUERY_NOT_ALLOWED_ERROR = "ADHOC_SUBQUERY_NOT_ALLOWED_ERROR"

# Generic errors
GENERIC_COMMAND_ERROR = "GENERIC_COMMAND_ERROR"
Expand Down Expand Up @@ -138,10 +139,12 @@ class SupersetErrorType(str, Enum):
1034: _("The port number is invalid."),
1035: _("Failed to start remote query on a worker."),
1036: _("The database was deleted."),
1037: _("Custom SQL fields cannot contain sub-queries."),
}


ERROR_TYPES_TO_ISSUE_CODES_MAPPING = {
SupersetErrorType.ADHOC_SUBQUERY_NOT_ALLOWED_ERROR: [1037],
SupersetErrorType.BACKEND_TIMEOUT_ERROR: [1000, 1001],
SupersetErrorType.GENERIC_DB_ENGINE_ERROR: [1002],
SupersetErrorType.COLUMN_DOES_NOT_EXIST_ERROR: [1003, 1004],
Expand Down
31 changes: 30 additions & 1 deletion tests/integration_tests/sqla_models_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from superset.constants import EMPTY_STRING, NULL_STRING
from superset.db_engine_specs.bigquery import BigQueryEngineSpec
from superset.db_engine_specs.druid import DruidEngineSpec
from superset.exceptions import QueryObjectValidationError
from superset.exceptions import QueryObjectValidationError, SupersetSecurityException
from superset.models.core import Database
from superset.utils.core import (
AdhocMetricExpressionType,
Expand Down Expand Up @@ -239,6 +239,35 @@ def test_jinja_metrics_and_calc_columns(self, flask_g):
db.session.delete(table)
db.session.commit()

def test_adhoc_metrics_and_calc_columns(self):
base_query_obj = {
"granularity": None,
"from_dttm": None,
"to_dttm": None,
"groupby": ["user", "expr"],
"metrics": [
{
"expressionType": AdhocMetricExpressionType.SQL,
"sqlExpression": "(SELECT (SELECT * from birth_names) "
"from test_validate_adhoc_sql)",
"label": "adhoc_metrics",
}
],
"is_timeseries": False,
"filter": [],
}

table = SqlaTable(
table_name="test_validate_adhoc_sql", database=get_example_database()
)
db.session.commit()

with pytest.raises(SupersetSecurityException):
table.get_sqla_query(**base_query_obj)
# Cleanup
db.session.delete(table)
db.session.commit()

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_where_operators(self):
filters: Tuple[FilterTestCase, ...] = (
Expand Down
2 changes: 2 additions & 0 deletions tests/unit_tests/sql_parse_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1208,6 +1208,8 @@ def test_sqlparse_issue_652():
("SELECT * FROM (SELECT 1 AS foo, 2 AS bar) ORDER BY foo ASC, bar", False),
("SELECT * FROM other_table", True),
("extract(HOUR from from_unixtime(hour_ts)", False),
("(SELECT * FROM table)", True),
("(SELECT COUNT(DISTINCT name) from birth_names)", True),
],
)
def test_has_table_query(sql: str, expected: bool) -> None:
Expand Down