diff --git a/docs/rest_api.rst b/docs/rest_api.rst index 4fda4b0693..7f0c6cf7f1 100644 --- a/docs/rest_api.rst +++ b/docs/rest_api.rst @@ -138,7 +138,7 @@ so data can be translated back and forth without loss or guesswork:: if 'name' in kwargs['rison']: return self.response( 200, - message="Hello {}".format(kwargs['rison']['name']) + message=f"Hello {kwargs['rison']['name']}" ) return self.response_400(message="Please send your name") @@ -238,7 +238,7 @@ validate your Rison arguments, this way you can implement a very strict API easi def greeting4(self, **kwargs): return self.response( 200, - message="Hello {}".format(kwargs['rison']['name']) + message=f"Hello {kwargs['rison']['name']}" ) Finally to properly handle all possible exceptions use the ``safe`` decorator, @@ -396,7 +396,7 @@ easily reference them:: """ return self.response( 200, - message="Hello {}".format(kwargs['rison']['name']) + message=f"Hello {kwargs['rison']['name']}" ) @@ -1015,6 +1015,33 @@ the ``show_columns`` property. This takes precedence from the *Rison* arguments: datamodel = SQLAInterface(Contact) show_columns = ['name'] +By default FAB will issue a query containing the exact fields for `show_columns`, but these are also associated with +the response object. Sometimes it's useful to distinguish between the query select columns and the response itself. +Imagine the case you want to use a `@property` to further transform the output, and that transformation implies +two model fields (concat or sum for example):: + + class ContactModelApi(ModelRestApi): + resource_name = 'contact' + datamodel = SQLAInterface(Contact) + show_columns = ['name', 'age'] + show_select_columns = ['name', 'birthday'] + + +The Model:: + + class Contact(Model): + id = Column(Integer, primary_key=True) + name = Column(String(150), unique=True, nullable=False) + ... + birthday = Column(Date, nullable=True) + ... + + @property + def age(self): + return date.today().year - self.birthday.year + +Note: The same logic is applied on `list_select_columns` + We can add fields that are python functions also, for this on the SQLAlchemy definition, let's add a new function:: @@ -1034,7 +1061,7 @@ let's add a new function:: return self.name def some_function(self): - return "Hello {}".format(self.name) + return f"Hello {self.name}" And then on the REST API:: diff --git a/flask_appbuilder/api/__init__.py b/flask_appbuilder/api/__init__.py index 937fd1ba85..4ec158cd64 100644 --- a/flask_appbuilder/api/__init__.py +++ b/flask_appbuilder/api/__init__.py @@ -3,7 +3,7 @@ import logging import re import traceback -from typing import Dict, Optional +from typing import Callable, Dict, List, Optional, Set import urllib.parse from apispec import APISpec, yaml_utils @@ -11,7 +11,7 @@ from flask import Blueprint, current_app, jsonify, make_response, request, Response from flask_babel import lazy_gettext as _ import jsonschema -from marshmallow import ValidationError +from marshmallow import Schema, ValidationError from marshmallow_sqlalchemy.fields import Related, RelatedList import prison from sqlalchemy.exc import IntegrityError @@ -214,22 +214,22 @@ class BaseApi(object): appbuilder = None blueprint = None - endpoint = None + endpoint: Optional[str] = None - version = "v1" + version: Optional[str] = "v1" """ Define the Api version for this resource/class """ - route_base = None + route_base: Optional[str] = None """ Define the route base where all methods will suffix from """ - resource_name = None + resource_name: Optional[str] = None """ Defines a custom resource name, overrides the inferred from Class name makes no sense to use it with route base """ - base_permissions = None + base_permissions: Optional[List[str]] = None """ A list of allowed base permissions:: @@ -237,16 +237,16 @@ class ExampleApi(BaseApi): base_permissions = ['can_get'] """ - class_permission_name = None + class_permission_name: Optional[str] = None """ Override class permission name default fallback to self.__class__.__name__ """ - previous_class_permission_name = None + previous_class_permission_name: Optional[str] = None """ If set security converge will replace all permissions tuples with this name by the class_permission_name or self.__class__.__name__ """ - method_permission_name = None + method_permission_name: Optional[Dict[str, str]] = None """ Override method permission names, example:: @@ -258,7 +258,7 @@ class ExampleApi(BaseApi): 'delete': 'write' } """ - previous_method_permission_name = None + previous_method_permission_name: Optional[Dict[str, str]] = None """ Use same structure as method_permission_name. If set security converge will replace all method permissions by the new ones @@ -272,7 +272,7 @@ class ExampleApi(BaseApi): """ If using flask-wtf CSRFProtect exempt the API from check """ - apispec_parameter_schemas = None + apispec_parameter_schemas: Optional[Dict[str, Dict]] = None """ Set your custom Rison parameter schemas here so that they get registered on the OpenApi spec:: @@ -377,7 +377,7 @@ class ContactModelView(ModelRestApi): The previous examples will only register the `put`, `post` and `delete` routes """ - include_route_methods = None + include_route_methods: Set[str] = None """ If defined will assume a white list setup, where all endpoints are excluded except those define on this attribute @@ -412,7 +412,7 @@ class GreetingApi(BaseApi): Use this attribute to override the tag name """ - def __init__(self): + def __init__(self) -> None: """ Initialization of base permissions based on exposed methods and actions @@ -855,72 +855,83 @@ class ModelRestApi(BaseModelApi): List Title, if not configured the default is 'List ' with pretty model name """ - show_title = "" + show_title: Optional[str] = "" """ Show Title , if not configured the default is 'Show ' with pretty model name """ - add_title = "" + add_title: Optional[str] = "" """ Add Title , if not configured the default is 'Add ' with pretty model name """ - edit_title = "" + edit_title: Optional[str] = "" """ Edit Title , if not configured the default is 'Edit ' with pretty model name """ - - list_columns = None + list_select_columns: Optional[List[str]] = None + """ + A List of column names that will be included on the SQL select. + This is useful for including all necessary columns that are referenced + by properties listed on `list_columns` without generating N+1 queries. + """ + list_columns: Optional[List[str]] = None """ A list of columns (or model's methods) to be displayed on the list view. Use it to control the order of the display """ - show_columns = None + show_select_columns: Optional[List[str]] = None + """ + A List of column names that will be included on the SQL select. + This is useful for including all necessary columns that are referenced + by properties listed on `show_columns` without generating N+1 queries. + """ + show_columns: Optional[List[str]] = None """ A list of columns (or model's methods) for the get item endpoint. Use it to control the order of the results """ - add_columns = None + add_columns: Optional[List[str]] = None """ A list of columns (or model's methods) to be allowed to post """ - edit_columns = None + edit_columns: Optional[List[str]] = None """ A list of columns (or model's methods) to be allowed to update """ - list_exclude_columns = None + list_exclude_columns: Optional[List[str]] = None """ A list of columns to exclude from the get list endpoint. By default all columns are included. """ - show_exclude_columns = None + show_exclude_columns: Optional[List[str]] = None """ A list of columns to exclude from the get item endpoint. By default all columns are included. """ - add_exclude_columns = None + add_exclude_columns: Optional[List[str]] = None """ A list of columns to exclude from the add endpoint. By default all columns are included. """ - edit_exclude_columns = None + edit_exclude_columns: Optional[List[str]] = None """ A list of columns to exclude from the edit endpoint. By default all columns are included. """ - order_columns = None + order_columns: Optional[List[str]] = None """ Allowed order columns """ page_size = 20 """ Use this property to change default page size """ - max_page_size = None + max_page_size: Optional[int] = None """ class override for the FAB_API_MAX_SIZE, use special -1 to allow for any page size """ - description_columns = None + description_columns: Optional[Dict[str, str]] = None """ Dictionary with column descriptions that will be shown on the forms:: @@ -930,8 +941,8 @@ class MyView(ModelView): description_columns = {'name':'your models name column', 'address':'the address column'} """ - validators_columns = None - """ Dictionary to add your own validators for forms """ + validators_columns: Optional[Dict[str, Callable]] = None + """ Dictionary to add your own marshmallow validators """ add_query_rel_fields = None """ @@ -973,22 +984,22 @@ class ContactModelView(ModelRestApi): 'gender': ('name', 'asc') } """ - list_model_schema = None + list_model_schema: Optional[Schema] = None """ Override to provide your own marshmallow Schema for JSON to SQLA dumps """ - add_model_schema = None + add_model_schema: Optional[Schema] = None """ Override to provide your own marshmallow Schema for JSON to SQLA dumps """ - edit_model_schema = None + edit_model_schema: Optional[Schema] = None """ Override to provide your own marshmallow Schema for JSON to SQLA dumps """ - show_model_schema = None + show_model_schema: Optional[Schema] = None """ Override to provide your own marshmallow Schema for JSON to SQLA dumps @@ -1069,7 +1080,7 @@ def _init_titles(self): self.show_title = "Show " + self._prettify_name(class_name) self.title = self.list_title - def _init_properties(self): + def _init_properties(self) -> None: """ Init Properties """ @@ -1091,6 +1102,7 @@ def _init_properties(self): for x in self.datamodel.get_user_columns_list() if x not in self.list_exclude_columns ] + self.list_select_columns = self.list_select_columns or self.list_columns self.order_columns = ( self.order_columns @@ -1101,6 +1113,8 @@ def _init_properties(self): self.show_columns = [ x for x in list_cols if x not in self.show_exclude_columns ] + self.show_select_columns = self.show_select_columns or self.show_columns + if not self.add_columns: self.add_columns = [ x for x in list_cols if x not in self.add_exclude_columns @@ -1302,7 +1316,7 @@ def get_headless(self, pk, **kwargs) -> Response: :param kwargs: Query string parameter arguments :return: HTTP Response """ - item = self.datamodel.get(pk, self._base_filters) + item = self.datamodel.get(pk, self._base_filters, self.show_select_columns) if not item: return self.response_404() @@ -1417,13 +1431,15 @@ def get_list_headless(self, **kwargs) -> Response: # handle select columns select_cols = _args.get(API_SELECT_COLUMNS_RIS_KEY, []) _pruned_select_cols = [col for col in select_cols if col in self.list_columns] + # map decorated metadata self.set_response_key_mappings( _response, self.get_list, _args, **{API_SELECT_COLUMNS_RIS_KEY: _pruned_select_cols}, ) - + # Create a response schema with the computed response columns, + # defined or requested if _pruned_select_cols: _list_model_schema = self.model2schemaconverter.convert(_pruned_select_cols) else: @@ -1441,14 +1457,13 @@ def get_list_headless(self, **kwargs) -> Response: # handle pagination page_index, page_size = self._handle_page_args(_args) # Make the query - query_select_columns = _pruned_select_cols or self.list_columns count, lst = self.datamodel.query( joined_filters, order_column, order_direction, page=page_index, page_size=page_size, - select_columns=query_select_columns, + select_columns=self.list_select_columns, ) pks = self.datamodel.get_keys(lst) _response[API_RESULT_RES_KEY] = _list_model_schema.dump(lst, many=True) diff --git a/flask_appbuilder/exceptions.py b/flask_appbuilder/exceptions.py index 99eebaaeec..7f410ab0fd 100644 --- a/flask_appbuilder/exceptions.py +++ b/flask_appbuilder/exceptions.py @@ -20,3 +20,9 @@ class InvalidOrderByColumnFABException(FABException): """Invalid order by column""" pass + + +class InterfaceQueryWithoutSession(FABException): + """You need to setup a session on the interface to perform queries""" + + pass diff --git a/flask_appbuilder/models/base.py b/flask_appbuilder/models/base.py index 1e4c19ec2b..8971da66a8 100644 --- a/flask_appbuilder/models/base.py +++ b/flask_appbuilder/models/base.py @@ -1,10 +1,11 @@ import datetime from functools import reduce import logging +from typing import Type from flask_babel import lazy_gettext -from .filters import Filters +from .filters import BaseFilterConverter, Filters try: import enum @@ -22,9 +23,7 @@ class BaseInterface(object): Sub class it to implement your own interface for some data engine. """ - obj = None - - filter_converter_class = None + filter_converter_class = Type[BaseFilterConverter] """ when sub classing override with your own custom filter converter """ """ Messages to display on CRUD Events """ diff --git a/flask_appbuilder/models/filters.py b/flask_appbuilder/models/filters.py index 36579add5a..0c11a93a76 100644 --- a/flask_appbuilder/models/filters.py +++ b/flask_appbuilder/models/filters.py @@ -1,6 +1,6 @@ import copy import logging -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Optional, Tuple, Type from .._compat import as_unicode from ..exceptions import ( @@ -128,10 +128,10 @@ class Filters(object): def __init__( self, - filter_converter: BaseFilterConverter, + filter_converter: Type[BaseFilterConverter], datamodel, - search_columns: List[str] = None, - search_filters: Dict[str, List[BaseFilter]] = None, + search_columns: Optional[List[str]] = None, + search_filters: Optional[Dict[str, List[BaseFilter]]] = None, ): """ diff --git a/flask_appbuilder/models/sqla/interface.py b/flask_appbuilder/models/sqla/interface.py index af337aebc6..028c0d9f65 100644 --- a/flask_appbuilder/models/sqla/interface.py +++ b/flask_appbuilder/models/sqla/interface.py @@ -1,17 +1,21 @@ # -*- coding: utf-8 -*- import logging import sys -from typing import List, Tuple +from typing import Any, List, Optional, Tuple, Type, Union from flask_sqlalchemy import BaseQuery import sqlalchemy as sa -from sqlalchemy import func +from sqlalchemy import asc, desc from sqlalchemy.exc import IntegrityError -from sqlalchemy.orm import aliased, contains_eager, Load, load_only +from sqlalchemy.orm import aliased, contains_eager, Load from sqlalchemy.orm.descriptor_props import SynonymProperty +from sqlalchemy.orm.query import Query +from sqlalchemy.orm.session import Session as SessionBase from sqlalchemy.sql.elements import BinaryExpression +from sqlalchemy.sql.sqltypes import TypeEngine from sqlalchemy_utils.types.uuid import UUIDType + from . import filters, Model from ..base import BaseInterface from ..filters import Filters @@ -26,23 +30,18 @@ LOGMSG_WAR_DBI_DEL_INTEGRITY, LOGMSG_WAR_DBI_EDIT_INTEGRITY, ) +from ...exceptions import InterfaceQueryWithoutSession from ...filemanager import FileManager, ImageManager from ...utils.base import get_column_leaf, get_column_root_relation, is_column_dotted log = logging.getLogger(__name__) -def _include_filters(obj): - for key in filters.__all__: - if not hasattr(obj, key): - setattr(obj, key, getattr(filters, key)) - - -def _is_sqla_type(obj, sa_type): +def _is_sqla_type(model: Model, sa_type: Type[TypeEngine]) -> bool: return ( - isinstance(obj, sa_type) - or isinstance(obj, sa.types.TypeDecorator) - and isinstance(obj.impl, sa_type) + isinstance(model, sa_type) + or isinstance(model, sa.types.TypeDecorator) + and isinstance(model.impl, sa_type) ) @@ -56,7 +55,7 @@ class SQLAInterface(BaseInterface): filter_converter_class = filters.SQLAFilterConverter - def __init__(self, obj, session=None): + def __init__(self, obj: Model, session: Optional[SessionBase] = None) -> None: _include_filters(self) self.list_columns = dict() self.list_properties = dict() @@ -80,30 +79,15 @@ def model_name(self): return self.obj.__name__ @staticmethod - def is_model_already_joined(query, model): + def is_model_already_joined(query: BaseQuery, model: Model) -> bool: return model in [mapper.class_ for mapper in query._join_entities] - def _apply_query_order( - self, query, order_column: str, order_direction: str - ) -> BaseQuery: - if order_column != "": - # if Model has custom decorator **renders('')** - # this decorator will add a property to the method named *_col_name* - if hasattr(self.obj, order_column): - if hasattr(getattr(self.obj, order_column), "_col_name"): - order_column = getattr(self._get_attr(order_column), "_col_name") - if order_direction == "asc": - query = query.order_by(self._get_attr(order_column).asc()) - else: - query = query.order_by(self._get_attr(order_column).desc()) - return query - def _get_base_query( self, query=None, filters=None, order_column="", order_direction="" ): if filters: query = filters.apply_all(query) - return self._apply_query_order(query, order_column, order_direction) + return self.apply_order_by(query, order_column, order_direction) def _query_join_relation(self, query: BaseQuery, root_relation: str) -> BaseQuery: """ @@ -123,89 +107,199 @@ def _query_join_relation(self, query: BaseQuery, root_relation: str) -> BaseQuer # Since the join already exists apply a new aliased one model_relation = aliased(model_relation) # The binary expression needs to be inverted + relation_pk = self.get_pk(model_relation) relation_join = BinaryExpression( - relation_join.left, model_relation.id, relation_join.operator + relation_join.left, relation_pk, relation_join.operator ) query = query.join(model_relation, relation_join, isouter=True) return query - def _query_join_dotted_column(self, query: BaseQuery, column: str) -> BaseQuery: - """ + def apply_engine_specific_hack( + self, query: BaseQuery, page, page_size, order_column + ) -> BaseQuery: + # MSSQL exception page/limit must have an order by + if ( + page + and page_size + and not order_column + and self.session.bind.dialect.name == "mssql" + ): + pk_name = self.get_pk_name() + return query.order_by(pk_name) + return query - :param query: SQLAlchemy query object - :param column: If the column is dotted will join the root relation - :return: Transformed SQLAlchemy Query - """ - if is_column_dotted(column): - return self._query_join_relation(query, get_column_root_relation(column)) + def apply_order_by( + self, query: BaseQuery, order_column: str, order_direction: str + ) -> BaseQuery: + if order_column != "": + # if Model has custom decorator **renders('')** + # this decorator will add a property to the method named *_col_name* + if hasattr(self.obj, order_column): + if hasattr(getattr(self.obj, order_column), "_col_name"): + order_column = getattr(self._get_attr(order_column), "_col_name") + _order_column = self._get_attr(order_column) or order_column + if is_column_dotted(order_column): + query = self._query_join_relation( + query, get_column_root_relation(order_column) + ) + if order_direction == "asc": + query = query.order_by(asc(_order_column)) + else: + query = query.order_by(desc(_order_column)) + return query + + def apply_pagination( + self, query: BaseQuery, page: Optional[int], page_size: Optional[int] + ) -> BaseQuery: + if page and page_size: + query = query.offset(page * page_size) + if page_size: + query = query.limit(page_size) + return query + + def apply_filters(self, query: BaseQuery, filters: Optional[Filters]) -> BaseQuery: + if filters: + return filters.apply_all(query) + return query + + def _apply_normal_col_select_option(self, query, column) -> BaseQuery: + if not self.is_relation(column) and not self.is_property_or_function(column): + return query.options(Load(self.obj).load_only(column)) return query - def _query_select_options( + def apply_inner_select_joins( self, query: BaseQuery, select_columns: List[str] = None ) -> BaseQuery: """ Add select load options to query. The goal is to only SQL select what is requested and join all the necessary - models when dotted notation is used + models when dotted notation is used. Inner implies non dotted columns + and one to many and one to one - :param query: SQLAlchemy Query obj to apply joins and selects - :param select_columns: (list) of columns - :return: Transformed SQLAlchemy Query + :param query: + :param select_columns: + :return: """ - if select_columns: - load_options = list() - joined_models = list() - for column in select_columns: - if is_column_dotted(column): - root_relation = get_column_root_relation(column) - leaf_column = get_column_leaf(column) - if self.is_relation_many_to_many( - root_relation - ) or self.is_relation_one_to_many(root_relation): - load_options.append( - ( - Load(self.obj) - .joinedload(root_relation) - .load_only(leaf_column) - ) - ) - continue - elif root_relation not in joined_models: + if not select_columns: + return query + joined_models = list() + for column in select_columns: + if is_column_dotted(column): + root_relation = get_column_root_relation(column) + leaf_column = get_column_leaf(column) + if self.is_relation_many_to_one( + root_relation + ) or self.is_relation_one_to_one(root_relation): + if root_relation not in joined_models: query = self._query_join_relation(query, root_relation) + # only needed if we need to wrap this query, from_self + if select_columns and self.exists_col_to_many(select_columns): + related_model = self.get_related_model(root_relation) + query = query.add_entity(related_model) joined_models.append(root_relation) - load_options.append( + query = query.options( (contains_eager(root_relation).load_only(leaf_column)) ) + else: + query = self._apply_normal_col_select_option(query, column) + return query + + def apply_outer_select_joins( + self, query: BaseQuery, select_columns: List[str] = None + ) -> BaseQuery: + if not select_columns: + return query + for column in select_columns: + if is_column_dotted(column): + root_relation = get_column_root_relation(column) + leaf_column = get_column_leaf(column) + if self.is_relation_many_to_many( + root_relation + ) or self.is_relation_one_to_many(root_relation): + query = query.options( + Load(self.obj).joinedload(root_relation).load_only(leaf_column) + ) else: - if not self.is_relation( - column - ) and not self.is_property_or_function(column): - load_options.append(load_only(column)) - query = query.options(*tuple(load_options)) + related_model = self.get_related_model(root_relation) + query = query.options(Load(related_model).load_only(leaf_column)) + else: + query = self._apply_normal_col_select_option(query, column) return query - def _get_non_dotted_filters(self, filters): - dotted_filters = Filters(self.filter_converter_class, self, [], []) + def get_inner_filters(self, filters: Optional[Filters]) -> Filters: + """ + Inner filters are non dotted columns and + one to many or one to one relations + + :param filters: All filters + :return: New filtered filters to apply to an inner query + """ + inner_filters = Filters(self.filter_converter_class, self) _filters = [] if filters: for flt, value in zip(filters.filters, filters.values): if not is_column_dotted(flt.column_name): _filters.append((flt.column_name, flt.__class__, value)) - dotted_filters.add_filter_list(_filters) - return dotted_filters + elif self.is_relation_many_to_one( + flt.column_name + ) or self.is_relation_one_to_one(flt.column_name): + _filters.append((flt.column_name, flt.__class__, value)) + inner_filters.add_filter_list(_filters) + return inner_filters + + def exists_col_to_many(self, select_columns: List[str]) -> bool: + for column in select_columns: + if is_column_dotted(column): + root_relation = get_column_root_relation(column) + if self.is_relation_many_to_many( + root_relation + ) or self.is_relation_one_to_many(root_relation): + return True + return False + + def _apply_inner_all( + self, + query: Query, + filters: Optional[Filters] = None, + order_column: str = "", + order_direction: str = "", + page: Optional[int] = None, + page_size: Optional[int] = None, + select_columns: Optional[List[str]] = None, + ): + inner_filters = self.get_inner_filters(filters) + query = self.apply_inner_select_joins(query, select_columns) + query = self.apply_filters(query, inner_filters) + query = self.apply_engine_specific_hack(query, page, page_size, order_column) + query = self.apply_order_by(query, order_column, order_direction) + query = self.apply_pagination(query, page, page_size) + return query - def query( + def query_count( self, - filters=None, - order_column="", - order_direction="", - page=None, - page_size=None, - select_columns=None, + query: Query, + filters: Optional[Filters] = None, + select_columns: Optional[List[str]] = None, ): + return self._apply_inner_all( + query, filters, select_columns=select_columns + ).count() + + def apply_all( + self, + query: Query, + filters: Optional[Filters] = None, + order_column: str = "", + order_direction: str = "", + page: Optional[int] = None, + page_size: Optional[int] = None, + select_columns: Optional[List[str]] = None, + ) -> BaseQuery: """ - Returns the results for a model query, applies filters, sorting and pagination + Accepts a SQLAlchemy Query and applies all filtering logic, order by and + pagination. + :param query: The query to apply all :param filters: dict with filters {: Tuple[int, List[Model]]: + """ + Returns the results for a model query, applies filters, sorting and pagination - # Pagination comes first - if page and page_size: - query = query.offset(page * page_size) - if page_size: - query = query.limit(page_size) + :param filters: A Filter class that contains all filters to apply + :param order_column: name of the column to order + :param order_direction: the direction to order <'asc'|'desc'> + :param page: the current page + :param page_size: the current page size + :param select_columns: A List of columns to be specifically selected + on the query. Supports dotted notation. + :return: A tuple with the query count (non paginated) and the results + """ + if not self.session: + raise InterfaceQueryWithoutSession() + query = self.session.query(self.obj) - if select_columns and order_column: - # Use from self strategy - select_columns = select_columns + [order_column] - # Everything uses an inner query because of joins to m/m m/1 - query = self._query_select_options(query.from_self(), select_columns) - - query = self._get_base_query( - query=query, - filters=filters, - order_column=order_column, - order_direction=order_direction, + count = self.query_count(query, filters, select_columns) + query = self.apply_all( + query, + filters, + order_column, + order_direction, + page, + page_size, + select_columns, ) + query_results = query.all() - count = query_count.scalar() - return count, query.all() + result = list() + for item in query_results: + if hasattr(item, self.obj.__name__): + result.append(getattr(item, self.obj.__name__)) + else: + return count, query_results + return count, result def query_simple_group( self, group_by="", aggregate_func=None, aggregate_col=None, filters=None @@ -291,13 +410,13 @@ def query_year_group(self, group_by="", filters=None): def is_image(self, col_name: str) -> bool: try: return isinstance(self.list_columns[col_name].type, ImageColumn) - except Exception: + except KeyError: return False def is_file(self, col_name: str) -> bool: try: return isinstance(self.list_columns[col_name].type, FileColumn) - except Exception: + except KeyError: return False def is_string(self, col_name: str) -> bool: @@ -306,61 +425,61 @@ def is_string(self, col_name: str) -> bool: _is_sqla_type(self.list_columns[col_name].type, sa.types.String) or self.list_columns[col_name].type.__class__ == UUIDType ) - except Exception: + except KeyError: return False def is_text(self, col_name: str) -> bool: try: return _is_sqla_type(self.list_columns[col_name].type, sa.types.Text) - except Exception: + except KeyError: return False def is_binary(self, col_name: str) -> bool: try: return _is_sqla_type(self.list_columns[col_name].type, sa.types.LargeBinary) - except Exception: + except KeyError: return False def is_integer(self, col_name: str) -> bool: try: return _is_sqla_type(self.list_columns[col_name].type, sa.types.Integer) - except Exception: + except KeyError: return False def is_numeric(self, col_name: str) -> bool: try: return _is_sqla_type(self.list_columns[col_name].type, sa.types.Numeric) - except Exception: + except KeyError: return False def is_float(self, col_name: str) -> bool: try: return _is_sqla_type(self.list_columns[col_name].type, sa.types.Float) - except Exception: + except KeyError: return False def is_boolean(self, col_name: str) -> bool: try: return _is_sqla_type(self.list_columns[col_name].type, sa.types.Boolean) - except Exception: + except KeyError: return False def is_date(self, col_name: str) -> bool: try: return _is_sqla_type(self.list_columns[col_name].type, sa.types.Date) - except Exception: + except KeyError: return False def is_datetime(self, col_name: str) -> bool: try: return _is_sqla_type(self.list_columns[col_name].type, sa.types.DateTime) - except Exception: + except KeyError: return False def is_enum(self, col_name: str) -> bool: try: return _is_sqla_type(self.list_columns[col_name].type, sa.types.Enum) - except Exception: + except KeyError: return False def is_relation(self, col_name: str) -> bool: @@ -368,35 +487,39 @@ def is_relation(self, col_name: str) -> bool: return isinstance( self.list_properties[col_name], sa.orm.properties.RelationshipProperty ) - except Exception: + except KeyError: return False def is_relation_many_to_one(self, col_name: str) -> bool: try: if self.is_relation(col_name): return self.list_properties[col_name].direction.name == "MANYTOONE" - except Exception: + return False + except KeyError: return False def is_relation_many_to_many(self, col_name: str) -> bool: try: if self.is_relation(col_name): return self.list_properties[col_name].direction.name == "MANYTOMANY" - except Exception: + return False + except KeyError: return False def is_relation_one_to_one(self, col_name: str) -> bool: try: if self.is_relation(col_name): return self.list_properties[col_name].direction.name == "ONETOONE" - except Exception: + return False + except KeyError: return False def is_relation_one_to_many(self, col_name: str) -> bool: try: if self.is_relation(col_name): return self.list_properties[col_name].direction.name == "ONETOMANY" - except Exception: + return False + except KeyError: return False def is_nullable(self, col_name: str) -> bool: @@ -405,19 +528,19 @@ def is_nullable(self, col_name: str) -> bool: return col.nullable try: return self.list_columns[col_name].nullable - except Exception: + except KeyError: return False def is_unique(self, col_name: str) -> bool: try: return self.list_columns[col_name].unique is True - except Exception: + except KeyError: return False def is_pk(self, col_name: str) -> bool: try: return self.list_columns[col_name].primary_key - except Exception: + except KeyError: return False def is_pk_composite(self) -> bool: @@ -426,7 +549,7 @@ def is_pk_composite(self) -> bool: def is_fk(self, col_name: str) -> bool: try: return self.list_columns[col_name].foreign_keys - except Exception: + except KeyError: return False def is_property(self, col_name: str) -> bool: @@ -609,25 +732,24 @@ def get_related_model_and_join(self, col_name: str) -> List[Tuple[Model, object] ] return [(relation.mapper.class_, relation.primaryjoin)] - def query_model_relation(self, col_name): - model = self.get_related_model(col_name) - return self.session.query(model).all() - - def get_related_interface(self, col_name): + def get_related_interface(self, col_name: str): return self.__class__(self.get_related_model(col_name), self.session) - def get_related_obj(self, col_name, value): + def get_related_obj(self, col_name: str, value: Any) -> Optional[Type[Model]]: rel_model = self.get_related_model(col_name) - return self.session.query(rel_model).get(value) + if self.session: + return self.session.query(rel_model).get(value) + return None - def get_related_fks(self, related_views): + def get_related_fks(self, related_views) -> List[str]: return [view.datamodel.get_related_fk(self.obj) for view in related_views] - def get_related_fk(self, model): + def get_related_fk(self, model) -> Optional[str]: for col_name in self.list_properties.keys(): if self.is_relation(col_name): if model == self.get_related_model(col_name): return col_name + return None def get_info(self, col_name): if col_name in self.list_properties: @@ -640,15 +762,15 @@ def get_info(self, col_name): ------------- """ - def get_columns_list(self): + def get_columns_list(self) -> List[str]: """ - Returns all model's columns on SQLA properties + Returns all model's columns on SQLA properties """ return list(self.list_properties.keys()) - def get_user_columns_list(self): + def get_user_columns_list(self) -> List[str]: """ - Returns all model's columns except pk or fk + Returns all model's columns except pk or fk """ ret_lst = list() for col_name in self.get_columns_list(): @@ -657,7 +779,7 @@ def get_user_columns_list(self): return ret_lst # TODO get different solution, more integrated with filters - def get_search_columns_list(self): + def get_search_columns_list(self) -> List[str]: ret_lst = list() for col_name in self.get_columns_list(): if not self.is_relation(col_name): @@ -673,12 +795,12 @@ def get_search_columns_list(self): ret_lst.append(col_name) return ret_lst - def get_order_columns_list(self, list_columns=None): + def get_order_columns_list(self, list_columns: List[str] = None) -> List[str]: """ - Returns the columns that can be ordered + Returns the columns that can be ordered - :param list_columns: optional list of columns name, if provided will - use this list only. + :param list_columns: optional list of columns name, if provided will + use this list only. """ ret_lst = list() list_columns = list_columns or self.get_columns_list() @@ -707,32 +829,83 @@ def get_image_column_list(self): if isinstance(i.type, ImageColumn) ] - def get_property_first_col(self, col_name): + def get_property_first_col(self, col_name: str) -> str: # support for only one col for pk and fk return self.list_properties[col_name].columns[0] - def get_relation_fk(self, col_name): + def get_relation_fk(self, col_name: str) -> str: # support for only one col for pk and fk return list(self.list_properties[col_name].local_columns)[0] - def get(self, id, filters=None): + def get( + self, + id, + filters: Optional[Filters] = None, + select_columns: Optional[List[str]] = None, + ) -> Optional[Model]: + """ + Returns the result for a model get, applies filters and supports dotted + notation for joins and granular selecting query columns. + + :param id: The model id (pk). + :param filters: A Filter class that contains all filters to apply. + :param select_columns: A List of columns to be specifically selected. + on the query. Supports dotted notation. + :return: + """ + pk = self.get_pk_name() if filters: - query = self.session.query(self.obj) _filters = filters.copy() - pk = self.get_pk_name() - if self.is_pk_composite(): - for _pk, _id in zip(pk, id): - _filters.add_filter(_pk, self.FilterEqual, _id) - else: - _filters.add_filter(pk, self.FilterEqual, id) - query = self._get_base_query(query=query, filters=_filters) - return query.first() - return self.session.query(self.obj).get(id) + else: + _filters = Filters(self.filter_converter_class, self) + _filters.add_filter(pk, self.FilterEqual, id) + + if self.is_pk_composite(): + for _pk, _id in zip(pk, id): + _filters.add_filter(_pk, self.FilterEqual, _id) + else: + _filters.add_filter(pk, self.FilterEqual, id) + query = self.session.query(self.obj) + item = self.apply_all( + query, _filters, select_columns=select_columns + ).one_or_none() + if item: + if hasattr(item, self.obj.__name__): + return getattr(item, self.obj.__name__) + return item + + def get_pk_name(self) -> Optional[Union[List[str], str]]: + """ + Get the model primary key column name. + """ + return self._get_pk_name(self.obj) - def get_pk_name(self): - pk = [pk.name for pk in self.obj.__mapper__.primary_key] + def get_pk(self, model: Optional[Model] = None): + """ + Get the model primary key SQLAlchemy column. + Will not support composite keys + """ + model_ = model or self.obj + pk_name = self._get_pk_name(model_) + if pk_name and isinstance(pk_name, str): + return getattr(model_, pk_name) + return None + + def _get_pk_name(self, model: Model) -> Optional[Union[List[str], str]]: + pk = [pk.name for pk in model.__mapper__.primary_key] if pk: return pk if self.is_pk_composite() else pk[0] + return None + + +def _include_filters(interface: SQLAInterface) -> None: + """ + Injects all filters on the interface class itself + :param interface: + """ + for key in filters.__all__: + if not hasattr(interface, key): + setattr(interface, key, getattr(filters, key)) """ diff --git a/flask_appbuilder/security/sqla/manager.py b/flask_appbuilder/security/sqla/manager.py index 93bc01b977..35f1710e8c 100644 --- a/flask_appbuilder/security/sqla/manager.py +++ b/flask_appbuilder/security/sqla/manager.py @@ -272,7 +272,9 @@ def update_role(self, pk, name: str) -> Optional[Role]: return def find_role(self, name): - return self.get_session.query(self.role_model).filter_by(name=name).first() + return ( + self.get_session.query(self.role_model).filter_by(name=name).one_or_none() + ) def get_all_roles(self): return self.get_session.query(self.role_model).all() @@ -281,7 +283,7 @@ def get_public_role(self): return ( self.get_session.query(self.role_model) .filter_by(name=self.auth_role_public) - .first() + .one_or_none() ) def get_public_permissions(self): @@ -295,7 +297,9 @@ def find_permission(self, name): Finds and returns a Permission by name """ return ( - self.get_session.query(self.permission_model).filter_by(name=name).first() + self.get_session.query(self.permission_model) + .filter_by(name=name) + .one_or_none() ) def exist_permission_on_roles( @@ -417,7 +421,11 @@ def find_view_menu(self, name): """ Finds and returns a ViewMenu by name """ - return self.get_session.query(self.viewmenu_model).filter_by(name=name).first() + return ( + self.get_session.query(self.viewmenu_model) + .filter_by(name=name) + .one_or_none() + ) def get_all_view_menu(self): return self.get_session.query(self.viewmenu_model).all() @@ -485,7 +493,7 @@ def find_permission_view_menu(self, permission_name, view_menu_name): return ( self.get_session.query(self.permissionview_model) .filter_by(permission=permission, view_menu=view_menu) - .first() + .one_or_none() ) def find_permissions_view_menu(self, view_menu): @@ -537,7 +545,7 @@ def del_permission_view_menu(self, permission_name, view_menu_name, cascade=True roles_pvs = ( self.get_session.query(self.role_model) .filter(self.role_model.permissions.contains(pv)) - .first() + .one_or_none() ) if roles_pvs: log.warning( diff --git a/flask_appbuilder/tests/const.py b/flask_appbuilder/tests/const.py index 5321fb0ac2..bdc26c6be8 100644 --- a/flask_appbuilder/tests/const.py +++ b/flask_appbuilder/tests/const.py @@ -1,5 +1,7 @@ MODEL1_DATA_SIZE = 30 MODEL2_DATA_SIZE = 30 +MODELOMCHILD_DATA_SIZE = 30 + USERNAME_ADMIN = "testadmin" PASSWORD_ADMIN = "password" MAX_PAGE_SIZE = 25 diff --git a/flask_appbuilder/tests/sqla/models.py b/flask_appbuilder/tests/sqla/models.py index a01a27a1ed..2506187e85 100644 --- a/flask_appbuilder/tests/sqla/models.py +++ b/flask_appbuilder/tests/sqla/models.py @@ -18,6 +18,8 @@ ) from sqlalchemy.orm import backref, relationship +from ..const import MODELOMCHILD_DATA_SIZE + def validate_name(n): if n[0] != "A": @@ -35,8 +37,9 @@ def __repr__(self): return str(self.field_string) def full_concat(self): - return "{}.{}.{}.{}".format( - self.field_string, self.field_integer, self.field_float, self.field_date + return ( + f"{self.field_string}.{self.field_integer}" + f".{self.field_float}.{self.field_date}" ) @@ -320,18 +323,18 @@ def insert_data(session, count): session.add(model) session.commit() - model_oo_parents = list() + model_om_parents = list() for i in range(count): model = ModelOMParent() model.field_string = f"text{i}" session.add(model) session.commit() - model_oo_parents.append(model) + model_om_parents.append(model) for i in range(count): - for j in range(1, 4): + for j in range(1, MODELOMCHILD_DATA_SIZE): model = ModelOMChild() model.field_string = f"text{i}.{j}" - model.parent = model_oo_parents[i] + model.parent = model_om_parents[i] session.add(model) session.commit() diff --git a/flask_appbuilder/tests/test_api.py b/flask_appbuilder/tests/test_api.py index 6f59d08b9f..782c306bd4 100644 --- a/flask_appbuilder/tests/test_api.py +++ b/flask_appbuilder/tests/test_api.py @@ -39,6 +39,7 @@ MAX_PAGE_SIZE, MODEL1_DATA_SIZE, MODEL2_DATA_SIZE, + MODELOMCHILD_DATA_SIZE, PASSWORD_ADMIN, PASSWORD_READONLY, USERNAME_ADMIN, @@ -289,6 +290,13 @@ class ModelOMParentApi(ModelRestApi): self.appbuilder.add_api(ModelOMParentApi) + class ModelDottedOMParentApi(ModelRestApi): + datamodel = SQLAInterface(ModelOMParent) + list_columns = ["field_string", "children.field_string"] + show_columns = ["field_string", "children.field_string"] + + self.appbuilder.add_api(ModelDottedOMParentApi) + class ModelMMRequiredApi(ModelRestApi): datamodel = SQLAInterface(ModelMMParentRequired) @@ -452,7 +460,7 @@ def test_auth_authorization(self): token = self.login(client, USERNAME_ADMIN, PASSWORD_ADMIN) # Test unauthorized DELETE pk = 1 - uri = "api/v1/model1apirestrictedpermissions/{}".format(pk) + uri = f"api/v1/model1apirestrictedpermissions/{pk}" rv = self.auth_client_delete(client, token, uri) self.assertEqual(rv.status_code, 401) # Test unauthorized POST @@ -466,7 +474,7 @@ def test_auth_authorization(self): rv = self.auth_client_post(client, token, uri, item) self.assertEqual(rv.status_code, 401) # Test authorized GET - uri = "api/v1/model1apirestrictedpermissions/1" + uri = f"api/v1/model1apirestrictedpermissions/{pk}" rv = self.auth_client_get(client, token, uri) self.assertEqual(rv.status_code, 200) @@ -682,9 +690,9 @@ def test_get_item_select_cols(self): ) self.assertEqual(rv.status_code, 200) - def test_get_item_dotted_notation(self): + def test_get_item_dotted_mo_notation(self): """ - REST Api: Test get item with dotted notation + REST Api: Test get item with dotted M-O related field """ client = self.app.test_client() token = self.login(client, USERNAME_ADMIN, PASSWORD_ADMIN) @@ -771,9 +779,9 @@ def test_get_item_base_filters(self): rv = self.auth_client_get(client, token, f"api/v1/model1apifiltered/{pk}") self.assertEqual(rv.status_code, 200) - def test_get_item_1m_field(self): + def test_get_item_mo_field(self): """ - REST Api: Test get item with 1-N related field + REST Api: Test get item with M-O related field """ client = self.app.test_client() token = self.login(client, USERNAME_ADMIN, PASSWORD_ADMIN) @@ -801,7 +809,7 @@ def test_get_item_1m_field(self): def test_get_item_mm_field(self): """ - REST Api: Test get item with N-N related field + REST Api: Test get item with M-M related field """ client = self.app.test_client() token = self.login(client, USERNAME_ADMIN, PASSWORD_ADMIN) @@ -820,7 +828,7 @@ def test_get_item_mm_field(self): def test_get_item_dotted_mm_field(self): """ - REST Api: Test get item with dotted N-N related field + REST Api: Test get item with dotted M-M related field """ client = self.app.test_client() token = self.login(client, USERNAME_ADMIN, PASSWORD_ADMIN) @@ -855,7 +863,8 @@ def test_get_item_om_field(self): data = json.loads(rv.data.decode("utf-8")) self.assertEqual(rv.status_code, 200) expected_rel_field = [ - {"field_string": f"text0.{i}", "id": i} for i in range(1, 4) + {"field_string": f"text0.{i}", "id": i} + for i in range(1, MODELOMCHILD_DATA_SIZE) ] self.assertEqual(data[API_RESULT_RES_KEY]["children"], expected_rel_field) @@ -874,9 +883,9 @@ def test_get_list(self): # Tests data result default page size self.assertEqual(len(data[API_RESULT_RES_KEY]), self.model1api.page_size) - def test_get_list_dotted_notation(self): + def test_get_list_dotted_mo_field(self): """ - REST Api: Test get list with dotted notation + REST Api: Test get list with dotted M-O related field """ client = self.app.test_client() token = self.login(client, USERNAME_ADMIN, PASSWORD_ADMIN) @@ -897,9 +906,44 @@ def test_get_list_dotted_notation(self): {"field_string": "test0", "group": {"field_string": "test0"}}, ) + def test_get_list_om_field(self): + """ + REST Api: Test get list with O-M related field + """ + client = self.app.test_client() + token = self.login(client, USERNAME_ADMIN, PASSWORD_ADMIN) + + rv = self.auth_client_get(client, token, "api/v1/modelomparentapi/") + data = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 200) + self.assertEqual(data["count"], MODEL1_DATA_SIZE) + self.assertEqual(len(data[API_RESULT_RES_KEY]), self.model1api.page_size) + expected_rel_field = [ + {"field_string": f"text0.{i}", "id": i} + for i in range(1, MODELOMCHILD_DATA_SIZE) + ] + self.assertEqual(data[API_RESULT_RES_KEY][0]["children"], expected_rel_field) + + def test_get_list_dotted_om_field(self): + """ + REST Api: Test get list with dotted O-M related field + """ + client = self.app.test_client() + token = self.login(client, USERNAME_ADMIN, PASSWORD_ADMIN) + + rv = self.auth_client_get(client, token, "api/v1/modeldottedomparentapi/") + data = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 200) + self.assertEqual(data["count"], MODEL1_DATA_SIZE) + self.assertEqual(len(data[API_RESULT_RES_KEY]), self.model1api.page_size) + expected_rel_field = [ + {"field_string": f"text0.{i}"} for i in range(1, MODELOMCHILD_DATA_SIZE) + ] + self.assertEqual(data[API_RESULT_RES_KEY][0]["children"], expected_rel_field) + def test_get_list_dotted_mm_field(self): """ - REST Api: Test get list with dotted N-N related field + REST Api: Test get list with dotted M-M related field """ client = self.app.test_client() token = self.login(client, USERNAME_ADMIN, PASSWORD_ADMIN) @@ -920,9 +964,9 @@ def test_get_list_dotted_mm_field(self): self.assertIn({"field_integer": 2}, data[API_RESULT_RES_KEY][i]["children"]) self.assertIn({"field_integer": 3}, data[API_RESULT_RES_KEY][i]["children"]) - def test_get_list_dotted_order(self): + def test_get_list_dotted_mo_order(self): """ - REST Api: Test get list and order dotted notation + REST Api: Test get list and order dotted M-O notation """ client = self.app.test_client() token = self.login(client, USERNAME_ADMIN, PASSWORD_ADMIN) diff --git a/flask_appbuilder/tests/test_mvc.py b/flask_appbuilder/tests/test_mvc.py index 18dab80f36..e2f7fe53e5 100644 --- a/flask_appbuilder/tests/test_mvc.py +++ b/flask_appbuilder/tests/test_mvc.py @@ -381,7 +381,7 @@ class Model22View(ModelView): class Model1View(ModelView): datamodel = SQLAInterface(Model1) related_views = [Model2View] - list_columns = ["field_string", "field_file"] + list_columns = ["field_string", "field_integer"] class Model3View(ModelView): datamodel = SQLAInterface(Model3) diff --git a/tox.ini b/tox.ini index 833a897be9..e772cfdc9d 100644 --- a/tox.ini +++ b/tox.ini @@ -16,29 +16,29 @@ deps = setenv = SQLALCHEMY_DATABASE_URI = sqlite:/// commands = - nosetests -v --with-coverage --cover-package=flask_appbuilder flask_appbuilder.tests --ignore-files="test_mongoengine\.py" + nosetests --stop -v --with-coverage --cover-package=flask_appbuilder flask_appbuilder.tests --ignore-files="test_mongoengine\.py" [testenv:mysql] setenv = SQLALCHEMY_DATABASE_URI = mysql://mysqluser:mysqluserpassword@0.0.0.0/app?charset=utf8 commands = - nosetests -v --with-coverage --cover-package=flask_appbuilder flask_appbuilder.tests --ignore-files="test_mongoengine\.py" + nosetests --stop -v --with-coverage --cover-package=flask_appbuilder flask_appbuilder.tests --ignore-files="test_mongoengine\.py" [testenv:postgres] setenv = SQLALCHEMY_DATABASE_URI = postgresql+psycopg2://pguser:pguserpassword@0.0.0.0/app commands = - nosetests -v --with-coverage --cover-package=flask_appbuilder flask_appbuilder.tests --ignore-files="test_mongoengine\.py" + nosetests --stop -v --with-coverage --cover-package=flask_appbuilder flask_appbuilder.tests --ignore-files="test_mongoengine\.py" [testenv:mssql] setenv = SQLALCHEMY_DATABASE_URI = mssql+pymssql://sa:Password_123@localhost:1433/master commands = - nosetests -v --with-coverage --cover-package=flask_appbuilder flask_appbuilder.tests --ignore-files="test_mongoengine\.py" + nosetests --stop -v --with-coverage --cover-package=flask_appbuilder flask_appbuilder.tests --ignore-files="test_mongoengine\.py" [testenv:mongodb] commands = - nosetests -v --with-coverage --cover-package=flask_appbuilder flask_appbuilder/tests/test_mongoengine.py + nosetests --stop -v --with-coverage --cover-package=flask_appbuilder flask_appbuilder/tests/test_mongoengine.py [testenv:black] commands =