Skip to content

Commit

Permalink
add trainer prototype.
Browse files Browse the repository at this point in the history
  • Loading branch information
ZHUI committed Mar 15, 2022
1 parent 654d45a commit e4f2f02
Show file tree
Hide file tree
Showing 6 changed files with 668 additions and 168 deletions.
25 changes: 15 additions & 10 deletions examples/language_model/ernie-1.0/finetune/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ DefaultArgs:
num_train_epochs: 3
batch_size: 64
max_seq_length: 128
weight_decay: 0.0
weight_decay: 0.01
logging_steps: 10
valid_steps: 100
minimum_valid_times: 20 # If under valid_steps, the valid time is less then 20, the config of valid_steps will be changed.
valid_steps: 200
minimum_valid_times: 20
max_steps: -1
warmup_steps: 0
metric: "Accuracy"
Expand All @@ -34,25 +34,30 @@ SequenceClassification:
max_seq_length: 256
batch_size: 32
xnli_cn:
learning_rate: 0.00005
learning_rate: 0.0001
num_train_epochs: 3
batch_size: 256
chnsenticorp_v2:
learning_rate: 0.00001
num_train_epochs: 5
learning_rate: 0.00005
batch_size: 16
num_train_epochs: 8

# Datasets which used for token classfication
TokenClassification:
peoples_daily_ner:
num_train_epochs: 5
learning_rate: 0.00005
num_train_epochs: 8
batch_size: 16
msra_ner:
num_train_epochs: 3

# Datasets which used for question answersing
QuestionAnswering:
cmrc2018:
num_train_epochs: 1
batch_size: 12
max_seq_length: 384
learning_rate: 0.00005
num_train_epochs: 5
batch_size: 32
max_seq_length: 512
dureader_nlp:
num_train_epochs: 1
batch_size: 12
Expand Down
51 changes: 34 additions & 17 deletions examples/language_model/ernie-1.0/finetune/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,10 @@
# limitations under the License.

import argparse
import logging
import os
import sys
import random
import time
import math
import copy
import yaml
from functools import partial
Expand All @@ -27,24 +25,21 @@

import numpy as np
import paddle
from paddle.io import DataLoader
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.metric import Accuracy
from paddlenlp.metrics.squad import squad_evaluate, compute_prediction

import paddlenlp
from paddlenlp.datasets import load_dataset
from paddlenlp.data import Stack, Tuple, Pad, Dict
from paddlenlp.transformers import AutoModelForSequenceClassification, AutoTokenizer
from paddlenlp.transformers import AutoModelForTokenClassification
from paddlenlp.transformers import AutoModelForQuestionAnswering
from paddlenlp.transformers import LinearDecayWithWarmup
from paddlenlp.transformers import AutoTokenizer
from paddlenlp.utils.log import logger

sys.path.insert(0, os.path.abspath("."))
from sequence_classification import CLUE_TRAINING
from question_answering import QA_TRAINING
from sequence_classification import ClueTrainer, SeqTrainer
from question_answering import MrcTrainer
from token_classification import NerTrainer

ALL_TASKS = {
"SequenceClassification": [],
Expand Down Expand Up @@ -77,7 +72,6 @@

def parse_args():
parser = argparse.ArgumentParser()
# Required parameters

parser.add_argument(
"--dataset",
Expand Down Expand Up @@ -115,7 +109,7 @@ def parse_args():
help="Batch size per GPU/CPU for training.", )
group.add_argument(
"--weight_decay",
default=0.0,
default=None,
type=float,
help="Weight decay if we apply some.")

Expand All @@ -134,6 +128,12 @@ def parse_args():
type=int,
default=200,
help="Save checkpoint every X updates steps.")
group.add_argument(
"--minimum_valid_times",
type=int,
default=None,
help="If under valid_steps, the valid time is less then minimum_valid_times, the config of override valid_steps."
)
group.add_argument(
"--max_steps",
default=-1,
Expand Down Expand Up @@ -257,12 +257,29 @@ def do_train(args):
if paddle.distributed.get_world_size() > 1:
model = paddle.DataParallel(model)

if 'clue' in args.dataset:
trainer = CLUE_TRAINING(all_ds["train"], all_ds["dev"], model,
tokenizer, args)
elif "Answering" in config["model"]:
trainer = QA_TRAINING(all_ds["train"], all_ds["dev"], model, tokenizer,
args)
if "SequenceClassification" in config["model"]:
if 'clue' in args.dataset:
trainer = ClueTrainer(all_ds["train"], all_ds["dev"], model,
tokenizer, args)
else:
trainer = SeqTrainer(
all_ds["train"],
all_ds["dev"],
model,
tokenizer,
args,
test_ds=all_ds["test"])
elif "QuestionAnswering" in config["model"]:
trainer = MrcTrainer(all_ds["train"], all_ds["dev"], model, tokenizer,
args)
elif 'TokenClassification' in config["model"]:
trainer = NerTrainer(
all_ds["train"],
all_ds["dev"],
model,
tokenizer,
args,
test_ds=all_ds["test"])

trainer.train()
trainer.eval()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from paddlenlp.data import Pad, Stack, Tuple, Dict
from paddlenlp.metrics.squad import squad_evaluate, compute_prediction

from sequence_classification import BaseTrainer
from trainer_base import TrainerBase
from paddlenlp.utils.log import logger


Expand Down Expand Up @@ -188,7 +188,7 @@ def prepare_validation_features(examples, tokenizer, args):
return tokenized_examples


class QA_TRAINING(BaseTrainer):
class MrcTrainer(TrainerBase):
def __init__(self, train_ds, dev_ds, model, tokenizer, args):
super().__init__()
self.rank = paddle.distributed.get_rank()
Expand Down Expand Up @@ -257,5 +257,8 @@ def train(self):
self.lr_scheduler.step()
self.optimizer.clear_grad()

if global_step % self.args.valid_steps == 0:
self.eval()

if global_step == self.args.num_training_steps:
break
Loading

0 comments on commit e4f2f02

Please sign in to comment.