Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create Pip Package #21

Closed
wants to merge 13 commits into from
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
*.o
*.so
*.a
.cache/
.vs/
Expand Down Expand Up @@ -32,3 +33,8 @@ compile_commands.json
.venv
__pycache__
.swiftpm

_skbuild
*.egg-info
dist/
CMakeFiles/
25 changes: 25 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -245,3 +245,28 @@ if (BUILD_SHARED_LIBS)
set_target_properties(rwkv PROPERTIES POSITION_INDEPENDENT_CODE ON)
target_compile_definitions(rwkv PRIVATE RWKV_SHARED RWKV_BUILD)
endif()

if(SKBUILD)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if(SKBUILD)
if (SKBUILD)

if (UNIX)
add_custom_command(
OUTPUT ${CMAKE_CURRENT_SOURCE_DIR}/librwkv.so
COMMAND make librwkv.so
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
)
add_custom_target(
run ALL
DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/librwkv.so
)
install(
FILES ${CMAKE_CURRENT_SOURCE_DIR}/librwkv.so
DESTINATION rwkv
)
else()
set(BUILD_SHARED_LIBS "On")
install(
TARGETS rwkv
LIBRARY DESTINATION rwkv
RUNTIME DESTINATION rwkv
)
endif(UNIX)
endif(SKBUILD)
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -227,3 +227,6 @@ ggml.o: ggml.c ggml.h

rwkv.o: rwkv.cpp rwkv.h
$(CXX) $(CXXFLAGS) -c rwkv.cpp -o rwkv.o

librwkv.so: rwkv.o ggml.o
saharNooby marked this conversation as resolved.
Show resolved Hide resolved
$(CXX) $(CXXFLAGS) -shared -fPIC -o librwkv.so rwkv.o ggml.o $(LDFLAGS)
8 changes: 8 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
[build-system]
requires = [
"setuptools>=42",
"scikit-build>=0.13",
"cmake>=3.18",
"ninja",
]
build-backend = "setuptools.build_meta"
2 changes: 2 additions & 0 deletions rwkv/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .rwkv_cpp_shared_library import *
from .rwkv_cpp_model import *
4 changes: 3 additions & 1 deletion rwkv/rwkv_cpp_shared_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def load_rwkv_shared_library() -> RWKVSharedLibrary:
if 'win32' in sys.platform or 'cygwin' in sys.platform:
file_name = 'rwkv.dll'
elif 'darwin' in sys.platform:
file_name = 'librwkv.dylib'
file_name = 'librwkv.so'
else:
file_name = 'librwkv.so'

Expand All @@ -201,6 +201,8 @@ def load_rwkv_shared_library() -> RWKVSharedLibrary:
f'bin/Release/{file_name}',
# Search relative to this file
str(repo_root_dir / 'bin' / 'Release' / file_name),
# Search in python package
str(repo_root_dir / 'rwkv' / file_name),
# Fallback
str(repo_root_dir / file_name)
]
Expand Down
27 changes: 27 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from skbuild import setup

setup(
name="rwkv_cpp_python",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess just rwkv_cpp or more preferred rwkv-cpp is fine, since this is already a Python package :)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, I'll update this as well. I believe both alias to the same name on PyPI.

description="A Python wrapper for rwkv.cpp",
long_description_content_type="text/markdown",
version="0.0.1",
author="",
author_email="",
license="MIT",
package_dir={"rwkv_cpp": "rwkv"},
packages=["rwkv_cpp"],
install_requires=[
"numpy>=1.24.1",
"torch>=2.0.0",
"tokenizers>=0.13.3",
],
python_requires=">=3.7",
classifiers=[
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
],
)