forked from bghira/SimpleTuner
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
66 lines (56 loc) · 2.08 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from helpers.training.trainer import Trainer
from helpers.training.state_tracker import StateTracker
from helpers import log_format
import logging
from os import environ
logger = logging.getLogger("SimpleTuner")
logger.setLevel(environ.get("SIMPLETUNER_LOG_LEVEL", "INFO"))
if __name__ == "__main__":
global bf
bf = None
trainer = None
try:
import multiprocessing
multiprocessing.set_start_method("fork")
except Exception as e:
logger.error(
"Failed to set the multiprocessing start method to 'fork'. Unexpected behaviour such as high memory overhead or poor performance may result."
f"\nError: {e}"
)
try:
trainer = Trainer()
trainer.configure_webhook()
trainer.init_noise_schedule()
trainer.init_seed()
trainer.init_huggingface_hub()
trainer.init_preprocessing_models()
trainer.init_data_backend()
trainer.init_validation_prompts()
trainer.init_unload_text_encoder()
trainer.init_unload_vae()
trainer.init_load_base_model()
trainer.init_precision()
trainer.init_controlnet_model()
trainer.init_freeze_models()
trainer.init_trainable_peft_adapter()
trainer.init_ema_model()
trainer.init_validations()
trainer.init_benchmark_base_model()
trainer.resume_and_prepare()
trainer.init_trackers()
trainer.train()
except KeyboardInterrupt:
if StateTracker.get_webhook_handler() is not None:
StateTracker.get_webhook_handler().send(
message="Training has been interrupted by user action (lost terminal, or ctrl+C)."
)
except Exception as e:
import traceback
if StateTracker.get_webhook_handler() is not None:
StateTracker.get_webhook_handler().send(
message=f"Training has failed. Please check the logs for more information: {e}"
)
print(e)
print(traceback.format_exc())
if trainer is not None and trainer.bf is not None:
bf.stop_fetching()