Skip to content

Commit

Permalink
Fix fetchall returning already returned rows
Browse files Browse the repository at this point in the history
If fetchall was called after fetching some rows already via either
fetchmany or fetchone then some rows were duplicated.

    cur = conn.cursor()
    cur.execute("SELECT * FROM ( VALUES (1), (2), (3), (4), (5), (6))")
    print(cur.fetchmany(2))
    print(cur.fetchall()) # should return 4 rows but returns all rows
    print(cur.fetchmany(10)) # should return no rows but returns next 4 rows
  • Loading branch information
aksakalli committed Oct 13, 2023
1 parent 2d2888b commit e4249f4
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 9 deletions.
9 changes: 9 additions & 0 deletions tests/integration/test_dbapi_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1328,6 +1328,15 @@ def test_select_tpch_1000(trino_connection):
assert len(rows) == 1000


def test_fetch_cursor(trino_connection):
cur = trino_connection.cursor()
cur.execute("SELECT * FROM tpch.sf1.customer LIMIT 1000")
for _ in range(100):
cur.fetchone()
assert len(cur.fetchmany(400)) == 400
assert len(cur.fetchall()) == 500


def test_cancel_query(trino_connection):
cur = trino_connection.cursor()
cur.execute("SELECT * FROM tpch.sf1.customer")
Expand Down
12 changes: 3 additions & 9 deletions trino/dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import uuid
from collections import OrderedDict
from decimal import Decimal
from itertools import islice
from threading import Lock
from time import time
from typing import Any, Dict, List, NamedTuple, Optional # NOQA for mypy types
Expand Down Expand Up @@ -661,14 +662,7 @@ def fetchmany(self, size=None) -> List[List[Any]]:
if size is None:
size = self.arraysize

result = []
for _ in range(size):
row = self.fetchone()
if row is None:
break
result.append(row)

return result
return list(islice(iter(self.fetchone, None), size))

def describe(self, sql: str) -> List[DescribeOutput]:
"""
Expand Down Expand Up @@ -696,7 +690,7 @@ def genall(self):
return self._query.result

def fetchall(self) -> List[List[Any]]:
return list(self.genall())
return list(iter(self.fetchone, None))

def cancel(self):
if self._query is None:
Expand Down

0 comments on commit e4249f4

Please sign in to comment.