Skip to content

Commit

Permalink
fix: colab GPU environment
Browse files Browse the repository at this point in the history
  • Loading branch information
borisdayma committed Aug 22, 2023
1 parent 2126f40 commit f96f1e5
Showing 1 changed file with 9 additions and 50 deletions.
59 changes: 9 additions & 50 deletions tools/inference/inference_pipeline.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,11 @@
},
"outputs": [],
"source": [
"# Required only for colab environments + GPU\n",
"!pip install \"jax[cuda12_pip]==0.3.25\" jaxlib==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n",
"\n",
"# Install required libraries\n",
"!pip install -q dalle-mini\n",
"!pip install -q dalle-mini orbax==0.0.23\n",
"!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git"
]
},
Expand Down Expand Up @@ -339,54 +342,9 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "SDjEx9JxR3v8",
"outputId": "48388f06-4733-43b4-f5e0-cf4a7d1c5b69",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 124,
"referenced_widgets": [
"9281546b8f124c35ac8f583fb81b97b5"
]
}
},
"outputs": [
{
"metadata": {
"tags": null
},
"name": "stdout",
"output_type": "stream",
"text": [
"Prompts: ['sunset over a lake in the mountains', 'the Eiffel tower landing on the moon']\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9281546b8f124c35ac8f583fb81b97b5",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/8 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"metadata": {
"tags": null
},
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.10/dist-packages/jax/_src/ops/scatter.py:87: FutureWarning: scatter inputs have incompatible types: cannot safely cast value from dtype=float16 to dtype=float32. In future JAX releases this will result in an error.\n",
" warnings.warn(\"scatter inputs have incompatible types: cannot safely cast \"\n"
]
}
],
"id": "SDjEx9JxR3v8"
},
"outputs": [],
"source": [
"from flax.training.common_utils import shard_prng_key\n",
"import numpy as np\n",
Expand Down Expand Up @@ -571,11 +529,12 @@
}
],
"metadata": {
"accelerator": "TPU",
"accelerator": "GPU",
"colab": {
"machine_shape": "hm",
"name": "DALL·E mini - Inference pipeline.ipynb",
"provenance": [],
"gpuType": "A100",
"include_colab_link": true
},
"kernelspec": {
Expand Down

0 comments on commit f96f1e5

Please sign in to comment.