Skip to content

Commit

Permalink
Use reverse mode in the if, forward mode in the else.
Browse files Browse the repository at this point in the history
  • Loading branch information
Jerry-Jzy committed Dec 2, 2024
1 parent 6c258db commit 5917be1
Showing 1 changed file with 34 additions and 34 deletions.
68 changes: 34 additions & 34 deletions deepxde/data/pde_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,40 @@ def __init__(
def _losses(self, outputs, loss_fn, inputs, model, num_func, aux=None):
bcs_start = np.cumsum([0] + self.pde.num_bcs)

if config.autodiff == "forward": # forward mode AD
if config.autodiff == "reverse": # reverse mode AD
losses = []
for i in range(num_func):
out = outputs[i]
# Single output
if bkd.ndim(out) == 1:
out = out[:, None]
f = []
if self.pde.pde is not None:
f = self.pde.pde(inputs[1], out, model.net.auxiliary_vars[i][:, None])
if not isinstance(f, (list, tuple)):
f = [f]
error_f = [fi[bcs_start[-1]:] for fi in f]
losses_i = [loss_fn(bkd.zeros_like(error), error) for error in error_f]

for j, bc in enumerate(self.pde.bcs):
beg, end = bcs_start[j], bcs_start[j + 1]
# The same BC points are used for training and testing.
error = bc.error(
self.train_x[1],
inputs[1],
out,
beg,
end,
aux_var=model.net.auxiliary_vars[i][:, None],
)
losses_i.append(loss_fn(bkd.zeros_like(error), error))

losses.append(losses_i)

losses = zip(*losses)
# Use stack instead of as_tensor to keep the gradients.
losses = [bkd.reduce_mean(bkd.stack(loss, 0)) for loss in losses]
else: # forward mode AD
losses=[]

def forward_call(trunk_input):
Expand Down Expand Up @@ -274,39 +307,6 @@ def forward_call(trunk_input):
error_k = bkd.stack(error_k, axis=0) # noqa
loss_k = loss_fn(bkd.zeros_like(error_k), error_k) # noqa
losses.append(loss_k)
else: # reverse mode AD
losses = []
for i in range(num_func):
out = outputs[i]
# Single output
if bkd.ndim(out) == 1:
out = out[:, None]
f = []
if self.pde.pde is not None:
f = self.pde.pde(inputs[1], out, model.net.auxiliary_vars[i][:, None])
if not isinstance(f, (list, tuple)):
f = [f]
error_f = [fi[bcs_start[-1] :] for fi in f]
losses_i = [loss_fn(bkd.zeros_like(error), error) for error in error_f]

for j, bc in enumerate(self.pde.bcs):
beg, end = bcs_start[j], bcs_start[j + 1]
# The same BC points are used for training and testing.
error = bc.error(
self.train_x[1],
inputs[1],
out,
beg,
end,
aux_var=model.net.auxiliary_vars[i][:, None],
)
losses_i.append(loss_fn(bkd.zeros_like(error), error))

losses.append(losses_i)

losses = zip(*losses)
# Use stack instead of as_tensor to keep the gradients.
losses = [bkd.reduce_mean(bkd.stack(loss, 0)) for loss in losses]
return losses

def losses_train(self, targets, outputs, loss_fn, inputs, model, aux=None):
Expand Down

0 comments on commit 5917be1

Please sign in to comment.