diff --git a/README.md b/README.md index aa36cb09..ac1529e2 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,8 @@ with an efficient hardware-aware design and implementation in the spirit of [Fla - [Option] `pip install causal-conv1d>=1.2.0`: an efficient implementation of a simple causal Conv1d layer used inside the Mamba block. - `pip install mamba-ssm`: the core Mamba package. +- `pip install mamba-ssm[causal-conv1d]`: To install core Mamba package and causal-conv1d. +- `pip install mamba-ssm[dev]`: To install core Mamba package and dev depdencies. It can also be built from source with `pip install .` from this repository. diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..8703ad35 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,46 @@ +[project] +name = "mamba_ssm" +description = "Mamba state-space model" +readme = "README.md" +authors = [ + { name = "Tri Dao", email = "tri@tridao.me" }, + { name = "Albert Gu", email = "agu@cs.cmu.edu" } +] +requires-python = ">= 3.7" +dynamic = ["version"] +license = { file = "LICENSE" } # Include a LICENSE file in your repo +keywords = ["cuda", "pytorch", "state-space model"] +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: BSD License", + "Operating System :: Unix" +] +dependencies = [ + "torch", + "ninja", + "einops", + "triton", + "transformers", + "packaging", + "setuptools>=61.0.0", +] +urls = { name = "Repository", url = "https://github.com/state-spaces/mamba"} + +[project.optional-dependencies] +causal-conv1d = [ + "causal-conv1d>=1.2.0" +] +dev = [ + "pytest" +] + + +[build-system] +requires = [ + "setuptools>=61.0.0", + "wheel", + "torch", + "packaging", + "ninja", +] +build-backend = "setuptools.build_meta" diff --git a/setup.py b/setup.py index c17ab0bb..b4380f01 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,6 @@ import torch from torch.utils.cpp_extension import ( BuildExtension, - CppExtension, CUDAExtension, CUDA_HOME, ) @@ -254,31 +253,13 @@ def run(self): "mamba_ssm.egg-info", ) ), - author="Tri Dao, Albert Gu", - author_email="tri@tridao.me, agu@cs.cmu.edu", - description="Mamba state-space model", long_description=long_description, long_description_content_type="text/markdown", - url="https://github.com/state-spaces/mamba", - classifiers=[ - "Programming Language :: Python :: 3", - "License :: OSI Approved :: BSD License", - "Operating System :: Unix", - ], + ext_modules=ext_modules, cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": BuildExtension} if ext_modules else { "bdist_wheel": CachedWheelsCommand, - }, - python_requires=">=3.7", - install_requires=[ - "torch", - "packaging", - "ninja", - "einops", - "triton", - "transformers", - # "causal_conv1d>=1.2.0", - ], + } )