Skip to content

Commit

Permalink
modify code
Browse files Browse the repository at this point in the history
  • Loading branch information
Jerry-Jzy committed Dec 18, 2024
1 parent 5ef4f4f commit 66682a4
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions deepxde/data/pde_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ def __init__(

def _losses(self, outputs, loss_fn, inputs, model, num_func, aux=None):
bcs_start = np.cumsum([0] + self.pde.num_bcs)

losses = []
# PDE loss
if config.autodiff == "reverse": # reverse mode AD
Expand Down Expand Up @@ -277,10 +278,11 @@ def forward_call(trunk_input):
error_f = [fi[:, bcs_start[-1] :] for fi in f]
# Each error has the shape (N1, ~N2)
losses = [loss_fn(bkd.zeros_like(error), error) for error in error_f]

# BC loss
error_bc = []
for i in range(num_func):
error_i = []
losses_i = []
out = outputs[i]
if bkd.ndim(out) == 1:
out = out[:, None]
Expand All @@ -294,9 +296,9 @@ def forward_call(trunk_input):
end,
aux_var=model.net.auxiliary_vars[i][:, None],
)
error_i.append(loss_fn(bkd.zeros_like(error), error))
losses_i.append(loss_fn(bkd.zeros_like(error), error))

error_bc.append(error_i)
error_bc.append(losses_i)

error_bc = zip(*error_bc)
error_bc = [bkd.reduce_mean(bkd.stack(error, 0)) for error in error_bc]
Expand Down

0 comments on commit 66682a4

Please sign in to comment.