Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Trainer] PaddleNLP trainer and finetune ernie-1.0 pretrain. #1761

Merged
merged 19 commits into from
Mar 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions examples/language_model/ernie-1.0/finetune/config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Default Args for all dataset
# You can overwrite the configs in each dataset.
DefaultArgs:
learning_rate: 0.00005
num_train_epochs: 3
batch_size: 64
max_seq_length: 128
weight_decay: 0.01
logging_steps: 10
eval_steps: 200
minimum_eval_times: 20
max_steps: -1
warmup_steps: 0
metric: "Accuracy"
split: "train dev"

# Datasets which used for sequence classfication
SequenceClassification:
clue afqmc:
num_train_epochs: 4
clue tnews:
num_train_epochs: 4
clue iflytek:
num_train_epochs: 8
clue ocnli:
num_train_epochs: 8
clue cmnli:
learning_rate: 1e-4, 5e-5, 1e-5
num_train_epochs: 3
clue wsc:
num_train_epochs: 50
clue csl:
num_train_epochs: 10
max_seq_length: 256
batch_size: 32
xnli_cn:
learning_rate: 0.0001
num_train_epochs: 3
batch_size: 256
chnsenticorp_v2:
learning_rate: 0.00005
batch_size: 16
num_train_epochs: 8

# Datasets which used for token classfication
TokenClassification:
peoples_daily_ner:
learning_rate: 0.00005
num_train_epochs: 8
batch_size: 16
msra_ner:
num_train_epochs: 3

# Datasets which used for question answersing
QuestionAnswering:
cmrc2018:
learning_rate: 0.00005
num_train_epochs: 5
batch_size: 32
max_seq_length: 512
dureader_nlp:
num_train_epochs: 1
batch_size: 12
max_seq_length: 384
dureader_robust:
num_train_epochs: 1
batch_size: 12
max_seq_length: 384
dlbp:
num_train_epochs: 1
batch_size: 12
max_seq_length: 384
271 changes: 271 additions & 0 deletions examples/language_model/ernie-1.0/finetune/question_answering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@
# Copyright 2020-present the HuggingFace Inc. team.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import time
import json
import os
import sys
from functools import partial

import numpy as np
import paddle
import paddlenlp as ppnlp
from paddlenlp.data import Pad, Stack, Tuple
from paddlenlp.utils.log import logger
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我看library增加一下新的logging模块,是否需要新增logging模块,是否在logger模块上进行优化升级了?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已删除 新的logging模块, trainer 这边之前主要有些日志分级控制、重定向文件输出等能力,后续可以升级

from paddlenlp.trainer import Trainer
from paddlenlp.trainer.trainer_utils import PredictionOutput

sys.path.insert(0, os.path.abspath("."))
from utils import Dict


class QuestionAnsweringTrainer(Trainer):
def __init__(self,
*args,
eval_examples=None,
post_process_function=None,
**kwargs):
super().__init__(*args, **kwargs)
self.eval_examples = eval_examples
self.post_process_function = post_process_function

def evaluate(self,
eval_dataset=None,
eval_examples=None,
ignore_keys=None,
metric_key_prefix: str="eval"):
eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset
eval_dataloader = self.get_eval_dataloader(eval_dataset)
eval_examples = self.eval_examples if eval_examples is None else eval_examples

# Temporarily disable metric computation, we will do it in the loop here.
compute_metrics = self.compute_metrics
self.compute_metrics = None
eval_loop = self.evaluation_loop
try:
output = eval_loop(
eval_dataloader,
description="Evaluation",
# No point gathering the predictions if there are no metrics, otherwise we defer to
# self.args.prediction_loss_only
prediction_loss_only=True if compute_metrics is None else None,
ignore_keys=ignore_keys, )
finally:
self.compute_metrics = compute_metrics

if self.post_process_function is not None and self.compute_metrics is not None:
eval_preds = self.post_process_function(eval_examples, eval_dataset,
output.predictions)
metrics = self.compute_metrics(eval_preds)

# Prefix all keys with metric_key_prefix + '_'
for key in list(metrics.keys()):
if not key.startswith(f"{metric_key_prefix}_"):
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)

self.log(metrics)
else:
metrics = {}

self.control = self.callback_handler.on_evaluate(self.args, self.state,
self.control, metrics)
return metrics

def predict(self,
predict_dataset,
predict_examples,
ignore_keys=None,
metric_key_prefix: str="test"):
predict_dataloader = self.get_test_dataloader(predict_dataset)

# Temporarily disable metric computation, we will do it in the loop here.
compute_metrics = self.compute_metrics
self.compute_metrics = None
eval_loop = self.evaluation_loop
try:
output = eval_loop(
predict_dataloader,
description="Prediction",
# No point gathering the predictions if there are no metrics, otherwise we defer to
# self.args.prediction_loss_only
prediction_loss_only=True if compute_metrics is None else None,
ignore_keys=ignore_keys, )
finally:
self.compute_metrics = compute_metrics

if self.post_process_function is None or self.compute_metrics is None:
return output

predictions = self.post_process_function(
predict_examples, predict_dataset, output.predictions, "predict")
metrics = self.compute_metrics(predictions)

# Prefix all keys with metric_key_prefix + '_'
for key in list(metrics.keys()):
if not key.startswith(f"{metric_key_prefix}_"):
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)

return PredictionOutput(
predictions=predictions.predictions,
label_ids=predictions.label_ids,
metrics=metrics)


def qa_collator(tokenizer, args):
train_batchify_fn = lambda samples, fn=Dict({
"input_ids": Pad(axis=0, pad_val=tokenizer.pad_token_id),
"token_type_ids": Pad(axis=0, pad_val=tokenizer.pad_token_type_id),
"start_positions": Stack(dtype="int64"),
"end_positions": Stack(dtype="int64")
}): fn(samples)

return train_batchify_fn


class CrossEntropyLossForSQuAD(paddle.nn.Layer):
def __init__(self):
super(CrossEntropyLossForSQuAD, self).__init__()

def forward(self, y, label):
start_logits, end_logits = y
start_position, end_position = label
start_position = paddle.unsqueeze(start_position, axis=-1)
end_position = paddle.unsqueeze(end_position, axis=-1)
start_loss = paddle.nn.functional.cross_entropy(
input=start_logits, label=start_position)
end_loss = paddle.nn.functional.cross_entropy(
input=end_logits, label=end_position)
loss = (start_loss + end_loss) / 2
return loss


def prepare_train_features(examples, tokenizer, args):
# Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
# in one example possible giving several features when a context is long, each of those features having a
# context that overlaps a bit the context of the previous feature.
# NOTE: Almost the same functionality as HuggingFace's prepare_train_features function. The main difference is
# that HugggingFace uses ArrowTable as basic data structure, while we use list of dictionary instead.
contexts = examples['context']
questions = examples['question']

tokenized_examples = tokenizer(
questions,
contexts,
stride=args.doc_stride,
max_seq_len=args.max_seq_length)

# Since one example might give us several features if it has a long context, we need a map from a feature to
# its corresponding example. This key gives us just that.
sample_mapping = tokenized_examples.pop("overflow_to_sample")
# The offset mappings will give us a map from token to character position in the original context. This will
# help us compute the start_positions and end_positions.
offset_mapping = tokenized_examples.pop("offset_mapping")

# Let's label those examples!
tokenized_examples["start_positions"] = []
tokenized_examples["end_positions"] = []

for i, offsets in enumerate(offset_mapping):
# We will label impossible answers with the index of the CLS token.
input_ids = tokenized_examples["input_ids"][i]
cls_index = input_ids.index(tokenizer.cls_token_id)

# Grab the sequence corresponding to that example (to know what is the context and what is the question).
sequence_ids = tokenized_examples['token_type_ids'][i]

# One example can give several spans, this is the index of the example containing this span of text.
sample_index = sample_mapping[i]
answers = examples['answers'][sample_index]
# If no answers are given, set the cls_index as answer.
if len(answers["answer_start"]) == 0:
tokenized_examples["start_positions"].append(cls_index)
tokenized_examples["end_positions"].append(cls_index)
else:
# Start/end character index of the answer in the text.
start_char = answers["answer_start"][0]
end_char = start_char + len(answers["text"][0])

# Start token index of the current span in the text.
token_start_index = 0
while sequence_ids[token_start_index] != 1:
token_start_index += 1

# End token index of the current span in the text.
token_end_index = len(input_ids) - 1
while sequence_ids[token_end_index] != 1:
token_end_index -= 1
token_end_index -= 1

# Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
if not (offsets[token_start_index][0] <= start_char and
offsets[token_end_index][1] >= end_char):
tokenized_examples["start_positions"].append(cls_index)
tokenized_examples["end_positions"].append(cls_index)
else:
# Otherwise move the token_start_index and token_end_index to the two ends of the answer.
# Note: we could go after the last offset if the answer is the last word (edge case).
while token_start_index < len(offsets) and offsets[
token_start_index][0] <= start_char:
token_start_index += 1
tokenized_examples["start_positions"].append(token_start_index -
1)
while offsets[token_end_index][1] >= end_char:
token_end_index -= 1
tokenized_examples["end_positions"].append(token_end_index + 1)

return tokenized_examples


def prepare_validation_features(examples, tokenizer, args):
# Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
# in one example possible giving several features when a context is long, each of those features having a
# context that overlaps a bit the context of the previous feature.
#NOTE: Almost the same functionality as HuggingFace's prepare_train_features function. The main difference is
# that HugggingFace uses ArrowTable as basic data structure, while we use list of dictionary instead.
contexts = examples['context']
questions = examples['question']

tokenized_examples = tokenizer(
questions,
contexts,
stride=args.doc_stride,
max_seq_len=args.max_seq_length,
return_attention_mask=True)

# Since one example might give us several features if it has a long context, we need a map from a feature to
# its corresponding example. This key gives us just that.
sample_mapping = tokenized_examples.pop("overflow_to_sample")

# For evaluation, we will need to convert our predictions to substrings of the context, so we keep the
# corresponding example_id and we will store the offset mappings.
tokenized_examples["example_id"] = []

for i in range(len(tokenized_examples["input_ids"])):
# Grab the sequence corresponding to that example (to know what is the context and what is the question).
sequence_ids = tokenized_examples['token_type_ids'][i]
context_index = 1

# One example can give several spans, this is the index of the example containing this span of text.
sample_index = sample_mapping[i]
tokenized_examples["example_id"].append(examples["id"][sample_index])

# Set to None the offset_mapping that are not part of the context so it's easy to determine if a token
# position is part of the context or not.
tokenized_examples["offset_mapping"][i] = [
(o if sequence_ids[k] == context_index else None)
for k, o in enumerate(tokenized_examples["offset_mapping"][i])
]

return tokenized_examples
Loading