From 8d5241f8d16cabddfbc55bf8be803d8c596a3910 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Tue, 21 Mar 2023 14:31:54 -0700 Subject: [PATCH] Update usage of `get_worker()` in tests In https://github.com/dask/distributed/pull/7580/ `get_worker` was modified to return the worker of a task, thus it cannot be used by `client.run`, and we must now use `dask_worker` as the first argument to `client.run` to obtain the worker. --- dask_cuda/tests/test_explicit_comms.py | 6 +-- dask_cuda/tests/test_local_cuda_cluster.py | 9 +++-- dask_cuda/tests/test_spill.py | 44 +++++++++++++++------- 3 files changed, 39 insertions(+), 20 deletions(-) diff --git a/dask_cuda/tests/test_explicit_comms.py b/dask_cuda/tests/test_explicit_comms.py index 624815e7..d1024ff6 100644 --- a/dask_cuda/tests/test_explicit_comms.py +++ b/dask_cuda/tests/test_explicit_comms.py @@ -11,7 +11,7 @@ from dask import dataframe as dd from dask.dataframe.shuffle import partitioning_index from dask.dataframe.utils import assert_eq -from distributed import Client, get_worker +from distributed import Client from distributed.deploy.local import LocalCluster import dask_cuda @@ -314,8 +314,8 @@ def test_jit_unspill(protocol): def _test_lock_workers(scheduler_address, ranks): - async def f(_): - worker = get_worker() + async def f(info): + worker = info["worker"] if hasattr(worker, "running"): assert not worker.running worker.running = True diff --git a/dask_cuda/tests/test_local_cuda_cluster.py b/dask_cuda/tests/test_local_cuda_cluster.py index a72ec3f2..f2e48783 100644 --- a/dask_cuda/tests/test_local_cuda_cluster.py +++ b/dask_cuda/tests/test_local_cuda_cluster.py @@ -9,7 +9,6 @@ from dask.distributed import Client from distributed.system import MEMORY_LIMIT from distributed.utils_test import gen_test, raises_with_cause -from distributed.worker import get_worker from dask_cuda import CUDAWorker, LocalCUDACluster, utils from dask_cuda.initialize import initialize @@ -140,7 +139,9 @@ async def test_no_memory_limits_cluster(): ) as cluster: async with Client(cluster, asynchronous=True) as client: # Check that all workers use a regular dict as their "data store". - res = await client.run(lambda: isinstance(get_worker().data, dict)) + res = await client.run( + lambda dask_worker: isinstance(dask_worker.data, dict) + ) assert all(res.values()) @@ -161,7 +162,9 @@ async def test_no_memory_limits_cudaworker(): await new_worker await client.wait_for_workers(2) # Check that all workers use a regular dict as their "data store". - res = await client.run(lambda: isinstance(get_worker().data, dict)) + res = await client.run( + lambda dask_worker: isinstance(dask_worker.data, dict) + ) assert all(res.values()) await new_worker.close() diff --git a/dask_cuda/tests/test_spill.py b/dask_cuda/tests/test_spill.py index f93b83ec..bbd24d5a 100644 --- a/dask_cuda/tests/test_spill.py +++ b/dask_cuda/tests/test_spill.py @@ -6,7 +6,7 @@ import dask from dask import array as da -from distributed import Client, get_worker, wait +from distributed import Client, wait from distributed.metrics import time from distributed.sizeof import sizeof from distributed.utils_test import gen_cluster, gen_test, loop # noqa: F401 @@ -57,21 +57,25 @@ def assert_device_host_file_size( ) -def worker_assert(total_size, device_chunk_overhead, serialized_chunk_overhead): +def worker_assert( + dask_worker, total_size, device_chunk_overhead, serialized_chunk_overhead +): assert_device_host_file_size( - get_worker().data, total_size, device_chunk_overhead, serialized_chunk_overhead + dask_worker.data, total_size, device_chunk_overhead, serialized_chunk_overhead ) -def delayed_worker_assert(total_size, device_chunk_overhead, serialized_chunk_overhead): +def delayed_worker_assert( + dask_worker, total_size, device_chunk_overhead, serialized_chunk_overhead +): start = time() while not device_host_file_size_matches( - get_worker().data, total_size, device_chunk_overhead, serialized_chunk_overhead + dask_worker.data, total_size, device_chunk_overhead, serialized_chunk_overhead ): sleep(0.01) if time() < start + 3: assert_device_host_file_size( - get_worker().data, + dask_worker.data, total_size, device_chunk_overhead, serialized_chunk_overhead, @@ -143,17 +147,23 @@ async def test_cupy_cluster_device_spill(params): await wait(xx) # Allow up to 1024 bytes overhead per chunk serialized - await client.run(worker_assert, x.nbytes, 1024, 1024) + await client.run( + lambda dask_worker: worker_assert(dask_worker, x.nbytes, 1024, 1024) + ) y = client.compute(x.sum()) res = await y assert (abs(res / x.size) - 0.5) < 1e-3 - await client.run(worker_assert, x.nbytes, 1024, 1024) - host_chunks = await client.run(lambda: len(get_worker().data.host)) + await client.run( + lambda dask_worker: worker_assert(dask_worker, x.nbytes, 1024, 1024) + ) + host_chunks = await client.run( + lambda dask_worker: len(dask_worker.data.host) + ) disk_chunks = await client.run( - lambda: len(get_worker().data.disk or list()) + lambda dask_worker: len(dask_worker.data.disk or list()) ) for hc, dc in zip(host_chunks.values(), disk_chunks.values()): if params["spills_to_disk"]: @@ -245,9 +255,11 @@ async def test_cudf_cluster_device_spill(params): del cdf - host_chunks = await client.run(lambda: len(get_worker().data.host)) + host_chunks = await client.run( + lambda dask_worker: len(dask_worker.data.host) + ) disk_chunks = await client.run( - lambda: len(get_worker().data.disk or list()) + lambda dask_worker: len(dask_worker.data.disk or list()) ) for hc, dc in zip(host_chunks.values(), disk_chunks.values()): if params["spills_to_disk"]: @@ -256,8 +268,12 @@ async def test_cudf_cluster_device_spill(params): assert hc > 0 assert dc == 0 - await client.run(worker_assert, nbytes, 32, 2048) + await client.run( + lambda dask_worker: worker_assert(dask_worker, nbytes, 32, 2048) + ) del cdf2 - await client.run(delayed_worker_assert, 0, 0, 0) + await client.run( + lambda dask_worker: delayed_worker_assert(dask_worker, 0, 0, 0) + )