From e4249f48402d1b9dfc585f19bd854d21ae12e6dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Can=20G=C3=BCney=20Aksakalli?= Date: Fri, 13 Oct 2023 14:00:11 +0200 Subject: [PATCH] Fix fetchall returning already returned rows If fetchall was called after fetching some rows already via either fetchmany or fetchone then some rows were duplicated. cur = conn.cursor() cur.execute("SELECT * FROM ( VALUES (1), (2), (3), (4), (5), (6))") print(cur.fetchmany(2)) print(cur.fetchall()) # should return 4 rows but returns all rows print(cur.fetchmany(10)) # should return no rows but returns next 4 rows --- tests/integration/test_dbapi_integration.py | 9 +++++++++ trino/dbapi.py | 12 +++--------- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/tests/integration/test_dbapi_integration.py b/tests/integration/test_dbapi_integration.py index 726c91bc..cdb28043 100644 --- a/tests/integration/test_dbapi_integration.py +++ b/tests/integration/test_dbapi_integration.py @@ -1328,6 +1328,15 @@ def test_select_tpch_1000(trino_connection): assert len(rows) == 1000 +def test_fetch_cursor(trino_connection): + cur = trino_connection.cursor() + cur.execute("SELECT * FROM tpch.sf1.customer LIMIT 1000") + for _ in range(100): + cur.fetchone() + assert len(cur.fetchmany(400)) == 400 + assert len(cur.fetchall()) == 500 + + def test_cancel_query(trino_connection): cur = trino_connection.cursor() cur.execute("SELECT * FROM tpch.sf1.customer") diff --git a/trino/dbapi.py b/trino/dbapi.py index 62ce893b..a417f526 100644 --- a/trino/dbapi.py +++ b/trino/dbapi.py @@ -23,6 +23,7 @@ import uuid from collections import OrderedDict from decimal import Decimal +from itertools import islice from threading import Lock from time import time from typing import Any, Dict, List, NamedTuple, Optional # NOQA for mypy types @@ -661,14 +662,7 @@ def fetchmany(self, size=None) -> List[List[Any]]: if size is None: size = self.arraysize - result = [] - for _ in range(size): - row = self.fetchone() - if row is None: - break - result.append(row) - - return result + return list(islice(iter(self.fetchone, None), size)) def describe(self, sql: str) -> List[DescribeOutput]: """ @@ -696,7 +690,7 @@ def genall(self): return self._query.result def fetchall(self) -> List[List[Any]]: - return list(self.genall()) + return list(iter(self.fetchone, None)) def cancel(self): if self._query is None: