The pytorch implementation version of the AttnTUL
We conducted extensive experiments on three different types of real trajectory data sets: Gowalla check-in dataset, Shenzhen private car dataset and Geolife personal travel dataset. The processed data to evaluate our model can be found in the data folder, which contains three different data sets and ready for directly used. Due to the limitation of the uploaded file size of GitHub, we store it on the cloud drive(extracted code: r3pq). You can download it directly and replace the contents of the data folder.
- python 3.7
- pytorch 1.7.0
- other:
pip install -r requirements.txt
- /code
datasets.py
: This is used to complete the data loading in pytorch.layers.py
: It includes the specific implementation of some layers in the model.main.py
: This is the entrance of the program, which is used to train model.models.py
: Including the whole part of the modelrawprocess.py
: This file contains some data preprocessing contents, such as the construction of local and global graphsutils.py
: Here are some common methods, including calculating metrics and drawing pictures.
/data
: The original data or some preprocessed data required for the experiment are stored here- /shenzhen
- /gowalla
- /geolife
/temp
: Here is the folder used to store checkpoints./log
: Here is the folder used to store training logs and metric pictures.
You can train and evaluate the model with the following sample command lines:
shenzhe-mini:
cd code
python main.py --dataset shenzhen-mini --read_pkl True --grid_size 120 --d_model 128 --n_heads 5 --n_layers 3
shenzhe-all:
cd code
python main.py --dataset shenzhen-all --read_pkl True --grid_size 120 --d_model 128 --n_heads 5 --n_layers 2
gowalla-mini:
cd code
python main.py --dataset gowalla-mini --read_pkl False --grid_size 40 --d_model 128 --n_heads 5 --n_layers 3
gowalla-all:
cd code
python main.py --dataset gowalla-all --read_pkl False --grid_size 40 --d_model 128 --n_heads 5 --n_layers 2
Note that we have added some code so that you can see the log of the training process and results in the log file. We repeat 10 experiments and take the average value, different random seeds are used in each experiment. Therefore, the average results may fluctuate slightly.
Here are some common optional parameter settings:
--dataset shenzhen-mini/shenzhen-all/gowalla-mini/gowalla-all/geolife-mini/geolife-all
--read_pkl True/False
--times 1/5/10
--epochs 80
--train_batch 16
--d_model 32/64/128/256/512
--head 2/3/4/5/6
--grid_size 40/80/120/160/200
In order to save the time of follow-up researchers, we store the processed data in the pkl file. You can use it directly by setting parameter read_pkl to True, or set it to False, and process the original data first (e.g. gowalla).
The source code of some important baselines compared in this paper are as follows:
Any comments and feedback are appreciated. :)