diff --git a/README.md b/README.md index a717f587..245ec534 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,10 @@ Install requirements: `pip3 install -r requirements.txt` python3 main.py --device=cuda --port=8080 --model=lama ``` +- `--crop-trigger-size`: If image size large then crop-trigger-size, crop each area from original image to do inference. + Mainly for performance and memory reasons on **very** large image.Default is 2042,2042 +- `--crop-size`: Crop size for `--crop-trigger-size`. Default is 512,512. + ### Start server with LDM model ```bash @@ -35,7 +39,6 @@ results than LaMa. |--------------|------|----| |![photo-1583445095369-9c651e7e5d34](https://user-images.githubusercontent.com/3998421/156923525-d6afdec3-7b98-403f-ad20-88ebc6eb8d6d.jpg)|![photo-1583445095369-9c651e7e5d34_cleanup_lama](https://user-images.githubusercontent.com/3998421/156923620-a40cc066-fd4a-4d85-a29f-6458711d1247.png)|![photo-1583445095369-9c651e7e5d34_cleanup_ldm](https://user-images.githubusercontent.com/3998421/156923652-0d06c8c8-33ad-4a42-a717-9c99f3268933.png)| - Blogs about diffusion models: - https://lilianweng.github.io/posts/2021-07-11-diffusion-models/ diff --git a/lama_cleaner/helper.py b/lama_cleaner/helper.py index 943d1be4..9c6e9863 100644 --- a/lama_cleaner/helper.py +++ b/lama_cleaner/helper.py @@ -1,5 +1,6 @@ import os import sys +from typing import List from urllib.parse import urlparse import cv2 @@ -80,3 +81,27 @@ def pad_img_to_modulo(img, mod): ((0, 0), (0, out_height - height), (0, out_width - width)), mode="symmetric", ) + + +def boxes_from_mask(mask: np.ndarray) -> List[np.ndarray]: + """ + Args: + mask: (1, h, w) 0~1 + + Returns: + + """ + height, width = mask.shape[1:] + _, thresh = cv2.threshold((mask.transpose(1, 2, 0) * 255).astype(np.uint8), 127, 255, 0) + contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + + boxes = [] + for cnt in contours: + x, y, w, h = cv2.boundingRect(cnt) + box = np.array([x, y, x + w, y + h]).astype(np.int) + + box[::2] = np.clip(box[::2], 0, width) + box[1::2] = np.clip(box[1::2], 0, height) + boxes.append(box) + + return boxes diff --git a/lama_cleaner/lama/__init__.py b/lama_cleaner/lama/__init__.py index 3075e153..49f2596e 100644 --- a/lama_cleaner/lama/__init__.py +++ b/lama_cleaner/lama/__init__.py @@ -1,10 +1,11 @@ import os +from typing import List import cv2 import torch import numpy as np -from lama_cleaner.helper import pad_img_to_modulo, download_model +from lama_cleaner.helper import pad_img_to_modulo, download_model, boxes_from_mask LAMA_MODEL_URL = os.environ.get( "LAMA_MODEL_URL", @@ -13,7 +14,16 @@ class LaMa: - def __init__(self, device): + def __init__(self, crop_trigger_size: List[int], crop_size: List[int], device): + """ + + Args: + crop_trigger_size: h, w + crop_size: h, w + device: + """ + self.crop_trigger_size = crop_trigger_size + self.crop_size = crop_size self.device = device if os.environ.get("LAMA_MODEL"): @@ -32,6 +42,63 @@ def __init__(self, device): @torch.no_grad() def __call__(self, image, mask): + """ + image: [C, H, W] RGB + mask: [1, H, W] + return: BGR IMAGE + """ + area = image.shape[1] * image.shape[2] + if area < self.crop_trigger_size[0] * self.crop_trigger_size[1]: + return self._run(image, mask) + + print("Trigger crop image") + boxes = boxes_from_mask(mask) + crop_result = [] + for box in boxes: + crop_image, crop_box = self._run_box(image, mask, box) + crop_result.append((crop_image, crop_box)) + + image = (image.transpose(1, 2, 0) * 255).astype(np.uint8)[:, :, ::-1] + for crop_image, crop_box in crop_result: + x1, y1, x2, y2 = crop_box + image[y1:y2, x1:x2, :] = crop_image + return image + + def _run_box(self, image, mask, box): + """ + + Args: + image: [C, H, W] RGB + mask: [1, H, W] + box: [left,top,right,bottom] + + Returns: + BGR IMAGE + """ + box_h = box[3] - box[1] + box_w = box[2] - box[0] + cx = (box[0] + box[2]) // 2 + cy = (box[1] + box[3]) // 2 + crop_h, crop_w = self.crop_size + img_h, img_w = image.shape[1:] + + # TODO: when box_w > crop_w, add some margin around? + w = max(crop_w, box_w) + h = max(crop_h, box_h) + + l = max(cx - w // 2, 0) + t = max(cy - h // 2, 0) + r = min(cx + w // 2, img_w) + b = min(cy + h // 2, img_h) + + crop_img = image[:, t:b, l:r] + crop_mask = mask[:, t:b, l:r] + + print(f"Apply zoom in size width x height: {crop_img.shape}") + + return self._run(crop_img, crop_mask), [l, t, r, b] + + def _run(self, image, mask): """ image: [C, H, W] RGB mask: [1, H, W] @@ -51,5 +118,5 @@ def __call__(self, image, mask): cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy() cur_res = cur_res[0:origin_height, 0:origin_width, :] cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8") - cur_res = cv2.cvtColor(cur_res, cv2.COLOR_BGR2RGB) + cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR) return cur_res diff --git a/lama_cleaner/tests/mask.jpg b/lama_cleaner/tests/mask.jpg new file mode 100644 index 00000000..a2aec11c Binary files /dev/null and b/lama_cleaner/tests/mask.jpg differ diff --git a/lama_cleaner/tests/test_boxes_from_mask.py b/lama_cleaner/tests/test_boxes_from_mask.py new file mode 100644 index 00000000..3faa4c6c --- /dev/null +++ b/lama_cleaner/tests/test_boxes_from_mask.py @@ -0,0 +1,15 @@ +import cv2 +import numpy as np + +from lama_cleaner.helper import boxes_from_mask + + +def test_boxes_from_mask(): + mask = cv2.imread("mask.jpg", cv2.IMREAD_GRAYSCALE) + mask = mask[:, :, np.newaxis] + mask = (mask / 255).transpose(2, 0, 1) + boxes = boxes_from_mask(mask) + print(boxes) + + +test_boxes_from_mask() diff --git a/main.py b/main.py index 18968a1a..a50df20a 100644 --- a/main.py +++ b/main.py @@ -97,12 +97,18 @@ def get_args_parser(): parser = argparse.ArgumentParser() parser.add_argument("--port", default=8080, type=int) parser.add_argument("--model", default="lama", choices=["lama", "ldm"]) + parser.add_argument("--crop-trigger-size", default="2042,2042", + help="If image size large then crop-trigger-size, " + "crop each area from original image to do inference." + "Mainly for performance and memory reasons" + "Only for lama") + parser.add_argument("--crop-size", default="512,512") parser.add_argument( "--ldm-steps", default=50, type=int, help="Steps for DDIM sampling process." - "The larger the value, the better the result, but it will be more time-consuming", + "The larger the value, the better the result, but it will be more time-consuming", ) parser.add_argument("--device", default="cuda", type=str) parser.add_argument("--debug", action="store_true") @@ -115,8 +121,11 @@ def main(): args = get_args_parser() device = torch.device(args.device) + crop_trigger_size = [int(it) for it in args.crop_trigger_size.split(",")] + crop_size = [int(it) for it in args.crop_size.split(",")] + if args.model == "lama": - model = LaMa(device) + model = LaMa(crop_trigger_size=crop_trigger_size, crop_size=crop_size, device=device) elif args.model == "ldm": model = LDM(device, steps=args.ldm_steps) else: