Skip to content

Commit

Permalink
Update tensorboard logging
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher committed Oct 26, 2020
1 parent a20ff8f commit 4d0cd9d
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 8 deletions.
4 changes: 2 additions & 2 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,9 @@ def test(data,

# Plot images
if plots and batch_i < 1:
f = save_dir / ('test_batch%g_gt.jpg' % batch_i) # filename
f = save_dir / f'test_batch{batch_i}_gt.jpg' # filename
plot_images(img, targets, paths, str(f), names) # ground truth
f = save_dir / ('test_batch%g_pred.jpg' % batch_i)
f = save_dir / f'test_batch{batch_i}_pred.jpg'
plot_images(img, output_to_target(output, width, height), paths, str(f), names) # predictions

# Compute statistics
Expand Down
8 changes: 4 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,11 +291,11 @@ def train(hyp, opt, device, tb_writer=None):

# Plot
if ni < 3:
f = str(log_dir / ('train_batch%g.jpg' % ni)) # filename
f = str(log_dir / f'train_batch{ni}.jpg') # filename
result = plot_images(images=imgs, targets=targets, paths=paths, fname=f)
if tb_writer and result is not None:
tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
# tb_writer.add_graph(model, imgs) # add model to tensorboard
# if tb_writer and result is not None:
# tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
# tb_writer.add_graph(model, imgs) # add model to tensorboard

# end batch ------------------------------------------------------------------------------------------------

Expand Down
5 changes: 3 additions & 2 deletions utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import torch
import torch.nn as nn
import yaml
from PIL import Image
from scipy.cluster.vq import kmeans
from scipy.signal import butter, filtfilt
from tqdm import tqdm
Expand Down Expand Up @@ -1096,8 +1097,8 @@ def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max

if fname is not None:
mosaic = cv2.resize(mosaic, (int(ns * w * 0.5), int(ns * h * 0.5)), interpolation=cv2.INTER_AREA)
cv2.imwrite(fname, cv2.cvtColor(mosaic, cv2.COLOR_BGR2RGB))

# cv2.imwrite(fname, cv2.cvtColor(mosaic, cv2.COLOR_BGR2RGB)) # cv2 save
Image.fromarray(mosaic).save(fname) # PIL save
return mosaic


Expand Down

0 comments on commit 4d0cd9d

Please sign in to comment.