Skip to content

Commit

Permalink
fix: preprocessing.
Browse files Browse the repository at this point in the history
  • Loading branch information
sayakpaul committed Oct 11, 2022
1 parent 60d9980 commit 12df753
Show file tree
Hide file tree
Showing 10 changed files with 18 additions and 43 deletions.
6 changes: 3 additions & 3 deletions convert_to_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,9 @@ def main(args):
tf_params_path = f"{variant}_{task.lower()}_{dataset_name}.h5"
tf_model.save_weights(tf_params_path)
print(f"Model params serialized to {tf_params_path}.")
saved_model_path = tf_params_path.replace(".h5", "")
tf_model.save(saved_model_path)
print(f"SavedModel serialized to {saved_model_path}.")
# saved_model_path = tf_params_path.replace(".h5", "")
# tf_model.save(saved_model_path)
# print(f"SavedModel serialized to {saved_model_path}.")


def parse_args():
Expand Down
Binary file removed images/Results/0003_30.png
Binary file not shown.
Binary file removed images/Results/0011_23.png
Binary file not shown.
Binary file removed images/Results/0013_19.png
Binary file not shown.
Binary file removed images/Results/0039_04.png
Binary file not shown.
Binary file added images/Results/1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/Results/111.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/Results/748.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/Results/a4541-DSC_0040-2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
55 changes: 15 additions & 40 deletions run_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Run evaluation."""
"""Modified from https://github.com/google-research/maxim/blob/main/maxim/run_eval.py"""

import collections
import os

import numpy as np
Expand Down Expand Up @@ -52,33 +51,6 @@
_IMG_SIZE = 256


def recover_tree(keys, values):
"""Recovers a tree as a nested dict from flat names and values.
This function is useful to analyze checkpoints that are saved by our programs
without need to access the exact source code of the experiment. In particular,
it can be used to extract an reuse various subtrees of the scheckpoint, e.g.
subtree of parameters.
Args:
keys: a list of keys, where '/' is used as separator between nodes.
values: a list of leaf values.
Returns:
A nested tree-like dict.
"""
tree = {}
sub_trees = collections.defaultdict(list)
for k, v in zip(keys, values):
if "/" not in k:
tree[k] = v
else:
k_left, k_right = k.split("/", 1)
sub_trees[k_left].append((k_right, v))
for k, kv_pairs in sub_trees.items():
k_subtree, v_subtree = zip(*kv_pairs)
tree[k] = recover_tree(k_subtree, v_subtree)
return tree


def mod_padding_symmetric(image, factor=64):
"""Padding the image to be divided by factor."""
height, width = image.shape[0], image.shape[1]
Expand All @@ -93,16 +65,21 @@ def mod_padding_symmetric(image, factor=64):
return image


# def get_params(ckpt_path):
# """Get params checkpoint."""
# Since the model was not initialized to take variable-length sizes (None, None, 3),
# we need to be careful about how we are resizing the images.
# From https://www.tensorflow.org/lite/examples/style_transfer/overview#pre-process_the_inputs
def resize_image(image, target_dim):
# Resize the image so that the shorter dimension becomes `target_dim`.
shape = tf.cast(tf.shape(image)[1:-1], tf.float32)
short_dim = min(shape)
scale = target_dim / short_dim
new_shape = tf.cast(shape * scale, tf.int32)
image = tf.image.resize(image, new_shape)

# with tf.io.gfile.GFile(ckpt_path, 'rb') as f:
# data = f.read()
# values = np.load(io.BytesIO(data))
# params = recover_tree(*zip(*values.items()))
# params = params['opt']['target']
# Central crop the image.
image = tf.image.resize_with_crop_or_pad(image, target_dim, target_dim)

# return params
return image


def calculate_psnr(img1, img2, crop_border, test_y_channel=False):
Expand Down Expand Up @@ -381,9 +358,7 @@ def _process_file(i):
else:
input_img = np.expand_dims(input_img, axis=0)

input_img = tf.keras.layers.CenterCrop(_IMG_SIZE, _IMG_SIZE)(
tf.convert_to_tensor(input_img)
)
input_img = resize_image(tf.convert_to_tensor(input_img), _IMG_SIZE)

# handle multi-stage outputs, obtain the last scale output of last stage
preds = model.predict(input_img)
Expand Down

0 comments on commit 12df753

Please sign in to comment.