Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is the warning emitted by self.log-ing an integer intentional? #18739

Closed
awaelchli opened this issue Oct 6, 2023 · 8 comments · Fixed by #18847
Closed

Is the warning emitted by self.log-ing an integer intentional? #18739

awaelchli opened this issue Oct 6, 2023 · 8 comments · Fixed by #18847
Labels
logging Related to the `LoggerConnector` and `log()` question Further information is requested ver: 1.9.x ver: 2.0.x ver: 2.1.x

Comments

@awaelchli
Copy link
Contributor

awaelchli commented Oct 6, 2023

Bug description

When you call

self.log("integer", 1)

you get the warning:

/Users/adrian/repositories/lightning/src/lightning/pytorch/trainer/connectors/logger_connector/result.py:232: UserWarning: You called self.log('integer', ...) in your training_step but the value needs to be floating point. Converting it to torch.float32.

Is this intentional?

What version are you seeing the problem on?

v1.9, v2.0, master

How to reproduce the bug

import torch
from lightning.pytorch import LightningModule, Trainer
from torch.utils.data import DataLoader, Dataset


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def training_step(self, batch, batch_idx):
        loss = self.layer(batch).sum()
        self.log("integer", 1)
        return {"loss": loss}

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
model = BoringModel()
trainer = Trainer(max_steps=1)
trainer.fit(model, train_data)

Error messages and logs

/Users/adrian/repositories/lightning/src/lightning/pytorch/trainer/connectors/logger_connector/result.py:232: UserWarning: You called self.log('integer', ...) in your training_step but the value needs to be floating point. Converting it to torch.float32.

Environment

Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0): 1.9+
#- Lightning App Version (e.g., 0.5.2): -
#- PyTorch Version (e.g., 2.0): 2.1
#- Python version (e.g., 3.9): 3.11
#- OS (e.g., Linux): MacOs
#- CUDA/cuDNN version: -
#- GPU models and configuration: -
#- How you installed Lightning(`conda`, `pip`, source): source
#- Running environment of LightningApp (e.g. local, cloud): -

More info

The question was raised here #18723 (comment), I'm not sure if there is a good reason for it.

cc @carmocca @Blaizzy @stas00

@awaelchli awaelchli added bug Something isn't working needs triage Waiting to be triaged by maintainers logging Related to the `LoggerConnector` and `log()` and removed needs triage Waiting to be triaged by maintainers labels Oct 6, 2023
@carmocca
Copy link
Contributor

carmocca commented Oct 9, 2023

I don't understand the question of whether it's intentional or not. The fact that there is a warning shows that it's intentional.

Logging is meant for floating point types, as the reduction logic will not preserve the integer types. If you log epoch=0 and epoch=1, then the mean will be 0.5 which is clearly not right if you want to track epochs. The warning aims to tell the user about this limitation.

For an example like this, we would expect that the user fixes it with:
self.log("epoch", float(trainer.current_epoch), reduce_fx=max)

Perhaps the message could be changed to make the implications of converting integers to floats clearer

@stas00
Copy link
Contributor

stas00 commented Oct 9, 2023

Logging is meant for floating point types

I'm not sure why you said that, Carlos. Since not all logged types are `mean'ed.

Some log types are aggregate or counters. Examples: like global_step, consumed_samples - these can't be float. They are ints. Usually there can be no 3.4 steps or 22234.8 samples.

@carmocca
Copy link
Contributor

carmocca commented Oct 9, 2023

@stas00 There's an unfortunate overload of nomenclature in lightning that is "self.logging" and "logging to a Logger". It's a source of confusion for new users.

With my comment, I meant specifically self.logging. As you point out, "logging to a Logger" integer types is perfectly normal.

self.log is a mechanism that supports aggregating data across steps/epochs and reducing across ranks. This is what is currently designed to work with floating types and the reason for this warning. After all this happens, PL ends up calling self.trainer.logger.log_metrics(what_was_self_logged)

For cases when one does not want the aggregation/reduction, the user should skip self.log and simply call trainer.logger.log_metrics(whatever) themselves.

Sorry for the confusion, happy to hear suggestions about this, but changes to names are impossible as we want to avoid annoying deprecations.

@awaelchli
Copy link
Contributor Author

@carmocca I'm just not finding the info in the history why the warning needs to be there as opposed to converting the value automatically. If we look at the PR where it was introduced #10076, the motivation seems to be different and no clear explanation why we couldn't just internally work with floats. Why does the user need to be informed?

@awaelchli awaelchli added question Further information is requested and removed bug Something isn't working labels Oct 9, 2023
@carmocca
Copy link
Contributor

carmocca commented Oct 9, 2023

We do work internally with floats: https://github.com/Lightning-AI/lightning/blob/master/src/lightning/pytorch/trainer/connectors/logger_connector/result.py#L216

The warning just aims to let the user know about the potential implications of this, such as that your epoch could become 0.5 as I described in an example above.

One option is that the warning is changed to also suggest logger.log_metrics({"epoch": epoch}) as an alternative to calling self.log("epoch", epoch.float())

@stas00
Copy link
Contributor

stas00 commented Oct 9, 2023

Thank you for explaining, Carlos. Yes, the naming is indeed unfortunately not self-documenting. log_reduce or something similar would have been more intuitive.

The incorrect use is then being done here in the integration of PTL at https://github.com/NVIDIA/nemo with:

$ grep '\.log(' | grep global 
nemo/collections/nlp/models/language_modeling/megatron_finetune_model.py:            
self.log('global_step', self.trainer.global_step, prog_bar=True, rank_zero_only=True, batch_size=1)
[...]
nemo/collections/multimodal/speech_cv/models/visual_ctc_models.py:        
self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32))
[...]
$ grep '\.log(' | grep global | grep -v float | wc -l
10

so sometimes it's converted to float, which is somewhat weird when one reads: global_step 3.0

edit: I filed an Issue there: NVIDIA/NeMo#7665

One option is that the warning is changed to also suggest logger.log_metrics({"epoch": epoch}) as an alternative to calling self.log("epoch", epoch.float())

That would make for a better warning, Carlos.

Also it's not just about float vs integer, it's about a wasted reduction across ranks of the data that doesn't need to be reduced. As all ranks should have the same counters in a deterministic way.

@LWprogramming
Copy link

I'm confused about using global_step for anything besides logging then. How are we supposed to log it without self.log but if we also want to have it accessible for e.g. checkpoint filenames? checkpoint-100, checkpoint-200... etc

trainer.logger.log_metrics(whatever) doesn't work because the key won't be available to the ModelCheckpoint callback.

@carmocca
Copy link
Contributor

@LWprogramming You can do trainer.callback_metrics["global_step"] = ...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
logging Related to the `LoggerConnector` and `log()` question Further information is requested ver: 1.9.x ver: 2.0.x ver: 2.1.x
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants