Skip to content

Commit

Permalink
#11 Implement BERTScore-based module
Browse files Browse the repository at this point in the history
  • Loading branch information
Huffon committed May 13, 2021
1 parent 30bf928 commit 4251a38
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 2 deletions.
21 changes: 21 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,11 @@ Triple Score: 0.5
Avg. ROUGE-1: 0.4415584415584415
Avg. ROUGE-2: 0.3287671232876712
Avg. ROUGE-L: 0.4415584415584415

BERTScore Score
Precision: 0.9151781797409058
Recall: 0.9141832590103149
F1: 0.9150083661079407
```

<br>
Expand Down Expand Up @@ -241,6 +246,22 @@ Simple but effective word-level overlap ROUGE score

<br>

### BERTScore Module

```python
>>> from factsumm import FactSumm
>>> factsumm = FactSumm()
>>> factsumm.calculate_bert_score(article, summary)
BERTScore Score
Precision: 0.9151781797409058
Recall: 0.9141832590103149
F1: 0.9150083661079407
```

[BERTScore](https://github.com/Tiiiger/bert_score) can be used to calculate the similarity between each source sentence and the summary sentence

<br>

### Citation

If you apply this library to any project, please cite:
Expand Down
47 changes: 46 additions & 1 deletion factsumm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from sumeval.metrics.rouge import RougeCalculator

from factsumm.utils.level_entity import load_ie, load_ner, load_rel
from factsumm.utils.level_sentence import load_qa, load_qg
from factsumm.utils.level_sentence import load_bert_score, load_qa, load_qg
from factsumm.utils.utils import Config, qags_score

os.environ["TOKENIZERS_PARALLELISM"] = "false"
Expand All @@ -26,6 +26,7 @@ def __init__(
rel_model: str = None,
qg_model: str = None,
qa_model: str = None,
bert_score_model: str = None,
):
self.config = Config()
self.segmenter = pysbd.Segmenter(language="en", clean=False)
Expand All @@ -36,6 +37,7 @@ def __init__(
self.rel = rel_model if rel_model is not None else self.config.REL_MODEL
self.qg = qg_model if qg_model is not None else self.config.QG_MODEL
self.qa = qa_model if qa_model is not None else self.config.QA_MODEL
self.bert_score = bert_score_model if bert_score_model is not None else self.config.BERT_SCORE_MODEL
self.ie = None

def build_perm(
Expand Down Expand Up @@ -321,25 +323,68 @@ def extract_triples(self, source: str, summary: str, verbose: bool = False):

return triple_score

def calculate_bert_score(self, source: str, summary: str):
"""
Calculate BERTScore
See also https://arxiv.org/abs/2005.03754
Args:
source (str): original source
summary (str): generated summary
"""
add_dummy = False

if isinstance(self.bert_score, str):
self.bert_score = load_bert_score(self.bert_score)

source_lines = self._segment(source)
summary_lines = [summary, "dummy"]

scores = self.bert_score(summary_lines, source_lines)
filtered_scores = list()

for score in scores:
score = score.tolist()
score.pop(-1)
filtered_scores.append(sum(score) / len(score))

print(
f"BERTScore Score\nPrecision: {filtered_scores[0]}\nRecall: {filtered_scores[1]}\nF1: {filtered_scores[2]}"
)

return filtered_scores

def __call__(self, source: str, summary: str, verbose: bool = False):
source_ents, summary_ents, fact_score = self.extract_facts(
source,
summary,
verbose,
)

qags_score = self.extract_qas(
source,
summary,
source_ents,
summary_ents,
verbose,
)

triple_score = self.extract_triples(source, summary, verbose)

rouge_1, rouge_2, rouge_l = self.calculate_rouge(source, summary)

bert_scores = self.calculate_bert_score(source, summary)

return {
"fact_score": fact_score,
"qa_score": qags_score,
"triple_score": triple_score,
"rouge": (rouge_1, rouge_2, rouge_l),
"bert_score": {
"precision": bert_scores[0],
"recall": bert_scores[1],
"f1": bert_scores[2],
},
}
17 changes: 16 additions & 1 deletion factsumm/utils/level_sentence.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List

from bert_score import BERTScorer
from rich import print
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline

Expand Down Expand Up @@ -111,4 +112,18 @@ def answer_question(context: str, qa_pairs: List):
return answer_question


# TODO: NLI, FactCC
def load_bert_score(model: str):
"""
Load BERTScore model from HuggingFace hub
Args:
model (str): model name to be loaded
Returns:
function: BERTScore score function
"""
print("Loading BERTScore Pipeline...")

scorer = BERTScorer(model_type=model, lang="en", rescale_with_baseline=True)
return scorer.score
1 change: 1 addition & 0 deletions factsumm/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class Config:
QG_MODEL: str = "mrm8488/t5-base-finetuned-question-generation-ap"
QA_MODEL: str = "deepset/roberta-base-squad2"
SUMM_MODEL: str = "sshleifer/distilbart-cnn-12-6"
BERT_SCORE_MODEL: str = "microsoft/deberta-base-mnli"


def grouped_entities(entities: List[Dict]):
Expand Down

0 comments on commit 4251a38

Please sign in to comment.