diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py index 9237a1f715b0..e9c6029e7720 100644 --- a/superset/connectors/druid/models.py +++ b/superset/connectors/druid/models.py @@ -44,7 +44,9 @@ Text, UniqueConstraint, ) +from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import backref, relationship, Session +from sqlalchemy.sql import expression from sqlalchemy_utils import EncryptedType from superset import conf, db, is_feature_enabled, security_manager @@ -280,12 +282,16 @@ def refresh( datasource.refresh_metrics() session.commit() - @property + @hybrid_property def perm(self) -> str: - return "[{obj.cluster_name}].(id:{obj.id})".format(obj=self) + return f"[{self.cluster_name}].(id:{self.id})" + + @perm.expression # type: ignore + def perm(cls) -> str: # pylint: disable=no-self-argument + return "[" + cls.cluster_name + "].(id:" + expression.cast(cls.id, String) + ")" def get_perm(self) -> str: - return self.perm + return self.perm # type: ignore @property def name(self) -> str: diff --git a/superset/migrations/versions/a72cb0ebeb22_deprecate_dbs_perm_column.py b/superset/migrations/versions/a72cb0ebeb22_deprecate_dbs_perm_column.py new file mode 100644 index 000000000000..4e39b673128c --- /dev/null +++ b/superset/migrations/versions/a72cb0ebeb22_deprecate_dbs_perm_column.py @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""deprecate dbs.perm column + +Revision ID: a72cb0ebeb22 +Revises: 743a117f0d98 +Create Date: 2020-06-21 19:50:51.630917 +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "a72cb0ebeb22" +down_revision = "743a117f0d98" + + +def upgrade(): + with op.batch_alter_table("dbs") as batch_op: + batch_op.drop_column("perm") + + +def downgrade(): + with op.batch_alter_table("dbs") as batch_op: + batch_op.add_column(sa.Column("perm", sa.String(1000), nullable=True)) diff --git a/superset/models/core.py b/superset/models/core.py index 562e9c523d81..42c43453a025 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -45,10 +45,11 @@ from sqlalchemy.engine import Dialect, Engine, url from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.url import make_url, URL +from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import relationship from sqlalchemy.pool import NullPool from sqlalchemy.schema import UniqueConstraint -from sqlalchemy.sql import Select +from sqlalchemy.sql import expression, Select from sqlalchemy_utils import EncryptedType from superset import app, db_engine_specs, is_feature_enabled, security_manager @@ -138,7 +139,6 @@ class Database( ), ) encrypted_extra = Column(EncryptedType(Text, config["SECRET_KEY"]), nullable=True) - perm = Column(String(1000)) impersonate_user = Column(Boolean, default=False) server_cert = Column(EncryptedType(Text, config["SECRET_KEY"]), nullable=True) export_fields = [ @@ -641,9 +641,19 @@ def sqlalchemy_uri_decrypted(self) -> str: def sql_url(self) -> str: return f"/superset/sql/{self.id}/" - def get_perm(self) -> str: + @hybrid_property + def perm(self) -> str: return f"[{self.database_name}].(id:{self.id})" + @perm.expression # type: ignore + def perm(cls) -> str: # pylint: disable=no-self-argument + return ( + "[" + cls.database_name + "].(id:" + expression.cast(cls.id, String) + ")" + ) + + def get_perm(self) -> str: + return self.perm # type: ignore + def has_table(self, table: Table) -> bool: engine = self.get_sqla_engine() return engine.has_table(table.table_name, table.schema or None) diff --git a/superset/security/manager.py b/superset/security/manager.py index 9a001cfa87d3..3d03ad058bc3 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -247,7 +247,7 @@ def can_access_database(self, database: Union["Database", "DruidCluster"]) -> bo return ( self.can_access_all_datasources() or self.can_access_all_databases() - or self.can_access("database_access", database.perm) + or self.can_access("database_access", database.perm) # type: ignore ) def can_access_schema(self, datasource: "BaseDatasource") -> bool: diff --git a/tests/security_tests.py b/tests/security_tests.py index a36a49998aad..b4279827cd5a 100644 --- a/tests/security_tests.py +++ b/tests/security_tests.py @@ -352,6 +352,53 @@ def test_set_perm_database(self): session.delete(stored_db) session.commit() + def test_hybrid_perm_druid_cluster(self): + cluster = DruidCluster(cluster_name="tmp_druid_cluster3") + db.session.add(cluster) + + id_ = ( + db.session.query(DruidCluster.id) + .filter_by(cluster_name="tmp_druid_cluster3") + .scalar() + ) + + record = ( + db.session.query(DruidCluster) + .filter_by(perm=f"[tmp_druid_cluster3].(id:{id_})") + .one() + ) + + self.assertEquals(record.get_perm(), record.perm) + self.assertEquals(record.id, id_) + self.assertEquals(record.cluster_name, "tmp_druid_cluster3") + db.session.delete(cluster) + db.session.commit() + + def test_hybrid_perm_database(self): + database = Database( + database_name="tmp_database3", sqlalchemy_uri="sqlite://test" + ) + + db.session.add(database) + + id_ = ( + db.session.query(Database.id) + .filter_by(database_name="tmp_database3") + .scalar() + ) + + record = ( + db.session.query(Database) + .filter_by(perm=f"[tmp_database3].(id:{id_})") + .one() + ) + + self.assertEquals(record.get_perm(), record.perm) + self.assertEquals(record.id, id_) + self.assertEquals(record.database_name, "tmp_database3") + db.session.delete(database) + db.session.commit() + def test_set_perm_slice(self): session = db.session database = Database(