-
Notifications
You must be signed in to change notification settings - Fork 47
/
emotic.py
28 lines (25 loc) · 1019 Bytes
/
emotic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
import torch.nn as nn
class Emotic(nn.Module):
''' Emotic Model'''
def __init__(self, num_context_features, num_body_features):
super(Emotic,self).__init__()
self.num_context_features = num_context_features
self.num_body_features = num_body_features
self.fc1 = nn.Linear((self.num_context_features + num_body_features), 256)
self.bn1 = nn.BatchNorm1d(256)
self.d1 = nn.Dropout(p=0.5)
self.fc_cat = nn.Linear(256, 26)
self.fc_cont = nn.Linear(256, 3)
self.relu = nn.ReLU()
def forward(self, x_context, x_body):
context_features = x_context.view(-1, self.num_context_features)
body_features = x_body.view(-1, self.num_body_features)
fuse_features = torch.cat((context_features, body_features), 1)
fuse_out = self.fc1(fuse_features)
fuse_out = self.bn1(fuse_out)
fuse_out = self.relu(fuse_out)
fuse_out = self.d1(fuse_out)
cat_out = self.fc_cat(fuse_out)
cont_out = self.fc_cont(fuse_out)
return cat_out, cont_out