Skip to content

Commit

Permalink
update the format of the log of ofa-bert (PaddlePaddle#112)
Browse files Browse the repository at this point in the history
  • Loading branch information
ceci3 committed Mar 11, 2021
1 parent d46bb2a commit 9ac5abc
Showing 1 changed file with 35 additions and 17 deletions.
52 changes: 35 additions & 17 deletions examples/model_compression/ofa/run_glue_ofa.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,26 +171,41 @@ def set_seed(args):
# `paddle.seed(args.seed + paddle.distributed.get_rank())`
paddle.seed(args.seed)


@paddle.no_grad()
def evaluate(model, criterion, metric, data_loader, width_mult=1.0):
with paddle.no_grad():
model.eval()
metric.reset()
for batch in data_loader:
input_ids, segment_ids, labels = batch
logits = model(input_ids, segment_ids, attention_mask=[None, None])
if isinstance(logits, tuple):
logits = logits[0]
loss = criterion(logits, labels)
correct = metric.compute(logits, labels)
metric.update(correct)
results = metric.accumulate()
model.eval()
metric.reset()
for batch in data_loader:
input_ids, segment_ids, labels = batch
logits = model(input_ids, segment_ids, attention_mask=[None, None])
if isinstance(logits, tuple):
logits = logits[0]
loss = criterion(logits, labels)
correct = metric.compute(logits, labels)
metric.update(correct)
res = metric.accumulate()
if isinstance(metric, AccuracyAndF1):
print(
"width_mult: %f, eval loss: %f, %s: %s\n" %
(width_mult, loss.numpy(), metric.name(), results),
"width_mult: %f, eval loss: %f, acc: %s, precision: %s, recall: %s, f1: %s, acc and f1: %s, "
% (
width_mult,
loss.numpy(),
res[0],
res[1],
res[2],
res[3],
res[4], ),
end='')
model.train()

elif isinstance(metric, Mcc):
print("width_mult: %f, eval loss: %f, mcc: %s, " % (width_mult, loss.numpy(), res[0]), end='')
elif isinstance(metric, PearsonAndSpearman):
print(
"width_mult: eval loss: %f, pearson: %s, spearman: %s, pearson and spearman: %s, "
% (width_mult, loss.numpy(), res[0], res[1], res[2]),
end='')
else:
print("eval loss: %f, acc: %s, " % (loss.numpy(), res), end='')
model.train()

### monkey patch for bert forward to accept [attention_mask, head_mask] as attention_mask
def bert_forward(self,
Expand Down Expand Up @@ -450,6 +465,7 @@ def do_train(args):
tic_train = time.time()

if global_step % args.save_steps == 0:
tic_eval = time.time()
if args.task_name == "mnli":
evaluate(
teacher_model,
Expand All @@ -470,6 +486,8 @@ def do_train(args):
metric,
dev_data_loader,
width_mult=100)
print("eval done total : %s s" %
(time.time() - tic_eval))
for idx, width_mult in enumerate(args.width_mult_list):
net_config = utils.dynabert_config(ofa_model, width_mult)
ofa_model.set_net_config(net_config)
Expand Down

0 comments on commit 9ac5abc

Please sign in to comment.