-
Notifications
You must be signed in to change notification settings - Fork 7
/
InverseFaceNetEncoderPredict.py
71 lines (57 loc) · 2.07 KB
/
InverseFaceNetEncoderPredict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from InverseFaceNetEncoder import InverseFaceNetEncoder
from LoadDataset import LoadDataset
from ImageFormationLayer import ImageFormationLayer
import numpy as np
from FaceNet3D import FaceNet3D as Helpers
tf.compat.v1.enable_eager_execution()
class InverseFaceNetEncoderPredict(Helpers):
def __init__(self):
"""
Class initializer
"""
super().__init__()
self.latest = self.trained_models_dir + "cp-resnet50.ckpt"
print("Latest checkpoint: ", self.latest)
self.encoder = InverseFaceNetEncoder()
self.model = self.load_model()
def load_model(self):
"""
Load trained model and compile
:return: Compiled Keras model
"""
self.encoder.build_model()
model = self.encoder.model
model.load_weights(self.latest)
self.encoder.compile()
model = self.encoder.model
return model
def evaluate_model(self):
"""
Evaluate model on validation data
"""
with tf.device('/device:CPU:0'):
test_ds = LoadDataset().load_dataset_single_image(self._case)
loss, mse, mae = self.model.evaluate(test_ds)
print("\nRestored model, Loss: {0} \nMean Squared Error: {1}\n"
"Mean Absolute Error: {2}\n".format(loss, mse, mae))
def model_predict(self, image_path):
"""
Predict out of image_path
:param image_path: path
:return:
"""
image = LoadDataset().load_and_preprocess_image_4d(image_path)
x = self.model.predict(image)
return np.transpose(x)
@staticmethod
def calculate_decoder_output(x):
"""
Reconstruct image
:param x: <class 'numpy.ndarray'> with shape (self.scv_length, ) : semantic code vector
:return: <class 'numpy.ndarray'> with shape self.IMG_SHAPE
"""
decoder = ImageFormationLayer(x)
image = decoder.get_reconstructed_image()
return image