diff --git a/runhouse/servers/http/http_server.py b/runhouse/servers/http/http_server.py index 4d69ef25b..5e365ac0f 100644 --- a/runhouse/servers/http/http_server.py +++ b/runhouse/servers/http/http_server.py @@ -14,6 +14,7 @@ import yaml from fastapi import Body, FastAPI, HTTPException, Request from fastapi.encoders import jsonable_encoder +from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html from fastapi.responses import StreamingResponse from runhouse.constants import ( @@ -47,6 +48,7 @@ from runhouse.servers.obj_store import ( ClusterServletSetupOption, ObjStore, + ObjStoreError, RaySetupOption, ) @@ -361,6 +363,45 @@ async def _call( from_http_server=True, ) + # Open API docs routes + @staticmethod + @app.get("/{key}/openapi.json") + @validate_cluster_access + async def get_openapi_spec(request: Request, key: str): + try: + module_openapi_spec = obj_store.call( + key, method_name="openapi_spec", serialization=None + ) + except (AttributeError, ObjStoreError): + # The object put on the server is not an `rh.Module`, so it doesn't have an openapi_spec method + # OR + # The object is not found in the object store at all + module_openapi_spec = None + + if not module_openapi_spec: + raise HTTPException(status_code=404, detail=f"Module {key} not found.") + + return module_openapi_spec + + @staticmethod + @app.get("/{key}/redoc") + @validate_cluster_access + async def get_redoc(request: Request, key: str): + return get_redoc_html( + openapi_url=f"/{key}/openapi.json", title="Developer Documentation" + ) + + @staticmethod + @app.get("/{key}/docs") + @validate_cluster_access + async def get_swagger_ui_html(request: Request, key: str): + return get_swagger_ui_html( + openapi_url=f"/{key}/openapi.json", + title=f"{key} - Swagger UI", + oauth2_redirect_url=app.swagger_ui_oauth2_redirect_url, + init_oauth=app.swagger_ui_init_oauth, + ) + # TODO match "/{key}/{method_name}/{path:more_path}" for asgi / proxy requests @staticmethod @app.post("/{key}/{method_name}")