Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup MegaBlocks Installation #126

Closed
wants to merge 26 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix reading repo version
  • Loading branch information
Eitan Turok committed Jul 24, 2024
commit f62c03492e0561da83633ad8bcf290378539b406
20 changes: 6 additions & 14 deletions setup.py
Original file line number Diff line number Diff line change
@@ -4,7 +4,6 @@
"""MegaBlocks package setup."""

import os
import re
import warnings

from setuptools import find_packages, setup
@@ -17,7 +16,7 @@
CUDAExtension,)
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
"No module named 'torch'. Torch is required to install this repo."
"No module named 'torch'. `torch` is required to install this repo."
eitanturok marked this conversation as resolved.
Show resolved Hide resolved
) from e


@@ -28,18 +27,11 @@

# Read the package version
# We can't use `.__version__` from the library since it's not installed yet
with open(os.path.join(_PACKAGE_REAL_PATH, '__init__.py'), encoding='utf-8') as f:
content = f.read()
print(f"Content: {content}")
# regex: '__version__', whitespace?, '=', whitespace, quote, version, quote
# we put parens around the version so that it becomes elem 1 of the match
expr = re.compile(
r"""^__version__\s*=\s*['"]([0-9]+\.[0-9]+\.[0-9]+(?:\.\w+)?)['"]""",
re.MULTILINE,
)
repo_version = expr.findall(content)[0]


with open(os.path.join(_PACKAGE_REAL_PATH, '_version.py'), encoding='utf-8') as f:
version_globals = {}
version_locals = {}
exec(f.read(), version_globals, version_locals)
repo_version = version_locals['__version__']

install_requires = [
'numpy>=1.21.5,<2.1.0',