Skip to content

Commit

Permalink
Fix, load options and limits for many to many truncating results (#1389)
Browse files Browse the repository at this point in the history
* Fix, load options and limits for many to many truncating results

This fix introduces inner queries so that many to many relations are not truncated, so now the limit and offset is applied on the inner query allowing for a result to have more rows then the limit caused by a many to many or many to one relation.

* Fix, non dotted (aka related) filters are applied first on the inner query

* sync with latest fix on 2.3.4

* fix test
  • Loading branch information
dpgaspar committed Jun 9, 2020
1 parent 00907ce commit 1c91e7a
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 53 deletions.
141 changes: 88 additions & 53 deletions flask_appbuilder/models/sqla/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
import sqlalchemy as sa
from sqlalchemy import func
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import aliased, Load
from sqlalchemy.orm import aliased, contains_eager, Load, load_only
from sqlalchemy.orm.descriptor_props import SynonymProperty
from sqlalchemy.sql.elements import BinaryExpression
from sqlalchemy_utils.types.uuid import UUIDType

from . import filters, Model
from ..base import BaseInterface
from ..filters import Filters
from ..group import GroupByCol, GroupByDateMonth, GroupByDateYear
from ..mixins import FileColumn, ImageColumn
from ..._compat import as_unicode
Expand Down Expand Up @@ -82,11 +83,9 @@ def model_name(self):
def is_model_already_joined(query, model):
return model in [mapper.class_ for mapper in query._join_entities]

def _get_base_query(
self, query=None, filters=None, order_column="", order_direction=""
):
if filters:
query = filters.apply_all(query)
def _apply_query_order(
self, query, order_column: str, order_direction: str
) -> BaseQuery:
if order_column != "":
# if Model has custom decorator **renders('<COL_NAME>')**
# this decorator will add a property to the method named *_col_name*
Expand All @@ -99,6 +98,13 @@ def _get_base_query(
query = query.order_by(self._get_attr(order_column).desc())
return query

def _get_base_query(
self, query=None, filters=None, order_column="", order_direction=""
):
if filters:
query = filters.apply_all(query)
return self._apply_query_order(query, order_column, order_direction)

def _query_join_relation(self, query: BaseQuery, root_relation: str) -> BaseQuery:
"""
Helper function that applies necessary joins for dotted columns on a
Expand Down Expand Up @@ -153,31 +159,41 @@ def _query_select_options(
if is_column_dotted(column):
root_relation = get_column_root_relation(column)
leaf_column = get_column_leaf(column)
if root_relation not in joined_models:
if self.is_relation_many_to_many(
root_relation
) or self.is_relation_one_to_many(root_relation):
load_options.append(
(
Load(self.obj)
.joinedload(root_relation)
.load_only(leaf_column)
)
)
continue
elif root_relation not in joined_models:
query = self._query_join_relation(query, root_relation)
joined_models.append(root_relation)
load_options.append(
(
Load(self.obj)
.joinedload(root_relation)
.load_only(leaf_column)
)
(contains_eager(root_relation).load_only(leaf_column))
)
else:
# is a custom property method field?
if hasattr(getattr(self.obj, column), "fget"):
pass
# is not a relation and not a function?
elif not self.is_relation(column) and not hasattr(
getattr(self.obj, column), "__call__"
):
load_options.append(Load(self.obj).load_only(column))
# it's a normal column
else:
load_options.append(Load(self.obj))
if not self.is_relation(
column
) and not self.is_property_or_function(column):
load_options.append(load_only(column))
query = query.options(*tuple(load_options))
return query

def _get_non_dotted_filters(self, filters):
dotted_filters = Filters(self.filter_converter_class, self, [], [])
_filters = []
if filters:
for flt, value in zip(filters.filters, filters.values):
if not is_column_dotted(flt.column_name):
_filters.append((flt.column_name, flt.__class__, value))
dotted_filters.add_filter_list(_filters)
return dotted_filters

def query(
self,
filters=None,
Expand All @@ -202,8 +218,6 @@ def query(
the current page size
"""
query = self.session.query(self.obj)
query = self._query_join_dotted_column(query, order_column)
query = self._query_select_options(query, select_columns)
query_count = self.session.query(func.count("*")).select_from(self.obj)

query_count = self._get_base_query(query=query_count, filters=filters)
Expand All @@ -218,6 +232,23 @@ def query(
pk_name = self.get_pk_name()
query = query.order_by(pk_name)

# If order by is not dotted (related) we need to apply it first
if not is_column_dotted(order_column):
query = self._get_non_dotted_filters(filters).apply_all(query)
query = self._apply_query_order(query, order_column, order_direction)

# Pagination comes first
if page and page_size:
query = query.offset(page * page_size)
if page_size:
query = query.limit(page_size)

if select_columns and order_column:
# Use from self strategy
select_columns = select_columns + [order_column]
# Everything uses an inner query because of joins to m/m m/1
query = self._query_select_options(query.from_self(), select_columns)

query = self._get_base_query(
query=query,
filters=filters,
Expand All @@ -226,11 +257,6 @@ def query(
)

count = query_count.scalar()

if page and page_size:
query = query.offset(page * page_size)
if page_size:
query = query.limit(page_size)
return count, query.all()

def query_simple_group(
Expand Down Expand Up @@ -262,19 +288,19 @@ def query_year_group(self, group_by="", filters=None):
-----------------------------------------
"""

def is_image(self, col_name):
def is_image(self, col_name: str) -> bool:
try:
return isinstance(self.list_columns[col_name].type, ImageColumn)
except Exception:
return False

def is_file(self, col_name):
def is_file(self, col_name: str) -> bool:
try:
return isinstance(self.list_columns[col_name].type, FileColumn)
except Exception:
return False

def is_string(self, col_name):
def is_string(self, col_name: str) -> bool:
try:
return (
_is_sqla_type(self.list_columns[col_name].type, sa.types.String)
Expand All @@ -283,97 +309,97 @@ def is_string(self, col_name):
except Exception:
return False

def is_text(self, col_name):
def is_text(self, col_name: str) -> bool:
try:
return _is_sqla_type(self.list_columns[col_name].type, sa.types.Text)
except Exception:
return False

def is_binary(self, col_name):
def is_binary(self, col_name: str) -> bool:
try:
return _is_sqla_type(self.list_columns[col_name].type, sa.types.LargeBinary)
except Exception:
return False

def is_integer(self, col_name):
def is_integer(self, col_name: str) -> bool:
try:
return _is_sqla_type(self.list_columns[col_name].type, sa.types.Integer)
except Exception:
return False

def is_numeric(self, col_name):
def is_numeric(self, col_name: str) -> bool:
try:
return _is_sqla_type(self.list_columns[col_name].type, sa.types.Numeric)
except Exception:
return False

def is_float(self, col_name):
def is_float(self, col_name: str) -> bool:
try:
return _is_sqla_type(self.list_columns[col_name].type, sa.types.Float)
except Exception:
return False

def is_boolean(self, col_name):
def is_boolean(self, col_name: str) -> bool:
try:
return _is_sqla_type(self.list_columns[col_name].type, sa.types.Boolean)
except Exception:
return False

def is_date(self, col_name):
def is_date(self, col_name: str) -> bool:
try:
return _is_sqla_type(self.list_columns[col_name].type, sa.types.Date)
except Exception:
return False

def is_datetime(self, col_name):
def is_datetime(self, col_name: str) -> bool:
try:
return _is_sqla_type(self.list_columns[col_name].type, sa.types.DateTime)
except Exception:
return False

def is_enum(self, col_name):
def is_enum(self, col_name: str) -> bool:
try:
return _is_sqla_type(self.list_columns[col_name].type, sa.types.Enum)
except Exception:
return False

def is_relation(self, col_name):
def is_relation(self, col_name: str) -> bool:
try:
return isinstance(
self.list_properties[col_name], sa.orm.properties.RelationshipProperty
)
except Exception:
return False

def is_relation_many_to_one(self, col_name):
def is_relation_many_to_one(self, col_name: str) -> bool:
try:
if self.is_relation(col_name):
return self.list_properties[col_name].direction.name == "MANYTOONE"
except Exception:
return False

def is_relation_many_to_many(self, col_name):
def is_relation_many_to_many(self, col_name: str) -> bool:
try:
if self.is_relation(col_name):
return self.list_properties[col_name].direction.name == "MANYTOMANY"
except Exception:
return False

def is_relation_one_to_one(self, col_name):
def is_relation_one_to_one(self, col_name: str) -> bool:
try:
if self.is_relation(col_name):
return self.list_properties[col_name].direction.name == "ONETOONE"
except Exception:
return False

def is_relation_one_to_many(self, col_name):
def is_relation_one_to_many(self, col_name: str) -> bool:
try:
if self.is_relation(col_name):
return self.list_properties[col_name].direction.name == "ONETOMANY"
except Exception:
return False

def is_nullable(self, col_name):
def is_nullable(self, col_name: str) -> bool:
if self.is_relation_many_to_one(col_name):
col = self.get_relation_fk(col_name)
return col.nullable
Expand All @@ -382,28 +408,37 @@ def is_nullable(self, col_name):
except Exception:
return False

def is_unique(self, col_name):
def is_unique(self, col_name: str) -> bool:
try:
return self.list_columns[col_name].unique is True
except Exception:
return False

def is_pk(self, col_name):
def is_pk(self, col_name: str) -> bool:
try:
return self.list_columns[col_name].primary_key
except Exception:
return False

def is_pk_composite(self):
def is_pk_composite(self) -> bool:
return len(self.obj.__mapper__.primary_key) > 1

def is_fk(self, col_name):
def is_fk(self, col_name: str) -> bool:
try:
return self.list_columns[col_name].foreign_keys
except Exception:
return False

def get_max_length(self, col_name):
def is_property(self, col_name: str) -> bool:
return hasattr(getattr(self.obj, col_name), "fget")

def is_function(self, col_name: str) -> bool:
return hasattr(getattr(self.obj, col_name), "__call__")

def is_property_or_function(self, col_name: str) -> bool:
return self.is_property(col_name) or self.is_function(col_name)

def get_max_length(self, col_name: str) -> int:
try:
if self.is_enum(col_name):
return -1
Expand Down
4 changes: 4 additions & 0 deletions flask_appbuilder/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ class ModelDottedMMApi(ModelRestApi):
list_columns = ["field_string", "children.field_integer"]
show_columns = ["field_string", "children.field_integer"]

self.modeldottedmmapi = ModelDottedMMApi
self.appbuilder.add_api(ModelDottedMMApi)

class ModelOMParentApi(ModelRestApi):
Expand Down Expand Up @@ -910,8 +911,11 @@ def test_get_list_dotted_mm_field(self):
rv = self.auth_client_get(client, token, uri)
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 200)
self.assertEqual(data["count"], MODEL2_DATA_SIZE)
self.assertEqual(len(data[API_RESULT_RES_KEY]), self.modeldottedmmapi.page_size)
i = 0
self.assertEqual(data[API_RESULT_RES_KEY][i]["field_string"], "0")
self.assertEqual(len(data[API_RESULT_RES_KEY][i]["children"]), 3)
self.assertIn({"field_integer": 1}, data[API_RESULT_RES_KEY][i]["children"])
self.assertIn({"field_integer": 2}, data[API_RESULT_RES_KEY][i]["children"])
self.assertIn({"field_integer": 3}, data[API_RESULT_RES_KEY][i]["children"])
Expand Down

0 comments on commit 1c91e7a

Please sign in to comment.