diff --git a/src/levanter/utils/thread_utils.py b/src/levanter/utils/thread_utils.py index 0b4abcdaf..4ac4e1f1e 100644 --- a/src/levanter/utils/thread_utils.py +++ b/src/levanter/utils/thread_utils.py @@ -34,24 +34,41 @@ class AsyncIteratorWrapper(Iterator): def __init__(self, async_iter): self.async_iter = async_iter self.loop = asyncio.new_event_loop() - self.executor = ThreadPoolExecutor(max_workers=1) self.thread = threading.Thread(target=self._run_loop, daemon=True) self.thread.start() + self._exhausted = False # Flag to indicate if the iterator is exhausted def _run_loop(self): asyncio.set_event_loop(self.loop) self.loop.run_forever() def _run_async_task(self, coro): - return asyncio.run_coroutine_threadsafe(coro, self.loop).result() + if not self.loop.is_running() or not self.thread.is_alive(): + raise StopIteration # Loop is not running or thread has been joined + try: + future = asyncio.run_coroutine_threadsafe(coro, self.loop) + return future.result() + except (RuntimeError, asyncio.CancelledError): + raise StopIteration # Either the loop was closed or the coroutine was cancelled def __iter__(self): return self def __next__(self): + if self._exhausted: + raise StopIteration try: return self._run_async_task(self.async_iter.__anext__()) except StopAsyncIteration: - self.loop.call_soon_threadsafe(self.loop.stop) - self.thread.join() + self._exhausted = True # Mark the iterator as exhausted + if self.loop.is_running(): + self.loop.call_soon_threadsafe(self.loop.stop) + self.thread.join() # Ensure the thread is safely joined raise StopIteration + + def close(self): + """Close the event loop and thread gracefully.""" + if self.loop.is_running(): + self.loop.call_soon_threadsafe(self.loop.stop) + self.thread.join() # Join the thread to ensure the loop is fully stopped + self.loop.close() # Explicitly close the loop to avoid dangling tasks