Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Does setting grad_accum="auto" cause the model to be called in parallel? #1329

Closed
vedantroy opened this issue Jul 30, 2022 · 3 comments
Closed
Labels
bug Something isn't working

Comments

@vedantroy
Copy link

I added a print("start forward") statement to the start of my model's forward method and an print("end forward") statement to the end of my model's forward method.

I noticed that when I turn on grad_accum="auto" (along with a batch_size > 1), that instead of the stdout output looking like:

start forward
end forward
start forward
end forward

I'll get something like

start forward
start forward

(my model crashes after this, because I have some stateful code for skip-connection-related things in the forward method that relies on the assumption that the model is not being called in parallel).

@vedantroy vedantroy added the bug Something isn't working label Jul 30, 2022
@vedantroy
Copy link
Author

Whoops, bug label shouldn't be on this, but can't figure out how to remove it.

@vedantroy
Copy link
Author

To give context on what I mean by "stateful code" in the forward method:

class AddSkipConnection(TimestepBlock):
    def __init__(self, fn, skips):
        super().__init__()
        self.fn = fn
        self.skips = skips

    def forward(self, x, emb=None, **kwargs):
        y = None
        if isinstance(self.fn, TimestepBlock):
            assert emb != None, "Missing embedding"
            y = self.fn(x, emb, **kwargs)
        else:
            y = self.fn(x, **kwargs)
        self.skips.append(y)
        return y


class TakeFromSkipConnection(TimestepBlock):
    def __init__(self, fn, skips, expected_channels):
        super().__init__()
        self.fn = fn
        self.skips = skips
        self.expected_channels = expected_channels

    def forward(self, x, emb, **kwargs):
        skip = self.skips.pop()
        assert (
            skip.shape[1] == self.expected_channels
        ), f"expected skip connection channels {skip.shape[1]} != {self.expected_channels}"
        with_skip = th.cat([skip, x], dim=1)
        y = self.fn(with_skip, emb, **kwargs)
        return y

and then in my __init__ method, I might have things like:

        self.skips = []
        self.in_conv = AddSkipConnection(
            nn.Conv2d(in_channels, model_channels, kernel_size=3, padding=1), self.skips
        )

@vedantroy
Copy link
Author

vedantroy commented Jul 30, 2022

Figured it out (see: #1331, bullet 5)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant