From f27cece69a28edc197f2211eb78249578fc0d3ca Mon Sep 17 00:00:00 2001 From: Ville Brofeldt <33317356+villebro@users.noreply.github.com> Date: Mon, 22 May 2023 13:35:58 +0300 Subject: [PATCH] fix(permalink): migrate to marshmallow codec (#24166) --- superset/dashboards/permalink/api.py | 6 +- .../dashboards/permalink/commands/base.py | 9 +- .../dashboards/permalink/commands/create.py | 3 + superset/dashboards/permalink/commands/get.py | 7 +- superset/dashboards/permalink/schemas.py | 11 +- superset/explore/permalink/api.py | 6 +- superset/explore/permalink/commands/base.py | 9 +- superset/explore/permalink/commands/create.py | 3 + superset/explore/permalink/commands/get.py | 7 +- superset/explore/permalink/schemas.py | 26 +++- superset/key_value/exceptions.py | 12 ++ superset/key_value/types.py | 36 +++++- .../explore/permalink/api_tests.py | 16 ++- tests/unit_tests/key_value/codec_test.py | 122 ++++++++++++++++++ 14 files changed, 251 insertions(+), 22 deletions(-) create mode 100644 tests/unit_tests/key_value/codec_test.py diff --git a/superset/dashboards/permalink/api.py b/superset/dashboards/permalink/api.py index a8664f0ddd80..d9211df2aa00 100644 --- a/superset/dashboards/permalink/api.py +++ b/superset/dashboards/permalink/api.py @@ -30,7 +30,7 @@ ) from superset.dashboards.permalink.commands.get import GetDashboardPermalinkCommand from superset.dashboards.permalink.exceptions import DashboardPermalinkInvalidStateError -from superset.dashboards.permalink.schemas import DashboardPermalinkPostSchema +from superset.dashboards.permalink.schemas import DashboardPermalinkStateSchema from superset.extensions import event_logger from superset.key_value.exceptions import KeyValueAccessDeniedError from superset.views.base_api import BaseSupersetApi, requires_json @@ -39,13 +39,13 @@ class DashboardPermalinkRestApi(BaseSupersetApi): - add_model_schema = DashboardPermalinkPostSchema() + add_model_schema = DashboardPermalinkStateSchema() method_permission_name = MODEL_API_RW_METHOD_PERMISSION_MAP allow_browser_login = True class_permission_name = "DashboardPermalinkRestApi" resource_name = "dashboard" openapi_spec_tag = "Dashboard Permanent Link" - openapi_spec_component_schemas = (DashboardPermalinkPostSchema,) + openapi_spec_component_schemas = (DashboardPermalinkStateSchema,) @expose("//permalink", methods=["POST"]) @protect() diff --git a/superset/dashboards/permalink/commands/base.py b/superset/dashboards/permalink/commands/base.py index 82e24264ca92..4bfb78ea26e4 100644 --- a/superset/dashboards/permalink/commands/base.py +++ b/superset/dashboards/permalink/commands/base.py @@ -17,13 +17,18 @@ from abc import ABC from superset.commands.base import BaseCommand +from superset.dashboards.permalink.schemas import DashboardPermalinkSchema from superset.key_value.shared_entries import get_permalink_salt -from superset.key_value.types import JsonKeyValueCodec, KeyValueResource, SharedKey +from superset.key_value.types import ( + KeyValueResource, + MarshmallowKeyValueCodec, + SharedKey, +) class BaseDashboardPermalinkCommand(BaseCommand, ABC): resource = KeyValueResource.DASHBOARD_PERMALINK - codec = JsonKeyValueCodec() + codec = MarshmallowKeyValueCodec(DashboardPermalinkSchema()) @property def salt(self) -> str: diff --git a/superset/dashboards/permalink/commands/create.py b/superset/dashboards/permalink/commands/create.py index 2b6151fbb219..048704107001 100644 --- a/superset/dashboards/permalink/commands/create.py +++ b/superset/dashboards/permalink/commands/create.py @@ -23,6 +23,7 @@ from superset.dashboards.permalink.exceptions import DashboardPermalinkCreateFailedError from superset.dashboards.permalink.types import DashboardPermalinkState from superset.key_value.commands.upsert import UpsertKeyValueCommand +from superset.key_value.exceptions import KeyValueCodecEncodeException from superset.key_value.utils import encode_permalink_key, get_deterministic_uuid from superset.utils.core import get_user_id @@ -62,6 +63,8 @@ def run(self) -> str: ).run() assert key.id # for type checks return encode_permalink_key(key=key.id, salt=self.salt) + except KeyValueCodecEncodeException as ex: + raise DashboardPermalinkCreateFailedError(str(ex)) from ex except SQLAlchemyError as ex: logger.exception("Error running create command") raise DashboardPermalinkCreateFailedError() from ex diff --git a/superset/dashboards/permalink/commands/get.py b/superset/dashboards/permalink/commands/get.py index 4206263a37fe..da54ae0b66e8 100644 --- a/superset/dashboards/permalink/commands/get.py +++ b/superset/dashboards/permalink/commands/get.py @@ -25,7 +25,11 @@ from superset.dashboards.permalink.exceptions import DashboardPermalinkGetFailedError from superset.dashboards.permalink.types import DashboardPermalinkValue from superset.key_value.commands.get import GetKeyValueCommand -from superset.key_value.exceptions import KeyValueGetFailedError, KeyValueParseKeyError +from superset.key_value.exceptions import ( + KeyValueCodecDecodeException, + KeyValueGetFailedError, + KeyValueParseKeyError, +) from superset.key_value.utils import decode_permalink_id logger = logging.getLogger(__name__) @@ -51,6 +55,7 @@ def run(self) -> Optional[DashboardPermalinkValue]: return None except ( DashboardNotFoundError, + KeyValueCodecDecodeException, KeyValueGetFailedError, KeyValueParseKeyError, ) as ex: diff --git a/superset/dashboards/permalink/schemas.py b/superset/dashboards/permalink/schemas.py index ce222d7ed62c..acbfec5a1760 100644 --- a/superset/dashboards/permalink/schemas.py +++ b/superset/dashboards/permalink/schemas.py @@ -17,7 +17,7 @@ from marshmallow import fields, Schema -class DashboardPermalinkPostSchema(Schema): +class DashboardPermalinkStateSchema(Schema): dataMask = fields.Dict( required=False, allow_none=True, @@ -48,3 +48,12 @@ class DashboardPermalinkPostSchema(Schema): allow_none=True, description="Optional anchor link added to url hash", ) + + +class DashboardPermalinkSchema(Schema): + dashboardId = fields.String( + required=True, + allow_none=False, + metadata={"description": "The id or slug of the dasbhoard"}, + ) + state = fields.Nested(DashboardPermalinkStateSchema()) diff --git a/superset/explore/permalink/api.py b/superset/explore/permalink/api.py index 88e819aa2b0c..2a8ff1998dd7 100644 --- a/superset/explore/permalink/api.py +++ b/superset/explore/permalink/api.py @@ -32,7 +32,7 @@ from superset.explore.permalink.commands.create import CreateExplorePermalinkCommand from superset.explore.permalink.commands.get import GetExplorePermalinkCommand from superset.explore.permalink.exceptions import ExplorePermalinkInvalidStateError -from superset.explore.permalink.schemas import ExplorePermalinkPostSchema +from superset.explore.permalink.schemas import ExplorePermalinkStateSchema from superset.extensions import event_logger from superset.key_value.exceptions import KeyValueAccessDeniedError from superset.views.base_api import BaseSupersetApi, requires_json, statsd_metrics @@ -41,13 +41,13 @@ class ExplorePermalinkRestApi(BaseSupersetApi): - add_model_schema = ExplorePermalinkPostSchema() + add_model_schema = ExplorePermalinkStateSchema() method_permission_name = MODEL_API_RW_METHOD_PERMISSION_MAP allow_browser_login = True class_permission_name = "ExplorePermalinkRestApi" resource_name = "explore" openapi_spec_tag = "Explore Permanent Link" - openapi_spec_component_schemas = (ExplorePermalinkPostSchema,) + openapi_spec_component_schemas = (ExplorePermalinkStateSchema,) @expose("/permalink", methods=["POST"]) @protect() diff --git a/superset/explore/permalink/commands/base.py b/superset/explore/permalink/commands/base.py index a87183b7e9ed..0b7cfbb8ec42 100644 --- a/superset/explore/permalink/commands/base.py +++ b/superset/explore/permalink/commands/base.py @@ -17,13 +17,18 @@ from abc import ABC from superset.commands.base import BaseCommand +from superset.explore.permalink.schemas import ExplorePermalinkSchema from superset.key_value.shared_entries import get_permalink_salt -from superset.key_value.types import JsonKeyValueCodec, KeyValueResource, SharedKey +from superset.key_value.types import ( + KeyValueResource, + MarshmallowKeyValueCodec, + SharedKey, +) class BaseExplorePermalinkCommand(BaseCommand, ABC): resource: KeyValueResource = KeyValueResource.EXPLORE_PERMALINK - codec = JsonKeyValueCodec() + codec = MarshmallowKeyValueCodec(ExplorePermalinkSchema()) @property def salt(self) -> str: diff --git a/superset/explore/permalink/commands/create.py b/superset/explore/permalink/commands/create.py index 21c0f4e42f82..90e64f6df726 100644 --- a/superset/explore/permalink/commands/create.py +++ b/superset/explore/permalink/commands/create.py @@ -23,6 +23,7 @@ from superset.explore.permalink.exceptions import ExplorePermalinkCreateFailedError from superset.explore.utils import check_access as check_chart_access from superset.key_value.commands.create import CreateKeyValueCommand +from superset.key_value.exceptions import KeyValueCodecEncodeException from superset.key_value.utils import encode_permalink_key from superset.utils.core import DatasourceType @@ -58,6 +59,8 @@ def run(self) -> str: if key.id is None: raise ExplorePermalinkCreateFailedError("Unexpected missing key id") return encode_permalink_key(key=key.id, salt=self.salt) + except KeyValueCodecEncodeException as ex: + raise ExplorePermalinkCreateFailedError(str(ex)) from ex except SQLAlchemyError as ex: logger.exception("Error running create command") raise ExplorePermalinkCreateFailedError() from ex diff --git a/superset/explore/permalink/commands/get.py b/superset/explore/permalink/commands/get.py index 4823117ecef5..1aa093b38058 100644 --- a/superset/explore/permalink/commands/get.py +++ b/superset/explore/permalink/commands/get.py @@ -25,7 +25,11 @@ from superset.explore.permalink.types import ExplorePermalinkValue from superset.explore.utils import check_access as check_chart_access from superset.key_value.commands.get import GetKeyValueCommand -from superset.key_value.exceptions import KeyValueGetFailedError, KeyValueParseKeyError +from superset.key_value.exceptions import ( + KeyValueCodecDecodeException, + KeyValueGetFailedError, + KeyValueParseKeyError, +) from superset.key_value.utils import decode_permalink_id from superset.utils.core import DatasourceType @@ -59,6 +63,7 @@ def run(self) -> Optional[ExplorePermalinkValue]: return None except ( DatasetNotFoundError, + KeyValueCodecDecodeException, KeyValueGetFailedError, KeyValueParseKeyError, ) as ex: diff --git a/superset/explore/permalink/schemas.py b/superset/explore/permalink/schemas.py index e1f9d069b853..8b1ae129e802 100644 --- a/superset/explore/permalink/schemas.py +++ b/superset/explore/permalink/schemas.py @@ -17,7 +17,7 @@ from marshmallow import fields, Schema -class ExplorePermalinkPostSchema(Schema): +class ExplorePermalinkStateSchema(Schema): formData = fields.Dict( required=True, allow_none=False, @@ -37,3 +37,27 @@ class ExplorePermalinkPostSchema(Schema): allow_none=True, description="URL Parameters", ) + + +class ExplorePermalinkSchema(Schema): + chartId = fields.Integer( + required=False, + allow_none=True, + metadata={"description": "The id of the chart"}, + ) + datasourceType = fields.String( + required=True, + allow_none=False, + metadata={"description": "The type of the datasource"}, + ) + datasourceId = fields.Integer( + required=False, + allow_none=True, + metadata={"description": "The id of the datasource"}, + ) + datasource = fields.String( + required=False, + allow_none=True, + metadata={"description": "The fully qualified datasource reference"}, + ) + state = fields.Nested(ExplorePermalinkStateSchema()) diff --git a/superset/key_value/exceptions.py b/superset/key_value/exceptions.py index b05daf6b89e0..e16f961872c0 100644 --- a/superset/key_value/exceptions.py +++ b/superset/key_value/exceptions.py @@ -52,3 +52,15 @@ class KeyValueUpsertFailedError(UpdateFailedError): class KeyValueAccessDeniedError(ForbiddenError): message = _("You don't have permission to modify the value.") + + +class KeyValueCodecException(SupersetException): + pass + + +class KeyValueCodecEncodeException(KeyValueCodecException): + message = _("Unable to encode value") + + +class KeyValueCodecDecodeException(KeyValueCodecException): + message = _("Unable to decode value") diff --git a/superset/key_value/types.py b/superset/key_value/types.py index 07d06414f60e..fb9c31899f70 100644 --- a/superset/key_value/types.py +++ b/superset/key_value/types.py @@ -24,6 +24,13 @@ from typing import Any, Optional, TypedDict from uuid import UUID +from marshmallow import Schema, ValidationError + +from superset.key_value.exceptions import ( + KeyValueCodecDecodeException, + KeyValueCodecEncodeException, +) + @dataclass class Key: @@ -61,10 +68,16 @@ def decode(self, value: bytes) -> Any: class JsonKeyValueCodec(KeyValueCodec): def encode(self, value: dict[Any, Any]) -> bytes: - return bytes(json.dumps(value), encoding="utf-8") + try: + return bytes(json.dumps(value), encoding="utf-8") + except TypeError as ex: + raise KeyValueCodecEncodeException(str(ex)) from ex def decode(self, value: bytes) -> dict[Any, Any]: - return json.loads(value) + try: + return json.loads(value) + except TypeError as ex: + raise KeyValueCodecDecodeException(str(ex)) from ex class PickleKeyValueCodec(KeyValueCodec): @@ -73,3 +86,22 @@ def encode(self, value: dict[Any, Any]) -> bytes: def decode(self, value: bytes) -> dict[Any, Any]: return pickle.loads(value) + + +class MarshmallowKeyValueCodec(JsonKeyValueCodec): + def __init__(self, schema: Schema): + self.schema = schema + + def encode(self, value: dict[Any, Any]) -> bytes: + try: + obj = self.schema.dump(value) + return super().encode(obj) + except ValidationError as ex: + raise KeyValueCodecEncodeException(message=str(ex)) from ex + + def decode(self, value: bytes) -> dict[Any, Any]: + try: + obj = super().decode(value) + return self.schema.load(obj) + except ValidationError as ex: + raise KeyValueCodecEncodeException(message=str(ex)) from ex diff --git a/tests/integration_tests/explore/permalink/api_tests.py b/tests/integration_tests/explore/permalink/api_tests.py index 4c6a3c12ddfd..3a07bd977af3 100644 --- a/tests/integration_tests/explore/permalink/api_tests.py +++ b/tests/integration_tests/explore/permalink/api_tests.py @@ -22,8 +22,9 @@ from sqlalchemy.orm import Session from superset import db +from superset.explore.permalink.schemas import ExplorePermalinkSchema from superset.key_value.models import KeyValueEntry -from superset.key_value.types import JsonKeyValueCodec, KeyValueResource +from superset.key_value.types import KeyValueResource, MarshmallowKeyValueCodec from superset.key_value.utils import decode_permalink_id, encode_permalink_key from superset.models.slice import Slice from superset.utils.core import DatasourceType @@ -94,14 +95,17 @@ def test_get_missing_chart( chart_id = 1234 entry = KeyValueEntry( resource=KeyValueResource.EXPLORE_PERMALINK, - value=JsonKeyValueCodec().encode( + value=MarshmallowKeyValueCodec(ExplorePermalinkSchema()).encode( { "chartId": chart_id, "datasourceId": chart.datasource.id, - "datasourceType": DatasourceType.TABLE, - "formData": { - "slice_id": chart_id, - "datasource": f"{chart.datasource.id}__{chart.datasource.type}", + "datasourceType": DatasourceType.TABLE.value, + "state": { + "urlParams": [["foo", "bar"]], + "formData": { + "slice_id": chart_id, + "datasource": f"{chart.datasource.id}__{chart.datasource.type}", + }, }, } ), diff --git a/tests/unit_tests/key_value/codec_test.py b/tests/unit_tests/key_value/codec_test.py new file mode 100644 index 000000000000..1442a3a95acd --- /dev/null +++ b/tests/unit_tests/key_value/codec_test.py @@ -0,0 +1,122 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from contextlib import nullcontext +from typing import Any + +import pytest +from marshmallow import Schema + +from superset.dashboards.permalink.schemas import DashboardPermalinkSchema +from superset.key_value.exceptions import KeyValueCodecEncodeException +from superset.key_value.types import ( + JsonKeyValueCodec, + MarshmallowKeyValueCodec, + PickleKeyValueCodec, +) + + +@pytest.mark.parametrize( + "input_,expected_result", + [ + ( + {"foo": "bar"}, + {"foo": "bar"}, + ), + ( + {"foo": (1, 2, 3)}, + {"foo": [1, 2, 3]}, + ), + ( + {1, 2, 3}, + KeyValueCodecEncodeException(), + ), + ( + object(), + KeyValueCodecEncodeException(), + ), + ], +) +def test_json_codec(input_: Any, expected_result: Any): + cm = ( + pytest.raises(type(expected_result)) + if isinstance(expected_result, Exception) + else nullcontext() + ) + with cm: + codec = JsonKeyValueCodec() + encoded_value = codec.encode(input_) + assert expected_result == codec.decode(encoded_value) + + +@pytest.mark.parametrize( + "schema,input_,expected_result", + [ + ( + DashboardPermalinkSchema(), + { + "dashboardId": "1", + "state": { + "urlParams": [["foo", "bar"], ["foo", "baz"]], + }, + }, + { + "dashboardId": "1", + "state": { + "urlParams": [("foo", "bar"), ("foo", "baz")], + }, + }, + ), + ( + DashboardPermalinkSchema(), + {"foo": "bar"}, + KeyValueCodecEncodeException(), + ), + ], +) +def test_marshmallow_codec(schema: Schema, input_: Any, expected_result: Any): + cm = ( + pytest.raises(type(expected_result)) + if isinstance(expected_result, Exception) + else nullcontext() + ) + with cm: + codec = MarshmallowKeyValueCodec(schema) + encoded_value = codec.encode(input_) + assert expected_result == codec.decode(encoded_value) + + +@pytest.mark.parametrize( + "input_,expected_result", + [ + ( + {1, 2, 3}, + {1, 2, 3}, + ), + ( + {"foo": 1, "bar": {1: (1, 2, 3)}, "baz": {1, 2, 3}}, + { + "foo": 1, + "bar": {1: (1, 2, 3)}, + "baz": {1, 2, 3}, + }, + ), + ], +) +def test_pickle_codec(input_: Any, expected_result: Any): + codec = PickleKeyValueCodec() + encoded_value = codec.encode(input_) + assert expected_result == codec.decode(encoded_value)