Skip to content

Commit

Permalink
chore: Migrate /superset/estimate_query_cost/<database_id>/<schema>/ …
Browse files Browse the repository at this point in the history
…to API v1
  • Loading branch information
diegomedina248 committed Jan 30, 2023
1 parent b94052e commit f929d56
Show file tree
Hide file tree
Showing 8 changed files with 325 additions and 15 deletions.
20 changes: 11 additions & 9 deletions superset-frontend/src/SqlLab/actions/sqlLab.js
Original file line number Diff line number Diff line change
Expand Up @@ -184,18 +184,20 @@ export function estimateQueryCost(queryEditor) {
const { dbId, schema, sql, selectedText, templateParams } =
getUpToDateQuery(getState(), queryEditor);
const requestSql = selectedText || sql;
const endpoint =
schema === null
? `/superset/estimate_query_cost/${dbId}/`
: `/superset/estimate_query_cost/${dbId}/${schema}/`;

const postPayload = {
database_id: dbId,
schema,
sql: requestSql,
template_params: JSON.parse(templateParams || '{}'),
};

return Promise.all([
dispatch({ type: COST_ESTIMATE_STARTED, query: queryEditor }),
SupersetClient.post({
endpoint,
postPayload: {
sql: requestSql,
templateParams: JSON.parse(templateParams || '{}'),
},
endpoint: '/api/v1/sqllab/estimate/',
body: JSON.stringify(postPayload),
headers: { 'Content-Type': 'application/json' },
})
.then(({ json }) =>
dispatch({ type: COST_ESTIMATE_RETURNED, query: queryEditor, json }),
Expand Down
2 changes: 1 addition & 1 deletion superset-frontend/src/SqlLab/reducers/sqlLab.js
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ export default function sqlLabReducer(state = {}, action) {
...state.queryCostEstimates,
[action.query.id]: {
completed: true,
cost: action.json,
cost: action.json.result,
error: null,
},
},
Expand Down
1 change: 1 addition & 0 deletions superset/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ class RouteMethod: # pylint: disable=too-few-public-methods
"delete_ssh_tunnel": "write",
"get_updated_since": "read",
"stop_query": "read",
"estimate_query_cost": "read",
}

EXTRA_FORM_DATA_APPEND_KEYS = {
Expand Down
59 changes: 58 additions & 1 deletion superset/sqllab/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,21 @@
from typing import Any, cast, Dict, Optional

import simplejson as json
from flask import request
from flask import request, Response
from flask_appbuilder.api import expose, protect, rison
from flask_appbuilder.models.sqla.interface import SQLAInterface
from marshmallow import ValidationError

from superset import app, is_feature_enabled
from superset.constants import MODEL_API_RW_METHOD_PERMISSION_MAP
from superset.databases.dao import DatabaseDAO
from superset.extensions import event_logger
from superset.jinja_context import get_template_processor
from superset.models.sql_lab import Query
from superset.queries.dao import QueryDAO
from superset.sql_lab import get_sql_results
from superset.sqllab.command_status import SqlJsonExecutionStatus
from superset.sqllab.commands.estimate import QueryEstimationCommand
from superset.sqllab.commands.execute import CommandResult, ExecuteSqlCommand
from superset.sqllab.commands.results import SqlExecutionResultsCommand
from superset.sqllab.exceptions import (
Expand All @@ -40,6 +42,7 @@
from superset.sqllab.execution_context_convertor import ExecutionContextConvertor
from superset.sqllab.query_render import SqlQueryRenderImpl
from superset.sqllab.schemas import (
EstimateQueryCostSchema,
ExecutePayloadSchema,
QueryExecutionResponseSchema,
sql_lab_get_results_schema,
Expand Down Expand Up @@ -68,6 +71,8 @@ class SqlLabRestApi(BaseSupersetApi):

class_permission_name = "Query"

method_permission_name = MODEL_API_RW_METHOD_PERMISSION_MAP
estimate_model_schema = EstimateQueryCostSchema()
execute_model_schema = ExecutePayloadSchema()

apispec_parameter_schemas = {
Expand All @@ -79,6 +84,58 @@ class SqlLabRestApi(BaseSupersetApi):
QueryExecutionResponseSchema,
)

@expose("/estimate/", methods=["POST"])
@protect()
@statsd_metrics
@event_logger.log_this_with_context(
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}"
f".estimate_query_cost",
log_to_statsd=False,
)
@requires_json
def estimate_query_cost(self, **kwargs: Any) -> Response:
"""Estimates the SQL query execution cost
---
post:
summary: >-
Estimates the SQL query execution cost
requestBody:
description: SQL query and params
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/EstimateQueryCostSchema'
responses:
200:
description: Query estimation result
content:
application/json:
schema:
type: object
properties:
result:
type: object
400:
$ref: '#/components/responses/400'
401:
$ref: '#/components/responses/401'
403:
$ref: '#/components/responses/403'
404:
$ref: '#/components/responses/404'
500:
$ref: '#/components/responses/500'
"""
try:
model = self.estimate_model_schema.load(request.json)
except ValidationError as error:
return self.response_400(message=error.messages)

command = QueryEstimationCommand(model)
result = command.run()
return self.response(200, result=result)

@expose("/results/")
@protect()
@statsd_metrics
Expand Down
106 changes: 106 additions & 0 deletions superset/sqllab/commands/estimate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=too-few-public-methods, too-many-arguments
from __future__ import annotations

import logging
from typing import Any, Dict, List

from flask_babel import gettext as __, lazy_gettext as _

from superset import app, db
from superset.commands.base import BaseCommand
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import SupersetErrorException, SupersetTimeoutException
from superset.jinja_context import get_template_processor
from superset.models.core import Database
from superset.sqllab.schemas import EstimateQueryCostSchema
from superset.utils import core as utils

config = app.config
SQLLAB_QUERY_COST_ESTIMATE_TIMEOUT = config["SQLLAB_QUERY_COST_ESTIMATE_TIMEOUT"]
stats_logger = config["STATS_LOGGER"]

logger = logging.getLogger(__name__)


class QueryEstimationCommand(BaseCommand):
_database_id: int
_sql: str
_template_params: Dict[str, Any]
_schema: str
_database: Database

def __init__(self, params: EstimateQueryCostSchema) -> None:
self._database_id = params.get("database_id")
self._sql = params.get("sql", "")
self._template_params = params.get("template_params", {})
self._schema = params.get("schema", "")

def validate(self) -> None:
self._database = db.session.query(Database).get(self._database_id)
if not self._database:
raise SupersetErrorException(
SupersetError(
message=__("The database could not be found"),
error_type=SupersetErrorType.RESULTS_BACKEND_ERROR,
level=ErrorLevel.ERROR,
),
status=404,
)

def run(
self,
) -> List[Dict[str, Any]]:
self.validate()

sql = self._sql
if self._template_params:
template_processor = get_template_processor(self._database)
sql = template_processor.process_template(sql, **self._template_params)

timeout = SQLLAB_QUERY_COST_ESTIMATE_TIMEOUT
timeout_msg = f"The estimation exceeded the {timeout} seconds timeout."
try:
with utils.timeout(seconds=timeout, error_message=timeout_msg):
cost = self._database.db_engine_spec.estimate_query_cost(
self._database, self._schema, sql, utils.QuerySource.SQL_LAB
)
except SupersetTimeoutException as ex:
logger.exception(ex)
raise SupersetErrorException(
SupersetError(
message=__(
"The query estimation was killed after %(sqllab_timeout)s seconds. It might "
"be too complex, or the database might be under heavy load.",
sqllab_timeout=SQLLAB_QUERY_COST_ESTIMATE_TIMEOUT,
),
error_type=SupersetErrorType.SQLLAB_TIMEOUT_ERROR,
level=ErrorLevel.ERROR,
),
status=500,
)

spec = self._database.db_engine_spec
query_cost_formatters: Dict[str, Any] = app.config[
"QUERY_COST_FORMATTERS_BY_ENGINE"
]
query_cost_formatter = query_cost_formatters.get(
spec.engine, spec.query_cost_formatter
)
cost = query_cost_formatter(cost)
return cost
7 changes: 7 additions & 0 deletions superset/sqllab/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@
}


class EstimateQueryCostSchema(Schema):
database_id = fields.Integer(required=True)
sql = fields.String(required=True)
template_params = fields.Dict(keys=fields.String())
schema = fields.String(allow_none=True)


class ExecutePayloadSchema(Schema):
database_id = fields.Integer(required=True)
sql = fields.String(required=True)
Expand Down
67 changes: 67 additions & 0 deletions tests/integration_tests/sql_lab/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,73 @@


class TestSqlLabApi(SupersetTestCase):
def test_estimate_required_params(self):
self.login()

rv = self.client.post(
"/api/v1/sqllab/estimate/",
json={},
)
failed_resp = {
"message": {
"sql": ["Missing data for required field."],
"database_id": ["Missing data for required field."],
}
}
resp_data = json.loads(rv.data.decode("utf-8"))
self.assertDictEqual(resp_data, failed_resp)
self.assertEqual(rv.status_code, 400)

data = {"sql": "SELECT 1"}
rv = self.client.post(
"/api/v1/sqllab/estimate/",
json=data,
)
failed_resp = {"message": {"database_id": ["Missing data for required field."]}}
resp_data = json.loads(rv.data.decode("utf-8"))
self.assertDictEqual(resp_data, failed_resp)
self.assertEqual(rv.status_code, 400)

data = {"database_id": 1}
rv = self.client.post(
"/api/v1/sqllab/estimate/",
json=data,
)
failed_resp = {"message": {"sql": ["Missing data for required field."]}}
resp_data = json.loads(rv.data.decode("utf-8"))
self.assertDictEqual(resp_data, failed_resp)
self.assertEqual(rv.status_code, 400)

def test_estimate_valid_request(self):
self.login()

formatter_response = [
{
"value": 100,
}
]

db_mock = mock.Mock()
db_mock.db_engine_spec = mock.Mock()
db_mock.db_engine_spec.estimate_query_cost = mock.Mock(return_value=100)
db_mock.db_engine_spec.query_cost_formatter = mock.Mock(
return_value=formatter_response
)

with mock.patch("superset.sqllab.commands.estimate.db") as mock_superset_db:
mock_superset_db.session.query().get.return_value = db_mock

data = {"database_id": 1, "sql": "SELECT 1"}
rv = self.client.post(
"/api/v1/sqllab/estimate/",
json=data,
)

success_resp = {"result": formatter_response}
resp_data = json.loads(rv.data.decode("utf-8"))
self.assertDictEqual(resp_data, success_resp)
self.assertEqual(rv.status_code, 200)

@mock.patch("superset.sqllab.commands.results.results_backend_use_msgpack", False)
def test_execute_required_params(self):
self.login()
Expand Down
Loading

0 comments on commit f929d56

Please sign in to comment.