Skip to content

Commit

Permalink
init verison for paddlenlp trainer.
Browse files Browse the repository at this point in the history
  • Loading branch information
ZHUI committed Mar 18, 2022
1 parent e4f2f02 commit 68dea62
Show file tree
Hide file tree
Showing 7 changed files with 3,125 additions and 117 deletions.
1 change: 0 additions & 1 deletion examples/language_model/ernie-1.0/finetune/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,6 @@ def do_train(args):
tokenizer,
args,
test_ds=all_ds["test"])

trainer.train()
trainer.eval()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
import numpy as np

import paddlenlp
from paddlenlp.data import Stack, Tuple, Pad, Dict
from paddlenlp.data import Stack, Tuple, Pad
from paddlenlp.utils.log import logger

from trainer_base import TrainerBase
from trainer_base import TrainerBase, Trainer


def convert_example(example, tokenizer, max_seq_length=512, is_test=False):
Expand All @@ -46,7 +46,12 @@ def convert_example(example, tokenizer, max_seq_length=512, is_test=False):
if is_test:
return input_ids, token_type_ids
label = np.array([example["label"]], dtype="int64")
return input_ids, token_type_ids, label
# return input_ids, token_type_ids, label
return {
"input_ids": input_ids,
"token_type_ids": token_type_ids,
"labels": label
}


def seq_trans_fn(example, tokenizer, args):
Expand Down Expand Up @@ -130,6 +135,39 @@ def clue_batchify_fn(tokenizer, args):
return batchify_fn


class Dict(object):
def __init__(self, fn):
assert isinstance(fn, (dict)), 'Input pattern not understood. The input of Dict must be a dict with key of input column name and value of collate_fn ' \
'Received fn=%s' % (str(fn))

self._fn = fn

for col_name, ele_fn in self._fn.items():
assert callable(
ele_fn
), 'Batchify functions must be callable! type(fn[%d]) = %s' % (
col_name, str(type(ele_fn)))

def __call__(self, data):

ret = {}
for col_name, ele_fn in self._fn.items():
result = ele_fn([ele[col_name] for ele in data])
ret[col_name] = result

return ret


def clue_batchify_fn_dict(tokenizer, args):
batchify_fn = lambda samples, fn=Dict({
'input_ids': Pad(axis=0, pad_val=tokenizer.pad_token_id), # input
"token_type_ids": Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # segment
"labels": Stack(dtype="int64" if args.label_list else "float32") # label
}): fn(samples)

return batchify_fn


@paddle.no_grad()
def evaluate(model, criterion, metric, data_loader, mode="dev"):
"""
Expand Down Expand Up @@ -278,7 +316,7 @@ def train(self):
best_dev_acc, corr_test_acc))


class SeqTrainer(ClueTrainer):
class SeqTrainer2(ClueTrainer):
def dataloader_inner(self):
trans_fn = partial(
seq_trans_fn, tokenizer=self.tokenizer, args=self.args)
Expand All @@ -290,3 +328,44 @@ def dataloader_inner(self):
self.dev_ds, "dev", self.args.batch_size, batchify_fn, trans_fn)
self.test_dl = self.create_dataloader(
self.test_ds, "dev", self.args.batch_size, batchify_fn, trans_fn)


class SeqTrainer(Trainer):
def __init__(self, train_ds, dev_ds, model, tokenizer, args, *arg,
**kwargs):

trans_fn = partial(seq_trans_fn, tokenizer=tokenizer, args=args)
batchify_fn = clue_batchify_fn_dict(tokenizer, args)

train_ds = train_ds.map(trans_fn)
dev_ds = dev_ds.map(trans_fn)

loss_fct = paddle.nn.loss.CrossEntropyLoss(
) if train_ds.label_list else paddle.nn.loss.MSELoss()

def compute_metrics(p):
preds = p.predictions[0] if isinstance(p.predictions,
tuple) else p.predictions
probs = F.softmax(preds, axis=1)
metric = Accuracy()
metric.reset()
result = metric.compute(preds, p.label_ids)
metric.update(result)
accu = metric.accumulate()
metric.reset()
return {"eval_accuracy": accu}

# return {
# "accuracy": (preds == p.label_ids).astype(np.float32).mean()
# .item()
# }

super().__init__(
model,
loss_fct,
args,
batchify_fn,
train_ds,
dev_ds,
tokenizer,
compute_metrics=compute_metrics)
Loading

0 comments on commit 68dea62

Please sign in to comment.