From 66e9d5e20dbf75c87055b360e06e84b87e26316c Mon Sep 17 00:00:00 2001 From: tlpss Date: Fri, 20 Oct 2023 11:52:22 +0200 Subject: [PATCH] option to start training from checkpoint --- README.md | 12 ++++++++---- keypoint_detection/tasks/train.py | 20 ++++++++++++++++++-- keypoint_detection/utils/load_checkpoints.py | 12 ++++++++++-- 3 files changed, 36 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 8d6f550..883dd4a 100644 --- a/README.md +++ b/README.md @@ -37,10 +37,14 @@ TODO: add integration example. To train a keypoint detector, run the `keypoint-detection train` CLI with the appropriate arguments. To create your own configuration: run `keypoint-detection train -h` to see all parameter options and their documentation. -A starting point could be the bash script `bash test/integration_test.sh` to test on the provided test dataset, which contains 4 images. You should see the loss going down consistently until the detector has completely overfit the train set and the loss is around the entropy of the ground truth heatmaps (if you selected the default BCE loss). +A good starting point could be the bash script `bash test/integration_test.sh` to test on the provided test dataset, which contains 4 images. You should see the loss going down consistently until the detector has completely overfit the train set and the loss is around the entropy of the ground truth heatmaps (if you selected the default BCE loss). +### Wandb sweeps Alternatively, you can create a sweep on [wandb](https://wandb.ai) and to then start a (number of) wandb agent(s). This is very useful for running multiple configurations (hparam search, testing on multiple datasets,..) +### Loading pretrained weights +If you want to load pretrained keypoint detector weights, you can specify the wandb artifact of the checkpoint in the training parameters: `keypoint-detection train ..... -wandb_checkpoint_artifact `. This can be used for example to finetune on real data after pretraining on synthetic data. + ## Dataset This package used the [COCO format](https://cocodataset.org/#format-data) for keypoint annotation and expects a dataset with the following structure: @@ -48,7 +52,7 @@ This package used the [COCO format](https://cocodataset.org/#format-data) for ke dataset/ images/ ... - .json : a COCO-formatted keypoint annotation file. + .json : a COCO-formatted keypoint annotation file with filepaths relative to its parent directory. ``` For an example, see the `test_dataset` at `test/test_dataset`. @@ -66,7 +70,7 @@ TODO TODO `scripts/fiftyone_viewer` -## Using a trained model (Inference) +## Using a trained model for Inference During training Pytorch Lightning will have saved checkpoints. See `scripts/checkpoint_inference.py` for a simple example to run inference with a checkpoint. For benchmarking the inference (or training), see `scripts/benchmark.py`. @@ -86,7 +90,7 @@ In general a lower threshold will result in a lower metric. The size of this gap #TODO: add a figure to illustrate this. -We do not use OKS as in COCO for 2 reasons: +We do not use OKS as in COCO for the following reasons: 1. it requires bbox annotations, which are not always required for keypoint detection itself and represent additional label effort. 2. More importantly, in robotics the size of an object does not always correlate with the required precision. If a large and a small mug stand on a table, they require the same precise localisation of keypoints for a robot to grasp them even though their apparent size is different. 3. (you need to estimate label variance, though you could simply set k=1 and skip this part) diff --git a/keypoint_detection/tasks/train.py b/keypoint_detection/tasks/train.py index bd4f6e3..61cd51a 100644 --- a/keypoint_detection/tasks/train.py +++ b/keypoint_detection/tasks/train.py @@ -11,6 +11,7 @@ from keypoint_detection.models.backbones.backbone_factory import BackboneFactory from keypoint_detection.models.detector import KeypointDetector from keypoint_detection.tasks.train_utils import create_pl_trainer, parse_channel_configuration +from keypoint_detection.utils.load_checkpoints import get_model_from_wandb_checkpoint from keypoint_detection.utils.path import get_wandb_log_dir_path @@ -49,6 +50,12 @@ def add_system_args(parent_parser: ArgumentParser) -> ArgumentParser: help="do not use deterministic algorithms for pytorch. This can speed up training, but will make it non-reproducible.", ) + parser.add_argument( + "--wandb_checkpoint_artifact", + type=str, + help="A checkpoint to resume/start training from. keep in mind that you currently cannot specify hyperparameters other than the LR.", + required=False, + ) parser.set_defaults(deterministic=True) return parent_parser @@ -63,9 +70,18 @@ def train(hparams: dict) -> Tuple[KeypointDetector, pl.Trainer]: # use deterministic algorithms for torch to ensure exact reproducibility # we have to set it in the trainer! (see create_pl_trainer) + if "wandb_checkpoint_artifact" in hparams.keys(): + print("Loading checkpoint from wandb") + # This will create a KeypointDetector model with the associated hyperparameters. + # Model weights will be loaded. + # Optimizer and LR scheduler will be initiated from scratch" (if you want to really resume training, you have to pass the ckeckpoint to the trainer) + # cf. https://lightning.ai/docs/pytorch/latest/common/checkpointing_basic.html#lightningmodule-from-checkpoint + model = get_model_from_wandb_checkpoint(hparams["wandb_checkpoint_artifact"]) + # TODO: how can specific hparams be overwritten here? e.g. LR reduction for finetuning or something? + else: + backbone = BackboneFactory.create_backbone(**hparams) + model = KeypointDetector(backbone=backbone, **hparams) - backbone = BackboneFactory.create_backbone(**hparams) - model = KeypointDetector(backbone=backbone, **hparams) data_module = KeypointsDataModule(**hparams) wandb_logger = WandbLogger( project=hparams["wandb_project"], diff --git a/keypoint_detection/utils/load_checkpoints.py b/keypoint_detection/utils/load_checkpoints.py index 51e7473..d43feb9 100644 --- a/keypoint_detection/utils/load_checkpoints.py +++ b/keypoint_detection/utils/load_checkpoints.py @@ -15,14 +15,17 @@ def get_model_from_wandb_checkpoint(checkpoint_reference: str): import wandb # download checkpoint locally (if not already cached) - run = wandb.init(project="inference") + if wandb.run is None: + run = wandb.init(project="inference") + else: + run = wandb.run artifact = run.use_artifact(checkpoint_reference, type="model") artifact_dir = artifact.download() checkpoint_path = Path(artifact_dir) / "model.ckpt" return load_from_checkpoint(checkpoint_path) -def load_from_checkpoint(checkpoint_path: str): +def load_from_checkpoint(checkpoint_path: str, hparams_to_override: dict = None): """ function to load a Keypoint Detector model from a local pytorch lightning checkpoint. @@ -43,3 +46,8 @@ def load_from_checkpoint(checkpoint_path: str): backbone = BackboneFactory.create_backbone(**checkpoint["hyper_parameters"]) model = KeypointDetector.load_from_checkpoint(checkpoint_path, backbone=backbone) return model + + +if __name__ == "__main__": + model = get_model_from_wandb_checkpoint("tlips/synthetic-cloth-keypoints-tshirts/model-4um302zo:v0") + print(model.hparams)