Skip to content

Commit

Permalink
Merge branch 'main' into query_profiling
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-sanche committed Aug 7, 2024
2 parents 4319ac7 + ba20019 commit 175c76a
Show file tree
Hide file tree
Showing 8 changed files with 182 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1093,6 +1093,8 @@ async def sample_list_indexes():
method=rpc,
request=request,
response=response,
retry=retry,
timeout=timeout,
metadata=metadata,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1476,6 +1476,8 @@ def sample_list_indexes():
method=rpc,
request=request,
response=response,
retry=retry,
timeout=timeout,
metadata=metadata,
)

Expand Down
41 changes: 39 additions & 2 deletions google/cloud/datastore_admin_v1/services/datastore_admin/pagers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from google.api_core import gapic_v1
from google.api_core import retry as retries
from google.api_core import retry_async as retries_async
from typing import (
Any,
AsyncIterator,
Expand All @@ -22,8 +25,18 @@
Tuple,
Optional,
Iterator,
Union,
)

try:
OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None]
OptionalAsyncRetry = Union[
retries_async.AsyncRetry, gapic_v1.method._MethodDefault, None
]
except AttributeError: # pragma: NO COVER
OptionalRetry = Union[retries.Retry, object, None] # type: ignore
OptionalAsyncRetry = Union[retries_async.AsyncRetry, object, None] # type: ignore

from google.cloud.datastore_admin_v1.types import datastore_admin
from google.cloud.datastore_admin_v1.types import index

Expand Down Expand Up @@ -52,6 +65,8 @@ def __init__(
request: datastore_admin.ListIndexesRequest,
response: datastore_admin.ListIndexesResponse,
*,
retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: Union[float, object] = gapic_v1.method.DEFAULT,
metadata: Sequence[Tuple[str, str]] = ()
):
"""Instantiate the pager.
Expand All @@ -63,12 +78,17 @@ def __init__(
The initial request object.
response (google.cloud.datastore_admin_v1.types.ListIndexesResponse):
The initial response object.
retry (google.api_core.retry.Retry): Designation of what errors,
if any, should be retried.
timeout (float): The timeout for this request.
metadata (Sequence[Tuple[str, str]]): Strings which should be
sent along with the request as metadata.
"""
self._method = method
self._request = datastore_admin.ListIndexesRequest(request)
self._response = response
self._retry = retry
self._timeout = timeout
self._metadata = metadata

def __getattr__(self, name: str) -> Any:
Expand All @@ -79,7 +99,12 @@ def pages(self) -> Iterator[datastore_admin.ListIndexesResponse]:
yield self._response
while self._response.next_page_token:
self._request.page_token = self._response.next_page_token
self._response = self._method(self._request, metadata=self._metadata)
self._response = self._method(
self._request,
retry=self._retry,
timeout=self._timeout,
metadata=self._metadata,
)
yield self._response

def __iter__(self) -> Iterator[index.Index]:
Expand Down Expand Up @@ -114,6 +139,8 @@ def __init__(
request: datastore_admin.ListIndexesRequest,
response: datastore_admin.ListIndexesResponse,
*,
retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT,
timeout: Union[float, object] = gapic_v1.method.DEFAULT,
metadata: Sequence[Tuple[str, str]] = ()
):
"""Instantiates the pager.
Expand All @@ -125,12 +152,17 @@ def __init__(
The initial request object.
response (google.cloud.datastore_admin_v1.types.ListIndexesResponse):
The initial response object.
retry (google.api_core.retry.AsyncRetry): Designation of what errors,
if any, should be retried.
timeout (float): The timeout for this request.
metadata (Sequence[Tuple[str, str]]): Strings which should be
sent along with the request as metadata.
"""
self._method = method
self._request = datastore_admin.ListIndexesRequest(request)
self._response = response
self._retry = retry
self._timeout = timeout
self._metadata = metadata

def __getattr__(self, name: str) -> Any:
Expand All @@ -141,7 +173,12 @@ async def pages(self) -> AsyncIterator[datastore_admin.ListIndexesResponse]:
yield self._response
while self._response.next_page_token:
self._request.page_token = self._response.next_page_token
self._response = await self._method(self._request, metadata=self._metadata)
self._response = await self._method(
self._request,
retry=self._retry,
timeout=self._timeout,
metadata=self._metadata,
)
yield self._response

def __aiter__(self) -> AsyncIterator[index.Index]:
Expand Down
34 changes: 34 additions & 0 deletions tests/system/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,17 @@ def large_query_client(datastore_client):
return large_query_client


@pytest.fixture(scope="session")
def mergejoin_query_client(datastore_client):
mergejoin_query_client = _helpers.clone_client(
datastore_client,
namespace=populate_datastore.MERGEJOIN_DATASET_NAMESPACE,
)
populate_datastore.add_mergejoin_dataset_entities(client=mergejoin_query_client)

return mergejoin_query_client


@pytest.fixture(scope="function")
def large_query(large_query_client):
# Use the client for this test instead of the global.
Expand All @@ -346,6 +357,15 @@ def large_query(large_query_client):
)


@pytest.fixture(scope="function")
def mergejoin_query(mergejoin_query_client):
# Use the client for this test instead of the global.
return mergejoin_query_client.query(
kind=populate_datastore.MERGEJOIN_DATASET_KIND,
namespace=populate_datastore.MERGEJOIN_DATASET_NAMESPACE,
)


@pytest.mark.parametrize(
"limit,offset,expected",
[
Expand Down Expand Up @@ -385,6 +405,20 @@ def test_large_query(large_query, limit, offset, expected, database_id):
assert len(entities) == expected


@pytest.mark.parametrize("database_id", [_helpers.TEST_DATABASE], indirect=True)
def test_mergejoin_query(mergejoin_query, database_id):
query = mergejoin_query
query.add_filter(filter=PropertyFilter("a", "=", 1))
query.add_filter(filter=PropertyFilter("b", "=", 1))

# There should be 2 * MERGEJOIN_QUERY_NUM_RESULTS results total
expected_total = 2 * populate_datastore.MERGEJOIN_QUERY_NUM_RESULTS
for offset in range(0, expected_total + 1):
iterator = query.fetch(offset=offset)
num_entities = len([e for e in iterator])
assert num_entities == expected_total - offset


@pytest.mark.parametrize("database_id", [None, _helpers.TEST_DATABASE], indirect=True)
def test_query_add_property_filter(ancestor_query, database_id):
query = ancestor_query
Expand Down
10 changes: 6 additions & 4 deletions tests/system/utils/clear_datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
"Post",
"uuid_key",
"timestamp_key",
"LargeCharacter",
"Mergejoin",
)
TRANSACTION_MAX_GROUPS = 5
MAX_DEL_ENTITIES = 500
Expand Down Expand Up @@ -90,12 +92,10 @@ def remove_all_entities(client):


def run(database):
client = datastore.Client(database=database)
kinds = sys.argv[1:]

if len(kinds) == 0:
kinds = ALL_KINDS

print_func(
"This command will remove all entities from the database "
+ database
Expand All @@ -105,8 +105,10 @@ def run(database):
response = input("Is this OK [y/n]? ")

if response.lower() == "y":
for kind in kinds:
remove_kind(kind, client)
for namespace in ["", "LargeCharacterEntity", "MergejoinNamespace"]:
client = datastore.Client(database=database, namespace=namespace)
for kind in kinds:
remove_kind(kind, client)

else:
print_func("Doing nothing.")
Expand Down
93 changes: 92 additions & 1 deletion tests/system/utils/populate_datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@
LARGE_CHARACTER_NAMESPACE = "LargeCharacterEntity"
LARGE_CHARACTER_KIND = "LargeCharacter"

MERGEJOIN_QUERY_NUM_RESULTS = 7
MERGEJOIN_DATASET_INTERMEDIATE_OBJECTS = 20000
MERGEJOIN_DATASET_NAMESPACE = "MergejoinNamespace"
MERGEJOIN_DATASET_KIND = "Mergejoin"


def get_system_test_db():
return os.getenv("SYSTEM_TESTS_DATABASE") or "system-tests-named-db"
Expand Down Expand Up @@ -179,12 +184,92 @@ def add_timestamp_keys(client=None):
batch.put(entity)


def add_mergejoin_dataset_entities(client=None):
"""
Dataset to account for one bug that was seen in https://github.com/googleapis/python-datastore/issues/547
The root cause of this is us setting a subsequent query's start_cursor to skipped_cursor instead of end_cursor.
In niche scenarios involving mergejoins, skipped_cursor becomes empty and the query starts back from the beginning,
returning duplicate items.
This bug is able to be reproduced with a dataset shown in b/352377540, with 7 items of a=1, b=1
followed by 20k items of alternating a=1, b=0 and a=0, b=1, then 7 more a=1, b=1, then querying for all
items with a=1, b=1 and an offset of 8.
"""
client.namespace = MERGEJOIN_DATASET_NAMESPACE

# Query used for all tests
page_query = client.query(
kind=MERGEJOIN_DATASET_KIND, namespace=MERGEJOIN_DATASET_NAMESPACE
)

def create_entity(id, a, b):
key = client.key(MERGEJOIN_DATASET_KIND, id)
entity = datastore.Entity(key=key)
entity["a"] = a
entity["b"] = b
return entity

def put_objects(count):
id = 1
curr_intermediate_entries = 0

# Can only do 500 operations in a transaction with an overall
# size limit.
ENTITIES_TO_BATCH = 500

with client.transaction() as xact:
for _ in range(0, MERGEJOIN_QUERY_NUM_RESULTS):
entity = create_entity(id, 1, 1)
id += 1
xact.put(entity)

while curr_intermediate_entries < count - MERGEJOIN_QUERY_NUM_RESULTS:
start = curr_intermediate_entries
end = min(curr_intermediate_entries + ENTITIES_TO_BATCH, count)
with client.transaction() as xact:
# The name/ID for the new entity
for i in range(start, end):
if id % 2:
entity = create_entity(id, 0, 1)
else:
entity = create_entity(id, 1, 0)
id += 1

# Saves the entity
xact.put(entity)
curr_intermediate_entries += ENTITIES_TO_BATCH

with client.transaction() as xact:
for _ in range(0, MERGEJOIN_QUERY_NUM_RESULTS):
entity = create_entity(id, 1, 1)
id += 1
xact.put(entity)

# If anything exists in this namespace, delete it, since we need to
# set up something very specific.
all_entities = [e for e in page_query.fetch()]
if len(all_entities) > 0:
# Cleanup Collection if not an exact match
while all_entities:
entities = all_entities[:500]
all_entities = all_entities[500:]
client.delete_multi([e.key for e in entities])
# Put objects
put_objects(MERGEJOIN_DATASET_INTERMEDIATE_OBJECTS)


def run(database):
client = datastore.Client(database=database)
flags = sys.argv[1:]

if len(flags) == 0:
flags = ["--characters", "--uuid", "--timestamps"]
flags = [
"--characters",
"--uuid",
"--timestamps",
"--large-characters",
"--mergejoin",
]

if "--characters" in flags:
add_characters(client)
Expand All @@ -195,6 +280,12 @@ def run(database):
if "--timestamps" in flags:
add_timestamp_keys(client)

if "--large-characters" in flags:
add_large_character_entities(client)

if "--mergejoin" in flags:
add_mergejoin_dataset_entities(client)


def main():
for database in ["", get_system_test_db()]:
Expand Down
7 changes: 6 additions & 1 deletion tests/unit/gapic/datastore_admin_v1/test_datastore_admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from google.api_core import operation_async # type: ignore
from google.api_core import operations_v1
from google.api_core import path_template
from google.api_core import retry as retries
from google.auth import credentials as ga_credentials
from google.auth.exceptions import MutualTLSChannelError
from google.cloud.datastore_admin_v1.services.datastore_admin import (
Expand Down Expand Up @@ -3119,12 +3120,16 @@ def test_list_indexes_pager(transport_name: str = "grpc"):
)

expected_metadata = ()
retry = retries.Retry()
timeout = 5
expected_metadata = tuple(expected_metadata) + (
gapic_v1.routing_header.to_grpc_metadata((("project_id", ""),)),
)
pager = client.list_indexes(request={})
pager = client.list_indexes(request={}, retry=retry, timeout=timeout)

assert pager._metadata == expected_metadata
assert pager._retry == retry
assert pager._timeout == timeout

results = list(pager)
assert len(results) == 6
Expand Down
1 change: 1 addition & 0 deletions tests/unit/gapic/datastore_v1/test_datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from google.api_core import grpc_helpers
from google.api_core import grpc_helpers_async
from google.api_core import path_template
from google.api_core import retry as retries
from google.auth import credentials as ga_credentials
from google.auth.exceptions import MutualTLSChannelError
from google.cloud.datastore_v1.services.datastore import DatastoreAsyncClient
Expand Down

0 comments on commit 175c76a

Please sign in to comment.