diff --git a/examples/language_model/ernie-1.0/finetune/config.yml b/examples/language_model/ernie-1.0/finetune/config.yml new file mode 100644 index 000000000000..fd7555ee78dc --- /dev/null +++ b/examples/language_model/ernie-1.0/finetune/config.yml @@ -0,0 +1,72 @@ +# Default Args for all dataset +# You can overwrite the configs in each dataset. +DefaultArgs: + learning_rate: 0.00005 + num_train_epochs: 3 + batch_size: 64 + max_seq_length: 128 + weight_decay: 0.01 + logging_steps: 10 + eval_steps: 200 + minimum_eval_times: 20 + max_steps: -1 + warmup_steps: 0 + metric: "Accuracy" + split: "train dev" + +# Datasets which used for sequence classfication +SequenceClassification: + clue afqmc: + num_train_epochs: 4 + clue tnews: + num_train_epochs: 4 + clue iflytek: + num_train_epochs: 8 + clue ocnli: + num_train_epochs: 8 + clue cmnli: + learning_rate: 1e-4, 5e-5, 1e-5 + num_train_epochs: 3 + clue wsc: + num_train_epochs: 50 + clue csl: + num_train_epochs: 10 + max_seq_length: 256 + batch_size: 32 + xnli_cn: + learning_rate: 0.0001 + num_train_epochs: 3 + batch_size: 256 + chnsenticorp_v2: + learning_rate: 0.00005 + batch_size: 16 + num_train_epochs: 8 + +# Datasets which used for token classfication +TokenClassification: + peoples_daily_ner: + learning_rate: 0.00005 + num_train_epochs: 8 + batch_size: 16 + msra_ner: + num_train_epochs: 3 + +# Datasets which used for question answersing +QuestionAnswering: + cmrc2018: + learning_rate: 0.00005 + num_train_epochs: 5 + batch_size: 32 + max_seq_length: 512 + dureader_nlp: + num_train_epochs: 1 + batch_size: 12 + max_seq_length: 384 + dureader_robust: + num_train_epochs: 1 + batch_size: 12 + max_seq_length: 384 + dlbp: + num_train_epochs: 1 + batch_size: 12 + max_seq_length: 384 \ No newline at end of file diff --git a/examples/language_model/ernie-1.0/finetune/question_answering.py b/examples/language_model/ernie-1.0/finetune/question_answering.py new file mode 100644 index 000000000000..12a2822eb2c2 --- /dev/null +++ b/examples/language_model/ernie-1.0/finetune/question_answering.py @@ -0,0 +1,271 @@ +# Copyright 2020-present the HuggingFace Inc. team. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import json +import os +import sys +from functools import partial + +import numpy as np +import paddle +import paddlenlp as ppnlp +from paddlenlp.data import Pad, Stack, Tuple +from paddlenlp.utils.log import logger +from paddlenlp.trainer import Trainer +from paddlenlp.trainer.trainer_utils import PredictionOutput + +sys.path.insert(0, os.path.abspath(".")) +from utils import Dict + + +class QuestionAnsweringTrainer(Trainer): + def __init__(self, + *args, + eval_examples=None, + post_process_function=None, + **kwargs): + super().__init__(*args, **kwargs) + self.eval_examples = eval_examples + self.post_process_function = post_process_function + + def evaluate(self, + eval_dataset=None, + eval_examples=None, + ignore_keys=None, + metric_key_prefix: str="eval"): + eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset + eval_dataloader = self.get_eval_dataloader(eval_dataset) + eval_examples = self.eval_examples if eval_examples is None else eval_examples + + # Temporarily disable metric computation, we will do it in the loop here. + compute_metrics = self.compute_metrics + self.compute_metrics = None + eval_loop = self.evaluation_loop + try: + output = eval_loop( + eval_dataloader, + description="Evaluation", + # No point gathering the predictions if there are no metrics, otherwise we defer to + # self.args.prediction_loss_only + prediction_loss_only=True if compute_metrics is None else None, + ignore_keys=ignore_keys, ) + finally: + self.compute_metrics = compute_metrics + + if self.post_process_function is not None and self.compute_metrics is not None: + eval_preds = self.post_process_function(eval_examples, eval_dataset, + output.predictions) + metrics = self.compute_metrics(eval_preds) + + # Prefix all keys with metric_key_prefix + '_' + for key in list(metrics.keys()): + if not key.startswith(f"{metric_key_prefix}_"): + metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) + + self.log(metrics) + else: + metrics = {} + + self.control = self.callback_handler.on_evaluate(self.args, self.state, + self.control, metrics) + return metrics + + def predict(self, + predict_dataset, + predict_examples, + ignore_keys=None, + metric_key_prefix: str="test"): + predict_dataloader = self.get_test_dataloader(predict_dataset) + + # Temporarily disable metric computation, we will do it in the loop here. + compute_metrics = self.compute_metrics + self.compute_metrics = None + eval_loop = self.evaluation_loop + try: + output = eval_loop( + predict_dataloader, + description="Prediction", + # No point gathering the predictions if there are no metrics, otherwise we defer to + # self.args.prediction_loss_only + prediction_loss_only=True if compute_metrics is None else None, + ignore_keys=ignore_keys, ) + finally: + self.compute_metrics = compute_metrics + + if self.post_process_function is None or self.compute_metrics is None: + return output + + predictions = self.post_process_function( + predict_examples, predict_dataset, output.predictions, "predict") + metrics = self.compute_metrics(predictions) + + # Prefix all keys with metric_key_prefix + '_' + for key in list(metrics.keys()): + if not key.startswith(f"{metric_key_prefix}_"): + metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) + + return PredictionOutput( + predictions=predictions.predictions, + label_ids=predictions.label_ids, + metrics=metrics) + + +def qa_collator(tokenizer, args): + train_batchify_fn = lambda samples, fn=Dict({ + "input_ids": Pad(axis=0, pad_val=tokenizer.pad_token_id), + "token_type_ids": Pad(axis=0, pad_val=tokenizer.pad_token_type_id), + "start_positions": Stack(dtype="int64"), + "end_positions": Stack(dtype="int64") + }): fn(samples) + + return train_batchify_fn + + +class CrossEntropyLossForSQuAD(paddle.nn.Layer): + def __init__(self): + super(CrossEntropyLossForSQuAD, self).__init__() + + def forward(self, y, label): + start_logits, end_logits = y + start_position, end_position = label + start_position = paddle.unsqueeze(start_position, axis=-1) + end_position = paddle.unsqueeze(end_position, axis=-1) + start_loss = paddle.nn.functional.cross_entropy( + input=start_logits, label=start_position) + end_loss = paddle.nn.functional.cross_entropy( + input=end_logits, label=end_position) + loss = (start_loss + end_loss) / 2 + return loss + + +def prepare_train_features(examples, tokenizer, args): + # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results + # in one example possible giving several features when a context is long, each of those features having a + # context that overlaps a bit the context of the previous feature. + # NOTE: Almost the same functionality as HuggingFace's prepare_train_features function. The main difference is + # that HugggingFace uses ArrowTable as basic data structure, while we use list of dictionary instead. + contexts = examples['context'] + questions = examples['question'] + + tokenized_examples = tokenizer( + questions, + contexts, + stride=args.doc_stride, + max_seq_len=args.max_seq_length) + + # Since one example might give us several features if it has a long context, we need a map from a feature to + # its corresponding example. This key gives us just that. + sample_mapping = tokenized_examples.pop("overflow_to_sample") + # The offset mappings will give us a map from token to character position in the original context. This will + # help us compute the start_positions and end_positions. + offset_mapping = tokenized_examples.pop("offset_mapping") + + # Let's label those examples! + tokenized_examples["start_positions"] = [] + tokenized_examples["end_positions"] = [] + + for i, offsets in enumerate(offset_mapping): + # We will label impossible answers with the index of the CLS token. + input_ids = tokenized_examples["input_ids"][i] + cls_index = input_ids.index(tokenizer.cls_token_id) + + # Grab the sequence corresponding to that example (to know what is the context and what is the question). + sequence_ids = tokenized_examples['token_type_ids'][i] + + # One example can give several spans, this is the index of the example containing this span of text. + sample_index = sample_mapping[i] + answers = examples['answers'][sample_index] + # If no answers are given, set the cls_index as answer. + if len(answers["answer_start"]) == 0: + tokenized_examples["start_positions"].append(cls_index) + tokenized_examples["end_positions"].append(cls_index) + else: + # Start/end character index of the answer in the text. + start_char = answers["answer_start"][0] + end_char = start_char + len(answers["text"][0]) + + # Start token index of the current span in the text. + token_start_index = 0 + while sequence_ids[token_start_index] != 1: + token_start_index += 1 + + # End token index of the current span in the text. + token_end_index = len(input_ids) - 1 + while sequence_ids[token_end_index] != 1: + token_end_index -= 1 + token_end_index -= 1 + + # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index). + if not (offsets[token_start_index][0] <= start_char and + offsets[token_end_index][1] >= end_char): + tokenized_examples["start_positions"].append(cls_index) + tokenized_examples["end_positions"].append(cls_index) + else: + # Otherwise move the token_start_index and token_end_index to the two ends of the answer. + # Note: we could go after the last offset if the answer is the last word (edge case). + while token_start_index < len(offsets) and offsets[ + token_start_index][0] <= start_char: + token_start_index += 1 + tokenized_examples["start_positions"].append(token_start_index - + 1) + while offsets[token_end_index][1] >= end_char: + token_end_index -= 1 + tokenized_examples["end_positions"].append(token_end_index + 1) + + return tokenized_examples + + +def prepare_validation_features(examples, tokenizer, args): + # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results + # in one example possible giving several features when a context is long, each of those features having a + # context that overlaps a bit the context of the previous feature. + #NOTE: Almost the same functionality as HuggingFace's prepare_train_features function. The main difference is + # that HugggingFace uses ArrowTable as basic data structure, while we use list of dictionary instead. + contexts = examples['context'] + questions = examples['question'] + + tokenized_examples = tokenizer( + questions, + contexts, + stride=args.doc_stride, + max_seq_len=args.max_seq_length, + return_attention_mask=True) + + # Since one example might give us several features if it has a long context, we need a map from a feature to + # its corresponding example. This key gives us just that. + sample_mapping = tokenized_examples.pop("overflow_to_sample") + + # For evaluation, we will need to convert our predictions to substrings of the context, so we keep the + # corresponding example_id and we will store the offset mappings. + tokenized_examples["example_id"] = [] + + for i in range(len(tokenized_examples["input_ids"])): + # Grab the sequence corresponding to that example (to know what is the context and what is the question). + sequence_ids = tokenized_examples['token_type_ids'][i] + context_index = 1 + + # One example can give several spans, this is the index of the example containing this span of text. + sample_index = sample_mapping[i] + tokenized_examples["example_id"].append(examples["id"][sample_index]) + + # Set to None the offset_mapping that are not part of the context so it's easy to determine if a token + # position is part of the context or not. + tokenized_examples["offset_mapping"][i] = [ + (o if sequence_ids[k] == context_index else None) + for k, o in enumerate(tokenized_examples["offset_mapping"][i]) + ] + + return tokenized_examples diff --git a/examples/language_model/ernie-1.0/finetune/run_ner.py b/examples/language_model/ernie-1.0/finetune/run_ner.py new file mode 100644 index 000000000000..0cbb5e464fb5 --- /dev/null +++ b/examples/language_model/ernie-1.0/finetune/run_ner.py @@ -0,0 +1,214 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import yaml +from functools import partial +import distutils.util +import os.path as osp +from typing import Optional + +import numpy as np +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from datasets import load_metric +import paddlenlp +from paddlenlp.trainer import ( + PdArgumentParser, + TrainingArguments, + Trainer, ) +from paddlenlp.datasets import load_dataset +from paddlenlp.transformers import ( + AutoTokenizer, + AutoModelForTokenClassification, ) +from paddlenlp.utils.log import logger + +sys.path.insert(0, os.path.abspath(".")) +from token_classification import ner_trans_fn, ner_collator +from utils import ( + ALL_DATASETS, + DataTrainingArguments, + ModelArguments, ) + + +def do_train(): + parser = PdArgumentParser( + (ModelArguments, DataTrainingArguments, TrainingArguments)) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + paddle.set_device(training_args.device) + if paddle.distributed.get_world_size() > 1: + paddle.distributed.init_parallel_env() + + # Log on each process the small summary: + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, world_size: {training_args.world_size}, " + + + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + ) + + # Detecting last checkpoint. + last_checkpoint = None + if os.path.isdir( + training_args.output_dir + ) and training_args.do_train and not training_args.overwrite_output_dir: + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len( + os.listdir(training_args.output_dir)) > 0: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome.") + elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + + # set_seed(args) + data_args.dataset = data_args.dataset.strip() + if data_args.dataset not in ALL_DATASETS: + raise ValueError("Not found dataset {}".format(data_args.dataset)) + + # Use yaml config to rewrite all args. + config = ALL_DATASETS[data_args.dataset] + for args in (model_args, data_args, training_args): + for arg in vars(args): + if arg in config.keys(): + setattr(args, arg, config[arg]) + + training_args.per_device_train_batch_size = config["batch_size"] + training_args.per_device_eval_batch_size = config["batch_size"] + + dataset_config = data_args.dataset.split(" ") + all_ds = load_dataset( + dataset_config[0], + None if len(dataset_config) <= 1 else dataset_config[1], ) + + label_list = getattr(all_ds['train'], "label_list", None) + data_args.label_list = label_list + data_args.ignore_label = -100 + data_args.no_entity_id = len(data_args.label_list) - 1 + + num_classes = 1 if all_ds["train"].label_list == None else len(all_ds[ + 'train'].label_list) + + # Define tokenizer, model, loss function. + tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path) + model = AutoModelForTokenClassification.from_pretrained( + model_args.model_name_or_path, num_classes=num_classes) + + class criterion(nn.Layer): + def __init__(self): + super(criterion, self).__init__() + self.loss_fn = paddle.nn.loss.CrossEntropyLoss( + ignore_index=data_args.ignore_label) + + def forward(self, *args, **kwargs): + return paddle.mean(self.loss_fn(*args, **kwargs)) + + loss_fct = criterion() + + # Define dataset pre-process function + trans_fn = partial(ner_trans_fn, tokenizer=tokenizer, args=data_args) + + # Define data collector + batchify_fn = ner_collator(tokenizer, data_args) + + # Dataset pre-process + train_dataset = all_ds["train"].map(trans_fn) + eval_dataset = all_ds["dev"].map(trans_fn) + test_dataset = all_ds["test"].map(trans_fn) + + # Define the metrics of tasks. + # Metrics + metric = load_metric("seqeval") + + def compute_metrics(p): + predictions, labels = p + predictions = np.argmax(predictions, axis=2) + + # Remove ignored index (special tokens) + true_predictions = [ + [label_list[p] for (p, l) in zip(prediction, label) if l != -100] + for prediction, label in zip(predictions, labels) + ] + true_labels = [ + [label_list[l] for (p, l) in zip(prediction, label) if l != -100] + for prediction, label in zip(predictions, labels) + ] + results = metric.compute( + predictions=true_predictions, references=true_labels) + return { + "precision": results["overall_precision"], + "recall": results["overall_recall"], + "f1": results["overall_f1"], + "accuracy": results["overall_accuracy"], + } + + trainer = Trainer( + model=model, + criterion=loss_fct, + args=training_args, + data_collator=batchify_fn, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + compute_metrics=compute_metrics, ) + + # Log model and data config + trainer.print_config(model_args, "Model") + trainer.print_config(data_args, "Data") + + checkpoint = None + if training_args.resume_from_checkpoint is not None: + checkpoint = training_args.resume_from_checkpoint + elif last_checkpoint is not None: + checkpoint = last_checkpoint + + # Training + train_result = trainer.train(resume_from_checkpoint=checkpoint) + metrics = train_result.metrics + trainer.save_model() # Saves the tokenizer too for easy upload + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_state() + + # Evaluate and tests model + eval_metrics = trainer.evaluate() + trainer.log_metrics("eval", eval_metrics) + + test_ret = trainer.predict(test_dataset) + trainer.log_metrics("test", test_ret.metrics) + if test_ret.label_ids is None: + paddle.save( + test_ret.predictions, + os.path.join(training_args.output_dir, "test_results.pdtensor"), ) + + # export inference model + input_spec = [ + paddle.static.InputSpec( + shape=[None, None], dtype="int64"), # input_ids + paddle.static.InputSpec( + shape=[None, None], dtype="int64") # segment_ids + ] + trainer.export_model( + input_spec=input_spec, + load_best_model=True, + output_dir=model_args.export_model_dir) + + +if __name__ == "__main__": + do_train() diff --git a/examples/language_model/ernie-1.0/finetune/run_qa.py b/examples/language_model/ernie-1.0/finetune/run_qa.py new file mode 100644 index 000000000000..7ccaf482fb2e --- /dev/null +++ b/examples/language_model/ernie-1.0/finetune/run_qa.py @@ -0,0 +1,241 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import yaml +from functools import partial +import distutils.util +import os.path as osp +from typing import Optional + +import numpy as np +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from datasets import load_metric, load_dataset + +import paddlenlp +from paddlenlp.trainer import ( + PdArgumentParser, + TrainingArguments, + Trainer, ) +from paddlenlp.trainer.trainer_utils import EvalPrediction + +# from paddlenlp.datasets import load_dataset +from paddlenlp.transformers import ( + AutoTokenizer, + AutoModelForQuestionAnswering, ) +from paddlenlp.utils.log import logger +from paddlenlp.metrics.squad import squad_evaluate, compute_prediction + +sys.path.insert(0, os.path.abspath(".")) +from question_answering import ( + QuestionAnsweringTrainer, + CrossEntropyLossForSQuAD, + prepare_train_features, + prepare_validation_features, + qa_collator, ) + +from utils import ( + ALL_DATASETS, + DataTrainingArguments, + ModelArguments, ) + + +def do_train(): + parser = PdArgumentParser( + (ModelArguments, DataTrainingArguments, TrainingArguments)) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + paddle.set_device(training_args.device) + if paddle.distributed.get_world_size() > 1: + paddle.distributed.init_parallel_env() + + # Log on each process the small summary: + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, world_size: {training_args.world_size}, " + + + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + ) + + # Detecting last checkpoint. + last_checkpoint = None + if os.path.isdir( + training_args.output_dir + ) and training_args.do_train and not training_args.overwrite_output_dir: + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len( + os.listdir(training_args.output_dir)) > 0: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome.") + elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + + # set_seed(args) + data_args.dataset = data_args.dataset.strip() + if data_args.dataset not in ALL_DATASETS: + raise ValueError("Not found dataset {}".format(data_args.dataset)) + + # Use yaml config to rewrite all args. + config = ALL_DATASETS[data_args.dataset] + for args in (model_args, data_args, training_args): + for arg in vars(args): + if arg in config.keys(): + setattr(args, arg, config[arg]) + + training_args.per_device_train_batch_size = config["batch_size"] + training_args.per_device_eval_batch_size = config["batch_size"] + + dataset_config = data_args.dataset.split(" ") + raw_datasets = load_dataset( + dataset_config[0], + None if len(dataset_config) <= 1 else dataset_config[1], + cache_dir=model_args.cache_dir) + + label_list = getattr(raw_datasets['train'], "label_list", None) + data_args.label_list = label_list + + # Define tokenizer, model, loss function. + tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path) + model = AutoModelForQuestionAnswering.from_pretrained( + model_args.model_name_or_path) + + loss_fct = CrossEntropyLossForSQuAD() + + train_dataset = raw_datasets["train"] + eval_examples = raw_datasets["validation"] + predict_examples = raw_datasets["test"] + + column_names = raw_datasets["train"].column_names + # Dataset pre-process + train_dataset = train_dataset.map( + partial( + prepare_train_features, tokenizer=tokenizer, args=data_args), + batched=True, + num_proc=4, + remove_columns=column_names, + load_from_cache_file=not data_args.overwrite_cache, + desc="Running tokenizer on train dataset", ) + + eval_dataset = eval_examples.map( + partial( + prepare_validation_features, tokenizer=tokenizer, args=data_args), + batched=True, + num_proc=4, + remove_columns=column_names, + load_from_cache_file=not data_args.overwrite_cache, + desc="Running tokenizer on validation dataset", ) + + predict_dataset = predict_examples.map( + partial( + prepare_validation_features, tokenizer=tokenizer, args=data_args), + batched=True, + num_proc=4, + remove_columns=column_names, + load_from_cache_file=not data_args.overwrite_cache, + desc="Running tokenizer on prediction dataset", ) + + # Define data collector + data_collator = qa_collator(tokenizer, data_args) + + # Post-processing: + def post_processing_function(examples, features, predictions, stage="eval"): + # Post-processing: we match the start logits and end logits to answers in the original context. + predictions, all_nbest_json, scores_diff_json = compute_prediction( + examples=examples, + features=features, + predictions=predictions, + n_best_size=data_args.n_best_size, + max_answer_length=data_args.max_answer_length, + null_score_diff_threshold=data_args.null_score_diff_threshold, ) + # Format the result to the format the metric expects. + formatted_predictions = [{ + "id": k, + "prediction_text": v + } for k, v in predictions.items()] + references = [{ + "id": ex["id"], + "answers": ex["answers"] + } for ex in examples] + return EvalPrediction( + predictions=formatted_predictions, label_ids=references) + + # Define the metrics of tasks. + # Metrics + metric = load_metric("squad") + + def compute_metrics(p: EvalPrediction): + return metric.compute(predictions=p.predictions, references=p.label_ids) + + trainer = QuestionAnsweringTrainer( + model=model, + criterion=loss_fct, + args=training_args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + eval_examples=eval_examples, + data_collator=data_collator, + post_process_function=post_processing_function, + tokenizer=tokenizer, + compute_metrics=compute_metrics, ) + + # Log model and data config + trainer.print_config(model_args, "Model") + trainer.print_config(data_args, "Data") + + checkpoint = None + if training_args.resume_from_checkpoint is not None: + checkpoint = training_args.resume_from_checkpoint + elif last_checkpoint is not None: + checkpoint = last_checkpoint + + # Training + train_result = trainer.train(resume_from_checkpoint=checkpoint) + metrics = train_result.metrics + trainer.save_model() # Saves the tokenizer too for easy upload + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_state() + + # Evaluate and tests model + eval_metrics = trainer.evaluate() + trainer.log_metrics("eval", eval_metrics) + + test_ret = trainer.predict(predict_dataset, predict_examples) + trainer.log_metrics("predict", test_ret.metrics) + if test_ret.label_ids is None: + paddle.save( + test_ret.predictions, + os.path.join(training_args.output_dir, "test_results.pdtensor"), ) + + # export inference model + input_spec = [ + paddle.static.InputSpec( + shape=[None, None], dtype="int64"), # input_ids + paddle.static.InputSpec( + shape=[None, None], dtype="int64") # segment_ids + ] + trainer.export_model( + input_spec=input_spec, + load_best_model=True, + output_dir=model_args.export_model_dir) + + +if __name__ == "__main__": + do_train() diff --git a/examples/language_model/ernie-1.0/finetune/run_seq_cls.py b/examples/language_model/ernie-1.0/finetune/run_seq_cls.py new file mode 100644 index 000000000000..e0440e4f1daa --- /dev/null +++ b/examples/language_model/ernie-1.0/finetune/run_seq_cls.py @@ -0,0 +1,194 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import yaml +from functools import partial +import distutils.util +import os.path as osp +from typing import Optional + +import numpy as np +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle.metric import Accuracy + +import paddlenlp +from paddlenlp.trainer import (PdArgumentParser, TrainingArguments, Trainer) +from paddlenlp.datasets import load_dataset +from paddlenlp.transformers import ( + AutoTokenizer, + AutoModelForSequenceClassification, ) +from paddlenlp.utils.log import logger + +sys.path.insert(0, os.path.abspath(".")) +from sequence_classification import seq_trans_fn, clue_trans_fn +from utils import ( + ALL_DATASETS, + DataTrainingArguments, + ModelArguments, + defaut_collator, ) + + +def do_train(): + parser = PdArgumentParser( + (ModelArguments, DataTrainingArguments, TrainingArguments)) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + paddle.set_device(training_args.device) + if paddle.distributed.get_world_size() > 1: + paddle.distributed.init_parallel_env() + + # Log on each process the small summary: + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, world_size: {training_args.world_size}, " + + + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + ) + + # Detecting last checkpoint. + last_checkpoint = None + if os.path.isdir( + training_args.output_dir + ) and training_args.do_train and not training_args.overwrite_output_dir: + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len( + os.listdir(training_args.output_dir)) > 0: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome.") + elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + + # set_seed(args) + data_args.dataset = data_args.dataset.strip() + if data_args.dataset not in ALL_DATASETS: + raise ValueError("Not found dataset {}".format(data_args.dataset)) + + # Use yaml config to rewrite all args. + config = ALL_DATASETS[data_args.dataset] + for args in (model_args, data_args, training_args): + for arg in vars(args): + if arg in config.keys(): + setattr(args, arg, config[arg]) + + training_args.per_device_train_batch_size = config["batch_size"] + training_args.per_device_eval_batch_size = config["batch_size"] + + dataset_config = data_args.dataset.split(" ") + raw_datasets = load_dataset( + dataset_config[0], + None if len(dataset_config) <= 1 else dataset_config[1], ) + + data_args.label_list = getattr(raw_datasets['train'], "label_list", None) + num_classes = 1 if raw_datasets["train"].label_list == None else len( + raw_datasets['train'].label_list) + + # Define tokenizer, model, loss function. + tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path) + model = AutoModelForSequenceClassification.from_pretrained( + model_args.model_name_or_path, num_classes=num_classes) + loss_fct = nn.loss.CrossEntropyLoss( + ) if data_args.label_list else nn.loss.MSELoss() + + # Define dataset pre-process function + if "clue" in data_args.dataset: + trans_fn = partial(clue_trans_fn, tokenizer=tokenizer, args=data_args) + else: + trans_fn = partial(seq_trans_fn, tokenizer=tokenizer, args=data_args) + + # Define data collector + batchify_fn = defaut_collator(tokenizer, data_args) + + # Dataset pre-process + train_dataset = raw_datasets["train"].map(trans_fn) + eval_dataset = raw_datasets["dev"].map(trans_fn) + test_dataset = raw_datasets["test"].map(trans_fn) + + # Define the metrics of tasks. + def compute_metrics(p): + preds = p.predictions[0] if isinstance(p.predictions, + tuple) else p.predictions + + preds = paddle.to_tensor(preds) + label = paddle.to_tensor(p.label_ids) + + probs = F.softmax(preds, axis=1) + metric = Accuracy() + metric.reset() + result = metric.compute(preds, label) + metric.update(result) + accu = metric.accumulate() + metric.reset() + return {"accuracy": accu} + + trainer = Trainer( + model=model, + criterion=loss_fct, + args=training_args, + data_collator=batchify_fn, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + compute_metrics=compute_metrics, ) + + # Log model and data config + trainer.print_config(model_args, "Model") + trainer.print_config(data_args, "Data") + + checkpoint = None + if training_args.resume_from_checkpoint is not None: + checkpoint = training_args.resume_from_checkpoint + elif last_checkpoint is not None: + checkpoint = last_checkpoint + + # Training + train_result = trainer.train(resume_from_checkpoint=checkpoint) + metrics = train_result.metrics + trainer.save_model() # Saves the tokenizer too for easy upload + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_state() + + # Evaluate and tests model + eval_metrics = trainer.evaluate() + trainer.log_metrics("eval", eval_metrics) + + test_ret = trainer.predict(test_dataset) + trainer.log_metrics("test", test_ret.metrics) + if test_ret.label_ids is None: + paddle.save( + test_ret.predictions, + os.path.join(training_args.output_dir, "test_results.pdtensor"), ) + + # export inference model + input_spec = [ + paddle.static.InputSpec( + shape=[None, None], dtype="int64"), # input_ids + paddle.static.InputSpec( + shape=[None, None], dtype="int64") # segment_ids + ] + trainer.export_model( + input_spec=input_spec, + load_best_model=True, + output_dir=model_args.export_model_dir) + + +if __name__ == "__main__": + do_train() diff --git a/examples/language_model/ernie-1.0/finetune/sequence_classification.py b/examples/language_model/ernie-1.0/finetune/sequence_classification.py new file mode 100644 index 000000000000..7edf621836b0 --- /dev/null +++ b/examples/language_model/ernie-1.0/finetune/sequence_classification.py @@ -0,0 +1,132 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +import numpy as np + + +def convert_example(example, tokenizer, max_seq_length=512, is_test=False): + is_test = True + if 'label' in example.keys(): + is_test = False + + if "text_b" in example.keys(): + text = example["text_a"] + text_pair = example["text_b"] + else: + text = example["text"] + text_pair = None + + encoded_inputs = tokenizer( + text=text, text_pair=text_pair, max_seq_len=max_seq_length) + input_ids = encoded_inputs["input_ids"] + token_type_ids = encoded_inputs["token_type_ids"] + + if is_test: + return { + "input_ids": input_ids, + "token_type_ids": token_type_ids, + } + else: + label = np.array([example["label"]], dtype="int64") + return { + "input_ids": input_ids, + "token_type_ids": token_type_ids, + "labels": label + } + + +# Data pre-process function for clue benchmark datatset +def convert_clue(example, + label_list, + tokenizer=None, + max_seq_length=512, + **kwargs): + """convert a glue example into necessary features""" + is_test = False + if 'label' not in example.keys(): + is_test = True + + if not is_test: + # `label_list == None` is for regression task + label_dtype = "int64" if label_list else "float32" + # Get the label + example['label'] = np.array(example["label"], dtype="int64") + label = example['label'] + # Convert raw text to feature + if 'keyword' in example: # CSL + sentence1 = " ".join(example['keyword']) + example = { + 'sentence1': sentence1, + 'sentence2': example['abst'], + 'label': example['label'] + } + elif 'target' in example: # wsc + text, query, pronoun, query_idx, pronoun_idx = example['text'], example[ + 'target']['span1_text'], example['target']['span2_text'], example[ + 'target']['span1_index'], example['target']['span2_index'] + text_list = list(text) + assert text[pronoun_idx:(pronoun_idx + len(pronoun) + )] == pronoun, "pronoun: {}".format(pronoun) + assert text[query_idx:(query_idx + len(query) + )] == query, "query: {}".format(query) + if pronoun_idx > query_idx: + text_list.insert(query_idx, "_") + text_list.insert(query_idx + len(query) + 1, "_") + text_list.insert(pronoun_idx + 2, "[") + text_list.insert(pronoun_idx + len(pronoun) + 2 + 1, "]") + else: + text_list.insert(pronoun_idx, "[") + text_list.insert(pronoun_idx + len(pronoun) + 1, "]") + text_list.insert(query_idx + 2, "_") + text_list.insert(query_idx + len(query) + 2 + 1, "_") + text = "".join(text_list) + example['sentence'] = text + + if tokenizer is None: + return example + if 'sentence' in example: + example = tokenizer(example['sentence'], max_seq_len=max_seq_length) + elif 'sentence1' in example: + example = tokenizer( + example['sentence1'], + text_pair=example['sentence2'], + max_seq_len=max_seq_length) + + if not is_test: + return { + "input_ids": example['input_ids'], + "token_type_ids": example['token_type_ids'], + "labels": label + } + else: + return { + "input_ids": example['input_ids'], + "token_type_ids": example['token_type_ids'] + } + + +def seq_trans_fn(example, tokenizer, args): + return convert_example( + example, + tokenizer=tokenizer, + max_seq_length=args.max_seq_length, ) + + +def clue_trans_fn(example, tokenizer, args): + return convert_clue( + example, + tokenizer=tokenizer, + label_list=args.label_list, + max_seq_length=args.max_seq_length) diff --git a/examples/language_model/ernie-1.0/finetune/token_classification.py b/examples/language_model/ernie-1.0/finetune/token_classification.py new file mode 100644 index 000000000000..d001edb7e8ed --- /dev/null +++ b/examples/language_model/ernie-1.0/finetune/token_classification.py @@ -0,0 +1,72 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import random +import time +import math +import sys +from functools import partial + +import numpy as np +import paddle + +import paddlenlp as ppnlp +from paddlenlp.transformers import LinearDecayWithWarmup +from paddlenlp.metrics import ChunkEvaluator +from paddlenlp.datasets import load_dataset +from paddlenlp.data import Stack, Tuple, Pad +from paddlenlp.utils.log import logger + +# from paddlenlp.trainer.trainer_base import TrainerBase +sys.path.insert(0, os.path.abspath(".")) +from utils import Dict + + +def tokenize_and_align_labels(example, tokenizer, no_entity_id, + max_seq_len=512): + labels = example['labels'] + example = example['tokens'] + tokenized_input = tokenizer( + example, + is_split_into_words=True, + max_seq_len=max_seq_len, ) + + # -2 for [CLS] and [SEP] + if len(tokenized_input['input_ids']) - 2 < len(labels): + labels = labels[:len(tokenized_input['input_ids']) - 2] + tokenized_input['labels'] = [no_entity_id] + labels + [no_entity_id] + tokenized_input['labels'] += [no_entity_id] * ( + len(tokenized_input['input_ids']) - len(tokenized_input['labels'])) + + return tokenized_input + + +def ner_collator(tokenizer, args): + batchify_fn = lambda samples, fn=Dict({ + 'input_ids': Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype='int32'), # input + 'token_type_ids': Pad(axis=0, pad_val=tokenizer.pad_token_type_id, dtype='int32'), # segment + 'labels': Pad(axis=0, pad_val=args.ignore_label, dtype='int64') # label + }): fn(samples) + + return batchify_fn + + +def ner_trans_fn(example, tokenizer, args): + return tokenize_and_align_labels( + example, + tokenizer=tokenizer, + no_entity_id=args.no_entity_id, + max_seq_len=args.max_seq_length) diff --git a/examples/language_model/ernie-1.0/finetune/utils.py b/examples/language_model/ernie-1.0/finetune/utils.py new file mode 100644 index 000000000000..149808446039 --- /dev/null +++ b/examples/language_model/ernie-1.0/finetune/utils.py @@ -0,0 +1,198 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Optional +import copy +import yaml +import os.path as osp + +from paddlenlp.data import Stack, Tuple, Pad + +TASKS = [ + "SequenceClassification", + "TokenClassification", + "QuestionAnswering", +] + +config = yaml.load( + open(osp.join(osp.abspath("."), "./config.yml"), 'r'), + Loader=yaml.FullLoader) +default_args = config["DefaultArgs"] + +ALL_DATASETS = {} + +for task_type in TASKS: + task = config[task_type] + for data_name in task.keys(): + new_args = task[data_name] + new_args = {} if new_args is None else new_args + final_args = copy.deepcopy(default_args) + final_args.update(new_args) + final_args["model"] = "AutoModelFor{}".format(task_type) + ALL_DATASETS[data_name] = final_args + + +class Dict(object): + def __init__(self, fn): + assert isinstance(fn, (dict)), 'Input pattern not understood. The input of Dict must be a dict with key of input column name and value of collate_fn ' \ + 'Received fn=%s' % (str(fn)) + + self._fn = fn + + for col_name, ele_fn in self._fn.items(): + assert callable( + ele_fn + ), 'Batchify functions must be callable! type(fn[%d]) = %s' % ( + col_name, str(type(ele_fn))) + + def __call__(self, data): + + ret = {} + if len(data) <= 0: + return ret + + for col_name, ele_fn in self._fn.items(): + # skip unused col_name, such as labels in test mode. + if col_name not in data[0].keys(): + continue + result = ele_fn([ele[col_name] for ele in data]) + ret[col_name] = result + + return ret + + +def defaut_collator(tokenizer, args): + """ Defaut collator for sequences classification + + Args: + tokenizer (PretrainedTokenizer): tokenizer of PretrainedModel + args : data argument, need label list. + + Returns: + batchify_fn (function): collator + """ + batchify_fn = lambda samples, fn=Dict({ + 'input_ids': Pad(axis=0, pad_val=tokenizer.pad_token_id), # input_ids + "token_type_ids": Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # token_type_ids + "labels": Stack(dtype="int64" if args.label_list else "float32") # labels + }): fn(samples) + + return batchify_fn + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + Using `PdArgumentParser` we can turn this class + into argparse arguments to be able to specify them on + the command line. + """ + + dataset: str = field( + default=None, + metadata={ + "help": "The name of the dataset to use (via the datasets library)." + }) + + max_seq_length: int = field( + default=128, + metadata={ + "help": + "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, ) + + # Additional configs for QA task. + doc_stride: int = field( + default=128, + metadata={ + "help": + "When splitting up a long document into chunks, how much stride to take between chunks." + }, ) + + n_best_size: int = field( + default=20, + metadata={ + "help": + "The total number of n-best predictions to generate in the nbest_predictions.json output file." + }, ) + + max_query_length: int = field( + default=64, + metadata={"help": "Max query length."}, ) + + max_answer_length: int = field( + default=30, + metadata={"help": "Max answer length."}, ) + + do_lower_case: bool = field( + default=False, + metadata={ + "help": + "Whether to lower case the input text. Should be True for uncased models and False for cased models." + }, ) + overwrite_cache: bool = field( + default=False, + metadata={"help": "Overwrite the cached training and evaluation sets"}) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={ + "help": "The number of processes to use for the preprocessing." + }, ) + null_score_diff_threshold: float = field( + default=0.0, + metadata={ + "help": + "The threshold used to select the null answer: if the best answer has a score that is less than " + "the score of the null answer minus this threshold, the null answer is selected for this example. " + "Only useful when `version_2_with_negative=True`." + }, ) + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ + + model_name_or_path: str = field(metadata={ + "help": + "Path to pretrained model or model identifier from https://paddlenlp.readthedocs.io/zh/latest/model_zoo/transformers.html" + }) + config_name: Optional[str] = field( + default=None, + metadata={ + "help": + "Pretrained config name or path if not the same as model_name" + }) + tokenizer_name: Optional[str] = field( + default=None, + metadata={ + "help": + "Pretrained tokenizer name or path if not the same as model_name" + }) + cache_dir: Optional[str] = field( + default=None, + metadata={ + "help": + "Path to directory to store the pretrained models downloaded from huggingface.co" + }, ) + export_model_dir: Optional[str] = field( + default=None, + metadata={ + "help": + "Path to directory to store the pretrained models downloaded from huggingface.co" + }, ) diff --git a/paddlenlp/__init__.py b/paddlenlp/__init__.py index 4aea54cd30c1..7e2eae95b3b6 100644 --- a/paddlenlp/__init__.py +++ b/paddlenlp/__init__.py @@ -31,6 +31,7 @@ from . import losses from . import experimental from .taskflow import Taskflow +from . import trainer import paddle paddle.disable_signal_handler() diff --git a/paddlenlp/datasets/__init__.py b/paddlenlp/datasets/__init__.py index ab21ef27c162..81153ed72234 100644 --- a/paddlenlp/datasets/__init__.py +++ b/paddlenlp/datasets/__init__.py @@ -16,6 +16,7 @@ from .chnsenticorp import * from .cmrc2018 import * from .drcd import * +from .drcd_cn import * from .dureader_robust import * from .glue import * from .lcqmc import * @@ -34,6 +35,7 @@ from .seabsa16 import * from .cote import * from .clue import * +from .nlpcc_dbqa import * from .nlpcc14_sc import * from .nlpcc13_evsam05_thu import * from .nlpcc13_evsam05_hit import * diff --git a/paddlenlp/datasets/chnsenticorp_v2.py b/paddlenlp/datasets/chnsenticorp_v2.py new file mode 100644 index 000000000000..908f558eb02c --- /dev/null +++ b/paddlenlp/datasets/chnsenticorp_v2.py @@ -0,0 +1,83 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import json +import os + +from paddle.dataset.common import md5file +from paddle.utils.download import get_path_from_url +from paddlenlp.utils.env import DATA_HOME +from . import DatasetBuilder + +__all__ = ['ChnSentiCorpV2'] + + +class ChnSentiCorpV2(DatasetBuilder): + """ + ChnSentiCorp (by Tan Songbo at ICT of Chinese Academy of Sciences, and for + opinion mining) + + """ + + URL = "https://paddlenlp.bj.bcebos.com/datasets/data-chnsenticorp.tar.gz" + MD5 = "e336e76d7be4ecd5479083d5b8f771e4" + META_INFO = collections.namedtuple('META_INFO', ('file', 'md5')) + SPLITS = { + 'train': META_INFO( + os.path.join('chnsenticorp', 'train', 'part.0'), + '3fac2659547f1ddf90d223b8ed31f22f'), + 'dev': META_INFO( + os.path.join('chnsenticorp', 'dev', 'part.0'), + 'a3a853bfb3af4a592fc4df24b56c88a7'), + 'test': META_INFO( + os.path.join('chnsenticorp', 'test', 'part.0'), + '6bfc8f35f523d2fdf12648d9d02778ff'), + } + + def _get_data(self, mode, **kwargs): + """Downloads dataset.""" + default_root = os.path.join(DATA_HOME, self.__class__.__name__) + filename, data_hash = self.SPLITS[mode] + fullname = os.path.join(default_root, filename) + if not os.path.exists(fullname) or (data_hash and + not md5file(fullname) == data_hash): + get_path_from_url(self.URL, default_root, self.MD5) + + return fullname + + def _read(self, filename, split): + """Reads data.""" + with open(filename, 'r', encoding='utf-8') as f: + head = True + for line in f: + data = line.strip().split("\t") + if not head: + head = data + else: + if split == 'train': + text, label = data + yield {"text": text, "label": label} + elif split == 'dev': + text, label = data + yield {"text": text, "label": label} + elif split == 'test': + text, label = data + yield {"text": text, "label": label} + + def get_labels(self): + """ + Return labels of the ChnSentiCorp object. + """ + return ["0", "1"] diff --git a/paddlenlp/datasets/drcd_cn.py b/paddlenlp/datasets/drcd_cn.py new file mode 100644 index 000000000000..42d78c24f7f5 --- /dev/null +++ b/paddlenlp/datasets/drcd_cn.py @@ -0,0 +1,88 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import json +import os + +from paddle.dataset.common import md5file +from paddle.utils.download import get_path_from_url +from paddlenlp.utils.env import DATA_HOME +from . import DatasetBuilder + +__all__ = ['DRCD_CN'] + + +class DRCD_CN(DatasetBuilder): + ''' + Delta Reading Comprehension Dataset is an open domain traditional Chinese + machine reading comprehension (MRC) dataset. The dataset contains 10,014 + paragraphs from 2,108 Wikipedia articles and 30,000+ questions generated + by annotators. + + This dataset translate origin Traditional Chinese to Simplified Chinese. + ''' + + URL = "https://bj.bcebos.com/paddlenlp/datasets/drcd_cn.tar.gz" + MD5 = "8ceed5076c4f59d7a3666b13851e41fa" + META_INFO = collections.namedtuple('META_INFO', ('file', 'md5')) + SPLITS = { + 'train': META_INFO( + os.path.join('drcd_cn', 'train.json'), + '5a51ee5a106e16965c85fce364d316d7'), + 'dev': META_INFO( + os.path.join('drcd_cn', 'dev.json'), + 'f352b17cddeed69877ff94d4321817ce'), + 'test': META_INFO( + os.path.join('drcd_cn', 'test.json'), + 'e674a667033c4e8c9ae6d05d95073d02') + } + + def _get_data(self, mode, **kwargs): + default_root = os.path.join(DATA_HOME, self.__class__.__name__) + filename, data_hash = self.SPLITS[mode] + fullname = os.path.join(default_root, filename) + if not os.path.exists(fullname) or (data_hash and + not md5file(fullname) == data_hash): + get_path_from_url(self.URL, default_root) + + return fullname + + def _read(self, filename, *args): + with open(filename, "r", encoding="utf8") as f: + input_data = json.load(f)["data"] + for entry in input_data: + title = entry.get("title", "").strip() + for paragraph in entry["paragraphs"]: + context = paragraph["context"].strip() + for qa in paragraph["qas"]: + qas_id = qa["id"] + question = qa["question"].strip() + answer_starts = [ + answer["answer_start"] + for answer in qa.get("answers", []) + ] + answers = [ + answer["text"].strip() + for answer in qa.get("answers", []) + ] + + yield { + 'id': qas_id, + 'title': title, + 'context': context, + 'question': question, + 'answers': answers, + 'answer_starts': answer_starts + } diff --git a/paddlenlp/datasets/dureader_nlp.py b/paddlenlp/datasets/dureader_nlp.py new file mode 100644 index 000000000000..d552504a1953 --- /dev/null +++ b/paddlenlp/datasets/dureader_nlp.py @@ -0,0 +1,83 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import json +import os + +from paddle.dataset.common import md5file +from paddle.utils.download import get_path_from_url +from paddlenlp.utils.env import DATA_HOME +from . import DatasetBuilder + +__all__ = ['DuReaderNLP'] + + +class DuReaderNLP(DatasetBuilder): + ''' + The machine reading comprehension dataset (i.e. DuReader) is designed + to measure the performance of a reading comprehension model. + + This is for internal dataset. You should nerver use it. + ''' + + URL = 'https://internal/datasets/dureader_nlp.tar.gz' + MD5 = '7372b42aadde59904c291341b73e30a1' + META_INFO = collections.namedtuple('META_INFO', ('file', 'md5')) + SPLITS = { + 'train': META_INFO( + os.path.join('dureader', 'train.json'), + 'd81648dccca54b48fd9cddecf28815b0'), + 'dev': META_INFO( + os.path.join('dureader', 'dev.json'), + 'd941140d8d5362d9031897ba2004af64'), + } + + def _get_data(self, mode, **kwargs): + default_root = os.path.join(DATA_HOME, self.__class__.__name__) + filename, data_hash = self.SPLITS[mode] + fullname = os.path.join(default_root, filename) + if not os.path.exists(fullname) or (data_hash and + not md5file(fullname) == data_hash): + get_path_from_url(self.URL, default_root, self.MD5) + + return fullname + + def _read(self, filename, *args): + with open(filename, "r", encoding="utf8") as f: + input_data = json.load(f)["data"] + for entry in input_data: + title = entry.get("title", "").strip() + for paragraph in entry["paragraphs"]: + context = paragraph["context"].strip() + for qa in paragraph["qas"]: + qas_id = qa["id"] + question = qa["question"].strip() + answer_starts = [ + answer["answer_start"] + for answer in qa.get("answers", []) + ] + answers = [ + answer["text"].strip() + for answer in qa.get("answers", []) + ] + + yield { + 'id': qas_id, + 'title': title, + 'context': context, + 'question': question, + 'answers': answers, + 'answer_starts': answer_starts + } diff --git a/paddlenlp/datasets/lcqmc_v2.py b/paddlenlp/datasets/lcqmc_v2.py new file mode 100644 index 000000000000..4ab60c4b0227 --- /dev/null +++ b/paddlenlp/datasets/lcqmc_v2.py @@ -0,0 +1,81 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import json +import os + +from paddle.dataset.common import md5file +from paddle.utils.download import get_path_from_url +from paddlenlp.utils.env import DATA_HOME +from . import DatasetBuilder + +__all__ = ['LCQMC_V2'] + + +class LCQMC_V2(DatasetBuilder): + """ + LCQMC:A Large-scale Chinese Question Matching Corpus + More information please refer to `https://www.aclweb.org/anthology/C18-1166/` + + """ + + URL = "https://bj.bcebos.com/paddlenlp/datasets/lcqmc_v2.tar.gz" + MD5 = "e44825d8e6d5117bc04caf3982cf934f" + META_INFO = collections.namedtuple('META_INFO', ('file', 'md5')) + SPLITS = { + 'train': META_INFO( + os.path.join('lcqmc', 'train.tsv'), + '2193c022439b038ac12c0ae918b211a1'), + 'dev': META_INFO( + os.path.join('lcqmc', 'dev.tsv'), + 'c5dcba253cb4105d914964fd8b3c0e94'), + 'test': META_INFO( + os.path.join('lcqmc', 'test.tsv'), + '8f4b71e15e67696cc9e112a459ec42bd'), + } + + def _get_data(self, mode, **kwargs): + default_root = os.path.join(DATA_HOME, self.__class__.__name__) + filename, data_hash = self.SPLITS[mode] + fullname = os.path.join(default_root, filename) + if not os.path.exists(fullname) or (data_hash and + not md5file(fullname) == data_hash): + get_path_from_url(self.URL, default_root, self.MD5) + + return fullname + + def _read(self, filename): + """Reads data.""" + with open(filename, 'r', encoding='utf-8') as f: + head = True + for line in f: + data = line.strip().split("\t") + if head: + head = False + else: + if len(data) == 3: + query, title, label = data + yield {"query": query, "title": title, "label": label} + elif len(data) == 2: + query, title = data + yield {"query": query, "title": title, "label": ''} + else: + continue + + def get_labels(self): + """ + Return labels of the LCQMC object. + """ + return ["0", "1"] diff --git a/paddlenlp/datasets/nlpcc_dbqa.py b/paddlenlp/datasets/nlpcc_dbqa.py new file mode 100644 index 000000000000..0befef2ebe49 --- /dev/null +++ b/paddlenlp/datasets/nlpcc_dbqa.py @@ -0,0 +1,88 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import os + +from paddle.dataset.common import md5file +from paddle.utils.download import get_path_from_url +from paddlenlp.utils.env import DATA_HOME +from . import DatasetBuilder + +__all__ = ['NLPCC_DBQA'] + + +class NLPCC_DBQA(DatasetBuilder): + """ + NLPCC2016 DBQA dataset. + + Document-based QA (or DBQA) task + When predicting answers to each question, a DBQA system built by each + participating team IS LIMITED TO select sentences as answersfrom the + question’s given document. + + For more imformation: http://tcci.ccf.org.cn/conference/2016/dldoc/evagline2.pdf + """ + + URL = "https://bj.bcebos.com/paddlenlp/datasets/nlpcc-dbqa.zip" + MD5 = "a5f69c2462136ef4d1707e4e2551a57b" + META_INFO = collections.namedtuple('META_INFO', ('file', 'md5')) + SPLITS = { + 'train': META_INFO( + os.path.join('nlpcc-dbqa', 'nlpcc-dbqa', 'train.tsv'), + '4f84fefce1a8f52c8d9248d1ff5ab9bd'), + 'dev': META_INFO( + os.path.join('nlpcc-dbqa', 'nlpcc-dbqa', 'dev.tsv'), + '3831beb0d42c29615d06343538538f53'), + 'test': META_INFO( + os.path.join('nlpcc-dbqa', 'nlpcc-dbqa', 'test.tsv'), + 'e224351353b1f6a15837008b5d0da703'), + } + + def _get_data(self, mode, **kwargs): + """Downloads dataset.""" + default_root = os.path.join(DATA_HOME, self.__class__.__name__) + filename, data_hash = self.SPLITS[mode] + fullname = os.path.join(default_root, filename) + if not os.path.exists(fullname) or (data_hash and + not md5file(fullname) == data_hash): + get_path_from_url(self.URL, default_root, self.MD5) + + return fullname + + def _read(self, filename, split): + """Reads data.""" + with open(filename, 'r', encoding='utf-8') as f: + head = None + for line in f: + data = line.strip().split("\t") + if not head: + head = data + else: + qid, text_a, text_b, label = data + yield { + "qid": qid, + "text_a": text_a, + "text_b": text_b, + "label": label + } + + def get_labels(self): + """ + Return labels of XNLI dataset. + + Note: + Contradictory and contradiction are the same label + """ + return ["0", "1"] diff --git a/paddlenlp/trainer/__init__.py b/paddlenlp/trainer/__init__.py new file mode 100644 index 000000000000..f2a3ca873fc6 --- /dev/null +++ b/paddlenlp/trainer/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you smay not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .argparser import PdArgumentParser +from .trainer_args import TrainingArguments +from .trainer_base import Trainer \ No newline at end of file diff --git a/paddlenlp/trainer/argparser.py b/paddlenlp/trainer/argparser.py new file mode 100644 index 000000000000..d7df0c9406cf --- /dev/null +++ b/paddlenlp/trainer/argparser.py @@ -0,0 +1,268 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# # Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +import json +import sys +from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, ArgumentTypeError +from copy import copy +from enum import Enum +from inspect import isclass +from pathlib import Path +from typing import Any, Dict, Iterable, NewType, Optional, Tuple, Union, get_type_hints + +DataClass = NewType("DataClass", Any) +DataClassType = NewType("DataClassType", Any) + + +# From https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse +def string_to_bool(v): + if isinstance(v, bool): + return v + if v.lower() in ("yes", "true", "t", "y", "1"): + return True + elif v.lower() in ("no", "false", "f", "n", "0"): + return False + else: + raise ArgumentTypeError( + f"Truthy value expected: got {v} but expected one of yes/no, true/false, t/f, y/n, 1/0 (case insensitive)." + ) + + +class PdArgumentParser(ArgumentParser): + """ + This subclass of `argparse.ArgumentParser` uses type hints on dataclasses to generate arguments. + + The class is designed to play well with the native argparse. In particular, you can add more (non-dataclass backed) + arguments to the parser after initialization and you'll get the output back after parsing as an additional + namespace. Optional: To create sub argument groups use the `_argument_group_name` attribute in the dataclass. + """ + + dataclass_types: Iterable[DataClassType] + + def __init__(self, + dataclass_types: Union[DataClassType, Iterable[DataClassType]], + **kwargs): + """ + Args: + dataclass_types: + Dataclass type, or list of dataclass types for which we will "fill" instances with the parsed args. + kwargs: + (Optional) Passed to `argparse.ArgumentParser()` in the regular way. + """ + # To make the default appear when using --help + if "formatter_class" not in kwargs: + kwargs["formatter_class"] = ArgumentDefaultsHelpFormatter + super().__init__(**kwargs) + if dataclasses.is_dataclass(dataclass_types): + dataclass_types = [dataclass_types] + self.dataclass_types = list(dataclass_types) + for dtype in self.dataclass_types: + self._add_dataclass_arguments(dtype) + + @staticmethod + def _parse_dataclass_field(parser: ArgumentParser, + field: dataclasses.Field): + field_name = f"--{field.name}" + kwargs = field.metadata.copy() + # field.metadata is not used at all by Data Classes, + # it is provided as a third-party extension mechanism. + if isinstance(field.type, str): + raise RuntimeError( + "Unresolved type detected, which should have been done with the help of " + "`typing.get_type_hints` method by default") + + origin_type = getattr(field.type, "__origin__", field.type) + if origin_type is Union: + if len(field.type.__args__) != 2 or type( + None) not in field.type.__args__: + raise ValueError( + "Only `Union[X, NoneType]` (i.e., `Optional[X]`) is allowed for `Union`" + ) + if bool not in field.type.__args__: + # filter `NoneType` in Union (except for `Union[bool, NoneType]`) + field.type = (field.type.__args__[0] + if isinstance(None, field.type.__args__[1]) else + field.type.__args__[1]) + origin_type = getattr(field.type, "__origin__", field.type) + + # A variable to store kwargs for a boolean field, if needed + # so that we can init a `no_*` complement argument (see below) + bool_kwargs = {} + if isinstance(field.type, type) and issubclass(field.type, Enum): + kwargs["choices"] = [x.value for x in field.type] + kwargs["type"] = type(kwargs["choices"][0]) + if field.default is not dataclasses.MISSING: + kwargs["default"] = field.default + else: + kwargs["required"] = True + elif field.type is bool or field.type is Optional[bool]: + # Copy the currect kwargs to use to instantiate a `no_*` complement argument below. + # We do not initialize it here because the `no_*` alternative must be instantiated after the real argument + bool_kwargs = copy(kwargs) + + # Hack because type=bool in argparse does not behave as we want. + kwargs["type"] = string_to_bool + if field.type is bool or (field.default is not None and + field.default is not dataclasses.MISSING): + # Default value is False if we have no default when of type bool. + default = False if field.default is dataclasses.MISSING else field.default + # This is the value that will get picked if we don't include --field_name in any way + kwargs["default"] = default + # This tells argparse we accept 0 or 1 value after --field_name + kwargs["nargs"] = "?" + # This is the value that will get picked if we do --field_name (without value) + kwargs["const"] = True + elif isclass(origin_type) and issubclass(origin_type, list): + kwargs["type"] = field.type.__args__[0] + kwargs["nargs"] = "+" + if field.default_factory is not dataclasses.MISSING: + kwargs["default"] = field.default_factory() + elif field.default is dataclasses.MISSING: + kwargs["required"] = True + else: + kwargs["type"] = field.type + if field.default is not dataclasses.MISSING: + kwargs["default"] = field.default + elif field.default_factory is not dataclasses.MISSING: + kwargs["default"] = field.default_factory() + else: + kwargs["required"] = True + parser.add_argument(field_name, **kwargs) + + # Add a complement `no_*` argument for a boolean field AFTER the initial field has already been added. + # Order is important for arguments with the same destination! + # We use a copy of earlier kwargs because the original kwargs have changed a lot before reaching down + # here and we do not need those changes/additional keys. + if field.default is True and (field.type is bool or + field.type is Optional[bool]): + bool_kwargs["default"] = False + parser.add_argument( + f"--no_{field.name}", + action="store_false", + dest=field.name, + **bool_kwargs) + + def _add_dataclass_arguments(self, dtype: DataClassType): + if hasattr(dtype, "_argument_group_name"): + parser = self.add_argument_group(dtype._argument_group_name) + else: + parser = self + + try: + type_hints: Dict[str, type] = get_type_hints(dtype) + except NameError: + raise RuntimeError( + f"Type resolution failed for f{dtype}. Try declaring the class in global scope or " + f"removing line of `from __future__ import annotations` which opts in Postponed " + f"Evaluation of Annotations (PEP 563)") + + for field in dataclasses.fields(dtype): + if not field.init: + continue + field.type = type_hints[field.name] + self._parse_dataclass_field(parser, field) + + def parse_args_into_dataclasses( + self, + args=None, + return_remaining_strings=False, + look_for_args_file=True, + args_filename=None) -> Tuple[DataClass, ...]: + """ + Parse command-line args into instances of the specified dataclass types. + + This relies on argparse's `ArgumentParser.parse_known_args`. See the doc at: + docs.python.org/3.7/library/argparse.html#argparse.ArgumentParser.parse_args + + Args: + args: + List of strings to parse. The default is taken from sys.argv. (same as argparse.ArgumentParser) + return_remaining_strings: + If true, also return a list of remaining argument strings. + look_for_args_file: + If true, will look for a ".args" file with the same base name as the entry point script for this + process, and will append its potential content to the command line args. + args_filename: + If not None, will uses this file instead of the ".args" file specified in the previous argument. + + Returns: + Tuple consisting of: + + - the dataclass instances in the same order as they were passed to the initializer.abspath + - if applicable, an additional namespace for more (non-dataclass backed) arguments added to the parser + after initialization. + - The potential list of remaining argument strings. (same as argparse.ArgumentParser.parse_known_args) + """ + if args_filename or (look_for_args_file and len(sys.argv)): + if args_filename: + args_file = Path(args_filename) + else: + args_file = Path(sys.argv[0]).with_suffix(".args") + + if args_file.exists(): + fargs = args_file.read_text().split() + args = fargs + args if args is not None else fargs + sys.argv[ + 1:] + # in case of duplicate arguments the first one has precedence + # so we append rather than prepend. + namespace, remaining_args = self.parse_known_args(args=args) + outputs = [] + for dtype in self.dataclass_types: + keys = {f.name for f in dataclasses.fields(dtype) if f.init} + inputs = {k: v for k, v in vars(namespace).items() if k in keys} + for k in keys: + delattr(namespace, k) + obj = dtype(**inputs) + outputs.append(obj) + if len(namespace.__dict__) > 0: + # additional namespace. + outputs.append(namespace) + if return_remaining_strings: + return (*outputs, remaining_args) + else: + if remaining_args: + raise ValueError( + f"Some specified arguments are not used by the PdArgumentParser: {remaining_args}" + ) + + return (*outputs, ) + + def parse_json_file(self, json_file: str) -> Tuple[DataClass, ...]: + """ + Alternative helper method that does not use `argparse` at all, instead loading a json file and populating the + dataclass types. + """ + data = json.loads(Path(json_file).read_text()) + outputs = [] + for dtype in self.dataclass_types: + keys = {f.name for f in dataclasses.fields(dtype) if f.init} + inputs = {k: v for k, v in data.items() if k in keys} + obj = dtype(**inputs) + outputs.append(obj) + return (*outputs, ) + + def parse_dict(self, args: dict) -> Tuple[DataClass, ...]: + """ + Alternative helper method that does not use `argparse` at all, instead uses a dict and populating the dataclass + types. + """ + outputs = [] + for dtype in self.dataclass_types: + keys = {f.name for f in dataclasses.fields(dtype) if f.init} + inputs = {k: v for k, v in args.items() if k in keys} + obj = dtype(**inputs) + outputs.append(obj) + return (*outputs, ) diff --git a/paddlenlp/trainer/trainer_args.py b/paddlenlp/trainer/trainer_args.py new file mode 100644 index 000000000000..9a03ff536127 --- /dev/null +++ b/paddlenlp/trainer/trainer_args.py @@ -0,0 +1,837 @@ +# Copyright 2020-present the HuggingFace Inc. team. +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import json +import math +import os +import warnings +from dataclasses import asdict, dataclass, field +from enum import Enum +from pathlib import Path +from typing import Any, Dict, List, Optional + +# from .utils import logging +from .trainer_utils import ( + SchedulerType, + IntervalStrategy, + EvaluationStrategy, + OptimizerNames, ) + +# logger = logging.get_logger(__name__) +# log_levels = logging.get_log_levels_dict().copy() +# trainer_log_levels = dict(**log_levels, passive=-1) +from paddlenlp.utils.log import logger +import paddle + + +def default_logdir() -> str: + """ + Same default + """ + import socket + from datetime import datetime + + current_time = datetime.now().strftime("%b%d_%H-%M-%S") + return os.path.join("runs", current_time + "_" + socket.gethostname()) + + +@dataclass +class TrainingArguments: + """ + TrainingArguments is the subset of the arguments we use in our example scripts **which relate to the training loop + itself**. + + Using [`HfArgumentParser`] we can turn this class into + [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the + command line. + + Parameters: + output_dir (`str`): + The output directory where the model predictions and checkpoints will be written. + overwrite_output_dir (`bool`, *optional*, defaults to `False`): + If `True`, overwrite the content of the output directory. Use this to continue training if `output_dir` + points to a checkpoint directory. + do_train (`bool`, *optional*, defaults to `False`): + Whether to run training or not. This argument is not directly used by [`Trainer`], it's intended to be used + by your training/evaluation scripts instead. See the [example + scripts](https://github.com/huggingface/transformers/tree/master/examples) for more details. + do_eval (`bool`, *optional*): + Whether to run evaluation on the validation set or not. Will be set to `True` if `evaluation_strategy` is + different from `"no"`. This argument is not directly used by [`Trainer`], it's intended to be used by your + training/evaluation scripts instead. See the [example + scripts](https://github.com/huggingface/transformers/tree/master/examples) for more details. + do_predict (`bool`, *optional*, defaults to `False`): + Whether to run predictions on the test set or not. This argument is not directly used by [`Trainer`], it's + intended to be used by your training/evaluation scripts instead. See the [example + scripts](https://github.com/huggingface/transformers/tree/master/examples) for more details. + evaluation_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"no"`): + The evaluation strategy to adopt during training. Possible values are: + + - `"no"`: No evaluation is done during training. + - `"steps"`: Evaluation is done (and logged) every `eval_steps`. + - `"epoch"`: Evaluation is done at the end of each epoch. + + prediction_loss_only (`bool`, *optional*, defaults to `False`): + When performing evaluation and generating predictions, only returns the loss. + per_device_train_batch_size (`int`, *optional*, defaults to 8): + The batch size per GPU core/CPU for training. + per_device_eval_batch_size (`int`, *optional*, defaults to 8): + The batch size per GPU core/CPU for evaluation. + gradient_accumulation_steps (`int`, *optional*, defaults to 1): + Number of updates steps to accumulate the gradients for, before performing a backward/update pass. + + + + When using gradient accumulation, one step is counted as one step with backward pass. Therefore, logging, + evaluation, save will be conducted every `gradient_accumulation_steps * xxx_step` training examples. + + + + eval_accumulation_steps (`int`, *optional*): + Number of predictions steps to accumulate the output tensors for, before moving the results to the CPU. If + left unset, the whole predictions are accumulated on GPU before being moved to the CPU (faster but + requires more memory). + learning_rate (`float`, *optional*, defaults to 5e-5): + The initial learning rate for [`AdamW`] optimizer. + weight_decay (`float`, *optional*, defaults to 0): + The weight decay to apply (if not zero) to all layers except all bias and LayerNorm weights in [`AdamW`] + optimizer. + adam_beta1 (`float`, *optional*, defaults to 0.9): + The beta1 hyperparameter for the [`AdamW`] optimizer. + adam_beta2 (`float`, *optional*, defaults to 0.999): + The beta2 hyperparameter for the [`AdamW`] optimizer. + adam_epsilon (`float`, *optional*, defaults to 1e-8): + The epsilon hyperparameter for the [`AdamW`] optimizer. + max_grad_norm (`float`, *optional*, defaults to 1.0): + Maximum gradient norm (for gradient clipping). + num_train_epochs(`float`, *optional*, defaults to 3.0): + Total number of training epochs to perform (if not an integer, will perform the decimal part percents of + the last epoch before stopping training). + max_steps (`int`, *optional*, defaults to -1): + If set to a positive number, the total number of training steps to perform. Overrides `num_train_epochs`. + In case of using a finite iterable dataset the training may stop before reaching the set number of steps + when all data is exhausted + lr_scheduler_type (`str` or [`SchedulerType`], *optional*, defaults to `"linear"`): + The scheduler type to use. See the documentation of [`SchedulerType`] for all possible values. + warmup_ratio (`float`, *optional*, defaults to 0.0): + Ratio of total training steps used for a linear warmup from 0 to `learning_rate`. + warmup_steps (`int`, *optional*, defaults to 0): + Number of steps used for a linear warmup from 0 to `learning_rate`. Overrides any effect of `warmup_ratio`. + log_level (`str`, *optional*, defaults to `passive`): + Logger log level to use on the main process. Possible choices are the log levels as strings: 'debug', + 'info', 'warning', 'error' and 'critical', plus a 'passive' level which doesn't set anything and lets the + application set the level. + log_level_replica (`str`, *optional*, defaults to `passive`): + Logger log level to use on replicas. Same choices as `log_level`" + log_on_each_node (`bool`, *optional*, defaults to `True`): + In multinode distributed training, whether to log using `log_level` once per node, or only on the main + node. + logging_dir (`str`, *optional*): + [TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to + *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***. + logging_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"steps"`): + The logging strategy to adopt during training. Possible values are: + + - `"no"`: No logging is done during training. + - `"epoch"`: Logging is done at the end of each epoch. + - `"steps"`: Logging is done every `logging_steps`. + + logging_first_step (`bool`, *optional*, defaults to `False`): + Whether to log and evaluate the first `global_step` or not. + logging_steps (`int`, *optional*, defaults to 500): + Number of update steps between two logs if `logging_strategy="steps"`. + logging_nan_inf_filter (`bool`, *optional*, defaults to `True`): + Whether to filter `nan` and `inf` losses for logging. If set to `True` the loss of every step that is `nan` + or `inf` is filtered and the average loss of the current logging window is taken instead. + + + + `logging_nan_inf_filter` only influences the logging of loss values, it does not change the behavior the + gradient is computed or applied to the model. + + + + save_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"steps"`): + The checkpoint save strategy to adopt during training. Possible values are: + + - `"no"`: No save is done during training. + - `"epoch"`: Save is done at the end of each epoch. + - `"steps"`: Save is done every `save_steps`. + save_steps (`int`, *optional*, defaults to 500): + Number of updates steps before two checkpoint saves if `save_strategy="steps"`. + save_total_limit (`int`, *optional*): + If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in + `output_dir`. + save_on_each_node (`bool`, *optional*, defaults to `False`): + When doing multi-node distributed training, whether to save models and checkpoints on each node, or only on + the main one. + + This should not be activated when the different nodes use the same storage as the files will be saved with + the same names for each node. + no_cuda (`bool`, *optional*, defaults to `False`): + Whether to not use CUDA even when it is available or not. + seed (`int`, *optional*, defaults to 42): + Random seed that will be set at the beginning of training. To ensure reproducibility across runs, use the + [`~Trainer.model_init`] function to instantiate the model if it has some randomly initialized parameters. + bf16 (`bool`, *optional*, defaults to `False`): + Whether to use bf16 16-bit (mixed) precision training instead of 32-bit training. Requires Ampere or higher + NVIDIA architecture. This is an experimental API and it may change. + fp16 (`bool`, *optional*, defaults to `False`): + Whether to use fp16 16-bit (mixed) precision training instead of 32-bit training. + fp16_opt_level (`str`, *optional*, defaults to 'O1'): + For `fp16` training, Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']. See details on + the [Apex documentation](https://nvidia.github.io/apex/amp). + local_rank (`int`, *optional*, defaults to -1): + Rank of the process during distributed training. + xpu_backend (`str`, *optional*): + The backend to use for xpu distributed training. Must be one of `"mpi"` or `"ccl"`. + dataloader_drop_last (`bool`, *optional*, defaults to `False`): + Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size) + or not. + eval_steps (`int`, *optional*): + Number of update steps between two evaluations if `evaluation_strategy="steps"`. Will default to the same + value as `logging_steps` if not set. + dataloader_num_workers (`int`, *optional*, defaults to 0): + Number of subprocesses to use for data loading (PyTorch only). 0 means that the data will be loaded in the + main process. + past_index (`int`, *optional*, defaults to -1): + Some models like [TransformerXL](../model_doc/transformerxl) or [XLNet](../model_doc/xlnet) can make use of + the past hidden states for their predictions. If this argument is set to a positive int, the `Trainer` will + use the corresponding output (usually index 2) as the past state and feed it to the model at the next + training step under the keyword argument `mems`. + run_name (`str`, *optional*): + A descriptor for the run. Typically used for [wandb](https://www.wandb.com/) and + [mlflow](https://www.mlflow.org/) logging. + disable_tqdm (`bool`, *optional*): + Whether or not to disable the tqdm progress bars and table of metrics produced by + [`~notebook.NotebookTrainingTracker`] in Jupyter Notebooks. Will default to `True` if the logging level is + set to warn or lower (default), `False` otherwise. + remove_unused_columns (`bool`, *optional*, defaults to `True`): + If using `datasets.Dataset` datasets, whether or not to automatically remove the columns unused by the + model forward method. + + (Note that this behavior is not implemented for [`TFTrainer`] yet.) + label_names (`List[str]`, *optional*): + The list of keys in your dictionary of inputs that correspond to the labels. + + Will eventually default to `["labels"]` except if the model used is one of the `XxxForQuestionAnswering` in + which case it will default to `["start_positions", "end_positions"]`. + load_best_model_at_end (`bool`, *optional*, defaults to `False`): + Whether or not to load the best model found during training at the end of training. + + + + When set to `True`, the parameters `save_strategy` needs to be the same as `eval_strategy`, and in the case + it is "steps", `save_steps` must be a round multiple of `eval_steps`. + + + + metric_for_best_model (`str`, *optional*): + Use in conjunction with `load_best_model_at_end` to specify the metric to use to compare two different + models. Must be the name of a metric returned by the evaluation with or without the prefix `"eval_"`. Will + default to `"loss"` if unspecified and `load_best_model_at_end=True` (to use the evaluation loss). + + If you set this value, `greater_is_better` will default to `True`. Don't forget to set it to `False` if + your metric is better when lower. + greater_is_better (`bool`, *optional*): + Use in conjunction with `load_best_model_at_end` and `metric_for_best_model` to specify if better models + should have a greater metric or not. Will default to: + + - `True` if `metric_for_best_model` is set to a value that isn't `"loss"` or `"eval_loss"`. + - `False` if `metric_for_best_model` is not set, or set to `"loss"` or `"eval_loss"`. + ignore_data_skip (`bool`, *optional*, defaults to `False`): + When resuming training, whether or not to skip the epochs and batches to get the data loading at the same + stage as in the previous training. If set to `True`, the training will begin faster (as that skipping step + can take a long time) but will not yield the same results as the interrupted training would have. + label_smoothing_factor (`float`, *optional*, defaults to 0.0): + The label smoothing factor to use. Zero means no label smoothing, otherwise the underlying onehot-encoded + labels are changed from 0s and 1s to `label_smoothing_factor/num_labels` and `1 - label_smoothing_factor + + label_smoothing_factor/num_labels` respectively. + debug (`str` or list of [`~debug_utils.DebugOption`], *optional*, defaults to `""`): + Enable one or more debug features. This is an experimental feature. + + Possible options are: + + - `"underflow_overflow"`: detects overflow in model's input/outputs and reports the last frames that led to + the event + - `"tpu_metrics_debug"`: print debug metrics on TPU + + The options should be separated by whitespaces. + optim (`str` or [`training_args.OptimizerNames`], *optional*, defaults to `"adamw"`): + The optimizer to use: adamw, or adafactor. + length_column_name (`str`, *optional*, defaults to `"length"`): + Column name for precomputed lengths. If the column exists, grouping by length will use these values rather + than computing them on train startup. Ignored unless `group_by_length` is `True` and the dataset is an + instance of `Dataset`. + report_to (`str` or `List[str]`, *optional*, defaults to `"all"`): + The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`, + `"comet_ml"`, `"mlflow"`, `"tensorboard"` and `"wandb"`. Use `"all"` to report to all integrations + installed, `"none"` for no integrations. + skip_memory_metrics (`bool`, *optional*, defaults to `True`): + Whether to skip adding of memory profiler reports to metrics. This is skipped by default because it slows + down the training and evaluation speed. + resume_from_checkpoint (`str`, *optional*): + The path to a folder with a valid checkpoint for your model. This argument is not directly used by + [`Trainer`], it's intended to be used by your training/evaluation scripts instead. See the [example + scripts](https://github.com/huggingface/transformers/tree/master/examples) for more details. + """ + + output_dir: str = field( + metadata={ + "help": + "The output directory where the model predictions and checkpoints will be written." + }, ) + overwrite_output_dir: bool = field( + default=False, + metadata={ + "help": + ("Overwrite the content of the output directory. " + "Use this to continue training if output_dir points to a checkpoint directory." + ) + }, ) + + do_train: bool = field( + default=False, metadata={"help": "Whether to run training."}) + do_eval: bool = field( + default=False, + metadata={"help": "Whether to run eval on the dev set."}) + do_predict: bool = field( + default=False, + metadata={"help": "Whether to run predictions on the test set."}) + evaluation_strategy: IntervalStrategy = field( + default="steps", + metadata={"help": "The evaluation strategy to use."}, ) + prediction_loss_only: bool = field( + default=False, + metadata={ + "help": + "When performing evaluation and predictions, only returns the loss." + }, ) + + per_device_train_batch_size: int = field( + default=8, + metadata={"help": "Batch size per GPU core/CPU for training."}) + per_device_eval_batch_size: int = field( + default=8, + metadata={"help": "Batch size per GPU core/CPU for evaluation."}) + + gradient_accumulation_steps: int = field( + default=1, + metadata={ + "help": + "Number of updates steps to accumulate before performing a backward/update pass." + }, ) + eval_accumulation_steps: Optional[int] = field( + default=None, + metadata={ + "help": + "Number of predictions steps to accumulate before moving the tensors to the CPU." + }, ) + + learning_rate: float = field( + default=5e-5, + metadata={"help": "The initial learning rate for AdamW."}) + weight_decay: float = field( + default=0.0, + metadata={"help": "Weight decay for AdamW if we apply some."}) + adam_beta1: float = field( + default=0.9, metadata={"help": "Beta1 for AdamW optimizer"}) + adam_beta2: float = field( + default=0.999, metadata={"help": "Beta2 for AdamW optimizer"}) + adam_epsilon: float = field( + default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."}) + max_grad_norm: float = field( + default=1.0, metadata={"help": "Max gradient norm."}) + + num_train_epochs: float = field( + default=3.0, + metadata={"help": "Total number of training epochs to perform."}) + max_steps: int = field( + default=-1, + metadata={ + "help": + "If > 0: set total number of training steps to perform. Override num_train_epochs." + }, ) + lr_scheduler_type: str = field( + default="linear", + metadata={"help": "The scheduler type to use."}, ) + warmup_ratio: float = field( + default=0.0, + metadata={ + "help": "Linear warmup over warmup_ratio fraction of total steps." + }) + warmup_steps: int = field( + default=0, metadata={"help": "Linear warmup over warmup_steps."}) + + log_on_each_node: bool = field( + default=True, + metadata={ + "help": + "When doing a multinode distributed training, whether to log once per node or just once on the main node." + }, ) + logging_dir: Optional[str] = field( + default=None, metadata={"help": "Tensorboard log dir."}) + logging_strategy: IntervalStrategy = field( + default="steps", + metadata={"help": "The logging strategy to use."}, ) + logging_first_step: bool = field( + default=False, metadata={"help": "Log the first global_step"}) + logging_steps: int = field( + default=500, metadata={"help": "Log every X updates steps."}) + + save_strategy: IntervalStrategy = field( + default="steps", + metadata={"help": "The checkpoint save strategy to use."}, ) + save_steps: int = field( + default=500, + metadata={"help": "Save checkpoint every X updates steps."}) + save_total_limit: Optional[int] = field( + default=None, + metadata={ + "help": + ("Limit the total amount of checkpoints. " + "Deletes the older checkpoints in the output_dir. Default is unlimited checkpoints" + ) + }, ) + save_on_each_node: bool = field( + default=False, + metadata={ + "help": + "When doing multi-node distributed training, whether to save models and checkpoints on each node, or only on the main one" + }, ) + no_cuda: bool = field( + default=False, + metadata={"help": "Do not use CUDA even when it is available"}) + seed: int = field( + default=42, + metadata={ + "help": "Random seed that will be set at the beginning of training." + }) + + fp16: bool = field( + default=False, + metadata={ + "help": "Whether to use fp16 (mixed) precision instead of 32-bit" + }, ) + fp16_opt_level: str = field( + default="O1", + metadata={ + "help": + ("For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']. " + "See details at https://nvidia.github.io/apex/amp.html") + }, ) + + scale_loss: float = field( + default=2**15, metadata={"help": "The value of scale_loss for fp16."}) + + minimum_eval_times: int = field( + default=None, + metadata={ + "help": + "If under eval_steps, the valid time is less then minimum_eval_times, the config of override eval_steps." + }) + + local_rank: int = field( + default=-1, metadata={"help": "For distributed training: local_rank"}) + + debug: str = field( + default="", + metadata={ + "help": "Whether or not to enable debug mode. Current options: " + "`underflow_overflow` (Detect underflow and overflow in activations and weights), " + "`tpu_metrics_debug` (print debug metrics on TPU)." + }, ) + + dataloader_drop_last: bool = field( + default=False, + metadata={ + "help": + "Drop the last incomplete batch if it is not divisible by the batch size." + }) + eval_steps: int = field( + default=200, metadata={"help": "Run an evaluation every X steps."}) + dataloader_num_workers: int = field( + default=0, + metadata={ + "help": + "Number of subprocesses to use for data loading (PyTorch only). 0 means that the data will be loaded in the main process." + }, ) + + past_index: int = field( + default=-1, + metadata={ + "help": + "If >=0, uses the corresponding part of the output as the past state for next step." + }, ) + + run_name: Optional[str] = field( + default=None, + metadata={ + "help": + "An optional descriptor for the run. Notably used for wandb logging." + }) + + device: Optional[str] = field( + default="gpu", + metadata={ + "help": + "An optional descriptor for the run. Notably used for wandb logging." + }) + + disable_tqdm: Optional[bool] = field( + default=None, + metadata={"help": "Whether or not to disable the tqdm progress bars."}) + + label_names: Optional[List[str]] = field( + default=None, + metadata={ + "help": + "The list of keys in your dictionary of inputs that correspond to the labels." + }) + + load_best_model_at_end: Optional[bool] = field( + default=False, + metadata={ + "help": + "Whether or not to load the best model found during training at the end of training." + }, ) + metric_for_best_model: Optional[str] = field( + default=None, + metadata={ + "help": "The metric to use to compare two different models." + }) + greater_is_better: Optional[bool] = field( + default=None, + metadata={ + "help": + "Whether the `metric_for_best_model` should be maximized or not." + }) + ignore_data_skip: bool = field( + default=False, + metadata={ + "help": + "When resuming training, whether or not to skip the first epochs and batches to get to the same training data." + }, ) + optim: str = field( + default="adamw", + metadata={"help": "The optimizer to use."}, ) + report_to: Optional[List[str]] = field( + default=None, + metadata={ + "help": + "The list of integrations to report the results and logs to." + }) + + skip_memory_metrics: bool = field( + default=True, + metadata={ + "help": + "Whether or not to skip adding of memory profiler reports to metrics." + }) + + resume_from_checkpoint: Optional[str] = field( + default=None, + metadata={ + "help": + "The path to a folder with a valid checkpoint for your model." + }, ) + + def __post_init__(self): + # Handle --use_env option in paddle.distributed.launch (local_rank not passed as an arg then). + # This needs to happen before any call to self.device or self.n_gpu. + env_local_rank = int(os.environ.get("PADDLE_RANK_IN_NODE", -1)) + if env_local_rank != -1 and env_local_rank != self.local_rank: + self.local_rank = env_local_rank + + # convert to int + self.log_level = -1 + self.log_level_replica = -1 + + # expand paths, if not os.makedirs("~/bar") will make directory + # in the current directory instead of the actual home + # see https://github.com/huggingface/transformers/issues/10628 + if self.output_dir is not None: + self.output_dir = os.path.expanduser(self.output_dir) + if self.logging_dir is None and self.output_dir is not None: + self.logging_dir = os.path.join(self.output_dir, default_logdir()) + if self.logging_dir is not None: + self.logging_dir = os.path.expanduser(self.logging_dir) + + if self.disable_tqdm is None: + self.disable_tqdm = False # logger.getEffectiveLevel() > logging.WARN + + if isinstance(self.evaluation_strategy, EvaluationStrategy): + warnings.warn( + "using `EvaluationStrategy` for `evaluation_strategy` is deprecated and will be removed in version 5 of 🤗 Transformers. Use `IntervalStrategy` instead", + FutureWarning, ) + # Go back to the underlying string or we won't be able to instantiate `IntervalStrategy` on it. + self.evaluation_strategy = self.evaluation_strategy.value + + self.evaluation_strategy = IntervalStrategy(self.evaluation_strategy) + self.logging_strategy = IntervalStrategy(self.logging_strategy) + self.save_strategy = IntervalStrategy(self.save_strategy) + + self.lr_scheduler_type = SchedulerType(self.lr_scheduler_type) + if self.do_eval is False and self.evaluation_strategy != IntervalStrategy.NO: + self.do_eval = True + + # eval_steps has to be defined and non-zero, fallbacks to logging_steps if the latter is non-zero + if self.evaluation_strategy == IntervalStrategy.STEPS and ( + self.eval_steps is None or self.eval_steps == 0): + if self.logging_steps > 0: + logger.info( + f"using `logging_steps` to initialize `eval_steps` to {self.logging_steps}" + ) + self.eval_steps = self.logging_steps + else: + raise ValueError( + f"evaluation strategy {self.evaluation_strategy} requires either non-zero --eval_steps or --logging_steps" + ) + + # logging_steps must be non-zero for logging_strategy that is other than 'no' + if self.logging_strategy == IntervalStrategy.STEPS and self.logging_steps == 0: + raise ValueError( + f"logging strategy {self.logging_strategy} requires non-zero --logging_steps" + ) + + # Sanity checks for load_best_model_at_end: we require save and eval strategies to be compatible. + if self.load_best_model_at_end: + if self.evaluation_strategy != self.save_strategy: + raise ValueError( + "--load_best_model_at_end requires the save and eval strategy to match, but found\n- Evaluation " + f"strategy: {self.evaluation_strategy}\n- Save strategy: {self.save_strategy}" + ) + if self.evaluation_strategy == IntervalStrategy.STEPS and self.save_steps % self.eval_steps != 0: + raise ValueError( + "--load_best_model_at_end requires the saving steps to be a round multiple of the evaluation " + f"steps, but found {self.save_steps}, which is not a round multiple of {self.eval_steps}." + ) + + if self.load_best_model_at_end and self.metric_for_best_model is None: + self.metric_for_best_model = "loss" + if self.greater_is_better is None and self.metric_for_best_model is not None: + self.greater_is_better = self.metric_for_best_model not in [ + "loss", "eval_loss" + ] + if self.run_name is None: + self.run_name = self.output_dir + + self.optim = OptimizerNames(self.optim) + + if self.warmup_ratio < 0 or self.warmup_ratio > 1: + raise ValueError("warmup_ratio must lie in range [0,1]") + elif self.warmup_ratio > 0 and self.warmup_steps > 0: + logger.info( + "Both warmup_ratio and warmup_steps given, warmup_steps will override any effect of warmup_ratio during training" + ) + + if isinstance(self.debug, str): + self.debug = [DebugOption(s) for s in self.debug.split()] + + def __str__(self): + self_as_dict = asdict(self) + self_as_dict = { + k: f"<{k.upper()}>" if k.endswith("_token") else v + for k, v in self_as_dict.items() + } + + attrs_as_str = [f"{k}={v},\n" for k, v in sorted(self_as_dict.items())] + return f"{self.__class__.__name__}(\n{''.join(attrs_as_str)})" + + __repr__ = __str__ + + @property + def train_batch_size(self) -> int: + """ + The actual batch size for training. + """ + train_batch_size = self.per_device_train_batch_size + return train_batch_size + + @property + def eval_batch_size(self) -> int: + """ + The actual batch size for evaluation. + """ + eval_batch_size = self.per_device_eval_batch_size + return eval_batch_size + + @property + def current_device(self) -> "paddle.device": + """ + The device used by this process. + """ + return paddle.device.get_device() + + @property + def world_size(self): + """ + The number of processes used in parallel. + """ + if self.local_rank != -1: + return paddle.distributed.get_world_size() + return 1 + + @property + def process_index(self): + """ + The index of the current process used. + """ + if self.local_rank != -1: + return paddle.distributed.get_rank() + return 0 + + @property + def local_process_index(self): + """ + The index of the local process used. + """ + if self.local_rank != -1: + return self.local_rank + return 0 + + @property + def should_log(self): + """ + Whether or not the current process should produce log. + """ + if self.log_on_each_node: + return self.local_process_index == 0 + else: + return self.process_index == 0 + + @property + def should_save(self): + """ + Whether or not the current process should write to disk, e.g., to save models and checkpoints. + """ + if self.save_on_each_node: + return self.local_process_index == 0 + else: + return self.process_index == 0 + + def get_process_log_level(self): + """ + Returns the log level to be used depending on whether this process is the main process of node 0, main process + of node non-0, or a non-main process. + + For the main process the log level defaults to `logging.INFO` unless overridden by `log_level` argument. + + For the replica processes the log level defaults to `logging.WARNING` unless overridden by `log_level_replica` + argument. + + The choice between the main and replica process settings is made according to the return value of `should_log`. + """ + + log_level_main_node = logging.INFO if self.log_level == -1 else self.log_level + log_level_replica_node = logging.WARNING if self.log_level_replica == -1 else self.log_level_replica + return log_level_main_node if self.should_log else log_level_replica_node + + @contextlib.contextmanager + def main_process_first(self, local=True, desc="work"): + """ + A context manager for paddle distributed environment where on needs to do something on the main process, while + blocking replicas, and when it's finished releasing the replicas. + + One such use is for `datasets`'s `map` feature which to be efficient should be run once on the main process, + which upon completion saves a cached version of results and which then automatically gets loaded by the + replicas. + + Args: + local (`bool`, *optional*, defaults to `True`): + if `True` first means process of rank 0 of each node if `False` first means process of rank 0 of node + rank 0 In multi-node environment with a shared filesystem you most likely will want to use + `local=False` so that only the main process of the first node will do the processing. If however, the + filesystem is not shared, then the main process of each node will need to do the processing, which is + the default behavior. + desc (`str`, *optional*, defaults to `"work"`): + a work description to be used in debug logs + + """ + if self.world_size > 1: + if local: + is_main_process = self.local_process_index == 0 + main_process_desc = "main local process" + else: + is_main_process = self.process_index == 0 + main_process_desc = "main process" + + try: + if not is_main_process: + # tell all replicas to wait + logger.debug( + f"{self.process_index}: waiting for the {main_process_desc} to perform {desc}" + ) + paddle.distributed.barrier() + yield + finally: + if is_main_process: + # the wait is over + logger.debug( + f"{self.process_index}: {main_process_desc} completed {desc}, releasing all replicas" + ) + paddle.distributed.barrier() + else: + yield + + def get_warmup_steps(self, num_training_steps: int): + """ + Get number of steps used for a linear warmup. + """ + warmup_steps = (self.warmup_steps if self.warmup_steps > 0 else + math.ceil(num_training_steps * self.warmup_ratio)) + return warmup_steps + + def to_dict(self): + """ + Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates + the token values by removing their value. + """ + d = asdict(self) + for k, v in d.items(): + if isinstance(v, Enum): + d[k] = v.value + if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum): + d[k] = [x.value for x in v] + if k.endswith("_token"): + d[k] = f"<{k.upper()}>" + return d + + def to_json_string(self): + """ + Serializes this instance to a JSON string. + """ + return json.dumps(self.to_dict(), indent=2) + + def to_sanitized_dict(self) -> Dict[str, Any]: + """ + Sanitized serialization to use with TensorBoard’s hparams + """ + d = self.to_dict() + d = { + ** d, ** { + "train_batch_size": self.train_batch_size, + "eval_batch_size": self.eval_batch_size + } + } + + valid_types = [bool, int, float, str] + valid_types.append(paddle.Tensor) + + return { + k: v if type(v) in valid_types else str(v) + for k, v in d.items() + } diff --git a/paddlenlp/trainer/trainer_base.py b/paddlenlp/trainer/trainer_base.py new file mode 100644 index 000000000000..1287f985f482 --- /dev/null +++ b/paddlenlp/trainer/trainer_base.py @@ -0,0 +1,1544 @@ +# Copyright 2020-present the HuggingFace Inc. team. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import contextlib +import inspect +import math +import os +import random +import re +import shutil +import sys +import time +import warnings +import types +from collections.abc import Mapping +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union + +from tqdm.auto import tqdm +import numpy as np +import paddle +import paddle.nn as nn +import paddle.amp.auto_cast as autocast +import paddle.distributed as dist +from paddle.io import ( + Dataset, + DataLoader, + DistributedBatchSampler, ) +from paddlenlp.transformers import LinearDecayWithWarmup +from paddlenlp.transformers.model_utils import PretrainedModel, unwrap_model +from paddlenlp.transformers.tokenizer_utils import PretrainedTokenizer +from paddlenlp.utils.log import logger + +from .trainer_args import (TrainingArguments, ) +from .trainer_utils import ( + IntervalStrategy, + EvaluationStrategy, + TrainOutput, + EvalPrediction, + PredictionOutput, + EvalLoopOutput, + speed_metrics, + OptimizerNames, + PREFIX_CHECKPOINT_DIR, + get_last_checkpoint, ) +from .trainer_callback import ( + CallbackHandler, + DefaultFlowCallback, + PrinterCallback, + ProgressCallback, + TrainerCallback, + TrainerControl, + TrainerState, ) +from .utils.helper import ( + distributed_concat, + nested_concat, + nested_detach, + nested_numpify, + nested_truncate, ) + +DEFAULT_CALLBACKS = [DefaultFlowCallback] + + +class DataCollator: + def __init__(self, *args, **kwargs): + pass + + +class DataCollatorWithPadding: + def __init__(self, *args, **kwargs): + pass + + +# Name of the files used for checkpointing +TRAINING_ARGS_NAME = "training_args.bin" +TRAINER_STATE_NAME = "trainer_state.json" + +OPTIMIZER_NAME = "optimizer.pdopt" +SCHEDULER_NAME = "scheduler.pdparams" +SCALER_NAME = "scaler.pdparams" + +WEIGHTS_NAME = "model_state.pdparams" +CONFIG_NAME = "model_config.json" + + +def set_seed(seed): + # Use the same data seed(for data shuffle) for all procs to guarantee data + # consistency after sharding. + random.seed(seed) + np.random.seed(seed) + # Maybe different op seeds(for dropout) for different procs is better. By: + # `paddle.seed(args.seed + paddle.distributed.get_rank())` + paddle.seed(seed) + # TODO: cuda state seed + + +class Trainer: + """ + Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 🤗 Transformers. + + Args: + model ([`PretrainedModel`] or `paddle.nn.Layer`, *optional*): + The model to train, evaluate or use for predictions. If not provided, a `model_init` must be passed. + + + + [`Trainer`] is optimized to work with the [`PretrainedModel`] provided by the library. You can still use + your own models defined as `paddle.nn.Layer` as long as they work the same way as the 🤗 Transformers + models. + + + + args ([`TrainingArguments`], *optional*): + The arguments to tweak for training. Will default to a basic instance of [`TrainingArguments`] with the + `output_dir` set to a directory named *tmp_trainer* in the current directory if not provided. + data_collator (`DataCollator`, *optional*): + The function to use to form a batch from a list of elements of `train_dataset` or `eval_dataset`. Will + default to [`default_data_collator`] if no `tokenizer` is provided, an instance of + [`DataCollatorWithPadding`] otherwise. + train_dataset (`paddle.utils.data.Dataset` or `paddle.utils.data.IterableDataset`, *optional*): + The dataset to use for training. If it is an `datasets.Dataset`, columns not accepted by the + `model.forward()` method are automatically removed. + + Note that if it's a `paddle.utils.data.IterableDataset` with some randomization and you are training in a + distributed fashion, your iterable dataset should either use a internal attribute `generator` that is a + `paddle.Generator` for the randomization that must be identical on all processes (and the Trainer will + manually set the seed of this `generator` at each epoch) or have a `set_epoch()` method that internally + sets the seed of the RNGs used. + eval_dataset (`paddle.utils.data.Dataset`, *optional*): + The dataset to use for evaluation. If it is an `datasets.Dataset`, columns not accepted by the + `model.forward()` method are automatically removed. + tokenizer ([`PretrainedTokenizer`], *optional*): + The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs the + maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an + interrupted training or reuse the fine-tuned model. + compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*): + The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return + a dictionary string to metric values. + optimizers (`Tuple[paddle.optimizer.Optimizer, paddle.optimizer.lr.LRScheduler]`, *optional*): A tuple + containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your model + and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`. + + Important attributes: + + - **model** -- Always points to the core model. If using a transformers model, it will be a [`PretrainedModel`] + subclass. + - **model_wrapped** -- Always points to the most external model in case one or more other modules wrap the + original model. This is the model that should be used for the forward pass. For example, the inner model is + wrapped in `paddle.nn.DataParallel`. If model hasn't been wrapped, then `self.model_wrapped` is the same + as `self.model`. + - **is_model_parallel** -- Whether or not a model has been switched to a model parallel mode (different from + data parallelism, this means some of the model layers are split on different GPUs). + - **place_model_on_device** -- Whether or not to automatically place the model on the device - it will be set + to `False` if model parallel or deepspeed is used, or if the default + `TrainingArguments.place_model_on_device` is overridden to return `False` . + - **is_in_train** -- Whether or not a model is currently running `train` (e.g. when `evaluate` is called while + in `train`) + + """ + from .trainer_utils import log_metrics, metrics_format, save_metrics, save_state + + def __init__( + self, + model: Union[PretrainedModel, nn.Layer]=None, + criterion: Union[nn.Layer]=None, + args: TrainingArguments=None, + data_collator: Optional[DataCollator]=None, + train_dataset: Optional[Dataset]=None, + eval_dataset: Optional[Dataset]=None, + tokenizer: Optional[PretrainedTokenizer]=None, + compute_metrics: Optional[Callable[[EvalPrediction], Dict]]=None, + optimizers: Tuple[paddle.optimizer.Optimizer, + paddle.optimizer.lr.LRScheduler]=(None, None), ): + if args is None: + output_dir = "tmp_trainer" + logger.info( + f"No `TrainingArguments` passed, using `output_dir={output_dir}`." + ) + args = TrainingArguments(output_dir=output_dir) + + self.args = args + self.do_grad_scaling = args.fp16 + + # Seed must be set before instantiating the model when using model + set_seed(self.args.seed) + if model is None: + raise RuntimeError( + "`Trainer` requires either a `model` or `model_init` argument") + + if self.args.should_save: + os.makedirs(self.args.output_dir, exist_ok=True) + + default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding( + tokenizer) + + self.data_collator = data_collator if data_collator is not None else default_collator + self.train_dataset = train_dataset + self.eval_dataset = eval_dataset + self.tokenizer = tokenizer + + self.model_wrapped = model + self.model = model + self.criterion = criterion + + self.compute_metrics = compute_metrics + self.optimizer, self.lr_scheduler = optimizers + + self.state = TrainerState() + self.control = TrainerControl() + + callbacks = DEFAULT_CALLBACKS + self.callback_handler = CallbackHandler(callbacks, self.model, + self.tokenizer, self.optimizer, + self.lr_scheduler) + + self.add_callback(ProgressCallback) + + if args.max_steps > 0: + logger.info( + "max_steps is given, it will override any value given in num_train_epochs" + ) + + if train_dataset is not None and not isinstance( + train_dataset, collections.abc.Sized) and args.max_steps <= 0: + raise ValueError( + "train_dataset does not implement __len__, max_steps has to be specified" + ) + + if args.fp16: + self.scaler = paddle.amp.GradScaler( + init_loss_scaling=self.args.scale_loss) + logger.info(f"Using half precision") + + default_label_names = (["start_positions", "end_positions"] if + "QusetionAnswering" in type(self.model).__name__ + else ["labels"]) + self.label_names = default_label_names if self.args.label_names is None else self.args.label_names + + self.control = self.callback_handler.on_init_end(self.args, self.state, + self.control) + self.print_config() + + def add_callback(self, callback): + """ + Add a callback to the current list of [`~TrainerCallback`]. + + Args: + callback (`type` or [`~TrainerCallback`]): + A [`~TrainerCallback`] class or an instance of a [`~TrainerCallback`]. In the + first case, will instantiate a member of that class. + """ + self.callback_handler.add_callback(callback) + + def pop_callback(self, callback): + """ + Remove a callback from the current list of [`~transformer.TrainerCallback`] and returns it. + If the callback is not found, returns `None` (and no error is raised). + Args: + callback (`type` or [`~transformer.TrainerCallback`]): + A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the + first case, will pop the first member of that class found in the list of callbacks. + Returns: + [`~transformer.TrainerCallback`]: The callback removed, if found. + """ + return self.callback_handler.pop_callback(callback) + + def remove_callback(self, callback): + """ + Remove a callback from the current list of [`~transformer.TrainerCallback`]. + Args: + callback (`type` or [`~transformer.TrainerCallback`]): + A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the + first case, will remove the first member of that class found in the list of callbacks. + """ + self.callback_handler.remove_callback(callback) + + def train( + self, + resume_from_checkpoint: Optional[Union[str, bool]]=None, + ignore_keys_for_eval: Optional[List[str]]=None, + **kwargs, ): + + args = self.args + resume_from_checkpoint = None if not resume_from_checkpoint else resume_from_checkpoint + + model_reloaded = False + + # Load potential model checkpoint + if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint: + resume_from_checkpoint = get_last_checkpoint(args.output_dir) + if resume_from_checkpoint is None: + raise ValueError( + f"No valid checkpoint found in output directory ({args.output_dir})" + ) + + if resume_from_checkpoint is not None: + if not os.path.isfile( + os.path.join(resume_from_checkpoint, WEIGHTS_NAME)): + raise ValueError( + f"Can't find a valid checkpoint at {resume_from_checkpoint}") + + logger.info(f"Loading model from {resume_from_checkpoint} .") + + # We load the model state dict on the CPU to avoid an OOM error. + state_dict = paddle.load( + os.path.join(resume_from_checkpoint, WEIGHTS_NAME)) + # If the model is on the GPU, it still works! + self._set_state_dict_in_model(state_dict) + + # release memory + del state_dict + + train_dataloader = self.get_train_dataloader() + model = self._wrap_model(self.model_wrapped) + + self.state = TrainerState() + + total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.world_size + + num_update_steps_per_epoch = len( + train_dataloader) // args.gradient_accumulation_steps + num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1) + + if args.max_steps > 0: + args.num_training_steps = args.max_steps + num_train_epochs = args.max_steps // num_update_steps_per_epoch + int( + args.max_steps % num_update_steps_per_epoch > 0) + num_train_samples = args.max_steps * total_train_batch_size + else: + args.num_training_steps = num_update_steps_per_epoch * args.num_train_epochs + num_train_epochs = math.ceil(args.num_train_epochs) + num_train_samples = len(self.train_dataset) * args.num_train_epochs + + if args.minimum_eval_times is not None and args.minimum_eval_times > 0: + if args.num_training_steps // args.eval_steps < args.minimum_eval_times: + exp_step = args.num_training_steps / args.minimum_eval_times + exp_step = max(int(exp_step - exp_step % 10), 10) + logger.info("Reset eval step by minimum_eval_times to %d" % + exp_step) + args.eval_steps = exp_step + + self.create_optimizer_and_scheduler( + num_training_steps=args.num_training_steps) + + # Check if saved optimizer or scheduler states exist + self._load_optimizer_and_scheduler(resume_from_checkpoint) + + num_examples = len(self.train_dataset) + + logger.info("***** Running training *****") + logger.info(f" Num examples = {num_examples}") + logger.info(f" Num Epochs = {num_train_epochs}") + logger.info( + f" Instantaneous batch size per device = {args.per_device_train_batch_size}" + ) + logger.info( + f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}" + ) + logger.info( + f" Gradient Accumulation steps = {args.gradient_accumulation_steps}" + ) + logger.info(f" Total optimization steps = {args.num_training_steps}") + logger.info(f" Total num train samples = {num_train_samples}") + + self.state.epoch = 0 + self.state.max_steps = int(args.num_training_steps) + self.state.num_train_epochs = num_train_epochs + self.state.is_local_process_zero = self.is_local_process_zero() + self.state.is_world_process_zero = self.is_world_process_zero() + + start_time = time.time() + epochs_trained = 0 + steps_trained_in_current_epoch = 0 + steps_trained_progress_bar = None + + epoch_iterator = train_dataloader + steps_in_epoch = len(epoch_iterator) + + self.callback_handler.model = self.model + self.callback_handler.optimizer = self.optimizer + self.callback_handler.lr_scheduler = self.lr_scheduler + self.callback_handler.train_dataloader = train_dataloader + + self.control = self.callback_handler.on_train_begin(args, self.state, + self.control) + + tr_loss = paddle.to_tensor(0.0) + self._total_loss_scalar = 0.0 + self._globalstep_last_logged = self.state.global_step + + for epoch in range(epochs_trained, num_train_epochs): + step = -1 + + self.control = self.callback_handler.on_epoch_begin( + args, self.state, self.control) + + for step, inputs in enumerate(epoch_iterator): + + if step % args.gradient_accumulation_steps == 0: + self.control = self.callback_handler.on_step_begin( + args, self.state, self.control) + + if (((step + 1) % args.gradient_accumulation_steps != 0) and + args.local_rank != -1 and + args._no_sync_in_gradient_accumulation): + # Avoid unnecessary DDP synchronization since there will be no backward pass on this example. + with model.no_sync(): + tr_loss_step = self.training_step(model, inputs) + else: + tr_loss_step = self.training_step(model, inputs) + + tr_loss += tr_loss_step + + if (step + 1) % args.gradient_accumulation_steps == 0 or ( + # last step in epoch but step is always smaller than gradient_accumulation_steps + steps_in_epoch <= args.gradient_accumulation_steps and + (step + 1) == steps_in_epoch): + if self.do_grad_scaling: + self.scaler.minimize(self.optimizer, tr_loss) + else: + self.optimizer.step() + + self.lr_scheduler.step() + self.optimizer.clear_grad() + + self.state.global_step += 1 + self.state.epoch = epoch + (step + 1) / steps_in_epoch + + self.control = self.callback_handler.on_step_end( + args, self.state, self.control) + + self._maybe_log_save_evaluate(tr_loss, model, epoch, + ignore_keys_for_eval) + else: + self.control = self.callback_handler.on_substep_end( + args, self.state, self.control) + + if self.control.should_epoch_stop or self.control.should_training_stop: + break + + if step < 0: + logger.warning( + f"There seems to be not a single sample in your epoch_iterator, stopping training at step" + f" {self.state.global_step}! This is expected if you're using an IterableDataset and set" + f" num_steps ({max_steps}) higher than the number of available samples." + ) + self.control.should_training_stop = True + + self.control = self.callback_handler.on_epoch_end(args, self.state, + self.control) + self._maybe_log_save_evaluate(tr_loss, model, epoch, + ignore_keys_for_eval) + + if self.control.should_training_stop: + break + + if args.past_index and hasattr(self, "_past"): + # Clean the state at the end of training + delattr(self, "_past") + + logger.info("\nTraining completed. \n") + if args.load_best_model_at_end and self.state.best_model_checkpoint is not None: + if args.local_rank != -1: + dist.barrier() + + logger.info( + f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric})." + ) + + best_model_path = os.path.join(self.state.best_model_checkpoint, + WEIGHTS_NAME) + if os.path.exists(best_model_path): + # We load the model state dict on the CPU to avoid an OOM error. + state_dict = paddle.load(best_model_path) + # If the model is on the GPU, it still works! + self._set_state_dict_in_model(state_dict) + else: + logger.warning( + f"Could not locate the best model at {best_model_path}, if you are running a distributed training " + "on multiple nodes, you should activate `--save_on_each_node`." + ) + + self._total_loss_scalar += tr_loss.item() + train_loss = self._total_loss_scalar / self.state.global_step + + metrics = speed_metrics( + "train", + start_time, + num_samples=num_train_samples, + num_steps=self.state.max_steps) + + metrics["train_loss"] = train_loss + + self.is_in_train = False + + self.log(metrics) + + self.control = self.callback_handler.on_train_end(args, self.state, + self.control) + + return TrainOutput(self.state.global_step, train_loss, metrics) + + def training_step( + self, model: nn.Layer, + inputs: Dict[str, Union[paddle.Tensor, Any]]) -> paddle.Tensor: + model.train() + inputs = self._prepare_inputs(inputs) + + with self.autocast_smart_context_manager(): + loss = self.compute_loss(model, inputs) + + loss.backward() + + return loss.detach() + + def _get_train_sampler(self) -> Optional[paddle.io.Sampler]: + if not isinstance(self.train_dataset, collections.abc.Sized): + return None + + if self.args.world_size <= 1: + return paddle.io.BatchSampler( + dataset=self.train_dataset, + shuffle=True, + batch_size=self.args.per_device_train_batch_size, + drop_last=self.args.dataloader_drop_last) + else: + return DistributedBatchSampler( + self.train_dataset, + batch_size=self.args.per_device_train_batch_size, + shuffle=True, + num_replicas=self.args.world_size, + rank=self.args.process_index, + drop_last=self.args.dataloader_drop_last) + + def _set_state_dict_in_model(self, state_dict): + load_result = self.model.set_state_dict(state_dict) + + def _maybe_log_save_evaluate(self, tr_loss, model, epoch, + ignore_keys_for_eval): + if self.control.should_log: + + logs: Dict[str, float] = {} + + # all_gather + mean() to get average loss over all processes + tr_loss_scalar = self._nested_gather(tr_loss).mean().item() + + # reset tr_loss to zero + tr_loss.subtract_(tr_loss) + + logs["loss"] = round(tr_loss_scalar / ( + self.state.global_step - self._globalstep_last_logged), 4) + logs["learning_rate"] = self._get_learning_rate() + logs["global_step"] = int(self.state.global_step) + + self._total_loss_scalar += tr_loss_scalar + self._globalstep_last_logged = self.state.global_step + + self.log(logs) + + metrics = None + if self.control.should_evaluate: + metrics = self.evaluate(ignore_keys=ignore_keys_for_eval) + + if self.control.should_save: + self._save_checkpoint(model, metrics=metrics) + self.control = self.callback_handler.on_save(self.args, self.state, + self.control) + + def _get_learning_rate(self): + return self.optimizer.get_lr() + + def get_train_dataloader(self): + """ + Returns the training [`~paddle.io.DataLoader`]. + + Will use no sampler if `self.train_dataset` does not implement `__len__`, a random sampler (adapted to + distributed training if necessary) otherwise. + + Subclass and override this method if you want to inject some custom behavior. + """ + if self.train_dataset is None: + raise ValueError("Trainer: training requires a train_dataset.") + + train_dataset = self.train_dataset + + train_sampler = self._get_train_sampler() + + return DataLoader( + train_dataset, + batch_sampler=train_sampler, + collate_fn=self.data_collator, + num_workers=self.args.dataloader_num_workers, ) + + def _get_eval_sampler(self, eval_dataset: Dataset): + if self.args.world_size <= 1: + return paddle.io.BatchSampler( + eval_dataset, + batch_size=self.args.eval_batch_size, + shuffle=False, + drop_last=False, ) + else: + return DistributedBatchSampler( + eval_dataset, + num_replicas=self.args.world_size, + rank=self.args.process_index, + batch_size=self.args.eval_batch_size, + shuffle=False, + drop_last=False, ) + + def get_eval_dataloader(self, + eval_dataset: Optional[Dataset]=None) -> DataLoader: + """ + Returns the evaluation [`~paddle.io.DataLoader`]. + + Subclass and override this method if you want to inject some custom behavior. + + Args: + eval_dataset (`paddle.io.Dataset`, *optional*): + If provided, will override `self.eval_dataset`. If it is an `datasets.Dataset`, columns not accepted by + the `model.forward()` method are automatically removed. It must implement `__len__`. + """ + if eval_dataset is None and self.eval_dataset is None: + raise ValueError("Trainer: evaluation requires an eval_dataset.") + eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset + + eval_sampler = self._get_eval_sampler(eval_dataset) + + return DataLoader( + eval_dataset, + batch_sampler=eval_sampler, + collate_fn=self.data_collator, + num_workers=self.args.dataloader_num_workers, ) + + def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: + """ + Returns the test [`~paddle.io.DataLoader`]. + + Subclass and override this method if you want to inject some custom behavior. + + Args: + test_dataset (`paddle.io.Dataset`, *optional*): + The test dataset to use. If it is an `datasets.Dataset`, columns not accepted by the `model.forward()` + method are automatically removed. It must implement `__len__`. + """ + + test_sampler = self._get_eval_sampler(test_dataset) + + # We use the same batch_size as for eval. + return DataLoader( + test_dataset, + batch_sampler=test_sampler, + collate_fn=self.data_collator, + drop_last=self.args.dataloader_drop_last, ) + + def create_optimizer_and_scheduler(self, num_training_steps: int): + """ + Setup the optimizer and the learning rate scheduler. + + We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the + Trainer's init through `optimizers`, or subclass and override this method (or `create_optimizer` and/or + `create_scheduler`) in a subclass. + """ + self.create_scheduler(num_training_steps=num_training_steps) + self.create_optimizer(self.lr_scheduler) + + def create_optimizer(self, lr_scheduler=None): + """ + Setup the optimizer. + + We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the + Trainer's init through `optimizers`, or subclass and override this method in a subclass. + """ + if self.optimizer is None: + decay_parameters = [ + p.name for n, p in self.model.named_parameters() + if not any(nd in n for nd in ["bias", "norm"]) + ] + apply_decay_param_fun = lambda x: x in decay_parameters + + optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs( + self.args) + + self.optimizer = optimizer_cls( + learning_rate=self.lr_scheduler + if lr_scheduler is None else lr_scheduler, + apply_decay_param_fun=apply_decay_param_fun, + parameters=self.model.parameters(), + weight_decay=self.args.weight_decay, + grad_clip=nn.ClipGradByGlobalNorm(self.args.max_grad_norm), + **optimizer_kwargs) + + return self.optimizer + + @staticmethod + def get_optimizer_cls_and_kwargs( + args: TrainingArguments) -> Tuple[Any, Any]: + """ + Returns the optimizer class and optimizer parameters based on the training arguments. + + Args: + args (`paddlenlp.training_args.TrainingArguments`): + The training arguments for the training session. + + """ + # optimizer_kwargs = {"lr": args.learning_rate} + optimizer_kwargs = {} + adam_kwargs = { + "beta1": args.adam_beta1, + "beta2": args.adam_beta2, + "epsilon": args.adam_epsilon, + } + if args.optim == OptimizerNames.ADAMW: + from paddle.optimizer import AdamW + + optimizer_cls = AdamW + optimizer_kwargs.update(adam_kwargs) + else: + raise ValueError( + f"Trainer cannot instantiate unsupported optimizer: {args.optim}" + ) + return optimizer_cls, optimizer_kwargs + + def create_scheduler(self, + num_training_steps: int, + optimizer: paddle.optimizer.Optimizer=None): + """ + Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or + passed as an argument. + + Args: + num_training_steps (int): The number of training steps to do. + """ + + def get_scheduler(lr_scheduler_type, learning_rate, num_warmup_steps, + num_training_steps): + # TODO @ZHUI support others + return LinearDecayWithWarmup(learning_rate, num_training_steps, + num_warmup_steps) + + warmup = self.args.warmup_steps if self.args.warmup_steps > 0 else self.args.warmup_ratio + + if self.lr_scheduler is None: + self.lr_scheduler = get_scheduler( + self.args.lr_scheduler_type, + learning_rate=self.args.learning_rate, + num_warmup_steps=warmup, + num_training_steps=num_training_steps, ) + + return self.lr_scheduler + + def _wrap_model(self, model, training=True): + # train/eval could be run multiple-times - if already wrapped, don't re-wrap it again + if unwrap_model(model) is not model: + return model + + # Note: in paddle.distributed mode, there's no point in wrapping the model + # inside a DistributedDataParallel as we'll be under `no_grad` anyways. + if not training: + return model + + return model + + def _prepare_input( + self, data: Union[paddle.Tensor, Any]) -> Union[paddle.Tensor, Any]: + """ + Prepares one `data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors. + """ + if isinstance(data, Mapping): + return type(data)( + {k: self._prepare_input(v) + for k, v in data.items()}) + elif isinstance(data, (tuple, list)): + return type(data)(self._prepare_input(v) for v in data) + elif isinstance(data, paddle.Tensor): + kwargs = dict(device=self.args.current_device) + # update data type for pure fp16 + return data + # return data.to(**kwargs) + return data + + def _prepare_inputs(self, inputs: Dict[str, Union[paddle.Tensor, Any]] + ) -> Dict[str, Union[paddle.Tensor, Any]]: + """ + Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and + handling potential state. + """ + inputs = self._prepare_input(inputs) + if self.args.past_index >= 0 and self._past is not None: + inputs["mems"] = self._past + + return inputs + + def autocast_smart_context_manager(self): + """ + A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired + arguments, depending on the situation. + """ + if self.args.fp16: + ctx_manager = autocast( + True, + custom_black_list=[ + "reduce_sum", "c_softmax_with_cross_entropy", + "elementwise_div" + ], + level=self.args.fp16_opt_level) + else: + ctx_manager = contextlib.nullcontext() if sys.version_info >= ( + 3, 7) else contextlib.suppress() + + return ctx_manager + + def compute_loss(self, model, inputs, return_outputs=False): + """ + How the loss is computed by Trainer. By default, all models return the loss in the first element. + Subclass and override for custom behavior. + """ + if self.criterion is not None and "labels" in inputs: + labels = inputs.pop("labels") + else: + labels = None + + if self.criterion is not None and "start_positions" in inputs and "end_positions" in inputs: + labels = (inputs.pop("start_positions"), + inputs.pop("end_positions")) + else: + labels = None + + outputs = model(**inputs) + + if self.criterion is not None: + loss = self.criterion(outputs, labels) + outputs = (loss, outputs) + + # Save past state if it exists + # TODO: this needs to be fixed and made cleaner later. + if self.args.past_index >= 0: + self._past = outputs[self.args.past_index] + + # We don't use .loss here since the model may return tuples instead of ModelOutput. + loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] + + return (loss, outputs) if return_outputs else loss + + def training_step( + self, model: nn.Layer, + inputs: Dict[str, Union[paddle.Tensor, Any]]) -> paddle.Tensor: + """ + Perform a training step on a batch of inputs. + + Subclass and override to inject custom behavior. + + Args: + model (`nn.Layer`): + The model to train. + inputs (`Dict[str, Union[paddle.Tensor, Any]]`): + The inputs and targets of the model. + + The dictionary will be unpacked before being fed to the model. Most models expect the targets under the + argument `labels`. Check your model's documentation for all accepted arguments. + + Return: + `paddle.Tensor`: The tensor with training loss on this batch. + """ + model.train() + inputs = self._prepare_inputs(inputs) + + with self.autocast_smart_context_manager(): + loss = self.compute_loss(model, inputs) + + if self.args.gradient_accumulation_steps > 1: + loss = loss / self.args.gradient_accumulation_steps + + if self.do_grad_scaling: + self.scaler.scale(loss).backward() + else: + loss.backward() + + return loss.detach() + + def save_model(self, output_dir: Optional[str]=None): + """ + Will save the model, so you can reload it using `from_pretrained()`. + + Will only save from the main process. + """ + + if output_dir is None: + output_dir = self.args.output_dir + + if self.args.should_save: + self._save(output_dir) + + def export_model(self, + input_spec=None, + load_best_model=False, + output_dir: Optional[str]=None): + + if output_dir is None: + output_dir = self.args.output_dir + + if load_best_model and self.state.best_model_checkpoint is not None: + if self.args.local_rank != -1: + dist.barrier() + + logger.info( + f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric})." + ) + + best_model_path = os.path.join(self.state.best_model_checkpoint, + WEIGHTS_NAME) + if os.path.exists(best_model_path): + # We load the model state dict on the CPU to avoid an OOM error. + state_dict = paddle.load(best_model_path) + # If the model is on the GPU, it still works! + self._set_state_dict_in_model(state_dict) + else: + logger.warning( + f"Could not locate the best model at {best_model_path}, if you are running a distributed training " + "on multiple nodes, you should activate `--save_on_each_node`." + ) + + model = unwrap_model(self.model) + model.eval() + + # Convert to static graph with specific input description + model = paddle.jit.to_static(model, input_spec=input_spec) + + # Save in static graph model. + save_path = os.path.join(output_dir, "inference", "infer") + logger.info("Exporting inference model to %s" % save_path) + paddle.jit.save(model, save_path) + logger.info("Inference model exported.") + + def _save_checkpoint(self, model, metrics=None): + # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model" + + # Save model checkpoint + checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" + + run_dir = self.args.output_dir + + output_dir = os.path.join(run_dir, checkpoint_folder) + + self.save_model(output_dir) + + if self.args.should_save: + paddle.save(self.optimizer.state_dict(), + os.path.join(output_dir, OPTIMIZER_NAME)) + with warnings.catch_warnings(record=True) as caught_warnings: + paddle.save(self.lr_scheduler.state_dict(), + os.path.join(output_dir, SCHEDULER_NAME)) + if self.do_grad_scaling: + paddle.save(self.scaler.state_dict(), + os.path.join(output_dir, SCALER_NAME)) + + # Determine the new best metric / best model checkpoint + if metrics is not None and self.args.metric_for_best_model is not None: + metric_to_check = self.args.metric_for_best_model + if not metric_to_check.startswith("eval_"): + metric_to_check = f"eval_{metric_to_check}" + metric_value = metrics[metric_to_check] + + operator = np.greater if self.args.greater_is_better else np.less + if (self.state.best_metric is None or + self.state.best_model_checkpoint is None or + operator(metric_value, self.state.best_metric)): + self.state.best_metric = metric_value + self.state.best_model_checkpoint = output_dir + + # Save the Trainer state + if self.args.should_save: + self.state.save_to_json( + os.path.join(output_dir, TRAINER_STATE_NAME)) + + # Save RNG state in non-distributed training + rng_states = { + "python": random.getstate(), + "numpy": np.random.get_state(), + } + + # TODO: ZHUI save paddle, cudnn seed. + + # A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may + # not yet exist. + os.makedirs(output_dir, exist_ok=True) + local_rank = self.args.local_rank + + if local_rank == -1: + paddle.save(rng_states, os.path.join(output_dir, "rng_state.pth")) + else: + paddle.save(rng_states, + os.path.join(output_dir, f"rng_state_{local_rank}.pth")) + + # Maybe delete some older checkpoints. + if self.args.should_save: + self._rotate_checkpoints(use_mtime=True, output_dir=run_dir) + + def _sorted_checkpoints(self, + output_dir=None, + checkpoint_prefix=PREFIX_CHECKPOINT_DIR, + use_mtime=False) -> List[str]: + ordering_and_checkpoint_path = [] + + glob_checkpoints = [ + str(x) for x in Path(output_dir).glob(f"{checkpoint_prefix}-*") + ] + + for path in glob_checkpoints: + if use_mtime: + ordering_and_checkpoint_path.append( + (os.path.getmtime(path), path)) + else: + regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path) + if regex_match is not None and regex_match.groups() is not None: + ordering_and_checkpoint_path.append( + (int(regex_match.groups()[0]), path)) + + checkpoints_sorted = sorted(ordering_and_checkpoint_path) + checkpoints_sorted = [ + checkpoint[1] for checkpoint in checkpoints_sorted + ] + # Make sure we don't delete the best model. + if self.state.best_model_checkpoint is not None: + best_model_index = checkpoints_sorted.index( + str(Path(self.state.best_model_checkpoint))) + for i in range(best_model_index, len(checkpoints_sorted) - 2): + checkpoints_sorted[i], checkpoints_sorted[ + i + 1] = checkpoints_sorted[i + 1], checkpoints_sorted[i] + return checkpoints_sorted + + def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None: + if self.args.save_total_limit is None or self.args.save_total_limit <= 0: + return + + # Check if we should delete older checkpoint(s) + checkpoints_sorted = self._sorted_checkpoints( + use_mtime=use_mtime, output_dir=output_dir) + if len(checkpoints_sorted) <= self.args.save_total_limit: + return + + # If save_total_limit=1 with load_best_model_at_end=True, we could end up deleting the last checkpoint, which + # we don't do to allow resuming. + save_total_limit = self.args.save_total_limit + if (self.state.best_model_checkpoint is not None and + self.args.save_total_limit == 1 and + checkpoints_sorted[-1] != self.state.best_model_checkpoint): + save_total_limit = 2 + + number_of_checkpoints_to_delete = max( + 0, len(checkpoints_sorted) - save_total_limit) + checkpoints_to_be_deleted = checkpoints_sorted[: + number_of_checkpoints_to_delete] + for checkpoint in checkpoints_to_be_deleted: + logger.info( + f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit" + ) + shutil.rmtree(checkpoint) + + def _save(self, output_dir: Optional[str]=None, state_dict=None): + # If we are executing this function, we are the process zero, so we don't check for that. + output_dir = output_dir if output_dir is not None else self.args.output_dir + os.makedirs(output_dir, exist_ok=True) + logger.info(f"Saving model checkpoint to {output_dir}") + # Save a trained model and configuration using `save_pretrained()`. + # They can then be reloaded using `from_pretrained()` + if not isinstance(self.model, PretrainedModel): + if isinstance(unwrap_model(self.model), PretrainedModel): + if state_dict is None: + state_dict = self.model.state_dict() + # unwrap_model(self.model).save_pretrained( + # output_dir, state_dict=state_dict) + unwrap_model(self.model).save_pretrained(output_dir) + else: + logger.info( + "Trainer.model is not a `PretrainedModel`, only saving its state dict." + ) + if state_dict is None: + state_dict = self.model.state_dict() + paddle.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) + else: + self.model.save_pretrained(output_dir) + if self.tokenizer is not None: + self.tokenizer.save_pretrained(output_dir) + + # Good practice: save your training arguments together with the trained model + paddle.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) + + def _load_optimizer_and_scheduler(self, checkpoint): + """If optimizer and scheduler states exist, load them.""" + if checkpoint is None: + return + + if os.path.isfile(os.path.join( + checkpoint, OPTIMIZER_NAME)) and os.path.isfile( + os.path.join(checkpoint, SCHEDULER_NAME)): + # Load in optimizer and scheduler states + self.optimizer.set_state_dict( + paddle.load(os.path.join(checkpoint, OPTIMIZER_NAME))) + self.lr_scheduler.set_state_dict( + paddle.load(os.path.join(checkpoint, SCHEDULER_NAME))) + if self.do_grad_scaling and os.path.isfile( + os.path.join(checkpoint, SCALER_NAME)): + self.scaler.load_state_dict( + paddle.load( + os.path.join(checkpoint, SCALER_NAME), + return_numpy=True)) + + def log(self, logs: Dict[str, float]) -> None: + """ + Log `logs` on the various objects watching training. + + Subclass and override this method to inject custom behavior. + + Args: + logs (`Dict[str, float]`): + The values to log. + """ + if self.state.epoch is not None: + logs["epoch"] = round(self.state.epoch, 2) + + output = { ** logs, ** {"step": self.state.global_step}} + self.state.log_history.append(output) + self.control = self.callback_handler.on_log(self.args, self.state, + self.control, logs) + + def evaluate( + self, + eval_dataset: Optional[Dataset]=None, + ignore_keys: Optional[List[str]]=None, + metric_key_prefix: str="eval", ) -> Dict[str, float]: + """ + Run evaluation and returns metrics. + + The calling script will be responsible for providing a method to compute metrics, as they are task-dependent + (pass it to the init `compute_metrics` argument). + + You can also subclass and override this method to inject custom behavior. + + Args: + eval_dataset (`Dataset`, *optional*): + Pass a dataset if you wish to override `self.eval_dataset`. If it is an `datasets.Dataset`, columns not + accepted by the `model.forward()` method are automatically removed. It must implement the `__len__` + method. + ignore_keys (`Lst[str]`, *optional*): + A list of keys in the output of your model (if it is a dictionary) that should be ignored when + gathering predictions. + metric_key_prefix (`str`, *optional*, defaults to `"eval"`): + An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named + "eval_bleu" if the prefix is "eval" (default) + + Returns: + A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The + dictionary also contains the epoch number which comes from the training state. + """ + eval_dataloader = self.get_eval_dataloader(eval_dataset) + start_time = time.time() + + output = self.evaluation_loop( + eval_dataloader, + description="Evaluation", + # No point gathering the predictions if there are no metrics, otherwise we defer to + # self.args.prediction_loss_only + prediction_loss_only=True if self.compute_metrics is None else None, + ignore_keys=ignore_keys, + metric_key_prefix=metric_key_prefix, ) + + total_batch_size = self.args.eval_batch_size * self.args.world_size + output.metrics.update( + speed_metrics( + metric_key_prefix, + start_time, + num_samples=output.num_samples, + num_steps=math.ceil(output.num_samples / total_batch_size), )) + + self.log(output.metrics) + + self.control = self.callback_handler.on_evaluate( + self.args, self.state, self.control, output.metrics) + + return output.metrics + + def evaluation_loop( + self, + dataloader: DataLoader, + description: str, + prediction_loss_only: Optional[bool]=None, + ignore_keys: Optional[List[str]]=None, + metric_key_prefix: str="eval", ) -> EvalLoopOutput: + """ + Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`. + + Works both with or without labels. + """ + args = self.args + + prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only + prediction_loss_only = False + + model = self._wrap_model(self.model, training=False) + + batch_size = dataloader.batch_sampler.batch_size + num_samples = self.num_examples(dataloader) + logger.info(f"***** Running {description} *****") + logger.info(f" Num examples = {num_samples}") + logger.info(f" Pre device batch size = {batch_size}") + logger.info(f" Total Batch size = {batch_size * self.args.world_size}") + logger.info(f" Total prediction steps = {len(dataloader)}") + + model.eval() + + self.callback_handler.eval_dataloader = dataloader + # Do this before wrapping. + eval_dataset = dataloader.dataset + + if args.past_index >= 0: + self._past = None + + # Initialize containers + # losses/preds/labels on GPU (accumulated for eval_accumulation_steps) + losses_host = None + preds_host = None + labels_host = None + # losses/preds/labels on CPU (final containers) + all_losses = None + all_preds = None + all_labels = None + # Will be useful when we have an iterable dataset so don't know its length. + + observed_num_examples = 0 + # Main evaluation loop + losses = [] + for step, inputs in enumerate(dataloader): + # Update the observed num examples + # Prediction step + loss, logits, labels = self.prediction_step( + model, inputs, prediction_loss_only, ignore_keys=ignore_keys) + + # Update containers on host + if loss is not None: + # losses = self._nested_gather(loss.repeat(batch_size)) + losses = self._nested_gather( + paddle.tile( + loss, repeat_times=[batch_size, 1])) + losses_host = losses if losses_host is None else paddle.concat( + (losses_host, losses), axis=0) + if labels is not None: + labels = self._pad_across_processes(labels) + labels = self._nested_gather(labels) + labels_host = labels if labels_host is None else nested_concat( + labels_host, labels, padding_index=-100) + if logits is not None: + logits = self._pad_across_processes(logits) + logits = self._nested_gather(logits) + preds_host = logits if preds_host is None else nested_concat( + preds_host, logits, padding_index=-100) + self.control = self.callback_handler.on_prediction_step( + args, self.state, self.control) + + # Gather all remaining tensors and put them back on the CPU + if losses_host is not None: + losses = nested_numpify(losses_host) + all_losses = losses if all_losses is None else np.concatenate( + (all_losses, losses), axis=0) + if preds_host is not None: + logits = nested_numpify(preds_host) + all_preds = logits if all_preds is None else nested_concat( + all_preds, logits, padding_index=-100) + if labels_host is not None: + labels = nested_numpify(labels_host) + all_labels = labels if all_labels is None else nested_concat( + all_labels, labels, padding_index=-100) + + # Number of losses has been rounded to a multiple of batch_size and in a distributed training, the number of + # samplers has been rounded to a multiple of batch_size, so we truncate. + if all_losses is not None: + all_losses = all_losses[:num_samples] + if all_preds is not None: + all_preds = nested_truncate(all_preds, num_samples) + if all_labels is not None: + all_labels = nested_truncate(all_labels, num_samples) + + model.train() + + # Metrics! + if self.compute_metrics is not None and all_preds is not None and all_labels is not None: + metrics = self.compute_metrics( + EvalPrediction( + predictions=all_preds, label_ids=all_labels)) + else: + metrics = {} + + if all_losses is not None: + metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item() + + # Prefix all keys with metric_key_prefix + '_' + for key in list(metrics.keys()): + if not key.startswith(f"{metric_key_prefix}_"): + metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) + + return EvalLoopOutput( + predictions=all_preds, + label_ids=all_labels, + metrics=metrics, + num_samples=num_samples) + + def predict(self, + test_dataset: Dataset, + ignore_keys: Optional[List[str]]=None, + metric_key_prefix: str="test") -> PredictionOutput: + """ + Run prediction and returns predictions and potential metrics. + Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method + will also return metrics, like in `evaluate()`. + Args: + test_dataset (`Dataset`): + Dataset to run the predictions on. If it is an `datasets.Dataset`, columns not accepted by the + `model.forward()` method are automatically removed. Has to implement the method `__len__` + ignore_keys (`Lst[str]`, *optional*): + A list of keys in the output of your model (if it is a dictionary) that should be ignored when + gathering predictions. + metric_key_prefix (`str`, *optional*, defaults to `"test"`): + An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named + "test_bleu" if the prefix is "test" (default) + + If your predictions or labels have different sequence length (for instance because you're doing dynamic padding + in a token classification task) the predictions will be padded (on the right) to allow for concatenation into + one array. The padding index is -100. + + Returns: *NamedTuple* A namedtuple with the following keys: + - predictions (`np.ndarray`): The predictions on `test_dataset`. + - label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some). + - metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained + labels). + """ + test_dataloader = self.get_test_dataloader(test_dataset) + start_time = time.time() + + eval_loop = self.evaluation_loop + output = eval_loop( + test_dataloader, + description="Prediction", + ignore_keys=ignore_keys, + metric_key_prefix=metric_key_prefix) + total_batch_size = self.args.eval_batch_size * self.args.world_size + output.metrics.update( + speed_metrics( + metric_key_prefix, + start_time, + num_samples=output.num_samples, + num_steps=math.ceil(output.num_samples / total_batch_size), )) + + return PredictionOutput( + predictions=output.predictions, + label_ids=output.label_ids, + metrics=output.metrics) + + def prediction_step( + self, + model: nn.Layer, + inputs: Dict[str, Union[paddle.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[List[str]]=None, ) -> Tuple[Optional[ + paddle.Tensor], Optional[paddle.Tensor], Optional[ + paddle.Tensor]]: + """ + Perform an evaluation step on `model` using `inputs`. + + Subclass and override to inject custom behavior. + + Args: + model (`nn.Layer`): + The model to evaluate. + inputs (`Dict[str, Union[paddle.Tensor, Any]]`): + The inputs and targets of the model. + + The dictionary will be unpacked before being fed to the model. Most models expect the targets under the + argument `labels`. Check your model's documentation for all accepted arguments. + prediction_loss_only (`bool`): + Whether or not to return the loss only. + ignore_keys (`Lst[str]`, *optional*): + A list of keys in the output of your model (if it is a dictionary) that should be ignored when + gathering predictions. + + Return: + Tuple[Optional[paddle.Tensor], Optional[paddle.Tensor], Optional[paddle.Tensor]]: A tuple with the loss, + logits and labels (each being optional). + """ + has_labels = all(inputs.get(k) is not None for k in self.label_names) + inputs = self._prepare_inputs(inputs) + if ignore_keys is None: + if hasattr(self.model, "config"): + ignore_keys = getattr(self.model.config, + "keys_to_ignore_at_inference", []) + else: + ignore_keys = [] + + # labels may be popped when computing the loss (label smoothing for instance) so we grab them first. + if has_labels: + labels = nested_detach( + tuple(inputs.get(name) for name in self.label_names)) + if len(labels) == 1: + labels = labels[0] + else: + labels = None + + with paddle.no_grad(): + if has_labels: + with self.autocast_smart_context_manager(): + loss, outputs = self.compute_loss( + model, inputs, return_outputs=True) + loss = loss.mean().detach() + + if isinstance(outputs, dict): + logits = tuple(v for k, v in outputs.items() + if k not in ignore_keys + ["loss"]) + else: + logits = outputs[1:] + else: + loss = None + with self.autocast_smart_context_manager(): + outputs = model(**inputs) + if isinstance(outputs, dict): + logits = tuple(v for k, v in outputs.items() + if k not in ignore_keys) + else: + logits = outputs + # TODO: this needs to be fixed and made cleaner later. + if self.args.past_index >= 0: + self._past = outputs[self.args.past_index - 1] + + if prediction_loss_only: + return (loss, None, None) + + logits = nested_detach(logits) + if len(logits) == 1: + logits = logits[0] + + return (loss, logits, labels) + + def num_examples(self, dataloader: DataLoader) -> int: + """ + Helper to get number of samples in a [`~paddle.io.DataLoader`] by accessing its dataset. + + Will raise an exception if the underlying dataset does not implement method `__len__` + """ + return len(dataloader.dataset) + + def is_local_process_zero(self) -> bool: + """ + Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on several + machines) main process. + """ + return self.args.local_process_index == 0 + + def is_world_process_zero(self) -> bool: + """ + Whether or not this process is the global main process (when training in a distributed fashion on several + machines, this is only going to be `True` for one process). + """ + return self.args.process_index == 0 + + def _nested_gather(self, tensors, name=None): + """ + Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before + concatenating them to `gathered` + """ + if tensors is None: + return + if self.args.local_rank != -1: + tensors = distributed_concat(tensors) + return tensors + + # Copied from Accelerate. + def _pad_across_processes(self, tensor, pad_index=-100): + """ + Recursively pad the tensors in a nested list/tuple/dictionary of tensors from all devices to the same size so + they can safely be gathered. + """ + if isinstance(tensor, (list, tuple)): + return type(tensor)(self._pad_across_processes( + t, pad_index=pad_index) for t in tensor) + elif isinstance(tensor, dict): + return type(tensor)({ + k: self._pad_across_processes( + v, pad_index=pad_index) + for k, v in tensor.items() + }) + elif not isinstance(tensor, paddle.Tensor): + raise TypeError( + f"Can't pad the values of type {type(tensor)}, only of nested list/tuple/dicts of tensors." + ) + + if len(tensor.shape) < 2: + return tensor + # Gather all sizes + size = paddle.to_tensor(tensor.shape)[None] + sizes = self._nested_gather(size).cpu() + + max_size = max(s[1] for s in sizes) + if tensor.shape[1] == max_size: + return tensor + + # Then pad to the maximum size + old_size = tensor.shape + new_size = list(old_size) + new_size[1] = max_size + # new_tensor = tensor.new_zeros(tuple(new_size)) + pad_index + new_tensor = paddle.zeros( + tuple(new_size), dtype=tensor.dtype) + pad_index + new_tensor[:, :old_size[1]] = tensor + return new_tensor + + def print_config(self, args=None, key=""): + """ + """ + logger.info("=" * 60) + if args is None: + args = self.args + key = "Training" + + logger.info('{:^40}'.format("{} Configuration Arguments".format(key))) + logger.info('{:30}:{}'.format("paddle commit id", + paddle.version.commit)) + + for a in dir(args): + if (a[:2] != "__"): #don't print double underscore methods + v = getattr(args, a) + if not isinstance(v, types.MethodType): + logger.info('{:30}:{}'.format(a, v)) + + logger.info("") diff --git a/paddlenlp/trainer/trainer_callback.py b/paddlenlp/trainer/trainer_callback.py new file mode 100644 index 000000000000..79b6ca490a31 --- /dev/null +++ b/paddlenlp/trainer/trainer_callback.py @@ -0,0 +1,667 @@ +# Copyright 2020-present the HuggingFace Inc. team. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Callbacks to use with the Trainer class and customize the training loop. +""" +import dataclasses +import json +from dataclasses import dataclass +from typing import Dict, List, Optional, Union + +import numpy as np +from tqdm.auto import tqdm + +from .trainer_utils import IntervalStrategy, has_length +from .trainer_args import TrainingArguments +from paddlenlp.utils.log import logger + +# logger = logging.get_logger(__name__) + + +@dataclass +class TrainerState: + """ + A class containing the [`Trainer`] inner state that will be saved along the model and optimizer when checkpointing + and passed to the [`TrainerCallback`]. + + + + In all this class, one step is to be understood as one update step. When using gradient accumulation, one update + step may require several forward and backward passes: if you use `gradient_accumulation_steps=n`, then one update + step requires going through *n* batches. + + + + Args: + epoch (`float`, *optional*): + Only set during training, will represent the epoch the training is at (the decimal part being the + percentage of the current epoch completed). + global_step (`int`, *optional*, defaults to 0): + During training, represents the number of update steps completed. + max_steps (`int`, *optional*, defaults to 0): + The number of update steps to do during the current training. + total_flos (`float`, *optional*, defaults to 0): + The total number of floating operations done by the model since the beginning of training (stored as floats + to avoid overflow). + log_history (`List[Dict[str, float]]`, *optional*): + The list of logs done since the beginning of training. + best_metric (`float`, *optional*): + When tracking the best model, the value of the best metric encountered so far. + best_model_checkpoint (`str`, *optional*): + When tracking the best model, the value of the name of the checkpoint for the best model encountered so + far. + is_local_process_zero (`bool`, *optional*, defaults to `True`): + Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on + several machines) main process. + is_world_process_zero (`bool`, *optional*, defaults to `True`): + Whether or not this process is the global main process (when training in a distributed fashion on several + machines, this is only going to be `True` for one process). + is_hyper_param_search (`bool`, *optional*, defaults to `False`): + Whether we are in the process of a hyper parameter search using Trainer.hyperparameter_search. This will + impact the way data will be logged in TensorBoard. + """ + + epoch: Optional[float] = None + global_step: int = 0 + max_steps: int = 0 + num_train_epochs: int = 0 + total_flos: float = 0 + log_history: List[Dict[str, float]] = None + best_metric: Optional[float] = None + best_model_checkpoint: Optional[str] = None + is_local_process_zero: bool = True + is_world_process_zero: bool = True + is_hyper_param_search: bool = False + trial_name: str = None + trial_params: Dict[str, Union[str, float, int, bool]] = None + + def __post_init__(self): + if self.log_history is None: + self.log_history = [] + + def save_to_json(self, json_path: str): + """Save the content of this instance in JSON format inside `json_path`.""" + json_string = json.dumps( + dataclasses.asdict(self), indent=2, sort_keys=True) + "\n" + with open(json_path, "w", encoding="utf-8") as f: + f.write(json_string) + + @classmethod + def load_from_json(cls, json_path: str): + """Create an instance from the content of `json_path`.""" + with open(json_path, "r", encoding="utf-8") as f: + text = f.read() + return cls(**json.loads(text)) + + +@dataclass +class TrainerControl: + """ + A class that handles the [`Trainer`] control flow. This class is used by the [`TrainerCallback`] to activate some + switches in the training loop. + + Args: + should_training_stop (`bool`, *optional*, defaults to `False`): + Whether or not the training should be interrupted. + + If `True`, this variable will not be set back to `False`. The training will just stop. + should_epoch_stop (`bool`, *optional*, defaults to `False`): + Whether or not the current epoch should be interrupted. + + If `True`, this variable will be set back to `False` at the beginning of the next epoch. + should_save (`bool`, *optional*, defaults to `False`): + Whether or not the model should be saved at this step. + + If `True`, this variable will be set back to `False` at the beginning of the next step. + should_evaluate (`bool`, *optional*, defaults to `False`): + Whether or not the model should be evaluated at this step. + + If `True`, this variable will be set back to `False` at the beginning of the next step. + should_log (`bool`, *optional*, defaults to `False`): + Whether or not the logs should be reported at this step. + + If `True`, this variable will be set back to `False` at the beginning of the next step. + """ + + should_training_stop: bool = False + should_epoch_stop: bool = False + should_save: bool = False + should_evaluate: bool = False + should_log: bool = False + + def _new_training(self): + """Internal method that resets the variable for a new training.""" + self.should_training_stop = False + + def _new_epoch(self): + """Internal method that resets the variable for a new epoch.""" + self.should_epoch_stop = False + + def _new_step(self): + """Internal method that resets the variable for a new step.""" + self.should_save = False + self.should_evaluate = False + self.should_log = False + + +class TrainerCallback: + """ + A class for objects that will inspect the state of the training loop at some events and take some decisions. At + each of those events the following arguments are available: + + Args: + args ([`TrainingArguments`]): + The training arguments used to instantiate the [`Trainer`]. + state ([`TrainerState`]): + The current state of the [`Trainer`]. + control ([`TrainerControl`]): + The object that is returned to the [`Trainer`] and can be used to make some decisions. + model ([`PreTrainedModel`] or `paddle.nn.Layer`): + The model being trained. + tokenizer ([`PreTrainedTokenizer`]): + The tokenizer used for encoding the data. + optimizer (`paddle.optimizer.Optimizer`): + The optimizer used for the training steps. + lr_scheduler (`paddle.optimizer.lr.LRScheduler`): + The scheduler used for setting the learning rate. + train_dataloader (`paddle.io.DataLoader`, *optional*): + The current dataloader used for training. + eval_dataloader (`paddle.io.DataLoader`, *optional*): + The current dataloader used for training. + metrics (`Dict[str, float]`): + The metrics computed by the last evaluation phase. + + Those are only accessible in the event `on_evaluate`. + logs (`Dict[str, float]`): + The values to log. + + Those are only accessible in the event `on_log`. + + The `control` object is the only one that can be changed by the callback, in which case the event that changes it + should return the modified version. + + The argument `args`, `state` and `control` are positionals for all events, all the others are grouped in `kwargs`. + You can unpack the ones you need in the signature of the event using them. As an example, see the code of the + simple [`~transformer.PrinterCallback`]. + + Example: + + ```python + class PrinterCallback(TrainerCallback): + def on_log(self, args, state, control, logs=None, **kwargs): + _ = logs.pop("total_flos", None) + if state.is_local_process_zero: + print(logs) + ```""" + + def on_init_end(self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs): + """ + Event called at the end of the initialization of the [`Trainer`]. + """ + pass + + def on_train_begin(self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs): + """ + Event called at the beginning of training. + """ + pass + + def on_train_end(self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs): + """ + Event called at the end of training. + """ + pass + + def on_epoch_begin(self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs): + """ + Event called at the beginning of an epoch. + """ + pass + + def on_epoch_end(self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs): + """ + Event called at the end of an epoch. + """ + pass + + def on_step_begin(self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs): + """ + Event called at the beginning of a training step. If using gradient accumulation, one training step might take + several inputs. + """ + pass + + def on_substep_end(self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs): + """ + Event called at the end of an substep during gradient accumulation. + """ + pass + + def on_step_end(self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs): + """ + Event called at the end of a training step. If using gradient accumulation, one training step might take + several inputs. + """ + pass + + def on_evaluate(self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs): + """ + Event called after an evaluation phase. + """ + pass + + def on_save(self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs): + """ + Event called after a checkpoint save. + """ + pass + + def on_log(self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs): + """ + Event called after logging the last logs. + """ + pass + + def on_prediction_step(self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs): + """ + Event called after a prediction step. + """ + pass + + +class CallbackHandler(TrainerCallback): + """Internal class that just calls the list of callbacks in order.""" + + def __init__(self, callbacks, model, tokenizer, optimizer, lr_scheduler): + self.callbacks = [] + for cb in callbacks: + self.add_callback(cb) + self.model = model + self.tokenizer = tokenizer + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + self.train_dataloader = None + self.eval_dataloader = None + + if not any( + isinstance(cb, DefaultFlowCallback) for cb in self.callbacks): + logger.warning( + "The Trainer will not work properly if you don't have a `DefaultFlowCallback` in its callbacks. You\n" + + + "should add one before training with `trainer.add_callback(DefaultFlowCallback). The current list of" + + "callbacks is\n:" + self.callback_list) + + def add_callback(self, callback): + cb = callback() if isinstance(callback, type) else callback + cb_class = callback if isinstance(callback, + type) else callback.__class__ + if cb_class in [c.__class__ for c in self.callbacks]: + logger.warning( + f"You are adding a {cb_class} to the callbacks of this Trainer, but there is already one. The current" + + "list of callbacks is\n:" + self.callback_list) + self.callbacks.append(cb) + + def pop_callback(self, callback): + if isinstance(callback, type): + for cb in self.callbacks: + if isinstance(cb, callback): + self.callbacks.remove(cb) + return cb + else: + for cb in self.callbacks: + if cb == callback: + self.callbacks.remove(cb) + return cb + + def remove_callback(self, callback): + if isinstance(callback, type): + for cb in self.callbacks: + if isinstance(cb, callback): + self.callbacks.remove(cb) + return + else: + self.callbacks.remove(callback) + + @property + def callback_list(self): + return "\n".join(cb.__class__.__name__ for cb in self.callbacks) + + def on_init_end(self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl): + return self.call_event("on_init_end", args, state, control) + + def on_train_begin(self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl): + control.should_training_stop = False + return self.call_event("on_train_begin", args, state, control) + + def on_train_end(self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl): + return self.call_event("on_train_end", args, state, control) + + def on_epoch_begin(self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl): + control.should_epoch_stop = False + return self.call_event("on_epoch_begin", args, state, control) + + def on_epoch_end(self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl): + return self.call_event("on_epoch_end", args, state, control) + + def on_step_begin(self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl): + control.should_log = False + control.should_evaluate = False + control.should_save = False + return self.call_event("on_step_begin", args, state, control) + + def on_substep_end(self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl): + return self.call_event("on_substep_end", args, state, control) + + def on_step_end(self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl): + return self.call_event("on_step_end", args, state, control) + + def on_evaluate(self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + metrics): + control.should_evaluate = False + return self.call_event( + "on_evaluate", args, state, control, metrics=metrics) + + def on_save(self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl): + control.should_save = False + return self.call_event("on_save", args, state, control) + + def on_log(self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + logs): + control.should_log = False + return self.call_event("on_log", args, state, control, logs=logs) + + def on_prediction_step(self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl): + return self.call_event("on_prediction_step", args, state, control) + + def call_event(self, event, args, state, control, **kwargs): + for callback in self.callbacks: + result = getattr(callback, event)( + args, + state, + control, + model=self.model, + tokenizer=self.tokenizer, + optimizer=self.optimizer, + lr_scheduler=self.lr_scheduler, + train_dataloader=self.train_dataloader, + eval_dataloader=self.eval_dataloader, + **kwargs, ) + # A Callback can skip the return of `control` if it doesn't change it. + if result is not None: + control = result + return control + + +class DefaultFlowCallback(TrainerCallback): + """ + A [`TrainerCallback`] that handles the default flow of the training loop for logs, evaluation and checkpoints. + """ + + def on_step_end(self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs): + # Log + if state.global_step == 1 and args.logging_first_step: + control.should_log = True + if args.logging_strategy == IntervalStrategy.STEPS and state.global_step % args.logging_steps == 0: + control.should_log = True + + # Evaluate + if args.evaluation_strategy == IntervalStrategy.STEPS and state.global_step % args.eval_steps == 0: + control.should_evaluate = True + + # Save + if (args.save_strategy == IntervalStrategy.STEPS and + args.save_steps > 0 and + state.global_step % args.save_steps == 0): + control.should_save = True + + # End training + if state.global_step >= state.max_steps: + control.should_training_stop = True + # Log and save on end + if args.logging_strategy == IntervalStrategy.STEPS and state.global_step >= args.logging_steps: + control.should_log = True + if args.evaluation_strategy == IntervalStrategy.STEPS and state.global_step >= args.eval_steps: + control.should_evaluate = True + if args.save_strategy == IntervalStrategy.STEPS and args.save_steps > 0 and state.global_step >= args.save_steps: + control.should_save = True + + return control + + def on_epoch_end(self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs): + # Log + if args.logging_strategy == IntervalStrategy.EPOCH: + control.should_log = True + + # Evaluate + if args.evaluation_strategy == IntervalStrategy.EPOCH: + control.should_evaluate = True + + # Save + if args.save_strategy == IntervalStrategy.EPOCH: + control.should_save = True + + return control + + +class ProgressCallback(TrainerCallback): + """ + A [`TrainerCallback`] that displays the progress of training or evaluation. + """ + + def __init__(self): + self.training_bar = None + self.prediction_bar = None + + def on_train_begin(self, args, state, control, **kwargs): + if state.is_local_process_zero: + self.training_bar = tqdm(total=state.max_steps) + self.current_step = 0 + + def on_step_end(self, args, state, control, **kwargs): + if state.is_local_process_zero: + self.training_bar.update(state.global_step - self.current_step) + self.current_step = state.global_step + + def on_prediction_step(self, + args, + state, + control, + eval_dataloader=None, + **kwargs): + if state.is_local_process_zero and has_length(eval_dataloader.dataset): + if self.prediction_bar is None: + self.prediction_bar = tqdm( + total=len(eval_dataloader), leave=self.training_bar is None) + self.prediction_bar.update(1) + + def on_evaluate(self, args, state, control, **kwargs): + if state.is_local_process_zero: + if self.prediction_bar is not None: + self.prediction_bar.close() + self.prediction_bar = None + + def on_log(self, args, state, control, logs=None, **kwargs): + if state.is_local_process_zero and self.training_bar is not None: + _ = logs.pop("total_flos", None) + self.training_bar.write(str(logs)) + + def on_train_end(self, args, state, control, **kwargs): + if state.is_local_process_zero: + self.training_bar.close() + self.training_bar = None + + +class PrinterCallback(TrainerCallback): + """ + A bare [`TrainerCallback`] that just prints the logs. + """ + + def on_log(self, args, state, control, logs=None, **kwargs): + _ = logs.pop("total_flos", None) + if state.is_local_process_zero: + print(logs) + + +class EarlyStoppingCallback(TrainerCallback): + """ + A [`TrainerCallback`] that handles early stopping. + + Args: + early_stopping_patience (`int`): + Use with `metric_for_best_model` to stop training when the specified metric worsens for + `early_stopping_patience` evaluation calls. + early_stopping_threshold(`float`, *optional*): + Use with TrainingArguments `metric_for_best_model` and `early_stopping_patience` to denote how much the + specified metric must improve to satisfy early stopping conditions. ` + + This callback depends on [`TrainingArguments`] argument *load_best_model_at_end* functionality to set best_metric + in [`TrainerState`]. + """ + + def __init__(self, + early_stopping_patience: int=1, + early_stopping_threshold: Optional[float]=0.0): + self.early_stopping_patience = early_stopping_patience + self.early_stopping_threshold = early_stopping_threshold + # early_stopping_patience_counter denotes the number of times validation metrics failed to improve. + self.early_stopping_patience_counter = 0 + + def check_metric_value(self, args, state, control, metric_value): + # best_metric is set by code for load_best_model + operator = np.greater if args.greater_is_better else np.less + if state.best_metric is None or ( + operator(metric_value, state.best_metric) and + abs(metric_value - state.best_metric) > + self.early_stopping_threshold): + self.early_stopping_patience_counter = 0 + else: + self.early_stopping_patience_counter += 1 + + def on_train_begin(self, args, state, control, **kwargs): + assert args.load_best_model_at_end, "EarlyStoppingCallback requires load_best_model_at_end = True" + assert ( + args.metric_for_best_model is not None + ), "EarlyStoppingCallback requires metric_for_best_model is defined" + assert ( + args.evaluation_strategy != IntervalStrategy.NO + ), "EarlyStoppingCallback requires IntervalStrategy of steps or epoch" + + def on_evaluate(self, args, state, control, metrics, **kwargs): + metric_to_check = args.metric_for_best_model + if not metric_to_check.startswith("eval_"): + metric_to_check = f"eval_{metric_to_check}" + metric_value = metrics.get(metric_to_check) + + if metric_value is None: + logger.warning( + f"early stopping required metric_for_best_model, but did not find {metric_to_check} so early stopping is disabled" + ) + return + + self.check_metric_value(args, state, control, metric_value) + if self.early_stopping_patience_counter >= self.early_stopping_patience: + control.should_training_stop = True diff --git a/paddlenlp/trainer/trainer_utils.py b/paddlenlp/trainer/trainer_utils.py new file mode 100644 index 000000000000..2538b97e7d5f --- /dev/null +++ b/paddlenlp/trainer/trainer_utils.py @@ -0,0 +1,339 @@ +# Copyright 2020-present the HuggingFace Inc. team. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Utilities for the Trainer class. +""" +import datetime +import json +import math +import copy +import functools +import gc +import inspect +import os +import random +import re +import threading +import time +from enum import Enum +from typing import Any, Dict, NamedTuple, Optional, Tuple, Union + +import numpy as np + + +class ExplicitEnum(Enum): + """ + Enum with more explicit error message for missing values. + """ + + @classmethod + def _missing_(cls, value): + raise ValueError( + f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}" + ) + + +class EvalPrediction(NamedTuple): + """ + Evaluation output (always contains labels), to be used to compute metrics. + + Parameters: + predictions (`np.ndarray`): Predictions of the model. + label_ids (`np.ndarray`): Targets to be matched. + """ + + predictions: Union[np.ndarray, Tuple[np.ndarray]] + label_ids: Union[np.ndarray, Tuple[np.ndarray]] + + +class EvalLoopOutput(NamedTuple): + predictions: Union[np.ndarray, Tuple[np.ndarray]] + label_ids: Optional[Union[np.ndarray, Tuple[np.ndarray]]] + metrics: Optional[Dict[str, float]] + num_samples: Optional[int] + + +class PredictionOutput(NamedTuple): + predictions: Union[np.ndarray, Tuple[np.ndarray]] + label_ids: Optional[Union[np.ndarray, Tuple[np.ndarray]]] + metrics: Optional[Dict[str, float]] + + +class TrainOutput(NamedTuple): + global_step: int + training_loss: float + metrics: Dict[str, float] + + +PREFIX_CHECKPOINT_DIR = "checkpoint" +_re_checkpoint = re.compile(r"^" + PREFIX_CHECKPOINT_DIR + r"\-(\d+)$") + + +def get_last_checkpoint(folder): + content = os.listdir(folder) + checkpoints = [ + path for path in content + if _re_checkpoint.search(path) is not None and os.path.isdir( + os.path.join(folder, path)) + ] + if len(checkpoints) == 0: + return + return os.path.join( + folder, + max(checkpoints, + key=lambda x: int(_re_checkpoint.search(x).groups()[0]))) + + +class IntervalStrategy(ExplicitEnum): + NO = "no" + STEPS = "steps" + EPOCH = "epoch" + + +class EvaluationStrategy(ExplicitEnum): + NO = "no" + STEPS = "steps" + EPOCH = "epoch" + + +class OptimizerNames(ExplicitEnum): + """ + Stores the acceptable string identifiers for optimizers. + """ + + ADAMW = "adamw" + ADAFACTOR = "adafactor" + + +class BestRun(NamedTuple): + """ + The best run found by an hyperparameter search (see [`~Trainer.hyperparameter_search`]). + + Parameters: + run_id (`str`): + The id of the best run (if models were saved, the corresponding checkpoint will be in the folder ending + with run-{run_id}). + objective (`float`): + The objective that was obtained for this run. + hyperparameters (`Dict[str, Any]`): + The hyperparameters picked to get this run. + """ + + run_id: str + objective: float + hyperparameters: Dict[str, Any] + + +def default_compute_objective(metrics: Dict[str, float]) -> float: + """ + The default objective to maximize/minimize when doing an hyperparameter search. It is the evaluation loss if no + metrics are provided to the [`Trainer`], the sum of all metrics otherwise. + + Args: + metrics (`Dict[str, float]`): The metrics returned by the evaluate method. + + Return: + `float`: The objective to minimize or maximize + """ + metrics = copy.deepcopy(metrics) + loss = metrics.pop("eval_loss", None) + _ = metrics.pop("epoch", None) + # Remove speed metrics + speed_metrics = [ + m for m in metrics.keys() + if m.endswith("_runtime") or m.endswith("_per_second") + ] + for sm in speed_metrics: + _ = metrics.pop(sm, None) + return loss if len(metrics) == 0 else sum(metrics.values()) + + +def is_main_process(local_rank): + """ + Whether or not the current process is the local process, based on `xm.get_ordinal()` (for TPUs) first, then on + `local_rank`. + """ + + return local_rank in [-1, 0] + + +def total_processes_number(local_rank): + """ + Return the number of processes launched in parallel. Works with `paddle.distributed` and TPUs. + """ + if local_rank != -1: + import paddle + + return paddle.distributed.get_world_size() + return 1 + + +def speed_metrics(split, start_time, num_samples=None, num_steps=None): + """ + Measure and return speed performance metrics. + + This function requires a time snapshot `start_time` before the operation to be measured starts and this function + should be run immediately after the operation to be measured has completed. + + Args: + + - split: name to prefix metric (like train, eval, test...) + - start_time: operation start time + - num_samples: number of samples processed + """ + runtime = time.time() - start_time + result = {f"{split}_runtime": round(runtime, 4)} + if num_samples is not None: + samples_per_second = num_samples / runtime + result[f"{split}_samples_per_second"] = round(samples_per_second, 3) + if num_steps is not None: + steps_per_second = num_steps / runtime + result[f"{split}_steps_per_second"] = round(steps_per_second, 3) + return result + + +class SchedulerType(ExplicitEnum): + LINEAR = "linear" + COSINE = "cosine" + COSINE_WITH_RESTARTS = "cosine_with_restarts" + POLYNOMIAL = "polynomial" + CONSTANT = "constant" + CONSTANT_WITH_WARMUP = "constant_with_warmup" + + +def _secs2timedelta(secs): + """ + convert seconds to hh:mm:ss.msec, msecs rounded to 2 decimals + """ + + msec = int(abs(secs - int(secs)) * 100) + return f"{datetime.timedelta(seconds=int(secs))}.{msec:02d}" + + +def metrics_format(self, metrics: Dict[str, float]) -> Dict[str, float]: + """ + Reformat Trainer metrics values to a human-readable format + Args: + metrics (`Dict[str, float]`): + The metrics returned from train/evaluate/predict + Returns: + metrics (`Dict[str, float]`): The reformatted metrics + """ + + metrics_copy = metrics.copy() + for k, v in metrics_copy.items(): + if "_mem_" in k: + metrics_copy[k] = f"{ v >> 20 }MB" + elif "_runtime" in k: + metrics_copy[k] = _secs2timedelta(v) + elif k == "total_flos": + metrics_copy[k] = f"{ int(v) >> 30 }GF" + elif type(metrics_copy[k]) == float: + metrics_copy[k] = round(v, 4) + + return metrics_copy + + +def log_metrics(self, split, metrics): + """ + Log metrics in a specially formatted way + Under distributed environment this is done only for a process with rank 0. + Args: + split (`str`): + Mode/split name: one of `train`, `eval`, `test` + metrics (`Dict[str, float]`): + The metrics returned from train/evaluate/predictmetrics: metrics dict + """ + if not self.is_world_process_zero(): + return + + print(f"***** {split} metrics *****") + metrics_formatted = self.metrics_format(metrics) + k_width = max(len(str(x)) for x in metrics_formatted.keys()) + v_width = max(len(str(x)) for x in metrics_formatted.values()) + for key in sorted(metrics_formatted.keys()): + print(f" {key: <{k_width}} = {metrics_formatted[key]:>{v_width}}") + + +def save_metrics(self, split, metrics, combined=True): + """ + Save metrics into a json file for that split, e.g. `train_results.json`. + Under distributed environment this is done only for a process with rank 0. + Args: + split (`str`): + Mode/split name: one of `train`, `eval`, `test`, `all` + metrics (`Dict[str, float]`): + The metrics returned from train/evaluate/predict + combined (`bool`, *optional*, defaults to `True`): + Creates combined metrics by updating `all_results.json` with metrics of this call + To understand the metrics please read the docstring of [`~Trainer.log_metrics`]. The only difference is that raw + unformatted numbers are saved in the current method. + """ + if not self.is_world_process_zero(): + return + + path = os.path.join(self.args.output_dir, f"{split}_results.json") + with open(path, "w") as f: + json.dump(metrics, f, indent=4, sort_keys=True) + + if combined: + path = os.path.join(self.args.output_dir, "all_results.json") + if os.path.exists(path): + with open(path, "r") as f: + all_metrics = json.load(f) + else: + all_metrics = {} + + all_metrics.update(metrics) + with open(path, "w") as f: + json.dump(all_metrics, f, indent=4, sort_keys=True) + + +def save_state(self): + """ + Saves the Trainer state, since Trainer.save_model saves only the tokenizer with the model + Under distributed environment this is done only for a process with rank 0. + """ + if not self.is_world_process_zero(): + return + + path = os.path.join(self.args.output_dir, "trainer_state.json") + self.state.save_to_json(path) + + +def has_length(dataset): + """ + Checks if the dataset implements __len__() and it doesn't raise an error + """ + try: + return len(dataset) is not None + except TypeError: + # TypeError: len() of unsized object + return False + + +def get_last_checkpoint(folder): + content = os.listdir(folder) + checkpoints = [ + path for path in content + if _re_checkpoint.search(path) is not None and os.path.isdir( + os.path.join(folder, path)) + ] + if len(checkpoints) == 0: + return + return os.path.join( + folder, + max(checkpoints, + key=lambda x: int(_re_checkpoint.search(x).groups()[0]))) diff --git a/paddlenlp/trainer/utils/__init__.py b/paddlenlp/trainer/utils/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/paddlenlp/trainer/utils/helper.py b/paddlenlp/trainer/utils/helper.py new file mode 100644 index 000000000000..017fc3032282 --- /dev/null +++ b/paddlenlp/trainer/utils/helper.py @@ -0,0 +1,90 @@ +import paddle +import paddle.distributed as dist +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + + +def distributed_concat(tensor: Any, + num_total_examples: Optional[int]=None) -> Any: + try: + if isinstance(tensor, (tuple, list)): + return type(tensor)(distributed_concat(t, num_total_examples) + for t in tensor) + output_tensors = [tensor.clone() for _ in range(dist.get_world_size())] + output_tensors = [ + t if len(t.shape) > 0 else t[None] for t in output_tensors + ] + dist.all_gather(output_tensors, tensor) + concat = paddle.concat(output_tensors, axis=0) + + # truncate the dummy elements added by SequentialDistributedSampler + if num_total_examples is not None: + concat = concat[:num_total_examples] + return concat + except AssertionError: + raise AssertionError("Not currently using distributed training") + + +def paddle_pad_and_concatenate(tensor1, tensor2, padding_index=-100): + """Concatenates `tensor1` and `tensor2` on first axis, applying padding on the second if necessary.""" + if len(tensor1.shape) == 1 or tensor1.shape[1] == tensor2.shape[1]: + return paddle.concat((tensor1, tensor2), axis=0) + + # raise ValueError("Error") + # Let's figure out the new shape + new_shape = (tensor1.shape[0] + tensor2.shape[0], max( + tensor1.shape[1], tensor2.shape[1])) + tuple(tensor1.shape[2:]) + + # Now let's fill the result tensor + # result = tensor1.new_full(new_shape, padding_index) + result = paddle.full(new_shape, padding_index, dtype=tensor1.dtype) + + result[:tensor1.shape[0], :tensor1.shape[1]] = tensor1 + result[tensor1.shape[0]:, :tensor2.shape[1]] = tensor2 + return result + + +def nested_concat(tensors, new_tensors, padding_index=-100): + """ + Concat the `new_tensors` to `tensors` on the first dim and pad them on the second if needed. Works for tensors or + nested list/tuples of tensors. + """ + assert type(tensors) == type( + new_tensors + ), f"Expected `tensors` and `new_tensors` to have the same type but found {type(tensors)} and {type(new_tensors)}." + if isinstance(tensors, (list, tuple)): + return type(tensors)(nested_concat( + t, n, padding_index=padding_index) + for t, n in zip(tensors, new_tensors)) + elif isinstance(tensors, paddle.Tensor): + return paddle_pad_and_concatenate( + tensors, new_tensors, padding_index=padding_index) + elif isinstance(tensors, np.ndarray): + return numpy_pad_and_concatenate( + tensors, new_tensors, padding_index=padding_index) + else: + raise TypeError( + f"Unsupported type for concatenation: got {type(tensors)}") + + +def nested_detach(tensors): + "Detach `tensors` (even if it's a nested list/tuple of tensors)." + if isinstance(tensors, (list, tuple)): + return type(tensors)(nested_detach(t) for t in tensors) + return tensors.detach() + + +def nested_numpify(tensors): + "Numpify `tensors` (even if it's a nested list/tuple of tensors)." + if isinstance(tensors, (list, tuple)): + return type(tensors)(nested_numpify(t) for t in tensors) + t = tensors.cpu() + if t.dtype == paddle.float16: + t = t.cast(paddle.float32) + return t.numpy() + + +def nested_truncate(tensors, limit): + "Truncate `tensors` at `limit` (even if it's a nested list/tuple of tensors)." + if isinstance(tensors, (list, tuple)): + return type(tensors)(nested_truncate(t, limit) for t in tensors) + return tensors[:limit] diff --git a/paddlenlp/transformers/model_utils.py b/paddlenlp/transformers/model_utils.py index b3cb7f9707bc..689a6f45a4f3 100644 --- a/paddlenlp/transformers/model_utils.py +++ b/paddlenlp/transformers/model_utils.py @@ -37,6 +37,12 @@ ] +def unwrap_model(model, *args, **kwargs): + raw_model = model._layers if isinstance(model, + paddle.DataParallel) else model + return raw_model + + def register_base_model(cls): """ A decorator for `PretrainedModel` class. It first retrieves the parent class diff --git a/requirements.txt b/requirements.txt index f96ce81c8a05..0dfc347567da 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,5 @@ colorlog colorama seqeval multiprocess -datasets \ No newline at end of file +datasets +tqdm