diff --git a/runhouse/resources/module.py b/runhouse/resources/module.py index 1dc184aa6..d4a4cd65f 100644 --- a/runhouse/resources/module.py +++ b/runhouse/resources/module.py @@ -967,6 +967,24 @@ def _extract_module_path(raw_cls_or_fn: Union[Type, Callable]): return module_path + @staticmethod + def _is_running_in_notebook(module_path: Union[str, None]) -> bool: + """Returns True if running in a notebook, False otherwise""" + + # Check if running in an IPython notebook + # TODO better way of detecting if in a notebook or interactive Python env + if not module_path or module_path.endswith("ipynb"): + return True + + # Check if running in a marimo notebook + try: + import marimo as mo + + return mo.running_in_notebook() + except (ImportError, ModuleNotFoundError): + # marimo not installed + return False + @staticmethod def _extract_pointers(raw_cls_or_fn: Union[Type, Callable], reqs: List[str]): """Get the path to the module, module name, and function name to be able to import it on the server""" @@ -980,8 +998,7 @@ def _extract_pointers(raw_cls_or_fn: Union[Type, Callable], reqs: List[str]): # Need to resolve in case just filename is given module_path = Module._extract_module_path(raw_cls_or_fn) - # TODO better way of detecting if in a notebook or interactive Python env - if not module_path or module_path.endswith("ipynb"): + if Module._is_running_in_notebook(module_path): # The only time __file__ wouldn't be present is if the function is defined in an interactive # interpreter or a notebook. We can't import on the server in that case, so we need to cloudpickle # the fn to send it over. The __call__ function will serialize the function if we return it this way.