Skip to content

1st Place Solution for Zhihu Machine Learning Challenge . Implementation of various text-classification models.(知乎看山杯第一名解决方案)

License

Notifications You must be signed in to change notification settings

chenyuntc/PyTorchText

Repository files navigation

中文用户请查看readme-zh.md

This is the solution for Zhihu Machine Learning Challenge 2017. We won the champion out of 963 teams.

1. Setup

  • install PyTorch from pytorch.org (Python 2, CUDA)
  • install other depencies:
    pip2 install -r requirements.txt

You may need tf.contrib.keras.preprocessing.sequence.pad_sequences for data preprocessing.

  • start visdom for visualization:
    python2 -m visdom.server

2. Data Preprocessing

Modify the data path in the related file

2.1 wordvector file -> numpy file

python scripts/data_process/embedding2matrix.py main char_embedding.txt char_embedding.npz 
python scripts/data_process/embedding2matrix.py main word_embedding.txt word_embedding.npz 

2.2 question set -> numpy file

it's memory consuming , make sure you have memory larger than 32G.

python scripts/data_process/question2array.py main question_train_set.txt train.npz
python scripts/data_process/question2array.py main question_eval_set.txt test.npz

2.3 label -> json

python scripts/data_process/label2id.py main question_topic_train_set.txt labels.json

2.4 validation data

python scripts/data_process/get_val.py 

3. Training

modify config.py for model path

Path to the models we used:

  • CNN:models/MultiCNNTextBNDeep.py
  • RNN(LSTM):models/LSTMText.py
  • RCNN: models/RCNN.py
  • inception: models/CNNText_inception.py
  • FastText: models/FastText3.py

3.1 Trian model without data augumentation

# LSTM char
python2 main.py main --max_epoch=5 --plot_every=100 --env='lstm_char' --weight=1 --model='LSTMText'  --batch-size=128  --lr=0.001 --lr2=0 --lr_decay=0.5 --decay_every=10000  --type_='char'   --zhuge=True --linear-hidden-size=2000 --hidden-size=256 --kmax-pooling=3   --num-layers=3  --augument=False

# LSTM word
python2 main.py main --max_epoch=5 --plot_every=100 --env='lstm_word' --weight=1 --model='LSTMText'  --batch-size=128  --lr=0.001 --lr2=0.0000 --lr_decay=0.5 --decay_every=10000  --type_='word'   --zhuge=True --linear-hidden-size=2000 --hidden-size=320 --kmax-pooling=2  --augument=False

#  RCNN char
python2 main.py main --max_epoch=5 --plot_every=100 --env='rcnn_char' --weight=1 --model='RCNN'  --batch-size=128  --lr=0.001 --lr2=0 --lr_decay=0.5 --decay_every=5000  --title-dim=1024 --content-dim=1024  --type_='char' --zhuge=True --kernel-size=3 --kmax-pooling=2 --linear-hidden-size=2000 --debug-file='/tmp/debugrcnn' --hidden-size=256 --num-layers=3 --augument=False

# RCNN word
main.py main --max_epoch=5 --plot_every=100 --env='RCNN-word' --weight=1 --model='RCNN'  --zhuge=True --num-workers=4 --batch-size=128 --model-path=None --lr2=0 --lr=1e-3 --lr-decay=0.8  --decay-every=5000  --title-dim=1024 --content-dim=512  --kernel-size=3 --debug-file='/tmp/debugrc'  --kmax-pooling=1 --type_='word' --augument=False
# CNN word
 python main.py main --max_epoch=5 --plot_every=100 --env='MultiCNNText' --weight=1 --model='MultiCNNTextBNDeep'  --batch-size=64  --lr=0.001 --lr2=0.000 --lr_decay=0.8 --decay_every=10000  --title-dim=250 --content-dim=250    --weight-decay=0 --type_='word' --debug-file='/tmp/debug'  --linear-hidden-size=2000 --zhuge=True  --augument=False

# inception word
python2 main.py main --max_epoch=5 --plot_every=100 --env='inception-word' --weight=1 --model='CNNText_inception'  --zhuge=True --num-workers=4 --batch-size=512 --model-path=None --lr2=0 --lr=1e-3 --lr-decay=0.8  --decay-every=2500 --title-dim=1200 --content-dim=1200 --type_='word' --augument=False                                                   
# inception char
python2 main.py main --max_epoch=5 --plot_every=100 --env='inception-char' --weight=1 --model='CNNText_inception'  --zhuge=True --num-workers=4 --batch-size=512 --model-path=None --lr2=0 --lr=1e-3 --lr-decay=0.8  --decay-every=2500 --title-dim=1200 --content-dim=1200 --type_='char'   --augument=False

# FastText3 word
python2 main.py main --max_epoch=5 --plot_every=100 --env='fasttext3-word' --weight=5 --model='FastText3' --zhuge=True --num-workers=4 --batch-size=512  --lr2=1e-4 --lr=1e-3 --lr-decay=0.8  --decay-every=2500 --linear_hidden_size=2000 --type_='word'  --debug-file=/tmp/debugf --augument=False                           

In most cases, the score could be boosted by finetune. for example:

python2 main.py main --max_epoch=2 --plot_every=100 --env='LSTMText-word-ft' --model='LSTMText'  --zhuge=True --num-workers=4 --batch-size=256 --model-path=None --lr2=5e-5 --lr=5e-5 --decay-every=5000 --type_='word'  --model-path='checkpoints/LSTMText_word_0.409196378421'                       

3.2 train models with data augumentation

Add --augument in the training command.

3.3 scores

model score
CNN_word 0.4103
RNN_word 0.4119
RCNN_word 0.4115
Inceptin_word 0.4109
FastText_word 0.4091
RNN_char 0.4031
RCNN_char 0.4037
Inception_char 0.4024
RCNN_word_aug 0.41344
CNN_word_aug 0.41051
RNN_word_aug 0.41368
Incetpion_word_aug 0.41254
FastText3_word_aug 0.40853
CNN_char_aug 0.38738
RCNN_char_aug 0.39854

with model ensemble, it can get up to 0.433.

4 Test and Submit

4.1 Test

  • model: include LSTMText,RCNN,MultiCNNTextBNDeep,FastText3,CNNText_inception
  • model-path: path to the pretrained model
  • result-path: where to save the model
  • val: test the val set or the test set..
# LSTM
python2 test.1.py main --model='LSTMText'  --batch-size=512  --model-path='checkpoints/LSTMText_word_0.411994005382' --result-path='/data_ssd/zhihu/result/LSTMText0.4119_word_test.pth'  --val=False --zhuge=True

python2 test.1.py main --model='LSTMText'  --batch-size=256 --type_=char --model-path='checkpoints/LSTMText_char_0.403192339135' --result-path='/data_ssd/zhihu/result/LSTMText0.4031_char_test.pth'  --val=False --zhuge=True
 
#RCNN
python2 test.1.py main --model='RCNN'  --batch-size=512  --model-path='checkpoints/RCNN_word_0.411511574999' --result-path='/data_ssd/zhihu/result/RCNN_0.4115_word_test.pth'  --val=False --zhuge=True

python2 test.1.py main --model='RCNN'  --batch-size=512  --model-path='checkpoints/RCNN_char_0.403710422571' --result-path='/data_ssd/zhihu/result/RCNN_0.4037_char_test.pth'  --val=False --zhuge=True

# DeepText

python2 test.1.py main --model='MultiCNNTextBNDeep'  --batch-size=512  --model-path='checkpoints/MultiCNNTextBNDeep_word_0.410330780091' --result-path='/data_ssd/zhihu/result/DeepText0.4103_word_test.pth'  --val=False --zhuge=True
# more to go ...

4.2 ensemble

See notebooks/val_ensemble.ipynb and notebooks/test_ensemble.ipynb for more detail

5 Main files

  • main.py: main(for training)
  • config.py: config file
  • test.1.py: for test
  • data/: for data loader
  • scripts/: for data preprocessing
  • utils/ : including calculate score and wrapper for visualization.
  • models/: models
    • models/BasicModel: Base model for models.
    • models/MultiCNNTextBNDeep: CNN
    • models/LSTMText: RNN
    • models/RCNN: RCNN
    • models/CNNText_inception Inception
    • models/MultiModelALLmodels/MultiModelAll2
    • other model
  • rep.py: code for reproducing.
  • del/: methods fail or not used.
  • notebooks/: notebooks.

Pretrained model

https://pan.baidu.com/s/1mjVtJGs passwd: tayb

About

1st Place Solution for Zhihu Machine Learning Challenge . Implementation of various text-classification models.(知乎看山杯第一名解决方案)

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy