diff --git a/setup.py b/setup.py index 2df791c..511d394 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,9 @@ -from setuptools import setup, find_packages -import torch +from setuptools import find_packages, setup + +install_requires = [ + 'torch>=2.3.0,<2.4', + 'triton>=2.1.0', +] extra_deps = {} @@ -7,7 +11,7 @@ 'absl-py', ] -extra_deps['all'] = set(dep for deps in extra_deps.values() for dep in deps) +extra_deps['all'] = list(set(dep for deps in extra_deps.values() for dep in deps)) setup( name="stanford-stk", @@ -24,8 +28,6 @@ "Operating System :: Unix", ], packages=find_packages(), - install_requires=[ - "triton>=2.1.0", - ], + install_requires=install_requires, extras_require=extra_deps, )