-
Notifications
You must be signed in to change notification settings - Fork 0
/
usenetwork.py
78 lines (55 loc) · 1.85 KB
/
usenetwork.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import numpy as np
import glob
import os
import sys
import json
from pathlib import Path
from PIL import Image
# disable GPU
# os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
import tensorflow as tf
import tensorflow.keras as keras
MODEL_NAME = "model"
IN_DIR = Path("./netin")
OUT_DIR = Path("./netout")
LOAD_SIZE = 64
BAND_DIRS = sorted(list(IN_DIR.glob("B*")))
MAX_X = 255
def read_fname(fname):
bands = np.asarray([np.array(Image.open(band_dir / fname)) for band_dir in BAND_DIRS]) / MAX_X
mchannel = np.dstack(bands)
return mchannel
def rgb_transform(ds):
# return np.flip(ds, 3) # Bands 2, 3, 4 -> rgb
return np.flip(ds[:,:,:,1:4], 3) # Bands 1, 2, 3, 4, 5, 7 -> rgb
def readin_batch(band1pths):
img_lst = []
img_names = []
for imgpth in band1pths:
img = read_fname(imgpth.name)
img_lst.append(img)
img_names.append(imgpth.name)
return img_names, np.asarray(img_lst)
def divide_chunks(l, n):
# https://www.geeksforgeeks.org/break-list-chunks-size-n-python/
# looping till length l
for i in range(0, len(l), n):
yield l[i:i + n]
def write_out_img(name, img_data):
imga = tf.cast(img_data * 255, tf.uint8)
imgencoded = tf.image.encode_png(imga)
out_file = OUT_DIR / name
tf.io.write_file(out_file.as_posix(), imgencoded)
def main():
band1_paths = list(BAND_DIRS[0].glob("*.png"))
model = tf.keras.models.load_model(f'{MODEL_NAME}.h5')
batch_groups = list(divide_chunks(band1_paths, LOAD_SIZE))
for i, batch in enumerate(batch_groups):
print(f"Running inference on batch {i + 1} of {len(batch_groups)}")
# print(f"{len(batch)}")
names, x_pred = readin_batch(batch)
y_pred = model.predict(x_pred)
for name, pred in zip(names, y_pred):
write_out_img(name, pred)
if __name__ == '__main__':
main()