Skip to content

Commit

Permalink
docs: ✏️ update configs
Browse files Browse the repository at this point in the history
  • Loading branch information
zezhishao committed Sep 12, 2024
1 parent a6d4276 commit c9d53cf
Show file tree
Hide file tree
Showing 11 changed files with 22 additions and 21 deletions.
3 changes: 1 addition & 2 deletions baselines/DeepAR/ExchangeRate.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,14 @@
NULL_VAL = regular_settings['NULL_VAL'] # Null value in the data
# Model architecture and parameters
MODEL_ARCH = DeepAR
NUM_NODES = 8
MODEL_PARAM = {
'cov_feat_size' : 2,
'embedding_size' : 32,
'hidden_size' : 64,
'num_layers': 3,
'use_ts_id' : True,
'id_feat_size': 32,
'num_nodes': 7
'num_nodes': 8
}
NUM_EPOCHS = 100

Expand Down
2 changes: 1 addition & 1 deletion baselines/STEP/README.md
Original file line number Diff line number Diff line change
@@ -1 +1 @@
STEP requires `timm` package. You can install it by `pip install timm`.
STEP requires `timm` package. You can install it by `pip install timm==0.6.7`.
2 changes: 1 addition & 1 deletion baselines/STEP/STEP_METR-LA.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
MODEL_ARCH = STEP
MODEL_PARAM = {
"dataset_name": DATA_NAME,
"pre_trained_tsformer_path": "checkpoints/TSFormer/METR-LA_100_2016_12/de9f10ca8535dbe99fb71072aab848ce/TSFormer_best_val_MAE.pt",
"pre_trained_tsformer_path": "checkpoints/TSFormer/METR-LA_100_2016_12/cd176b70ebb4620da5a289ad76355c75/TSFormer_best_val_MAE.pt",
"short_term_len": INPUT_LEN_SHORT,
"long_term_len": INPUT_LEN,
"tsformer_args": {
Expand Down
1 change: 0 additions & 1 deletion baselines/STEP/TSFormer_METR-LA.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
import sys
import torch
from easydict import EasyDict
sys.path.append(os.path.abspath(__file__ + '/../../..'))

Expand Down
9 changes: 6 additions & 3 deletions baselines/STEP/arch/tsformer/transformer_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,17 @@ class TransformerLayers(nn.Module):
def __init__(self, hidden_dim, nlayers, mlp_ratio, num_heads=4, dropout=0.1):
super().__init__()
self.d_model = hidden_dim
encoder_layers = TransformerEncoderLayer(hidden_dim, num_heads, hidden_dim*mlp_ratio, dropout)
# encoder_layers = TransformerEncoderLayer(hidden_dim, num_heads, hidden_dim*mlp_ratio, dropout)
encoder_layers = TransformerEncoderLayer(hidden_dim, num_heads, hidden_dim*mlp_ratio, dropout, batch_first=True)
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)

def forward(self, src):
B, N, L, D = src.shape
src = src * math.sqrt(self.d_model)
src = src.view(B*N, L, D)
src = src.transpose(0, 1)
# src = src.transpose(0, 1)
# output = self.transformer_encoder(src, mask=None)
# output = output.transpose(0, 1).view(B, N, L, D)
output = self.transformer_encoder(src, mask=None)
output = output.transpose(0, 1).view(B, N, L, D)
output = output.view(B, N, L, D)
return output
3 changes: 2 additions & 1 deletion baselines/STEP/arch/tsformer/tsformer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import torch
from torch import nn
from timm.models.vision_transformer import trunc_normal_

from .patch import PatchEmbedding
from .mask import MaskGenerator
Expand Down Expand Up @@ -66,6 +65,8 @@ def initialize_weights(self):
# positional encoding
nn.init.uniform_(self.positional_encoding.position_embedding, -.02, .02)
# mask token
# import here to fix bugs related to set visible device
from timm.models.vision_transformer import trunc_normal_
trunc_normal_(self.mask_token, std=.02)

def encoding(self, long_term_history, mask=True):
Expand Down
4 changes: 1 addition & 3 deletions baselines/STID/SD.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,8 @@
}
# Train data loader settings
CFG.TRAIN.DATA = EasyDict()
CFG.TRAIN.DATA.BATCH_SIZE = 64
CFG.TRAIN.DATA.BATCH_SIZE = 32
CFG.TRAIN.DATA.SHUFFLE = True
# Early stopping
CFG.TRAIN.EARLY_STOPPING_PATIENCE = 10 # Early stopping patience. Default: None. If not specified, the early stopping will not be used.

############################## Validation Configuration ##############################
CFG.VAL = EasyDict()
Expand Down
6 changes: 3 additions & 3 deletions baselines/STWave/PEMS07.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,20 +159,20 @@ def loadGraph(adj_mx, hs, ls):
}
# Train data loader settings
CFG.TRAIN.DATA = EasyDict()
CFG.TRAIN.DATA.BATCH_SIZE = 64
CFG.TRAIN.DATA.BATCH_SIZE = 16
CFG.TRAIN.DATA.SHUFFLE = True

############################## Validation Configuration ##############################
CFG.VAL = EasyDict()
CFG.VAL.INTERVAL = 1
CFG.VAL.DATA = EasyDict()
CFG.VAL.DATA.BATCH_SIZE = 64
CFG.VAL.DATA.BATCH_SIZE = 16

############################## Test Configuration ##############################
CFG.TEST = EasyDict()
CFG.TEST.INTERVAL = 1
CFG.TEST.DATA = EasyDict()
CFG.TEST.DATA.BATCH_SIZE = 64
CFG.TEST.DATA.BATCH_SIZE = 16

############################## Evaluation Configuration ##############################

Expand Down
6 changes: 3 additions & 3 deletions baselines/Triformer/Electricity.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,20 +116,20 @@
}
# Train data loader settings
CFG.TRAIN.DATA = EasyDict()
CFG.TRAIN.DATA.BATCH_SIZE = 64
CFG.TRAIN.DATA.BATCH_SIZE = 16
CFG.TRAIN.DATA.SHUFFLE = True

############################## Validation Configuration ##############################
CFG.VAL = EasyDict()
CFG.VAL.INTERVAL = 1
CFG.VAL.DATA = EasyDict()
CFG.VAL.DATA.BATCH_SIZE = 64
CFG.VAL.DATA.BATCH_SIZE = 16

############################## Test Configuration ##############################
CFG.TEST = EasyDict()
CFG.TEST.INTERVAL = 1
CFG.TEST.DATA = EasyDict()
CFG.TEST.DATA.BATCH_SIZE = 64
CFG.TEST.DATA.BATCH_SIZE = 16

############################## Evaluation Configuration ##############################

Expand Down
2 changes: 1 addition & 1 deletion baselines/Triformer/Weather.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
NULL_VAL = regular_settings['NULL_VAL'] # Null value in the data
# Model architecture and parameters
MODEL_ARCH = Triformer
NUM_NODES = 21
NUM_NODES = 8
MODEL_PARAM = {
"num_nodes": NUM_NODES,
"lag": INPUT_LEN,
Expand Down
5 changes: 3 additions & 2 deletions experiments/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@

def parse_args():
parser = ArgumentParser(description="Run time series forecasting model in BasicTS framework!")
# parser.add_argument("-c", "--cfg", default="baselines/STID/METR-LA.py", help="training config")
parser.add_argument("-c", "--cfg", default="baselines/STEP/STEP_METR-LA2.py", help="training config")
parser.add_argument("-c", "--cfg", default="baselines/STID/METR-LA.py", help="training config")
# parser.add_argument("-c", "--cfg", default="baselines/STEP/TSFormer_METR-LA.py", help="training config")
# parser.add_argument("-c", "--cfg", default="baselines/STEP/STEP_METR-LA.py", help="training config")
# parser.add_argument("-c", "--cfg", default="baselines/DGCRN/PEMS-BAY.py", help="training config")
# parser.add_argument("-c", "--cfg", default="baselines/DGCRN/example.py", help="training config")
# parser.add_argument("-c", "--cfg", default="examples/complete_config.py", help="training config")
Expand Down

0 comments on commit c9d53cf

Please sign in to comment.