Skip to content

Commit

Permalink
add ldm model
Browse files Browse the repository at this point in the history
  • Loading branch information
Sanster committed Mar 6, 2022
1 parent f09d40c commit f9b96cf
Show file tree
Hide file tree
Showing 7 changed files with 605 additions and 65 deletions.
42 changes: 32 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,37 +1,58 @@
# Lama-cleaner: Image inpainting tool powered by [LaMa](https://github.com/saic-mdal/lama)

This project is mainly used for selfhosting LaMa model, some interaction improvements may be added later.
# Lama-cleaner: Image inpainting tool powered by SOTA AI model

https://user-images.githubusercontent.com/3998421/153323093-b664bb68-2928-480b-b59b-7c1ee24a4507.mp4


- [x] Support multiple model architectures
1. [LaMa](https://github.com/saic-mdal/lama)
1. [LDM](https://github.com/CompVis/latent-diffusion)
- [x] High resolution support
- [x] Multi stroke support. Press and hold the `cmd/ctrl` key to enable multi stroke mode.
- [x] Zoom & Pan
- [ ] Keep image EXIF data

## Quick Start

- Install requirements: `pip3 install -r requirements.txt`
- Start server: `python3 main.py --device=cuda --port=8080`
Install requirements: `pip3 install -r requirements.txt`

### Start server with LaMa model

```bash
python3 main.py --device=cuda --port=8080 --model=lama
```

### Start server with LDM model

```bash
python3 main.py --device=cuda --port=8080 --model=ldm --ldm-steps=50
```

`--ldm-steps`: The larger the value, the better the result, but it will be more time-consuming

Diffusion model is **MUCH MORE** slower than GANs(1080x720 image takes 8s on 3090), but it's possible to get better
results than LaMa.

Blogs about diffusion models:

- https://lilianweng.github.io/posts/2021-07-11-diffusion-models/
- https://yang-song.github.io/blog/2021/score/

## Development

Only needed if you plan to modify the frontend and recompile yourself.

### Fronted

Frontend code are modified from [cleanup.pictures](https://github.com/initml/cleanup.pictures),
You can experience their great online services [here](https://cleanup.pictures/).
Frontend code are modified from [cleanup.pictures](https://github.com/initml/cleanup.pictures), You can experience their
great online services [here](https://cleanup.pictures/).

- Install dependencies:`cd lama_cleaner/app/ && yarn`
- Start development server: `yarn dev`
- Build: `yarn build`

## Docker

Run within a Docker container. Set the `CACHE_DIR` to models location path.
Optionally add a `-d` option to the `docker run` command below to run as a daemon.
Run within a Docker container. Set the `CACHE_DIR` to models location path. Optionally add a `-d` option to
the `docker run` command below to run as a daemon.

### Build Docker image

Expand All @@ -54,6 +75,7 @@ docker run --gpus all -p 8080:8080 -e CACHE_DIR=/app/models -v $(pwd)/models:/ap
Then open [http://localhost:8080](http://localhost:8080)

## Like My Work?

<a href="https://www.buymeacoffee.com/Sanster">
<img height="50em" src="https://cdn.buymeacoffee.com/buttons/v2/default-blue.png" alt="Sanster" />
</a>
7 changes: 1 addition & 6 deletions lama_cleaner/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,8 @@
import torch
from torch.hub import download_url_to_file, get_dir

LAMA_MODEL_URL = os.environ.get(
"LAMA_MODEL_URL",
"https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
)


def download_model(url=LAMA_MODEL_URL):
def download_model(url):
parts = urlparse(url)
hub_dir = get_dir()
model_dir = os.path.join(hub_dir, "checkpoints")
Expand Down
56 changes: 56 additions & 0 deletions lama_cleaner/lama/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import os
import time

import cv2
import torch
import numpy as np

from lama_cleaner.helper import pad_img_to_modulo, download_model

LAMA_MODEL_URL = os.environ.get(
"LAMA_MODEL_URL",
"https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
)


class LaMa:
def __init__(self, device):
self.device = device

if os.environ.get("LAMA_MODEL"):
model_path = os.environ.get("LAMA_MODEL")
if not os.path.exists(model_path):
raise FileNotFoundError(f"lama torchscript model not found: {model_path}")
else:
model_path = download_model(LAMA_MODEL_URL)

model = torch.jit.load(model_path, map_location="cpu")
model = model.to(device)
model.eval()
self.model = model

@torch.no_grad()
def __call__(self, image, mask):
"""
image: [C, H, W] RGB
mask: [1, H, W]
return: BGR IMAGE
"""
device = self.device
origin_height, origin_width = image.shape[1:]
image = pad_img_to_modulo(image, mod=8)
mask = pad_img_to_modulo(mask, mod=8)

mask = (mask > 0) * 1
image = torch.from_numpy(image).unsqueeze(0).to(device)
mask = torch.from_numpy(mask).unsqueeze(0).to(device)

start = time.time()
inpainted_image = self.model(image, mask)

print(f"process time: {(time.time() - start) * 1000}ms")
cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
cur_res = cur_res[0:origin_height, 0:origin_width, :]
cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_BGR2RGB)
return cur_res
Loading

0 comments on commit f9b96cf

Please sign in to comment.