Skip to content

Commit

Permalink
Merge pull request #10 from TuragaLab/toml
Browse files Browse the repository at this point in the history
Switch to pyproject.toml and use conda-forge channel
  • Loading branch information
vaxenburg authored Sep 26, 2024
2 parents 99836be + 3781ca5 commit a4f7219
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 82 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ Follow these steps to install `flybody`:
```bash
git clone https://github.com/TuragaLab/flybody.git
cd flybody
conda env create -f flybody.yml
conda create --name flybody -c conda-forge python=3.10 pip ipython cudatoolkit=11.8.0
conda activate flybody
```
`flybody` can be installed in one of the three modes described next. Also, for installation in editable (developer) mode, use the commands as shown. For installation in regular, not editable, mode, drop the `-e` flag.
Expand All @@ -96,7 +96,7 @@ Follow these steps to install `flybody`:
### Option 2: Installation from remote repo
1. Create a new conda environment:
```bash
conda create --name flybody python=3.10 pip ipython cudatoolkit cudnn=8.2.1=cuda11.3_0
conda create --name flybody -c conda-forge python=3.10 pip ipython cudatoolkit=11.8.0
conda activate flybody
```
Proceed with installation in one of the three modes (described above):
Expand Down
7 changes: 0 additions & 7 deletions flybody.yml

This file was deleted.

20 changes: 11 additions & 9 deletions flybody/train_dmpo_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@


parser = argparse.ArgumentParser()
parser.add_argument('--test',
parser.add_argument('--test', '-t',
help='Run job in test mode with one actor and output to current terminal.',
action='store_true')
args = parser.parse_args()
Expand Down Expand Up @@ -99,7 +99,9 @@ def environment_factory(training: bool) -> 'composer.Environment':

# This callable will be calculating penalization cost by converting canonical
# actions to real (not wrapped) environment actions inside DMPO agent.
penalization_cost = PenalizationCostRealActions(dummy_env.action_spec())
# Note that we need the action_spec of the underlying environment so we unwrap
# with dummy_env._environment.
penalization_cost = PenalizationCostRealActions(dummy_env._environment.action_spec())

# Distributed DMPO agent configuration.
dmpo_config = DMPOConfig(
Expand All @@ -113,13 +115,13 @@ def environment_factory(training: bool) -> 'composer.Environment':
n_step=5,
num_samples=20,
policy_loss_module=policy_loss_module_dmpo(
epsilon=0.1,
epsilon_mean=0.0025,
epsilon_stddev=1e-7,
action_penalization=True,
epsilon_penalty=0.1,
penalization_cost=penalization_cost,
),
epsilon=0.1,
epsilon_mean=0.0025,
epsilon_stddev=1e-7,
action_penalization=True,
epsilon_penalty=0.1,
penalization_cost=penalization_cost,
),
policy_optimizer=snt.optimizers.Adam(1e-4),
critic_optimizer=snt.optimizers.Adam(1e-4),
dual_optimizer=snt.optimizers.Adam(1e-3),
Expand Down
58 changes: 58 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
[build-system]
requires = ["setuptools >= 61.0"]
build-backend = "setuptools.build_meta"

[project]
name = "flybody"
version = "0.1.0"
dependencies = [
"numpy==1.26.4",
"dm_control",
"h5py",
"pytest",
"mediapy",
]
requires-python = ">=3.10"
authors = [
{name = "Roman Vaxenburg", email="vaxenburgr@hhmi.org"},
{name = "Gert-Jan Both", email="bothg@hhmi.org"},
{name = "Yuval Tassa", email="tassa@google.com"},
{name = "Zinovia Stefanidi", email="zinovia.stefanidi@uni-tuebingen.de"},
]
description = "MuJoCo fruit fly body model and reinforcement learning tasks."
readme = "README.md"
license = {file = "LICENSE"}
keywords = ["mujoco", "reinforcement learning", "flybody"]
classifiers = [
"Development Status :: 4 - Beta",
"Programming Language :: Python",
]

[project.urls]
repository = "https://github.com/TuragaLab/flybody"

[project.optional-dependencies]
tf = [
"dm-acme[tf,envs,jax]",
"nvidia-cudnn-cu11==8.9.*",
"tensorflow==2.8.0",
"tensorflow-probability==0.16.0",
"dm-reverb==0.7.0",
]
ray = [
"flybody[tf]",
"ray[default]",
]
dev = [
"ruff",
"jupyterlab",
"tqdm",
]
all = ["flybody[ray,dev]"]

[tool.setuptools.packages.find]
include = ["flybody*"]
namespaces = false

[tool.setuptools.package-data]
flybody = ["fruitfly/assets/*.obj", "fruitfly/assets/*.xml"]
64 changes: 0 additions & 64 deletions setup.py

This file was deleted.

0 comments on commit a4f7219

Please sign in to comment.