Skip to content

Chinese chatbot for neural machine translation in PyTorch.Including basic seq2seq、seq2seq with attention、pointer generator、seq2seq with cnn and so on.

Notifications You must be signed in to change notification settings

jiangnanboy/chatbot_chinese

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

32 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Chinese chatbot with seq2seq

Chinese chatbot for neural machine translation in PyTorch.

  • 利用seq2seq系列的神经网络模型构建中文chatbot。数据来自于小黄鸡.
  • 每行数据被处理成字形式,这里没有分词。数据集、字典等的生成使用torchtext处理。
  • 利用apex进行混合精度训练。

Model

1.seq2seq

2.seq2seq_attention

3.seq2seq_attention with pointer generator

4.seq2seq with Convolutional Neural Network

5.seq2seq with transformer

  • Attention Is All You Need
  • Encoder:multi-head-self-attention -> feedforward
  • Encoder:multi-head-self-attention -> multi-head-encoder-attention -> feedforward
  • image

Use

  • parameters setting
resource/config.cfg
  • train data
data/chat_source.src
data/chat_source.trg
  • model save path
model/
  • vocabulary dictionary
vocab/vocab.pk
python seq2seq.py -type train
python seq2seq.py -type predict
python seq2seq_attention.py -type train
python seq2seq_attention.py -type predict
python seq2seq_pointernet.py -type train
python seq2seq_pointernet.py -type predict
python seq2seq_convolution.py -type train
python seq2seq_convolution.py -type predict
python seq2seq_transformer.py -type train
python seq2seq_transformer.py -type predict

Note

使用Apex导致的问题:

Loss整体变大,而且很不稳定。效果变差。会遇到梯度溢出。
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 32768.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 8192.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 4096.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 2048.0
...
ZeroDivisionError: float division by zero

解决办法如下来防止出现梯度溢出:

1、apex中amp.initialize(model, optimizer, opt_level='O0')的opt_level由O2换成O1,再不行换成O0(欧零)
2、把batchsize从32调整为16会显著解决这个问题,另外在换成O0(欧0)的时候会出现内存不足的情况,减小batchsize也是有帮助的
3、减少学习率
4、增加Relu会有效保存梯度,防止梯度消失

Requirements

  • GPU & CUDA
  • Python3.6.5
  • PyTorch1.5
  • torchtext0.6
  • apex0.1

References

Based on the following implementations

About

Chinese chatbot for neural machine translation in PyTorch.Including basic seq2seq、seq2seq with attention、pointer generator、seq2seq with cnn and so on.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published