Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: put tensors on device during creation #103

Merged
merged 1 commit into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/2D_tutorials/SF2M_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@
"with torch.no_grad():\n",
" traj = node.trajectory(\n",
" x0,\n",
" t_span=torch.linspace(0, 1, 100).to(device),\n",
" t_span=torch.linspace(0, 1, 100, device=device),\n",
" )\n",
"\n",
"\n",
Expand Down
6 changes: 3 additions & 3 deletions examples/images/cifar10/compute_fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,14 @@

def gen_1_img(unused_latent):
with torch.no_grad():
x = torch.randn(500, 3, 32, 32).to(device)
x = torch.randn(500, 3, 32, 32, device=device)
if FLAGS.integration_method == "euler":
print("Use method: ", FLAGS.integration_method)
t_span = torch.linspace(0, 1, FLAGS.integration_steps + 1).to(device)
t_span = torch.linspace(0, 1, FLAGS.integration_steps + 1, device=device)
traj = node.trajectory(x, t_span=t_span)
else:
print("Use method: ", FLAGS.integration_method)
t_span = torch.linspace(0, 1, 2).to(device)
t_span = torch.linspace(0, 1, 2, device=device)
traj = odeint(
new_net, x, t_span, rtol=FLAGS.tol, atol=FLAGS.tol, method=FLAGS.integration_method
)
Expand Down
4 changes: 2 additions & 2 deletions examples/images/cifar10/utils_cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ def generate_samples(model, parallel, savedir, step, net_="normal"):
node_ = NeuralODE(model_, solver="euler", sensitivity="adjoint")
with torch.no_grad():
traj = node_.trajectory(
torch.randn(64, 3, 32, 32).to(device),
t_span=torch.linspace(0, 1, 100).to(device),
torch.randn(64, 3, 32, 32, device=device),
t_span=torch.linspace(0, 1, 100, device=device),
)
traj = traj[-1, :].view([-1, 3, 32, 32]).clip(-1, 1)
traj = traj / 2 + 0.5
Expand Down
36 changes: 18 additions & 18 deletions examples/images/conditional_mnist.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -117,21 +117,21 @@
],
"source": [
"USE_TORCH_DIFFEQ = True\n",
"generated_class_list = torch.arange(10).repeat(10).to(device)\n",
"generated_class_list = torch.arange(10, device=device).repeat(10)\n",
"with torch.no_grad():\n",
" if USE_TORCH_DIFFEQ:\n",
" traj = torchdiffeq.odeint(\n",
" lambda t, x: model.forward(t, x, generated_class_list),\n",
" torch.randn(100, 1, 28, 28).to(device),\n",
" torch.linspace(0, 1, 2).to(device),\n",
" torch.randn(100, 1, 28, 28, device=device),\n",
" torch.linspace(0, 1, 2, device=device),\n",
" atol=1e-4,\n",
" rtol=1e-4,\n",
" method=\"dopri5\",\n",
" )\n",
" else:\n",
" traj = node.trajectory(\n",
" torch.randn(100, 1, 28, 28).to(device),\n",
" t_span=torch.linspace(0, 1, 2).to(device),\n",
" torch.randn(100, 1, 28, 28, device=device),\n",
" t_span=torch.linspace(0, 1, 2, device=device),\n",
" )\n",
"grid = make_grid(\n",
" traj[-1, :100].view([-1, 1, 28, 28]).clip(-1, 1), value_range=(-1, 1), padding=0, nrow=10\n",
Expand Down Expand Up @@ -206,21 +206,21 @@
],
"source": [
"USE_TORCH_DIFFEQ = True\n",
"generated_class_list = torch.arange(10).repeat(10).to(device)\n",
"generated_class_list = torch.arange(10, device=device).repeat(10)\n",
"with torch.no_grad():\n",
" if USE_TORCH_DIFFEQ:\n",
" traj = torchdiffeq.odeint(\n",
" lambda t, x: model.forward(t, x, generated_class_list),\n",
" torch.randn(100, 1, 28, 28).to(device),\n",
" torch.linspace(0, 1, 2).to(device),\n",
" torch.randn(100, 1, 28, 28, device=device),\n",
" torch.linspace(0, 1, 2, device=device),\n",
" atol=1e-4,\n",
" rtol=1e-4,\n",
" method=\"dopri5\",\n",
" )\n",
" else:\n",
" traj = node.trajectory(\n",
" torch.randn(100, 1, 28, 28).to(device),\n",
" t_span=torch.linspace(0, 1, 2).to(device),\n",
" torch.randn(100, 1, 28, 28, device=device),\n",
" t_span=torch.linspace(0, 1, 2, device=device),\n",
" )\n",
"grid = make_grid(\n",
" traj[-1, :100].view([-1, 1, 28, 28]).clip(-1, 1), value_range=(-1, 1), padding=0, nrow=10\n",
Expand Down Expand Up @@ -326,24 +326,24 @@
],
"source": [
"USE_TORCH_DIFFEQ = True\n",
"generated_class_list = torch.arange(10).repeat(10).to(device)\n",
"generated_class_list = torch.arange(10, device=device).repeat(10)\n",
"\n",
"node = NeuralODE(model, solver=\"euler\", sensitivity=\"adjoint\", atol=1e-4, rtol=1e-4)\n",
"# Evaluate the ODE\n",
"with torch.no_grad():\n",
" if USE_TORCH_DIFFEQ:\n",
" traj = torchdiffeq.odeint(\n",
" lambda t, x: model.forward(t, x, generated_class_list),\n",
" torch.randn(100, 1, 28, 28).to(device),\n",
" torch.linspace(0, 1, 2).to(device),\n",
" torch.randn(100, 1, 28, 28, device=device),\n",
" torch.linspace(0, 1, 2, device=device),\n",
" atol=1e-4,\n",
" rtol=1e-4,\n",
" method=\"dopri5\",\n",
" )\n",
" else:\n",
" traj = node.trajectory(\n",
" torch.randn(100, 1, 28, 28).to(device),\n",
" t_span=torch.linspace(0, 1, 2).to(device),\n",
" torch.randn(100, 1, 28, 28, device=device),\n",
" t_span=torch.linspace(0, 1, 2, device=device),\n",
" )\n",
"grid = make_grid(\n",
" traj[-1, :100].view([-1, 1, 28, 28]).clip(-1, 1), value_range=(-1, 1), padding=0, nrow=10\n",
Expand Down Expand Up @@ -395,13 +395,13 @@
"metadata": {},
"outputs": [],
"source": [
"sde = SDE(model, score_model, labels=torch.arange(10).repeat(10).to(device), sigma=0.1)\n",
"sde = SDE(model, score_model, labels=torch.arange(10, device=device).repeat(10), sigma=0.1)\n",
"with torch.no_grad():\n",
" sde_traj = torchsde.sdeint(\n",
" sde,\n",
" # x0.view(x0.size(0), -1),\n",
" torch.randn(100, 1 * 28 * 28).to(device),\n",
" ts=torch.linspace(0, 1, 2).to(device),\n",
" torch.randn(100, 1 * 28 * 28, device=device),\n",
" ts=torch.linspace(0, 1, 2, device=device),\n",
" dt=0.01,\n",
" )"
]
Expand Down
12 changes: 6 additions & 6 deletions examples/images/mnist_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@
"source": [
"with torch.no_grad():\n",
" traj = node.trajectory(\n",
" torch.randn(100, 1, 28, 28).to(device),\n",
" t_span=torch.linspace(0, 1, 2).to(device),\n",
" torch.randn(100, 1, 28, 28, device=device),\n",
" t_span=torch.linspace(0, 1, 2, device=device),\n",
" )\n",
"grid = make_grid(\n",
" traj[-1, :100].view([-1, 1, 28, 28]).clip(-1, 1), value_range=(-1, 1), padding=0, nrow=10\n",
Expand Down Expand Up @@ -236,8 +236,8 @@
"# Evaluate the ODE\n",
"with torch.no_grad():\n",
" traj = node.trajectory(\n",
" torch.randn(100, 1, 28, 28).to(device),\n",
" t_span=torch.linspace(0, 1, 1000).to(device),\n",
" torch.randn(100, 1, 28, 28, device=device),\n",
" t_span=torch.linspace(0, 1, 1000, device=device),\n",
" )\n",
"grid = make_grid(\n",
" traj[-1, :100].view([-1, 1, 28, 28]).clip(-1, 1), value_range=(-1, 1), padding=0, nrow=10\n",
Expand Down Expand Up @@ -290,8 +290,8 @@
" sde_traj = torchsde.sdeint(\n",
" sde,\n",
" # x0.view(x0.size(0), -1),\n",
" torch.randn(50, 1 * 28 * 28).to(device),\n",
" ts=torch.linspace(0, 1, 2).to(device),\n",
" torch.randn(50, 1 * 28 * 28, device=device),\n",
" ts=torch.linspace(0, 1, 2, device=device),\n",
" dt=0.01,\n",
" )"
]
Expand Down
12 changes: 6 additions & 6 deletions examples/single_cell/single-cell_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@
"with torch.no_grad():\n",
" traj = node.trajectory(\n",
" x0.to(device),\n",
" t_span=torch.linspace(0, n_times - 1, 400).to(device),\n",
" t_span=torch.linspace(0, n_times - 1, 400, device=device),\n",
" ).cpu()"
]
},
Expand Down Expand Up @@ -529,7 +529,7 @@
" sde_traj = torchsde.sdeint(\n",
" sde,\n",
" x0.to(device),\n",
" ts=torch.linspace(0, n_times - 1, 400).to(device),\n",
" ts=torch.linspace(0, n_times - 1, 400, device=device),\n",
" ).cpu()"
]
},
Expand Down Expand Up @@ -594,7 +594,7 @@
"with torch.no_grad():\n",
" traj = node.trajectory(\n",
" x0[2].repeat(20).view(20, 2).to(device),\n",
" t_span=torch.linspace(0, n_times - 1, 400).to(device),\n",
" t_span=torch.linspace(0, n_times - 1, 400, device=device),\n",
" ).cpu()\n",
"# plot_trajectories(traj.cpu().numpy())"
]
Expand All @@ -609,7 +609,7 @@
" sde_traj = torchsde.sdeint(\n",
" sde,\n",
" x0[2].repeat(20).view(20, 2).to(device),\n",
" ts=torch.linspace(0, n_times - 1, 400).to(device),\n",
" ts=torch.linspace(0, n_times - 1, 400, device=device),\n",
" ).cpu()"
]
},
Expand Down Expand Up @@ -731,14 +731,14 @@
"with torch.no_grad():\n",
" traj = node.trajectory(\n",
" x0[1].repeat(15).view(15, 2).to(device),\n",
" t_span=torch.linspace(1, n_times - 1, 300).to(device),\n",
" t_span=torch.linspace(1, n_times - 1, 300, device=device),\n",
" ).cpu()\n",
"\n",
"with torch.no_grad():\n",
" sde_traj = torchsde.sdeint(\n",
" sde,\n",
" x0[1].repeat(15).view(15, 2).to(device),\n",
" ts=torch.linspace(1, n_times - 1, 300).to(device),\n",
" ts=torch.linspace(1, n_times - 1, 300, device=device),\n",
" ).cpu()\n",
"\n",
"traj = traj.detach().cpu().numpy()"
Expand Down
10 changes: 5 additions & 5 deletions runner/src/models/cfm_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@

def preprocess_batch(self, X, training=False):
"""Converts a batch of data into matched a random pair of (x0, x1)"""
t_select = torch.zeros(1).to(X.device)
t_select = torch.zeros(1, device=X.device)
if self.is_trajectory:
batch_size, times, dim = X.shape
if not hasattr(self.datamodule, "HAS_JOINT_PLANS"):
Expand All @@ -168,7 +168,7 @@

if training and self.hparams.leaveout_timepoint > 0:
# Select random except for the leftout timepoint
t_select = torch.randint(times - 2, size=(batch_size,)).to(X.device)
t_select = torch.randint(times - 2, size=(batch_size,), device=X.device)

Check warning on line 171 in runner/src/models/cfm_module.py

View check run for this annotation

Codecov / codecov/patch

runner/src/models/cfm_module.py#L171

Added line #L171 was not covered by tests
t_select[t_select >= self.hparams.leaveout_timepoint] += 1
else:
t_select = torch.randint(times - 1, size=(batch_size,))
Expand Down Expand Up @@ -623,12 +623,12 @@

def preprocess_batch(self, X, training=False):
"""Converts a batch of data into matched a random pair of (x0, x1)"""
t_select = torch.zeros(1).to(X.device)
t_select = torch.zeros(1, device=X.device)

Check warning on line 626 in runner/src/models/cfm_module.py

View check run for this annotation

Codecov / codecov/patch

runner/src/models/cfm_module.py#L626

Added line #L626 was not covered by tests
if self.is_trajectory:
batch_size, times, dim = X.shape
if training and self.hparams.leaveout_timepoint > 0:
# Select random except for the leftout timepoint
t_select = torch.randint(times - 2, size=(batch_size,)).to(X.device)
t_select = torch.randint(times - 2, size=(batch_size,), device=X.device)

Check warning on line 631 in runner/src/models/cfm_module.py

View check run for this annotation

Codecov / codecov/patch

runner/src/models/cfm_module.py#L631

Added line #L631 was not covered by tests
t_select[t_select >= self.hparams.leaveout_timepoint] += 1
else:
t_select = torch.randint(times - 1, size=(batch_size,))
Expand Down Expand Up @@ -1011,7 +1011,7 @@
# Randomly sample a batch from the stored data.
idx = torch.randint(self.stored_data.shape[0], size=(X.shape[0],))
X = self.stored_data[idx]
t_select = torch.zeros(1).to(X.device)
t_select = torch.zeros(1, device=X.device)

Check warning on line 1014 in runner/src/models/cfm_module.py

View check run for this annotation

Codecov / codecov/patch

runner/src/models/cfm_module.py#L1014

Added line #L1014 was not covered by tests
return X[:, 0], X[:, 1], t_select
return super().preprocess_batch(X, training)

Expand Down
6 changes: 4 additions & 2 deletions runner/src/models/components/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,10 @@ def timestep_embedding(timesteps, dim, max_period=10000):
"""
half = dim // 2
freqs = th.exp(
-math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
).to(device=timesteps.device)
-math.log(max_period)
* th.arange(start=0, end=half, dtype=th.float32, device=timesteps.device)
/ half
)
args = timesteps[:, None].float() * freqs[None]
embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
if dim % 2:
Expand Down
8 changes: 5 additions & 3 deletions runner/src/models/components/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@
points_real = 50
Y, X, T = np.mgrid[wmin:wmax:points, wmin:wmax:points, 0 : ts - 1 : 7j]
gridpoints = torch.tensor(
np.stack([X.flatten(), Y.flatten()], axis=1), requires_grad=True
np.stack([X.flatten(), Y.flatten()], axis=1), requires_grad=True, device=device
).type(torch.float32)
times = torch.tensor(T.flatten(), requires_grad=True).type(torch.float32)[:, None]
out = model(times.to(device), gridpoints.to(device))
times = torch.tensor(T.flatten(), requires_grad=True, device=device).type(torch.float32)[

Check warning on line 36 in runner/src/models/components/plotting.py

View check run for this annotation

Codecov / codecov/patch

runner/src/models/components/plotting.py#L36

Added line #L36 was not covered by tests
:, None
]
out = model(times, gridpoints)

Check warning on line 39 in runner/src/models/components/plotting.py

View check run for this annotation

Codecov / codecov/patch

runner/src/models/components/plotting.py#L39

Added line #L39 was not covered by tests
out = out.reshape([points_real, points_real, 7, dim])
out = out.cpu().detach().numpy()
# Stream over time
Expand Down
6 changes: 4 additions & 2 deletions torchcfm/models/unet/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,10 @@ def timestep_embedding(timesteps, dim, max_period=10000):
"""
half = dim // 2
freqs = th.exp(
-math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
).to(device=timesteps.device)
-math.log(max_period)
* th.arange(start=0, end=half, dtype=th.float32, device=timesteps.device)
/ half
)
args = timesteps[:, None].float() * freqs[None]
embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
if dim % 2:
Expand Down
Loading