diff --git a/tools/inference/inference_pipeline.ipynb b/tools/inference/inference_pipeline.ipynb index 92e0f3d15..298a550ff 100644 --- a/tools/inference/inference_pipeline.ipynb +++ b/tools/inference/inference_pipeline.ipynb @@ -47,7 +47,7 @@ "outputs": [], "source": [ "# Required only for colab environments + GPU\n", - "!pip install \"jax[cuda12_pip]==0.3.25\" jaxlib==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n", + "!pip install jax==0.3.25 jaxlib==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n", "\n", "# Install required libraries\n", "!pip install -q dalle-mini orbax==0.0.23\n",