-
Notifications
You must be signed in to change notification settings - Fork 3
/
diffusion.py
66 lines (58 loc) · 2.01 KB
/
diffusion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
"""
https://github.com/ProteinDesignLab/protpardelle
License: MIT
Author: Alex Chu
Noise and diffusion utils.
"""
from scipy.stats import norm
import torch
from torchtyping import TensorType
from core import utils
def noise_schedule(
time: TensorType[float],
function: str = "uniform",
sigma_data: float = 10.0,
psigma_mean: float = -1.2,
psigma_std: float = 1.2,
s_min: float = 0.001,
s_max: float = 60,
rho: float = 7.0,
time_power: float = 4.0,
constant_val: float = 0.0,
):
def sampling_noise(time):
# high noise = 1; low noise = 0. opposite of Karras et al. schedule
term1 = s_max ** (1 / rho)
term2 = (1 - time) * (s_min ** (1 / rho) - s_max ** (1 / rho))
noise_level = sigma_data * ((term1 + term2) ** rho)
return noise_level
if function == "lognormal":
normal_sample = torch.Tensor(norm.ppf(time.cpu())).to(time)
noise_level = sigma_data * torch.exp(psigma_mean + psigma_std * normal_sample)
elif function == "uniform":
noise_level = sampling_noise(time)
elif function == "mpnn":
time = time**time_power
noise_level = sampling_noise(time)
elif function == "constant":
noise_level = torch.ones_like(time) * constant_val
return noise_level
def noise_coords(
coords: TensorType["b n a x", float],
noise_level: TensorType["b", float],
dummy_fill_masked_atoms: bool = False,
atom_mask: TensorType["b n a"] = None,
):
# Does not apply atom mask after adding noise
if dummy_fill_masked_atoms:
assert atom_mask is not None
dummy_fill_mask = 1 - atom_mask
dummy_fill_value = coords[..., 1:2, :] # CA
# dummy_fill_value = utils.fill_in_cbeta_for_atom37(coords)[..., 3:4, :] # CB
coords = (
coords * atom_mask[..., None]
+ dummy_fill_value * dummy_fill_mask[..., None]
)
noise = torch.randn_like(coords) * utils.expand(noise_level, coords)
noisy_coords = coords + noise
return noisy_coords