- Notifications
You must be signed in to change notification settings - Fork0
iiot-tbb/pk-chat-dialogue
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
PK-Chat: Pointer Network Guided Knowledge Driven Generative Dialogue Modelpaper link
- python >= 3.6- paddlepaddle == 1.6.1- numpy- nltk- tqdm- visualdl >= 1.3.0 (optional)- regex
Recommend you install to python packages by command:pip install -r requirement.txt
You can see the PK-chat Dialog model from there:
- PK-chat model, uncased [model](path:https://pan.baidu.com/s/1bNzGnnRPMfT4jkD_UsNSYg?pwd=pa7h): 12-layers, 768-hidden, 12-heads, 132M parameters
mv /path/to/model.tar.gz.tar xzf model.tar.gz
We also provide instructions to fine-tune PK-chat model on different conversation datasets (chit-chat, knowledge grounded dialogues and conversational question answering).
Download data from thelink.The tar file contains three processed datasets:DailyDialog
,PersonaChat
andDSTC7_AVSD
.
mv /path/to/data.tar.gz.tar xzf data.tar.gz
Our model supports two kinds of data formats for dialogue context:multi
andmulti_knowledge
.
multi
: multi-turn dialogue context.
u_1 __eou__ u_2 __eou__ ... u_n \t r
multi_knowledge
: multi-turn dialogue context with background knowledges.
k_1 __eou__ k_2 __eou__ ... k_m \t u_1 __eou__ u_2 __eou__ ... u_n \t r
If you want to use this model on other datasets, you can process your data accordingly.
Fine-tuning the pre-trained model on different${DATASET}
.
# DailyDialog / PersonaChat / DSTC7_AVSD / ACE_Dialog_topicDATASET=ACE_Dialog_topicsh scripts/${DATASET}/train.sh
After training, you can find the output folderoutputs/${DATASET}
(by default). It contatinsbest.model
(best results on validation dataset),hparams.json
(hyper-parameters of training script) andtrainer.log
(training log).
Fine-tuning the pre-trained model on multiple GPUs.
Note: You need to install NCCL library and set up the environment variableLD_LIBRARY
properly.
sh scripts/ACE_Dialog_topic/multi_gpu_train.sh
For the fine-tuning of our pre-trained model, it usually requires about 10 epochs to reach convergence with learning rate = 1e-5 and about 2-3 epochs to reach convergence with learning rate = 5e-5.
GPU Memory | batch size | max len |
---|---|---|
16G | 6 | 256 |
32G | 12 | 256 |
Running inference on test dataset.
# DailyDialog / PersonaChat / DSTC7_AVSD / ACE_Dialog_topicDATASET=ACE_Dialog_topicsh scripts/${DATASET}/infer.sh
After inference, you can find the output foleroutputs/${DATASET}.infer
(by default). It containsinfer_0.result.json
(the inference result),hparams.json
(hyper-parameters of inference scipt) andtrainer.log
(inference log).
If you want to use top-k sampling (beam search by default), you can follow the example script:
sh scripts/DailyDialog/topk_infer.sh
If you find PK-Chat useful in your work, please cite the following paper:
@misc{deng2023pkchat, title={PK-Chat: Pointer Network Guided Knowledge Driven Generative Dialogue Model}, author={Cheng Deng and Bo Tong and Luoyi Fu and Jiaxin Ding and Dexing Cao and Xinbing Wang and Chenghu Zhou}, year={2023}, eprint={2304.00592}, archivePrefix={arXiv}, primaryClass={cs.CL}}
For help or issues using PK-chat, please submit a GitHub issue.
For personal communication related to PK-chat, please contact Bo Tong (bool_tbb@alumni.sjtu.edu.cn
), or Cheng Deng (davendw@sjtu.edu.cn
).