From 5262d9c44aa4305f15656ab77cd896734c0ee843 Mon Sep 17 00:00:00 2001 From: justin-park Date: Fri, 27 Jan 2023 16:21:54 -0800 Subject: [PATCH 1/5] fix(explore): unable to update linked charts --- superset/charts/commands/update.py | 2 +- superset/dao/base.py | 14 ++++-- tests/integration_tests/charts/api_tests.py | 55 +++++++++++++++++++++ tests/unit_tests/datasets/dao/dao_tests.py | 26 ++++++++++ 4 files changed, 92 insertions(+), 5 deletions(-) diff --git a/superset/charts/commands/update.py b/superset/charts/commands/update.py index e613222b36f7..89e8f5de9c3b 100644 --- a/superset/charts/commands/update.py +++ b/superset/charts/commands/update.py @@ -105,7 +105,7 @@ def validate(self) -> None: # Validate/Populate dashboards only if it's a list if dashboard_ids is not None: - dashboards = DashboardDAO.find_by_ids(dashboard_ids) + dashboards = DashboardDAO.find_by_ids(dashboard_ids, skip_base_filter=True) if len(dashboards) != len(dashboard_ids): exceptions.append(DashboardsNotFoundValidationError()) self._properties["dashboards"] = dashboards diff --git a/superset/dao/base.py b/superset/dao/base.py index c6890e53a5ce..126238f66132 100644 --- a/superset/dao/base.py +++ b/superset/dao/base.py @@ -73,16 +73,22 @@ def find_by_id( return None @classmethod - def find_by_ids(cls, model_ids: Union[List[str], List[int]]) -> List[Model]: + def find_by_ids( + cls, + model_ids: Union[List[str], List[int]], + session: Session = None, + skip_base_filter: bool = False, + ) -> List[Model]: """ Find a List of models by a list of ids, if defined applies `base_filter` """ id_col = getattr(cls.model_cls, cls.id_column_name, None) if id_col is None: return [] - query = db.session.query(cls.model_cls).filter(id_col.in_(model_ids)) - if cls.base_filter: - data_model = SQLAInterface(cls.model_cls, db.session) + session = session or db.session + query = session.query(cls.model_cls).filter(id_col.in_(model_ids)) + if cls.base_filter and not skip_base_filter: + data_model = SQLAInterface(cls.model_cls, session) query = cls.base_filter( # pylint: disable=not-callable cls.id_column_name, data_model ).apply(query, None) diff --git a/tests/integration_tests/charts/api_tests.py b/tests/integration_tests/charts/api_tests.py index 3d8a4695f4eb..965a9c137ba8 100644 --- a/tests/integration_tests/charts/api_tests.py +++ b/tests/integration_tests/charts/api_tests.py @@ -692,6 +692,61 @@ def test_update_chart_not_owned(self): db.session.delete(user_alpha2) db.session.commit() + def test_update_chart_linked_with_not_owned_dashboard(self): + """ + Chart API: Test update chart which is linked to not owned dashboard + """ + user_alpha1 = self.create_user( + "alpha1", "password", "Alpha", email="alpha1@superset.org" + ) + user_alpha2 = self.create_user( + "alpha2", "password", "Alpha", email="alpha2@superset.org" + ) + chart = self.insert_chart("title", [user_alpha1.id], 1) + + original_dashboard = Dashboard() + original_dashboard.dashboard_title = "Original Dashboard" + original_dashboard.slug = "slug" + original_dashboard.owners = [user_alpha1] + original_dashboard.slices = [chart] + original_dashboard.published = False + db.session.add(original_dashboard) + + new_dashboard = Dashboard() + new_dashboard.dashboard_title = "Cloned Dashboard" + new_dashboard.slug = "new_slug" + new_dashboard.owners = [user_alpha2] + new_dashboard.slices = [chart] + new_dashboard.published = False + db.session.add(new_dashboard) + + self.login(username="alpha1", password="password") + chart_data_with_invalid_dashboard = { + "slice_name": "title1_changed", + "dashboards": [original_dashboard.id, 0], + } + chart_data = { + "slice_name": "title1_changed", + "dashboards": [original_dashboard.id, new_dashboard.id], + } + uri = f"api/v1/chart/{chart.id}" + + rv = self.put_assert_metric(uri, chart_data_with_invalid_dashboard, "put") + self.assertEqual(rv.status_code, 422) + response = json.loads(rv.data.decode("utf-8")) + expected_response = {"message": {"dashboards": ["Dashboards do not exist"]}} + self.assertEqual(response, expected_response) + + rv = self.put_assert_metric(uri, chart_data, "put") + self.assertEqual(rv.status_code, 200) + + db.session.delete(chart) + db.session.delete(original_dashboard) + db.session.delete(new_dashboard) + db.session.delete(user_alpha1) + db.session.delete(user_alpha2) + db.session.commit() + def test_update_chart_validate_datasource(self): """ Chart API: Test update validate datasource diff --git a/tests/unit_tests/datasets/dao/dao_tests.py b/tests/unit_tests/datasets/dao/dao_tests.py index 31aa9f27d085..4bfad02d6e30 100644 --- a/tests/unit_tests/datasets/dao/dao_tests.py +++ b/tests/unit_tests/datasets/dao/dao_tests.py @@ -71,3 +71,29 @@ def test_datasource_find_by_id_skip_base_filter_not_found( skip_base_filter=True, ) assert result is None + +def test_datasource_find_by_ids_skip_base_filter(session_with_data: Session) -> None: + from superset.connectors.sqla.models import SqlaTable + from superset.datasets.dao import DatasetDAO + + result = DatasetDAO.find_by_ids( + [1, 125326326], + session=session_with_data, + skip_base_filter=True, + ) + + assert result + assert [1] == map(lambda x: x.id, result) + assert ["my_sqla_table"] == map(lambda x: x.table_name, result) + assert isinstance(result[0], SqlaTable) + +def test_datasource_find_by_ids_skip_base_filter(session_with_data: Session) -> None: + from superset.datasets.dao import DatasetDAO + + result = DatasetDAO.find_by_ids( + [125326326, 125326326125326326], + session=session_with_data, + skip_base_filter=True, + ) + + assert len(result) == 0 From a958bbae2f5305ecb13bafc3d294991f9ea3414c Mon Sep 17 00:00:00 2001 From: justin-park Date: Fri, 27 Jan 2023 17:58:45 -0800 Subject: [PATCH 2/5] update test name --- tests/unit_tests/datasets/dao/dao_tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit_tests/datasets/dao/dao_tests.py b/tests/unit_tests/datasets/dao/dao_tests.py index 4bfad02d6e30..4a2f1c544e0b 100644 --- a/tests/unit_tests/datasets/dao/dao_tests.py +++ b/tests/unit_tests/datasets/dao/dao_tests.py @@ -87,7 +87,7 @@ def test_datasource_find_by_ids_skip_base_filter(session_with_data: Session) -> assert ["my_sqla_table"] == map(lambda x: x.table_name, result) assert isinstance(result[0], SqlaTable) -def test_datasource_find_by_ids_skip_base_filter(session_with_data: Session) -> None: +def test_datasource_find_by_ids_skip_base_filter_not_found(session_with_data: Session) -> None: from superset.datasets.dao import DatasetDAO result = DatasetDAO.find_by_ids( From 8dc57df5b9cbc59de2aa5fdb5267cef70193b786 Mon Sep 17 00:00:00 2001 From: justin-park Date: Fri, 27 Jan 2023 20:16:07 -0800 Subject: [PATCH 3/5] missing list --- tests/unit_tests/datasets/dao/dao_tests.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit_tests/datasets/dao/dao_tests.py b/tests/unit_tests/datasets/dao/dao_tests.py index 4a2f1c544e0b..10115ba9374d 100644 --- a/tests/unit_tests/datasets/dao/dao_tests.py +++ b/tests/unit_tests/datasets/dao/dao_tests.py @@ -83,8 +83,8 @@ def test_datasource_find_by_ids_skip_base_filter(session_with_data: Session) -> ) assert result - assert [1] == map(lambda x: x.id, result) - assert ["my_sqla_table"] == map(lambda x: x.table_name, result) + assert [1] == list(map(lambda x: x.id, result)) + assert ["my_sqla_table"] == list(map(lambda x: x.table_name, result)) assert isinstance(result[0], SqlaTable) def test_datasource_find_by_ids_skip_base_filter_not_found(session_with_data: Session) -> None: From a62d7f35ca18b5025e426d97d5a17b30536221c9 Mon Sep 17 00:00:00 2001 From: justin-park Date: Fri, 27 Jan 2023 20:49:55 -0800 Subject: [PATCH 4/5] black codestyle --- superset/charts/commands/update.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/superset/charts/commands/update.py b/superset/charts/commands/update.py index 89e8f5de9c3b..042c85a930f9 100644 --- a/superset/charts/commands/update.py +++ b/superset/charts/commands/update.py @@ -105,7 +105,10 @@ def validate(self) -> None: # Validate/Populate dashboards only if it's a list if dashboard_ids is not None: - dashboards = DashboardDAO.find_by_ids(dashboard_ids, skip_base_filter=True) + dashboards = DashboardDAO.find_by_ids( + dashboard_ids, + skip_base_filter=True, + ) if len(dashboards) != len(dashboard_ids): exceptions.append(DashboardsNotFoundValidationError()) self._properties["dashboards"] = dashboards From 6d14cc20430456cab21c3476afd0b1749e8e60f7 Mon Sep 17 00:00:00 2001 From: justin-park Date: Sat, 28 Jan 2023 21:34:33 -0800 Subject: [PATCH 5/5] code styling by black --- tests/unit_tests/datasets/dao/dao_tests.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/unit_tests/datasets/dao/dao_tests.py b/tests/unit_tests/datasets/dao/dao_tests.py index 10115ba9374d..350425d08e89 100644 --- a/tests/unit_tests/datasets/dao/dao_tests.py +++ b/tests/unit_tests/datasets/dao/dao_tests.py @@ -72,6 +72,7 @@ def test_datasource_find_by_id_skip_base_filter_not_found( ) assert result is None + def test_datasource_find_by_ids_skip_base_filter(session_with_data: Session) -> None: from superset.connectors.sqla.models import SqlaTable from superset.datasets.dao import DatasetDAO @@ -87,7 +88,10 @@ def test_datasource_find_by_ids_skip_base_filter(session_with_data: Session) -> assert ["my_sqla_table"] == list(map(lambda x: x.table_name, result)) assert isinstance(result[0], SqlaTable) -def test_datasource_find_by_ids_skip_base_filter_not_found(session_with_data: Session) -> None: + +def test_datasource_find_by_ids_skip_base_filter_not_found( + session_with_data: Session, +) -> None: from superset.datasets.dao import DatasetDAO result = DatasetDAO.find_by_ids(