From bfe1b56e5b9100536d73d6b67b99e246bdc9231f Mon Sep 17 00:00:00 2001 From: Rohin Bhasin Date: Fri, 3 May 2024 14:09:06 -0400 Subject: [PATCH] Switch parallel embedding example to use async primitives. --- .../parallel_hf_embedding_ec2.py | 57 ++++++++++--------- 1 file changed, 29 insertions(+), 28 deletions(-) diff --git a/examples/parallel-hf-embedding-ec2/parallel_hf_embedding_ec2.py b/examples/parallel-hf-embedding-ec2/parallel_hf_embedding_ec2.py index abf61d9ad..380d0f295 100644 --- a/examples/parallel-hf-embedding-ec2/parallel_hf_embedding_ec2.py +++ b/examples/parallel-hf-embedding-ec2/parallel_hf_embedding_ec2.py @@ -35,7 +35,7 @@ # Imports of libraries that are needed on the remote machine (in this case, the `langchain` dependencies) # can happen within the functions that will be sent to the Runhouse cluster. -import concurrent.futures +import asyncio import time from typing import List from urllib.parse import urljoin, urlparse @@ -163,7 +163,7 @@ def embed_docs(self, urls: List[str]): # Make sure that your code runs within a `if __name__ == "__main__":` block, as shown below. Otherwise, # the script code will run when Runhouse attempts to import your code remotely. # ::: -if __name__ == "__main__": +async def main(): cluster = rh.cluster("rh-a10g", instance_type="A10G:4").save().up_if_not() # We set up some parameters for our embedding task. @@ -211,33 +211,34 @@ def embed_docs(self, urls: List[str]): print(f"Time to initialize {num_replicas} replicas: {time.time() - start_time}") # ## Calling the Runhouse modules in parallel - # We set up a loop to call each replica in parallel with the partitioned URLs. + # We'll simply use the `embed_docs` function on the remote module to embed all the URLs in parallel. Note that + # we can call this function exactly as if it were a local module. In this case, because we are using the + # `asyncio` library to make parallel calls, we need to use a special `run_async=True` argument to the + # Runhouse function. This tells Runhouse to return a coroutine that we can await on, rather than making + # a blocking network call to the server. start_time = time.time() - results = [] - - # Note again that we can call the `embed_docs` function on the - # remote module exactly as if it were a local module. - # This function is simply a wrapper to use with the ThreadPoolExecutor. - def call_on_replica(replica, urls): - return replica.embed_docs(urls) - - # This is standard Python code that uses the ThreadPoolExecutor to make `num_parallel_tasks` calls at once - # to the replicas. We then collect the results and print the total time taken. - with concurrent.futures.ThreadPoolExecutor( - max_workers=num_parallel_tasks - ) as executor: - futs = [ - executor.submit(call_on_replica, replicas[i % num_replicas], [urls[i]]) - for i in range(len(urls)) - ] - for fut in concurrent.futures.as_completed(futs): - res = fut.result() - if res is not None: - results.extend(res) - else: - print("An embedding call failed.") - - print(f"Received {len(results)} total embeddings.") + futs = [ + asyncio.create_task( + replicas[i % num_replicas].embed_docs([urls[i]], run_async=True) + ) + for i in range(len(urls)) + ] + + all_embeddings = [] + failures = 0 + task_results = await asyncio.gather(*futs) + for res in task_results: + if res is not None: + all_embeddings.extend(res) + else: + print("An embedding call failed.") + failures += 1 + + print(f"Received {len(all_embeddings)} total embeddings, with {failures} failures.") print( f"Embedded {len(urls)} docs across {num_replicas} replicas with {num_parallel_tasks} concurrent calls: {time.time() - start_time}" ) + + +if __name__ == "__main__": + asyncio.run(main())