-
Notifications
You must be signed in to change notification settings - Fork 36
/
wandb_utils.py
55 lines (42 loc) · 1.43 KB
/
wandb_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import wandb
import torch
from torchvision.utils import make_grid
import torch.distributed as dist
from PIL import Image
import os
import argparse
import hashlib
import math
def is_main_process():
return dist.get_rank() == 0
def namespace_to_dict(namespace):
return {
k: namespace_to_dict(v) if isinstance(v, argparse.Namespace) else v
for k, v in vars(namespace).items()
}
def generate_run_id(exp_name):
# https://stackoverflow.com/questions/16008670/how-to-hash-a-string-into-8-digits
return str(int(hashlib.sha256(exp_name.encode('utf-8')).hexdigest(), 16) % 10 ** 8)
def initialize(args, entity, exp_name, project_name):
config_dict = namespace_to_dict(args)
wandb.login(key=os.environ["WANDB_KEY"])
wandb.init(
entity=entity,
project=project_name,
name=exp_name,
config=config_dict,
id=generate_run_id(exp_name),
resume="allow",
)
def log(stats, step=None):
if is_main_process():
wandb.log({k: v for k, v in stats.items()}, step=step)
def log_image(sample, step=None):
if is_main_process():
sample = array2grid(sample)
wandb.log({f"samples": wandb.Image(sample), "train_step": step})
def array2grid(x):
nrow = round(math.sqrt(x.size(0)))
x = make_grid(x, nrow=nrow, normalize=True, value_range=(-1,1))
x = x.mul(255).add_(0.5).clamp_(0,255).permute(1,2,0).to('cpu', torch.uint8).numpy()
return x