Skip to content

Commit

Permalink
Add zh_dureader_model_v2 && add train function (#31)
Browse files Browse the repository at this point in the history
* 1. add zh_dureader_v2; 2. add train function

* update readme and example

* add train config

* Update README.md

* fix readme typo

* Fixed wrong config path, added train data format note

* fix zh_dureader models http address

* faiss example fit predict interator

* add dureader.para

Co-authored-by: sfwydyc <dingyuchen@baidu.com>
Co-authored-by: Xing Yiran <procedure2012@hotmail.com>
Co-authored-by: MrTuo <595983351@qq.com>
  • Loading branch information
4 people committed May 6, 2022
1 parent 5cb5240 commit 1746b93
Show file tree
Hide file tree
Showing 34 changed files with 38,397 additions and 434 deletions.
51 changes: 51 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ In recent years, the dense retrievers based on pre-trained language models have
* ***Easy-to-use***: By integrating this toolkit with [JINA](https://jina.ai/), 🚀RocketQA can help developers build an end-to-end retrieval system and question answering system with several lines of code. <img src="https://github.com/PaddlePaddle/RocketQA/blob/main/RocketQA_flow.png" alt="" align=center />

## News
* April 29, 2022: Training function is added to RocketQA toolkit.
* April 29, 2022: The baseline models of **DuReader<sub>retrieval</sub>** (both cross encoder and dual encoder) are available in RocketQA models now. You can use them directly by `rocketqa.load_model`.
* March 30, 2022: The baseline of **DuReader<sub>retrieval</sub>** [leaderboard](https://aistudio.baidu.com/aistudio/competition/detail/157/0/introduction) was released. [[code/model]](https://github.com/PaddlePaddle/RocketQA/tree/main/research/DuReader-Retrieval-Baseline)
* March 30, 2022: We released **DuReader<sub>retrieval</sub>**, a large-scale Chinese benchmark for passage retrieval. The dataset contains over 90K questions and 8M passages from Baidu Search. [[paper]](https://arxiv.org/abs/2203.10232) [[data]](https://github.com/baidu/DuReader/tree/master/DuReader-Retrieval)
* December 3, 2021: The toolkit of dense retriever RocketQA was released, including the first chinese dense retrieval model trained on DuReader.
Expand Down Expand Up @@ -117,13 +119,21 @@ Given a list of paragraphs and their corresponding titles (optional), returns th

Given a list of queries and paragraphs (and titles), returns their matching scores (dot product between two representation vectors).

#### [`model.train(train_set: str, epoch: int, save_model_path: str, args)`](https://github.com/sfwydyc/my-rocketqa/blob/594877ca505053cb67c6b9b689dbbf237f074ac4/rocketqa/encoder/dual_encoder.py#L247)

Given the hyperparameters `train_set`, `epoch` and `save_model_path`, you can train your own dual encoder model or finetune our models. Other settings like `save_steps` and `learning_rate` can also be set in `args`. Please see examples/example.py for detail.

### Cross encoder
Cross-encoder returned by "load_model()" supports the following function:

#### [`model.matching(query: List[str], para: List[str], title: List[str])`](https://github.com/PaddlePaddle/RocketQA/blob/3a99cf2720486df8cc54acc0e9ce4cbcee993413/rocketqa/predict/cross_encoder.py#L129)

Given a list of queries and paragraphs (and titles), returns their matching scores (probability that the paragraph is the query's right answer).

#### [`model.train(train_set: str, epoch: int, save_model_path: str, args)`](https://github.com/sfwydyc/my-rocketqa/blob/594877ca505053cb67c6b9b689dbbf237f074ac4/rocketqa/encoder/dual_encoder.py#L247)

Given the hyperparameters `train_set`, `epoch` and `save_model_path`, you can train your own cross encoder model or finetune our models. Other settings like `save_steps` and `learning_rate` can also be set in `args`. Please see examples/example.py for detail.


### Examples

Expand All @@ -149,7 +159,48 @@ p_embs = dual_encoder.encode_para(para=para_list)
dot_products = dual_encoder.matching(query=query_list, para=para_list)
```

#### Train Your Own Model
To train your own models, you can use 'train()' function with your dataset and parameters.The training data format refers to ./examples/data/cross.train.tsv, which contains 4 columns: query, title, para, label (0 or 1), separated by "\t"

```python
import rocketqa

# init cross encoder, and set device and batch_size
cross_encoder = rocketqa.load_model(model="zh_dureader_ce_v2", use_cuda=True, device_id=0, batch_size=32)

# finetune cross encoder based on "zh_dureader_ce_v2"
cross_encoder.train('./examples/data/cross.train.tsv', 2, 'ce_models', save_steps=1000, learning_rate=1e-5, log_folder='log_ce')

```

#### Run Your Own Model
To run your own models, you should set parameter `model` in 'load_model()' with a JSON config file.

```python
import rocketqa

# init cross encoder
cross_encoder = rocketqa.load_model(model="./examples/ce_models/config.json", use_cuda=True, device_id=0, batch_size=16)

# compute relevance of query and para
relevance = cross_encoder.matching(query=query_list, para=para_list)
```

config is a JSON file like this
```
{
"model_type": "cross_encoder",
"max_seq_len": 384,
"model_conf_path": "zh_config.json",
"model_vocab_path": "zh_vocab.txt",
"model_checkpoint_path": ${YOUR_MODEL},
"for_cn": true,
"share_parameter": 0
}
```
Folder `examples` provides more details.


## Citations

If you find RocketQA v1 models helpful, feel free to cite our publication [RocketQA: An Optimized Training Approach to Dense Passage Retrieval for Open-Domain Question Answering](https://arxiv.org/pdf/2010.08191.pdf)
Expand Down
9 changes: 9 additions & 0 deletions examples/ce_models/config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"model_type": "cross_encoder",
"max_seq_len": 384,
"model_conf_path": "zh_config.json",
"model_vocab_path": "zh_vocab.txt",
"model_checkpoint_path": "step_36000",
"for_cn": true,
"share_parameter": 0
}
12 changes: 12 additions & 0 deletions examples/ce_models/zh_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"attention_probs_dropout_prob": 0.1,
"hidden_act": "relu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"max_position_embeddings": 513,
"num_attention_heads": 12,
"num_hidden_layers": 12,
"type_vocab_size": 2,
"vocab_size": 18000
}
Loading

0 comments on commit 1746b93

Please sign in to comment.