Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: pipeline tuning wandb integration #398

Merged
merged 8 commits into from
Jan 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 108 additions & 15 deletions dance/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import importlib
import inspect
from copy import deepcopy
from pprint import pformat

from omegaconf import DictConfig, OmegaConf

from dance import logger
from dance.config import Config
from dance.exceptions import DevError
from dance.registry import REGISTRY, REGISTRY_PREFIX, Registry, resolve_from_registry
from dance.typing import Any, Callable, ConfigLike, Dict, FileExistHandle, List, Optional, PathLike, Union
from dance.utils import default
from dance.typing import Any, Callable, ConfigLike, Dict, FileExistHandle, List, Optional, PathLike, Tuple, Union
from dance.utils import Color, default


class Action:
Expand Down Expand Up @@ -245,6 +246,7 @@ class PipelinePlaner(Pipeline):
DEFAULT_PARAMS_KEY = "default_params"
PELEM_INCLUDE_KEY = "include"
PELEM_EXCLUDE_KEY = "exclude"
WANDB_KEY = "wandb"
VALID_TUNE_MODES = ("pipeline", "params")

def __init__(self, cfg: ConfigLike, **kwargs):
Expand Down Expand Up @@ -272,13 +274,21 @@ def candidate_pipelines(self) -> Optional[List[List[str]]]:
def candidate_params(self) -> Optional[List[Dict[str, Any]]]:
return getattr(self, "_candidate_params", None)

def _resolve_pelem_plan(self, idx: int) -> List[str]:
@property
def wandb_config(self) -> Optional[Dict[str, Any]]:
return self._wandb_config

def _resolve_pelem_plan(self, idx: int) -> Optional[List[str]]:
# NOTE: we need to use the raw config here instaed of the pipeline
# element action object, as obtained by self[idx], since that does not
# contain the extra information about tuning settings we need, e.g.,
# the inclusion and exlusion settings.
pelem_config = self.config[self.PIPELINE_KEY][idx]

# Use fixed target if available
if pelem_config.get(self.TARGET_KEY) is not None:
return None

# Disallow setting includes and excludes at the same time
if all(pelem_config.get(i) is not None for i in (self.PELEM_INCLUDE_KEY, self.PELEM_EXCLUDE_KEY)):
raise ValueError(f"Cannot set {self.PELEM_INCLUDE_KEY!r} and {self.PELEM_EXCLUDE_KEY!r}"
Expand Down Expand Up @@ -331,9 +341,10 @@ def config(self, cfg: ConfigLike):
raise ValueError("Empty pipeline.")

# Set up base config
base_keys = pelem_keys = (self.TYPE_KEY, self.DESC_KEY)
if self.tune_mode == "params":
pelem_keys = pelem_keys + (self.TARGET_KEY, )
base_keys = pelem_keys = (self.TYPE_KEY, self.DESC_KEY, self.TARGET_KEY)
if self.tune_mode == "pipeline":
# NOTE: params reserved for planing when tuning mode is ``params``
pelem_keys = pelem_keys + (self.PARAMS_KEY, )

base_config = {}
for key in base_keys:
Expand Down Expand Up @@ -376,6 +387,11 @@ def config(self, cfg: ConfigLike):
else:
raise ValueError(f"Unknown tune mode {self.tune_mode!r}, supported options are {self.VALID_TUNE_MODES}")

# Other configs
self._wandb_config = self.config.get(self.WANDB_KEY)
if self._wandb_config is not None:
self._wandb_config = OmegaConf.to_container(self._wandb_config)

@staticmethod
def _sanitize_pipeline(
pipeline: Optional[Union[Dict[str, Any], List[str]]],
Expand All @@ -393,11 +409,16 @@ def _sanitize_pipeline(
logger.debug(f"Setting pipeline element {idx} to {j}")
pipeline[idx] = j

if pipeline is None:
return

# Make sure pipeline length matches
if pipeline is not None and len(pipeline) != pipeline_length:
if len(pipeline) != pipeline_length:
raise ValueError(f"Expecting {pipeline_length} targets specifications, "
f"but only got {len(pipeline)}: {pipeline}")

logger.info(f"Pipeline plane:\n{Color('green')(pformat(pipeline))}")

return pipeline

@staticmethod
Expand All @@ -421,14 +442,26 @@ def _sanitize_params(
params[idx] = {}
params[idx][key] = j

if params is not None and len(params) != pipeline_length:
if params is None:
return

# Make sure pipeline length matches
if len(params) != pipeline_length:
raise ValueError(f"Expecting {pipeline_length} targets specifications, "
f"but only got {len(params)}: {params}")

logger.info(f"Params plane:\n{Color('green')(pformat(params))}")

return params

def _validate_pipeline(self, validate: bool, pipeline: List[str], i: int):
if validate and pipeline[i] not in self.candidate_pipelines[i]:
if not validate:
return

if self.candidate_pipelines[i] is None: # use fixed target
return

if pipeline[i] not in self.candidate_pipelines[i]: # invalid specified target
raise ValueError(f"Specified target {pipeline[i]} ({i=}) not supported. "
f"Available options are: {self.candidate_pipelines[i]}")

Expand Down Expand Up @@ -457,8 +490,9 @@ def _validate_params(
f"params specification for {full_scope!r} ({i=}): {unknown_keys}")
if strict_params_check:
raise ValueError(msg)
else:
logger.warning(msg)
# FIX: need to figure out a way to get inherited kwargs as well, e.g., ``out``...
# else:
# logger.warning(msg)

def generate_config(
self,
Expand Down Expand Up @@ -496,7 +530,7 @@ def get_ith_pelem(i: int):
# TODO: nested pipeline support?
for i in range(pipeline_length):
# Parse pipeline plan
if pipeline is not None:
if pipeline is not None and pipeline[i] is not None:
self._validate_pipeline(validate, pipeline, i)
get_ith_pelem(i)[self.TARGET_KEY] = pipeline[i]

Expand Down Expand Up @@ -614,7 +648,7 @@ def search_space(self) -> Dict[str, Any]:
}
)

dict(planer.search_space()) == {
planer.search_space() == {
"pipeline.0.target": {
"values": [
"FilterGenesScanpy",
Expand Down Expand Up @@ -662,7 +696,7 @@ def search_space(self) -> Dict[str, Any]:
}
)

dict(planer.search_space()) == {
planer.search_space() == {
"params.1.n_components": {
"values": [128, 256, 512, 1024],
},
Expand All @@ -681,7 +715,12 @@ def search_space(self) -> Dict[str, Any]:
return search_space

def _pipeline_search_space(self) -> Dict[str, str]:
search_space = {f"{self.PIPELINE_KEY}.{i}": {"values": j} for i, j in enumerate(self.candidate_pipelines)}
search_space = {
f"{self.PIPELINE_KEY}.{i}": {
"values": j
}
for i, j in enumerate(self.candidate_pipelines) if j is not None
}
return search_space

def _params_search_space(self) -> Dict[str, Dict[str, Optional[Union[str, float]]]]:
Expand All @@ -691,3 +730,57 @@ def _params_search_space(self) -> Dict[str, Dict[str, Optional[Union[str, float]
for key, val in param_dict.items():
search_space[f"{self.PARAMS_KEY}.{i}.{key}"] = val
return search_space

def wandb_sweep_config(self) -> Dict[str, Any]:
if self.wandb_config is None:
raise ValueError("wandb config not specified in the raw config.")
return {**self.wandb_config, "parameters": self.search_space()}

def wandb_sweep(self) -> Tuple[str, str, str]:
try:
import wandb
except ModuleNotFoundError as e:
raise ImportError("wandb not installed. Please install wandb first: $ pip install wandb") from e

if "wandb" not in self.config:
raise ValueError(f"{self.config_yaml}\nMissing wandb config.")
wandb_entity = self.config.wandb.get("entity")
wandb_project = self.config.wandb.get("project")
if wandb_entity is None or wandb_project is None:
raise ValueError(f"{self.config_yaml}\nMissing either one (or both) of wandb configs "
f"'entity' and 'project': {wandb_entity=!r}, {wandb_project=!r}")

sweep_config = self.wandb_sweep_config()
logger.info(f"Sweep config:\n{pformat(sweep_config)}")
wandb_sweep_id = wandb.sweep(sweep=sweep_config, entity=wandb_entity, project=wandb_project)
logger.info(Color("blue")(f"\n\n\t[*] Sweep ID: {wandb_sweep_id}\n"))

return wandb_entity, wandb_project, wandb_sweep_id

def wandb_sweep_agent(
self,
function: Callable,
*,
sweep_id: Optional[str] = None,
entity: Optional[str] = None,
project: Optional[str] = None,
count: Optional[int] = None,
) -> Tuple[str, str, str]:
try:
import wandb
except ModuleNotFoundError as e:
raise ImportError("wandb not installed. Please install wandb first: $ pip install wandb") from e

if sweep_id is None:
if entity is not None or project is not None:
raise ValueError("Cannot specify entity or project when sweep_id is not specified "
"(will be inferred from config)")
entity, project, sweep_id = self.wandb_sweep()
else:
entity = self.config.wandb.get("entity")
project = self.config.wandb.get("project")

logger.info(f"Spawning agent: {sweep_id=}, {entity=}, {project=}, {count=}")
wandb.agent(sweep_id, function=function, entity=entity, project=project, count=count)

return entity, project, sweep_id
27 changes: 27 additions & 0 deletions dance/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,33 @@ def __getitem__(self, index):
return x


class Color:
COLOR_DICT = {
"blue": "\033[94m",
"cyan": "\033[96m",
"green": "\033[92m",
"yellow": "\033[93m",
"red": "\033[91m",
}
ENDC = "\033[0m"

def __init__(self, color: str):
if (code := self.COLOR_DICT.get(color)) is None:
raise ValueError(f"Unknown color {color}, supported options: {sorted(self.COLOR_DICT)}")
self._start = code

@property
def start(self) -> str:
return self._start

@property
def end(self) -> str:
return self.ENDC

def __call__(self, txt: str) -> str:
return "".join((self.start, txt, self.end))


def set_seed(rndseed, cuda: bool = True, extreme_mode: bool = False):
os.environ["PYTHONHASHSEED"] = str(rndseed)
random.seed(rndseed)
Expand Down
19 changes: 19 additions & 0 deletions examples/tuning/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
First create a sweep and then optionally spawn multiple agents given the sweep id.

```bash
$ python main.py
```

Record the sweep id ("\[\*\] Sweep ID: \<sweep_id>") and use it to spawn another agent

```bash
$ python main.py --sweep_id <sweep_id>
```

### Known issue

Currently there seem to be an issue with wandb sweep agent *might be* throwing segfault at the end of the sweep.
This error might carry over to new runs. To fix this, user need to remove the old data and redownload from scratch.

\[Update 2023-06-04\] The segfault seems to happen when writing to the source code
(even with "no changes", e.g., adding blank lines) when runing the sweep agent?..
72 changes: 72 additions & 0 deletions examples/tuning/cta_svm/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import argparse
import pprint
from typing import get_args

import wandb

from dance import logger
from dance.datasets.singlemodality import CellTypeAnnotationDataset
from dance.modules.single_modality.cell_type_annotation.svm import SVM
from dance.pipeline import PipelinePlaner
from dance.typing import LogLevel
from dance.utils import set_seed

if __name__ == "__main__":
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--cache", action="store_true", help="Cache processed data.")
parser.add_argument("--dense_dim", type=int, default=400, help="dim of PCA")
parser.add_argument("--gpu", type=int, default=-1, help="GPU id, set to -1 for CPU")
parser.add_argument("--log_level", type=str, default="INFO", choices=get_args(LogLevel))
parser.add_argument("--species", default="mouse")
parser.add_argument("--test_dataset", nargs="+", default=[2695], type=int, help="list of dataset id")
parser.add_argument("--tissue", default="Brain") # TODO: Add option for different tissue name for train/test
parser.add_argument("--train_dataset", nargs="+", default=[753, 3285], type=int, help="list of dataset id")
parser.add_argument("--seed", type=int, default=10)
parser.add_argument("--sweep_id", type=str, default=None)

args = parser.parse_args()
logger.setLevel(args.log_level)
logger.info(f"\n{pprint.pformat(vars(args))}")

pipeline_planer = PipelinePlaner.from_config_file("tuning_config.yaml")

def evaluate_pipeline():
wandb.init()

set_seed(args.seed)
model = SVM(args, random_state=args.seed)

# Load raw data
data = CellTypeAnnotationDataset(train_dataset=args.train_dataset, test_dataset=args.test_dataset,
species=args.species, tissue=args.tissue).load_data()

# Prepare preprocessing pipeline and apply it to data
preprocessing_pipeline = pipeline_planer.generate(pipeline=dict(wandb.config))
print(f"Pipeline config:\n{preprocessing_pipeline.to_yaml()}")
preprocessing_pipeline(data)

# Obtain training and testing data
x_train, y_train = data.get_train_data()
y_train_converted = y_train.argmax(1) # convert one-hot representation into label index representation
x_test, y_test = data.get_test_data()

# Train and evaluate the model
model.fit(x_train, y_train_converted)
score = model.score(x_test, y_test)
wandb.log({"acc": score})

wandb.finish()

pipeline_planer.wandb_sweep_agent(evaluate_pipeline, sweep_id=args.sweep_id, count=3)
"""To reproduce SVM benchmarks, please refer to command lines below:

Mouse Brain
$ python main.py --species mouse --tissue Brain --train_dataset 753 3285 --test_dataset 2695

Mouse Spleen
$ python main.py --species mouse --tissue Spleen --train_dataset 1970 --test_dataset 1759

Mouse Kidney
$ python main.py --species mouse --tissue Kidney --train_dataset 4682 --test_dataset 203

"""
27 changes: 27 additions & 0 deletions examples/tuning/cta_svm/tuning_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
type: preprocessor
tune_mode: pipeline
pipeline:
- type: feature.cell
include:
- WeightedFeaturePCA
- CellPCA
- CellSVD
params:
n_components: 400
out: feature.cell
default_params:
WeightedFeaturePCA:
split_name: train
- type: misc
target: SetConfig
params:
config_dict:
feature_channel: feature.cell
label_channel: cell_type
wandb:
entity: danceteam
project: dance-dev
method: bayes
metric:
name: acc # val/acc
goal: maximize
Loading