Skip to content

Commit

Permalink
add dwpose loss
Browse files Browse the repository at this point in the history
  • Loading branch information
yzd-v committed Aug 23, 2023
1 parent 9eef741 commit a2bd26a
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 31 deletions.
43 changes: 25 additions & 18 deletions mmpose/models/losses/fea_dis_loss.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,45 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F

from mmpose.registry import MODELS


@MODELS.register_module()
class FeaLoss(nn.Module):
"""PyTorch version of feature-based distillation from DWPose Modified from
the official implementation.
"""PyTorch version of feature-based distillation
<https://github.com/IDEA-Research/DWPose>
Args:
student_channels(int): Number of channels in the student's feature map.
teacher_channels(int): Number of channels in the teacher's feature map.
teacher_channels(int): Number of channels in the teacher's feature map.
alpha_fea (float, optional): Weight of dis_loss. Defaults to 0.00007
"""
def __init__(self,
name,
use_this,
student_channels,
teacher_channels,
alpha_fea=0.00007,
):

def __init__(
self,
name,
use_this,
student_channels,
teacher_channels,
alpha_fea=0.00007,
):
super(FeaLoss, self).__init__()
self.alpha_fea = alpha_fea

if teacher_channels != student_channels:
self.align = nn.Conv2d(student_channels, teacher_channels, kernel_size=1, stride=1, padding=0)
self.align = nn.Conv2d(
student_channels,
teacher_channels,
kernel_size=1,
stride=1,
padding=0)
else:
self.align = None

def forward(self,
preds_S,
preds_T):
def forward(self, preds_S, preds_T):
"""Forward function.
Args:
preds_S(Tensor): Bs*C*H*W, student's feature map
preds_T(Tensor): Bs*C*H*W, teacher's feature map
Expand All @@ -51,6 +58,6 @@ def get_dis_loss(self, preds_S, preds_T):
loss_mse = nn.MSELoss(reduction='sum')
N, C, H, W = preds_T.shape

dis_loss = loss_mse(preds_S, preds_T)/N*self.alpha_fea
dis_loss = loss_mse(preds_S, preds_T) / N * self.alpha_fea

return dis_loss
return dis_loss
31 changes: 18 additions & 13 deletions mmpose/models/losses/logit_dis_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,26 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

from mmpose.registry import MODELS


@MODELS.register_module()
class KDLoss(nn.Module):
"""PyTorch version of logit-based distillation from DWPose Modified from
the official implementation.
""" PyTorch version of KD for Pose """
<https://github.com/IDEA-Research/DWPose>
Args:
weight (float, optional): Weight of dis_loss. Defaults to 1.0
"""

def __init__(self,
name,
use_this,
weight=1.0,
):
def __init__(
self,
name,
use_this,
weight=1.0,
):
super(KDLoss, self).__init__()

self.log_softmax = nn.LogSoftmax(dim=1)
Expand All @@ -30,17 +38,14 @@ def forward(self, pred, pred_t, beta, target_weight):
num_joints = ls_x.size(1)
loss = 0

loss += (
self.loss(ls_x, lt_x, beta, target_weight))
loss += (
self.loss(ls_y, lt_y, beta, target_weight))
loss += (self.loss(ls_x, lt_x, beta, target_weight))
loss += (self.loss(ls_y, lt_y, beta, target_weight))

return loss / num_joints

def loss(self, logit_s, logit_t, beta, weight):

N = logit_s.shape[0]
Bins = logit_s.shape[-1]

if len(logit_s.shape) == 3:
K = logit_s.shape[1]
Expand All @@ -56,4 +61,4 @@ def loss(self, logit_s, logit_t, beta, weight):
loss_all = loss_all.reshape(N, K).sum(dim=1).mean()
loss_all = self.weight * loss_all

return loss_all
return loss_all

0 comments on commit a2bd26a

Please sign in to comment.