diff --git a/zoobot/pytorch/estimators/define_model.py b/zoobot/pytorch/estimators/define_model.py index b4d784b2..ebb1450f 100755 --- a/zoobot/pytorch/estimators/define_model.py +++ b/zoobot/pytorch/estimators/define_model.py @@ -8,6 +8,7 @@ from torchmetrics import Accuracy import timm +from zoobot.shared import schemas from zoobot.pytorch.estimators import efficientnet_custom, custom_layers from zoobot.pytorch.training import losses @@ -142,7 +143,13 @@ class ZoobotTree(GenericLightningModule): def __init__( self, output_dim: int, - question_index_groups: List, + # in the simplest case, this is all zoobot needs: grouping of label col indices as questions + question_index_groups: List=None, + # BUT + # if you pass these, it enables better per-question and per-survey logging (because we have names) + # must be passed as simple dicts, not objects, so can't just pass schema in + question_answer_pairs: dict=None, + dependencies: dict=None, # encoder args architecture_name="efficientnet_b0", channels=1, @@ -162,8 +169,11 @@ def __init__( # now, finally, can pass only standard variables as hparams to save # will still need to actually use these variables later, this super init only saves them super().__init__( + # these all do nothing, they are simply saved by lightning as hparams output_dim, question_index_groups, + question_answer_pairs, + dependencies, architecture_name, channels, timm_kwargs, @@ -178,6 +188,14 @@ def __init__( logging.info('Generic __init__ complete - moving to Zoobot __init__') + if question_answer_pairs is not None: + logging.info('question_index_groups/dependencies passed to Zoobot, constructing schema in __init__') + assert question_index_groups is None, "Don't pass both question_index_groups and question_answer_pairs/dependencies" + assert dependencies is not None + self.schema = schemas.Schema(question_answer_pairs, dependencies) + # replace with schema-derived version + question_index_groups = self.schema.question_index_groups + # set attributes for learning rate, betas, used by self.configure_optimizers() # TODO refactor to optimizer params self.learning_rate = learning_rate @@ -253,7 +271,7 @@ def log_outputs(self, outputs, step_name): if outputs['predictions'].shape[1] == 2: # will only do for binary classifications self.log( "{}_accuracy".format(step_name), self.train_accuracy(outputs['predictions'], torch.argmax(outputs['labels'], dim=1, keepdim=False)), prog_bar=True, sync_dist=True) - # pass + def log_loss_per_question(self, multiq_loss, prefix): @@ -261,9 +279,39 @@ def log_loss_per_question(self, multiq_loss, prefix): # TODO need schema attribute or similar to have access to question names, this will do for now # unlike Finetuneable..., does not use TorchMetrics, simply logs directly # TODO could use TorchMetrics and for q in schema, self.q_metric loop - for question_n in range(multiq_loss.shape[1]): - self.log(f'{prefix}/epoch_questions/question_{question_n}_loss:0', torch.mean(multiq_loss[:, question_n]), on_epoch=True, on_step=False, sync_dist=True) - # pass + + if hasattr(self, 'schema'): + # use schema metadata to log intelligently + # will have schema if question_answer_pairs and dependencies are passed to __init__ + + # assume that questions are named like smooth-or-featured-CAMPAIGN + for question_n, question in enumerate(self.schema.questions): + self.log( + f'{prefix}/epoch_questions/loss_{question.text}', + torch.mean(multiq_loss[:, question_n]), + on_epoch=True, + on_step=False, + sync_dist=True + ) + + campaigns = [question.text.split('-')[-1] for question in self.schema.questions] + for campaign in campaigns: + campaign_questions = [q for q in self.schema.questions if campaign in q.text] + campaign_q_indices = [self.schema.questions.index(q) for q in campaign_questions] + self.log( + f'{prefix}/epoch_campaigns/loss_{campaign}', + torch.mean(multiq_loss[:, campaign_q_indices]), + on_epoch=True, + on_step=False, + sync_dist=True + ) + + else: + # fallback to logging with question_n + for question_n in range(multiq_loss.shape[1]): + self.log(f'{prefix}/epoch_questions/question_{question_n}_loss:0', torch.mean(multiq_loss[:, question_n]), on_epoch=True, on_step=False, sync_dist=True) + + diff --git a/zoobot/pytorch/training/train_with_pytorch_lightning.py b/zoobot/pytorch/training/train_with_pytorch_lightning.py index d5cab9cb..6e6b45f8 100644 --- a/zoobot/pytorch/training/train_with_pytorch_lightning.py +++ b/zoobot/pytorch/training/train_with_pytorch_lightning.py @@ -257,7 +257,10 @@ def train_default_zoobot_from_scratch( # these args are automatically logged lightning_model = define_model.ZoobotTree( output_dim=len(schema.label_cols), - question_index_groups=schema.question_index_groups, + # question_index_groups=schema.question_index_groups, + # NEW - pass these from schema, for better logging + question_answer_pairs=schema.question_answer_pairs, + dependencies=schema.dependencies, architecture_name=architecture_name, channels=channels, # use_imagenet_weights=False, diff --git a/zoobot/shared/schemas.py b/zoobot/shared/schemas.py index 8d32f878..88a6c3bf 100755 --- a/zoobot/shared/schemas.py +++ b/zoobot/shared/schemas.py @@ -130,7 +130,7 @@ def set_dependencies(questions, dependencies): class Schema(): - def __init__(self, question_answer_pairs:dict, dependencies): + def __init__(self, question_answer_pairs:dict, dependencies: dict): """ Relate the df label columns tor question/answer groups and to tfrecod label indices Requires that labels be continguous by question - easily satisfied