-
Notifications
You must be signed in to change notification settings - Fork 81
/
utils.py
89 lines (77 loc) · 2.57 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch
import numpy as np
import scipy
import os
#import re
from pathlib import Path
import folder_paths
FONTS_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "fonts")
SCRIPT_DIR = Path(__file__).parent
folder_paths.add_model_folder_path("luts", (SCRIPT_DIR / "luts").as_posix())
folder_paths.add_model_folder_path(
"luts", (Path(folder_paths.models_dir) / "luts").as_posix()
)
# from https://github.com/pythongosssss/ComfyUI-Custom-Scripts
class AnyType(str):
def __ne__(self, __value: object) -> bool:
return False
def min_(tensor_list):
# return the element-wise min of the tensor list.
x = torch.stack(tensor_list)
mn = x.min(axis=0)[0]
return torch.clamp(mn, min=0)
def max_(tensor_list):
# return the element-wise max of the tensor list.
x = torch.stack(tensor_list)
mx = x.max(axis=0)[0]
return torch.clamp(mx, max=1)
def expand_mask(mask, expand, tapered_corners):
c = 0 if tapered_corners else 1
kernel = np.array([[c, 1, c],
[1, 1, 1],
[c, 1, c]])
mask = mask.reshape((-1, mask.shape[-2], mask.shape[-1]))
out = []
for m in mask:
output = m.numpy()
for _ in range(abs(expand)):
if expand < 0:
output = scipy.ndimage.grey_erosion(output, footprint=kernel)
else:
output = scipy.ndimage.grey_dilation(output, footprint=kernel)
output = torch.from_numpy(output)
out.append(output)
return torch.stack(out, dim=0)
def parse_string_to_list(s):
elements = s.split(',')
result = []
def parse_number(s):
try:
if '.' in s:
return float(s)
else:
return int(s)
except ValueError:
return 0
def decimal_places(s):
if '.' in s:
return len(s.split('.')[1])
return 0
for element in elements:
element = element.strip()
if '...' in element:
start, rest = element.split('...')
end, step = rest.split('+')
decimals = decimal_places(step)
start = parse_number(start)
end = parse_number(end)
step = parse_number(step)
current = start
if (start > end and step > 0) or (start < end and step < 0):
step = -step
while current <= end:
result.append(round(current, decimals))
current += step
else:
result.append(round(parse_number(element), decimal_places(element)))
return result