-
Notifications
You must be signed in to change notification settings - Fork 0
/
adv_alstm.py
165 lines (127 loc) · 6.99 KB
/
adv_alstm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
# -*- coding: utf-8 -*-
# @author : RenKai (intern in HIGGS ASSET)
# @time : 4/24/24 7:15 PM
#
# pylint: disable=no-member
""" The Comparison Methods 8: ADV-ALSTM.
Ref. https://github.com/microsoft/qlib/blob/main/qlib/contrib/model/pytorch_alstm.py#L294
Ref. https://zhuanlan.zhihu.com/p/566172868
NOTE:
- In fact the ALSTM_Net in ADV-ALSTM is totally same with ALSTM_Net in ALSTM.
- The difference is training way
"""
import logging
import torch
from torch import nn
from typing import Dict
class ALSTM_Net(nn.Module):
""" The 2 Layer ALSTM. hidden_size=64. """
def __init__(
self, input_size: int, hidden_size: int = 64, num_layers: int = 2,
dropout: float = 0.0, epsilon: float = 1e-2, device: torch.device = torch.device("cpu")
):
""" The init function of ALSTM Net.
:param input_size: input size for each time step
:param hidden_size: hidden size of gru
:param num_layers: the num of gru layers
:param dropout: the dropout ratio
:param epsilon: the epsilon when generate adv
:param device: the computing device
"""
super(ALSTM_Net, self).__init__()
self.epsilon = epsilon
self.device = device
# ---- Log the info of ALSTM ---- #
logging.info(f"|||| Using ADV-ALSTM Now ! input_size={input_size}, hidden_size={hidden_size}, num_layers={num_layers}, dropout_ratio={dropout}||||")
# ---- Part 1. Feature Encoding Net ---- #
self.feature_mapping_net = nn.Sequential().to(device=device)
self.feature_mapping_net.add_module("fc_in", nn.Linear(in_features=input_size, out_features=hidden_size).to(device=device))
self.feature_mapping_net.add_module("act", nn.Tanh().to(device=device))
# ---- Part 2. LSTM module ---- #
self.lstm = nn.LSTM(
input_size=hidden_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, dropout=dropout,
).to(device=device)
# ---- Part 3. Attention Net ---- #
self.attention_net = nn.Sequential().to(device=device)
self.attention_net.add_module("att_fc_in", nn.Linear(in_features=hidden_size, out_features=int(hidden_size / 2)).to(device=device))
self.attention_net.add_module("att_dropout", torch.nn.Dropout(dropout).to(device=device))
self.attention_net.add_module("att_act", nn.Tanh().to(device=device))
self.attention_net.add_module("att_fc_out", nn.Linear(in_features=int(hidden_size / 2), out_features=1, bias=False).to(device=device))
self.attention_net.add_module("att_softmax", nn.Softmax(dim=1).to(device=device))
# ---- Part 4. The output fully connect layer ---- #
self.fc_out = nn.Linear(in_features=hidden_size * 2, out_features=1).to(device=device)
def forward(self, mul_granularity_input: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
""" The forward function of GRU Net.
:param mul_granularity_input: the input multi granularity, a dict with the format:
{
"g1": feature_g1,
"g2": feature_g2,
...,
"gG": feature_gG
}
returns: output, a dict with format:
{"pred" : the prediction result, shape=(bs, 1)}
"""
# ---- Step 1. Get the feature ---- #
# g1 feature (coarsest), shape=(bs, T, K^g1, D)
feature_g1 = mul_granularity_input["g1"].to(dtype=torch.float32, device=self.device)
# get the feature shape
bs, T, K_g1, d = feature_g1.shape[0], feature_g1.shape[1], feature_g1.shape[2], feature_g1.shape[3]
# ---- Step 2. Preprocess the input for encoding ---- #
feature_g1 = feature_g1.reshape(bs, T, K_g1 * d) # reshape, shape=(bs, T, K^g1*D)
# ---- Step 3. Encoding the feature ---- #
mapping_feature_g1 = self.feature_mapping_net(feature_g1) # shape=(bs, T, hidden_size)
# ---- Step 4. Using the lstm to do the sequence encoding ---- #
lstm_out_g1, _ = self.lstm(mapping_feature_g1) # (bs, T, hidden_size)
# ---- Step 5. Computing the attention score and weighting ---- #
attention_score_g1 = self.attention_net(lstm_out_g1) # shape=(bs, T, 1)
lstm_out_att_g1 = torch.mul(lstm_out_g1, attention_score_g1) # use the attention score to weight, shape=(bs, T, hidden_size)
lstm_out_att_g1 = torch.sum(lstm_out_att_g1, dim=1) # sum the hidden feature of each step, shape=(bs, 1, hidden_size)
# ---- Step 6. FC to get the prediction ---- #
# get the last step hidden g1
last_step_hidden_g1 = lstm_out_g1[:, -1, :] # shape=(bs, hidden_size)
# get the e_s of g1
e = torch.cat((last_step_hidden_g1, lstm_out_att_g1), dim=1) # shape=(bs, 2 * hidden_size)
# get the prediction
y = self.fc_out(e) # shape=(bs, 1)
# ---- Step 7. Return the output ---- #
output = {"pred": y, "e": e}
return output
def get_adv(self, e: torch.Tensor, y_true: torch.Tensor, weight: torch.Tensor, criterion):
""" Get the e_adv for adv training.
:param e: the e of raw ALSTM, shape=(bs, 2 * hidden_size)
:param y_true: the true label of time series prediction, shape=(bs, 1)
:param weight: the weight indicates item meaningful or meaningless, shape=(bs, 1)
:param criterion: the loss function of y_pred and y_true
return:
- y_adv: the adv prediction
"""
# ---- Step 1. Compute the g^s ---- #
e.retain_grad() # keep the grad of e
y_pred = self.fc_out(e) # get the y_pred for loss, shape=(bs, 1)
loss = criterion(y_true, y_pred, weight) # compute the loss, a number
g = torch.autograd.grad(outputs=loss, inputs=e, grad_outputs=None)[0] # number to metrix grad, get a metrix, shape=(bs, 2*hidden_size)
# ---- Step 2. Compute the ||g|| ---- #
g_norm = torch.sqrt(torch.norm(g, p=2)) # a number
# ---- Step 3. Get the y_adv ---- #
if g_norm == 0: # if g_norm == 0, then `y_adv = y_pred`
y_adv = y_pred # get the y_adv, shape=(bs, 1)
else: # else based on the function in paper to compute y_adv
r_adv = self.epsilon * g / g_norm # get the r_adv, shape=(bs, 2*hidden_size)
e_adv = e + r_adv.detach() # get the e_adv, shape=(bs, 2*hidden_size)
y_adv = self.fc_out(e_adv) # get the y_adv, shape=(bs, 1)
return y_adv
if __name__ == "__main__": # A demo of SFM
bath_size, time_steps, D = 16, 4, 1
mg_input = {
"g1": torch.ones((bath_size, time_steps, 1, D)),
"g2": torch.ones((bath_size, time_steps, 2, D)),
"g3": torch.ones((bath_size, time_steps, 6, D)),
"g4": torch.ones((bath_size, time_steps, 24, D)),
"g5": torch.ones((bath_size, time_steps, 96, D))
}
g_dict = {"g1": 1, "g2": 2, "g3": 6, "g4": 24, "g5": 96}
dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ALSTM_Net(input_size=1, device=dev)
out = model(mg_input)
print(out["pred"].shape)