Skip to content

Commit

Permalink
reduce VRAM memory usage by half during model loading
Browse files Browse the repository at this point in the history
* This moves the call to half() before model.to(device) to avoid GPU
copy of full model. Improves speed and reduces memory usage dramatically

* This fix contributed by @mh-dm (Mihai)
  • Loading branch information
lstein committed Sep 10, 2022
1 parent 9912270 commit 5c43988
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions ldm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,9 +536,6 @@ def _load_model_from_config(self, config, ckpt):
sd = pl_sd['state_dict']
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
model.to(self.device)
model.eval()


if self.full_precision:
print(
Expand All @@ -549,6 +546,8 @@ def _load_model_from_config(self, config, ckpt):
'>> Using half precision math. Call with --full_precision to use more accurate but VRAM-intensive full precision.'
)
model.half()
model.to(self.device)
model.eval()

# usage statistics
toc = time.time()
Expand Down

0 comments on commit 5c43988

Please sign in to comment.