-
Notifications
You must be signed in to change notification settings - Fork 557
/
arguments.py
163 lines (142 loc) · 4.83 KB
/
arguments.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import argparse
from dreambooth_helpers.joepenna_dreambooth_config import JoePennaDreamboothConfigSchemaV1
def parse_arguments() -> JoePennaDreamboothConfigSchemaV1:
def _get_parser(**parser_kwargs):
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError("Boolean value expected.")
parser = argparse.ArgumentParser(**parser_kwargs)
parser.add_argument(
"--config_file_path",
type=str,
required=False,
default=None,
help="A config file containing all of your variables"
)
parser.add_argument(
"--project_name",
type=str,
required=False,
default=None,
help="Name of the project"
)
parser.add_argument(
"--debug",
type=str2bool,
nargs="?",
const=True,
default=False,
help="Enable debug logging",
)
parser.add_argument(
"--seed",
type=int,
default=23,
help="seed for seed_everything",
)
parser.add_argument(
"--max_training_steps",
type=int,
required=False,
help="Number of training steps to run"
)
parser.add_argument(
"--token",
type=str,
required=False,
help="Unique token you want to represent your trained model. Ex: firstNameLastName."
)
parser.add_argument(
"--token_only",
type=str2bool,
const=True,
default=False,
nargs="?",
help="Train only using the token and no class."
)
parser.add_argument(
"--training_model",
type=str,
required=False,
help="Path to model to train (model.ckpt)"
)
parser.add_argument(
"--training_images",
type=str,
required=False,
help="Path to training images directory"
)
parser.add_argument(
"--regularization_images",
type=str,
required=False,
help="Path to directory with regularization images"
)
parser.add_argument(
"--class_word",
type=str,
required=False,
help="Match class_word to the category of images you want to train. Example: 'man', 'woman', 'dog', or 'artstyle'."
)
parser.add_argument(
"--flip_p",
type=float,
required=False,
default=0.5,
help="Flip Percentage "
"Example: if set to 0.5, will flip (mirror) your training images 50% of the time."
"This helps expand your dataset without needing to include more training images."
"This can lead to worse results for face training since most people's faces are not perfectly symmetrical."
)
parser.add_argument(
"--learning_rate",
type=float,
required=False,
default=1.0e-06,
help="Set the learning rate. Defaults to 1.0e-06 (0.000001). Accepts scientific notation."
)
parser.add_argument(
"--save_every_x_steps",
type=int,
required=False,
default=0,
help="Saves a checkpoint every x steps"
)
parser.add_argument(
"--gpu",
type=int,
default=0,
required=False,
help="Specify a GPU other than 0 to use for training. Multi-GPU support is not currently implemented."
)
return parser
parser = _get_parser()
opt, unknown = parser.parse_known_args()
config = JoePennaDreamboothConfigSchemaV1()
if opt.config_file_path is not None:
config.saturate_from_file(config_file_path=opt.config_file_path)
else:
config.saturate(
project_name=opt.project_name,
seed=opt.seed,
debug=opt.debug,
gpu=opt.gpu,
max_training_steps=opt.max_training_steps,
save_every_x_steps=opt.save_every_x_steps,
training_images_folder_path=opt.training_images,
regularization_images_folder_path=opt.regularization_images,
token=opt.token,
token_only=opt.token_only,
class_word=opt.class_word,
flip_percent=opt.flip_p,
learning_rate=opt.learning_rate,
model_repo_id='',
model_path=opt.training_model,
)
return config