-
Notifications
You must be signed in to change notification settings - Fork 6
/
example.py
101 lines (87 loc) · 2.73 KB
/
example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import time
import pathlib
import kgen.models as models
from kgen.formatter import seperate_tags, apply_format
from kgen.executor.dtg import apply_dtg_prompt, tag_gen
from kgen.metainfo import TARGET
from kgen.logging import logger
SEED_MAX = 2**31 - 1
TOTAL_TAG_LENGTH = {
"VERY_SHORT": "very short",
"SHORT": "short",
"LONG": "long",
"VERY_LONG": "very long",
}
DEFAULT_FORMAT = """<|special|>,
<|characters|>, <|copyrights|>,
<|artist|>,
<|general|>,
<|quality|>, <|meta|>, <|rating|>"""
def process(
prompt: str,
aspect_ratio: float,
seed: int,
tag_length: str,
ban_tags: str,
format: str,
temperature: float,
):
propmt_preview = prompt.replace("\n", " ")[:40]
logger.info(f"Processing propmt: {propmt_preview}...")
logger.info(f"Processing with seed: {seed}")
black_list = [tag.strip() for tag in ban_tags.split(",") if tag.strip()]
all_tags = [tag.strip().lower() for tag in prompt.strip().split(",") if tag.strip()]
tag_length = tag_length.replace(" ", "_")
len_target = TARGET[tag_length]
tag_map = seperate_tags(all_tags)
dtg_prompt = apply_dtg_prompt(tag_map, tag_length, aspect_ratio)
for _, extra_tokens, iter_count in tag_gen(
models.text_model,
models.tokenizer,
dtg_prompt,
tag_map["special"] + tag_map["general"],
len_target,
black_list,
temperature=temperature,
top_p=0.95,
top_k=100,
max_new_tokens=256,
max_retry=20,
max_same_output=15,
seed=seed % SEED_MAX,
):
pass
tag_map["general"] += extra_tokens
prompt_by_dtg = apply_format(tag_map, format)
logger.info(
"Prompt processing done. General Tags Count: "
f"{len(tag_map['general'] + tag_map['special'])}"
f" | Total iterations: {iter_count}"
)
return prompt_by_dtg
if __name__ == "__main__":
# or whatever path you want to put your model file
models.model_dir = pathlib.Path(__file__).parent / "models"
# file = models.download_gguf(gguf_name="ggml-model-Q6_K.gguf")
files = models.list_gguf()
file = files[-1]
logger.info(f"Use gguf model from local file: {file}")
models.load_model(file, gguf=True, device="cpu")
# models.load_model()
# models.text_model.half().cuda()
prompt = """
1girl, Umamusume, ask (askzy), horse girl, masterpiece, absurdres, sensitive
"""
t0 = time.time_ns()
result = process(
prompt,
aspect_ratio=1.0,
seed=1,
tag_length=TOTAL_TAG_LENGTH["LONG"],
ban_tags="",
format=DEFAULT_FORMAT,
temperature=1.35,
)
t1 = time.time_ns()
logger.info(f"Result:\n{result}")
logger.info(f"Time cost: {(t1 - t0) / 10**6:.1f}ms")