-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict.py
49 lines (35 loc) · 1.27 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from src import config
from src.models import Generator
from src.utils import load_model, get_device, denorm, configure_logger
import torch
from torchvision.utils import save_image
import os
from PIL import Image
# Get the logger for this module
logger = configure_logger(__name__)
# Get the default device
device = get_device()
def main() -> None:
# Load the generator model
generator = load_model(Generator,
model_path=os.path.join(config.MODELS_PATH, 'Generator_best.pth'),
device=device,
latent_size=config.LATENT_SIZE,
**config.GENR_PARAMS
).to(device)
logger.info(f'Generator model created succesfully and placed on device: {device}.')
# Initialize the latent space
latent = torch.randn(1, config.LATENT_SIZE, 1, 1, device=device)
# Generate image
with torch.inference_mode():
img_tensor = generator(latent).squeeze(dim=0)
logger.info(f'Generator created image succesfully.')
# Save image
image_path = os.path.join(config.IMAGES_PATH, 'prediction.jpg')
save_image(denorm(img_tensor, 0.5, 0.5), image_path, nrow=1)
logger.info(f'Image saved sucesfully on `{image_path}.`')
# SHow the image
image = Image.open(image_path)
image.show()
if __name__ == '__main__':
main()