diff --git a/examples/language_model/ernie-1.0/finetune/finetune.py b/examples/language_model/ernie-1.0/finetune/finetune.py index 320c9ff77aa3..79810dc5e28d 100644 --- a/examples/language_model/ernie-1.0/finetune/finetune.py +++ b/examples/language_model/ernie-1.0/finetune/finetune.py @@ -307,6 +307,30 @@ def do_train(): (ModelArguments, DataTrainingArguments, TrainingArguments)) model_args, data_args, training_args = parser.parse_args_into_dataclasses() + # Log on each process the small summary: + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, world_size: {training_args.world_size}, " + + + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + ) + + # Detecting last checkpoint. + last_checkpoint = None + if os.path.isdir( + training_args.output_dir + ) and training_args.do_train and not training_args.overwrite_output_dir: + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len( + os.listdir(training_args.output_dir)) > 0: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome.") + elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + paddle.set_device(training_args.device) rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1: @@ -374,22 +398,33 @@ def do_train(): training_args, test_ds=all_ds["test"]) - resume_from_checkpoint = training_args.resume_from_checkpoint - if training_args.resume_from_checkpoint is None: - resume_from_checkpoint = True - train_result = trainer.train(resume_from_checkpoint=resume_from_checkpoint) + checkpoint = None + if training_args.resume_from_checkpoint is not None: + checkpoint = training_args.resume_from_checkpoint + elif last_checkpoint is not None: + checkpoint = last_checkpoint + + train_result = trainer.train(resume_from_checkpoint=checkpoint) metrics = train_result.metrics trainer.save_model() # Saves the tokenizer too for easy upload - # trainer.save_infer_model() -> 部署, onnx, slim, 量化后可否加速 - trainer.log_metrics("train", metrics) trainer.save_metrics("train", metrics) trainer.save_state() - # trainer.train() - # trainer.eval() + eval_metrics = trainer.evaluate() + trainer.log_metrics("eval", eval_metrics) + test_ret = trainer.predict(trainer.test_ds) + trainer.log_metrics("test", test_ret.metrics) + + input_spec = [ + paddle.static.InputSpec( + shape=[None, None], dtype="int64"), # input_ids + paddle.static.InputSpec( + shape=[None, None], dtype="int64") # segment_ids + ] + trainer.export_model(input_spec=input_spec, load_best_model=True) def print_arguments(args): diff --git a/examples/language_model/ernie-1.0/finetune/sequence_classification.py b/examples/language_model/ernie-1.0/finetune/sequence_classification.py index a7dd0170608c..487c0ab4715f 100644 --- a/examples/language_model/ernie-1.0/finetune/sequence_classification.py +++ b/examples/language_model/ernie-1.0/finetune/sequence_classification.py @@ -340,6 +340,10 @@ def __init__(self, train_ds, dev_ds, model, tokenizer, data_args, train_ds = train_ds.map(trans_fn) dev_ds = dev_ds.map(trans_fn) + if "test_ds" in kwargs.keys(): + test_ds = kwargs["test_ds"] + self.test_ds = test_ds.map(trans_fn) + loss_fct = paddle.nn.loss.CrossEntropyLoss( ) if train_ds.label_list else paddle.nn.loss.MSELoss() diff --git a/paddlenlp/datasets/chnsenticorp_v2.py b/paddlenlp/datasets/chnsenticorp_v2.py index 5e6056a200fa..908f558eb02c 100644 --- a/paddlenlp/datasets/chnsenticorp_v2.py +++ b/paddlenlp/datasets/chnsenticorp_v2.py @@ -60,7 +60,7 @@ def _get_data(self, mode, **kwargs): def _read(self, filename, split): """Reads data.""" with open(filename, 'r', encoding='utf-8') as f: - head = None + head = True for line in f: data = line.strip().split("\t") if not head: diff --git a/paddlenlp/trainer/trainer_args.py b/paddlenlp/trainer/trainer_args.py index 68575e1feb96..d62144a516c3 100644 --- a/paddlenlp/trainer/trainer_args.py +++ b/paddlenlp/trainer/trainer_args.py @@ -280,8 +280,6 @@ class TrainingArguments: The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`, `"comet_ml"`, `"mlflow"`, `"tensorboard"` and `"wandb"`. Use `"all"` to report to all integrations installed, `"none"` for no integrations. - dataloader_pin_memory (`bool`, *optional*, defaults to `True`): - Whether you want to pin memory in data loaders or not. Will default to `True`. skip_memory_metrics (`bool`, *optional*, defaults to `True`): Whether to skip adding of memory profiler reports to metrics. This is skipped by default because it slows down the training and evaluation speed. @@ -544,9 +542,6 @@ class TrainingArguments: "The list of integrations to report the results and logs to." }) - dataloader_pin_memory: bool = field( - default=True, - metadata={"help": "Whether or not to pin memory for DataLoader."}) skip_memory_metrics: bool = field( default=True, metadata={ diff --git a/paddlenlp/trainer/trainer_base.py b/paddlenlp/trainer/trainer_base.py index f3a642074571..424739b352db 100644 --- a/paddlenlp/trainer/trainer_base.py +++ b/paddlenlp/trainer/trainer_base.py @@ -484,7 +484,7 @@ def train( WEIGHTS_NAME) if os.path.exists(best_model_path): # We load the model state dict on the CPU to avoid an OOM error. - state_dict = paddle.load(best_model_path, map_location="cpu") + state_dict = paddle.load(best_model_path) # If the model is on the GPU, it still works! self._set_state_dict_in_model(state_dict) else: @@ -535,7 +535,7 @@ def _get_train_sampler(self) -> Optional[paddle.io.Sampler]: dataset=self.train_dataset, shuffle=True, batch_size=self.args.per_device_train_batch_size, - drop_last=False) + drop_last=self.args.dataloader_drop_last) else: return DistributedBatchSampler( self.train_dataset, @@ -543,7 +543,7 @@ def _get_train_sampler(self) -> Optional[paddle.io.Sampler]: shuffle=True, num_replicas=self.args.world_size, rank=self.args.process_index, - drop_last=False) + drop_last=self.args.dataloader_drop_last) def _set_state_dict_in_model(self, state_dict): load_result = self.model.set_state_dict(state_dict) @@ -556,11 +556,9 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, # all_gather + mean() to get average loss over all processes tr_loss_scalar = self._nested_gather(tr_loss).mean().item() - # tr_loss_scalar = tr_loss.mean().item() # reset tr_loss to zero tr_loss.subtract_(tr_loss) - # tr_loss.zero_() logs["loss"] = round(tr_loss_scalar / ( self.state.global_step - self._globalstep_last_logged), 4) @@ -602,29 +600,25 @@ def get_train_dataloader(self): return DataLoader( train_dataset, - # batch_size=self.args.train_batch_size, batch_sampler=train_sampler, collate_fn=self.data_collator, - # drop_last=self.args.dataloader_drop_last, - num_workers=self.args.dataloader_num_workers, - # pin_memory=self.args.dataloader_pin_memory, - ) + num_workers=self.args.dataloader_num_workers, ) def _get_eval_sampler(self, eval_dataset: Dataset): if self.args.world_size <= 1: - return DistributedBatchSampler( + return paddle.io.BatchSampler( eval_dataset, - # num_replicas=self.args.world_size, - # rank=self.args.process_index, batch_size=self.args.eval_batch_size, shuffle=False, - # seed=self.args.seed, - ) + drop_last=False, ) else: return DistributedBatchSampler( eval_dataset, + num_replicas=self.args.world_size, + rank=self.args.process_index, batch_size=self.args.eval_batch_size, - shuffle=False) + shuffle=False, + drop_last=False, ) def get_eval_dataloader(self, eval_dataset: Optional[Dataset]=None) -> DataLoader: @@ -646,13 +640,9 @@ def get_eval_dataloader(self, return DataLoader( eval_dataset, - # batch_size=self.args.train_batch_size, batch_sampler=eval_sampler, collate_fn=self.data_collator, - # drop_last=self.args.dataloader_drop_last, - num_workers=self.args.dataloader_num_workers, - # pin_memory=self.args.dataloader_pin_memory, - ) + num_workers=self.args.dataloader_num_workers, ) def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: """ @@ -671,11 +661,9 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: # We use the same batch_size as for eval. return DataLoader( test_dataset, - sampler=test_sampler, - batch_size=self.args.eval_batch_size, + batch_sampler=test_sampler, collate_fn=self.data_collator, - drop_last=self.args.dataloader_drop_last, - pin_memory=self.args.dataloader_pin_memory, ) + drop_last=self.args.dataloader_drop_last, ) def create_optimizer_and_scheduler(self, num_training_steps: int): """ @@ -909,6 +897,47 @@ def save_model(self, output_dir: Optional[str]=None): if self.args.should_save: self._save(output_dir) + def export_model(self, + input_spec=None, + load_best_model=False, + output_dir: Optional[str]=None): + + if output_dir is None: + output_dir = self.args.output_dir + + if load_best_model and self.state.best_model_checkpoint is not None: + if self.args.local_rank != -1: + dist.barrier() + + logger.info( + f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric})." + ) + + best_model_path = os.path.join(self.state.best_model_checkpoint, + WEIGHTS_NAME) + if os.path.exists(best_model_path): + # We load the model state dict on the CPU to avoid an OOM error. + state_dict = paddle.load(best_model_path) + # If the model is on the GPU, it still works! + self._set_state_dict_in_model(state_dict) + else: + logger.warning( + f"Could not locate the best model at {best_model_path}, if you are running a distributed training " + "on multiple nodes, you should activate `--save_on_each_node`." + ) + + model = unwrap_model(self.model) + model.eval() + + # Convert to static graph with specific input description + model = paddle.jit.to_static(model, input_spec=input_spec) + + # Save in static graph model. + save_path = os.path.join(output_dir, "inference", "infer") + logger.info("Exporting inference model to %s" % save_path) + paddle.jit.save(model, save_path) + logger.info("Inference model exported.") + def _save_checkpoint(self, model, metrics=None): # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model" @@ -1072,7 +1101,6 @@ def _load_optimizer_and_scheduler(self, checkpoint): checkpoint, OPTIMIZER_NAME)) and os.path.isfile( os.path.join(checkpoint, SCHEDULER_NAME)): # Load in optimizer and scheduler states - map_location = self.args.device self.optimizer.set_state_dict( paddle.load(os.path.join(checkpoint, OPTIMIZER_NAME))) self.lr_scheduler.set_state_dict( @@ -1492,11 +1520,6 @@ def _pad_across_processes(self, tensor, pad_index=-100): new_tensor[:, :old_size[1]] = tensor return new_tensor - def eval(self, *args, **kwargs): - """ - """ - pass - def print_config(self): """ """ diff --git a/paddlenlp/trainer/trainer_callback.py b/paddlenlp/trainer/trainer_callback.py index f4cec4d6108f..2d33ac0525ab 100644 --- a/paddlenlp/trainer/trainer_callback.py +++ b/paddlenlp/trainer/trainer_callback.py @@ -517,6 +517,13 @@ def on_step_end(self, # End training if state.global_step >= state.max_steps: control.should_training_stop = True + # Log and save on end + if args.logging_strategy == IntervalStrategy.STEPS: + control.should_log = True + if args.evaluation_strategy == IntervalStrategy.STEPS: + control.should_evaluate = True + if args.save_strategy == IntervalStrategy.STEPS: + control.should_save = True return control