Skip to content

Commit

Permalink
Update usage of get_worker() in tests
Browse files Browse the repository at this point in the history
In dask/distributed#7580 `get_worker` was
modified to return the worker of a task, thus it cannot be used by
`client.run`, and we must now use `dask_worker` as the first argument to
`client.run` to obtain the worker.
  • Loading branch information
pentschev committed Mar 21, 2023
1 parent ec2de78 commit 8d5241f
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 20 deletions.
6 changes: 3 additions & 3 deletions dask_cuda/tests/test_explicit_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from dask import dataframe as dd
from dask.dataframe.shuffle import partitioning_index
from dask.dataframe.utils import assert_eq
from distributed import Client, get_worker
from distributed import Client
from distributed.deploy.local import LocalCluster

import dask_cuda
Expand Down Expand Up @@ -314,8 +314,8 @@ def test_jit_unspill(protocol):


def _test_lock_workers(scheduler_address, ranks):
async def f(_):
worker = get_worker()
async def f(info):
worker = info["worker"]
if hasattr(worker, "running"):
assert not worker.running
worker.running = True
Expand Down
9 changes: 6 additions & 3 deletions dask_cuda/tests/test_local_cuda_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from dask.distributed import Client
from distributed.system import MEMORY_LIMIT
from distributed.utils_test import gen_test, raises_with_cause
from distributed.worker import get_worker

from dask_cuda import CUDAWorker, LocalCUDACluster, utils
from dask_cuda.initialize import initialize
Expand Down Expand Up @@ -140,7 +139,9 @@ async def test_no_memory_limits_cluster():
) as cluster:
async with Client(cluster, asynchronous=True) as client:
# Check that all workers use a regular dict as their "data store".
res = await client.run(lambda: isinstance(get_worker().data, dict))
res = await client.run(
lambda dask_worker: isinstance(dask_worker.data, dict)
)
assert all(res.values())


Expand All @@ -161,7 +162,9 @@ async def test_no_memory_limits_cudaworker():
await new_worker
await client.wait_for_workers(2)
# Check that all workers use a regular dict as their "data store".
res = await client.run(lambda: isinstance(get_worker().data, dict))
res = await client.run(
lambda dask_worker: isinstance(dask_worker.data, dict)
)
assert all(res.values())
await new_worker.close()

Expand Down
44 changes: 30 additions & 14 deletions dask_cuda/tests/test_spill.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import dask
from dask import array as da
from distributed import Client, get_worker, wait
from distributed import Client, wait
from distributed.metrics import time
from distributed.sizeof import sizeof
from distributed.utils_test import gen_cluster, gen_test, loop # noqa: F401
Expand Down Expand Up @@ -57,21 +57,25 @@ def assert_device_host_file_size(
)


def worker_assert(total_size, device_chunk_overhead, serialized_chunk_overhead):
def worker_assert(
dask_worker, total_size, device_chunk_overhead, serialized_chunk_overhead
):
assert_device_host_file_size(
get_worker().data, total_size, device_chunk_overhead, serialized_chunk_overhead
dask_worker.data, total_size, device_chunk_overhead, serialized_chunk_overhead
)


def delayed_worker_assert(total_size, device_chunk_overhead, serialized_chunk_overhead):
def delayed_worker_assert(
dask_worker, total_size, device_chunk_overhead, serialized_chunk_overhead
):
start = time()
while not device_host_file_size_matches(
get_worker().data, total_size, device_chunk_overhead, serialized_chunk_overhead
dask_worker.data, total_size, device_chunk_overhead, serialized_chunk_overhead
):
sleep(0.01)
if time() < start + 3:
assert_device_host_file_size(
get_worker().data,
dask_worker.data,
total_size,
device_chunk_overhead,
serialized_chunk_overhead,
Expand Down Expand Up @@ -143,17 +147,23 @@ async def test_cupy_cluster_device_spill(params):
await wait(xx)

# Allow up to 1024 bytes overhead per chunk serialized
await client.run(worker_assert, x.nbytes, 1024, 1024)
await client.run(
lambda dask_worker: worker_assert(dask_worker, x.nbytes, 1024, 1024)
)

y = client.compute(x.sum())
res = await y

assert (abs(res / x.size) - 0.5) < 1e-3

await client.run(worker_assert, x.nbytes, 1024, 1024)
host_chunks = await client.run(lambda: len(get_worker().data.host))
await client.run(
lambda dask_worker: worker_assert(dask_worker, x.nbytes, 1024, 1024)
)
host_chunks = await client.run(
lambda dask_worker: len(dask_worker.data.host)
)
disk_chunks = await client.run(
lambda: len(get_worker().data.disk or list())
lambda dask_worker: len(dask_worker.data.disk or list())
)
for hc, dc in zip(host_chunks.values(), disk_chunks.values()):
if params["spills_to_disk"]:
Expand Down Expand Up @@ -245,9 +255,11 @@ async def test_cudf_cluster_device_spill(params):

del cdf

host_chunks = await client.run(lambda: len(get_worker().data.host))
host_chunks = await client.run(
lambda dask_worker: len(dask_worker.data.host)
)
disk_chunks = await client.run(
lambda: len(get_worker().data.disk or list())
lambda dask_worker: len(dask_worker.data.disk or list())
)
for hc, dc in zip(host_chunks.values(), disk_chunks.values()):
if params["spills_to_disk"]:
Expand All @@ -256,8 +268,12 @@ async def test_cudf_cluster_device_spill(params):
assert hc > 0
assert dc == 0

await client.run(worker_assert, nbytes, 32, 2048)
await client.run(
lambda dask_worker: worker_assert(dask_worker, nbytes, 32, 2048)
)

del cdf2

await client.run(delayed_worker_assert, 0, 0, 0)
await client.run(
lambda dask_worker: delayed_worker_assert(dask_worker, 0, 0, 0)
)

0 comments on commit 8d5241f

Please sign in to comment.