-
Notifications
You must be signed in to change notification settings - Fork 2
/
runner.py
121 lines (110 loc) · 4.79 KB
/
runner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import os
import copy
import utils
import torch
import math
import random
import warnings
import numpy as np
import torch.nn as nn
import torch.optim as optim
from constants import *
from transformers import *
from models import CorefModel
from argparse import ArgumentParser
from data import prepare_dataset, combine_datasets
from utils import evaluate, RunningAverage, prepare_configs
# Main Functions
def train(config_name):
# Prepare the config, the tokenizer, and the model
configs = prepare_configs(config_name)
tokenizer = AutoTokenizer.from_pretrained(configs['transformer'], do_basic_tokenize=False)
model = CorefModel(configs)
if PRETRAINED_MODEL:
checkpoint = torch.load(PRETRAINED_MODEL)
model.load_state_dict(checkpoint['model_state_dict'], strict=False)
print('Reload a model')
# Prepare datasets
ace05_dataset = prepare_dataset(ACE05, tokenizer)
kbp2016_dataset = prepare_dataset(KBP2016, tokenizer)
kbp2017_dataset = prepare_dataset(KBP2017, tokenizer)
kbp_dataset = combine_datasets([kbp2016_dataset, kbp2017_dataset])
ontonote_dataset = prepare_dataset(ONTONOTE, tokenizer)
dataset = combine_datasets([ontonote_dataset, ace05_dataset, kbp_dataset])
print('Number of train: {}'.format(len(dataset.examples[TRAIN])))
print('Number of dev: {}'.format(len(dataset.examples[DEV])))
print('Number of test: {}'.format(len(dataset.examples[TEST])))
# Evaluation (if there exists a checkpoint)
best_dev_f1 = 0
if PRETRAINED_MODEL:
with torch.no_grad():
print('Evaluation on the (aggregated) dev set')
dev_f1 = evaluate(model, dataset, DEV)
print('Evaluation on the (aggregated) test set')
evaluate(model, dataset, TEST)
# Individual Test Set
print('Evaluation on the Ontonote test set')
evaluate(model, ontonote_dataset, TEST)
print('Evaluation on the ACE05 test set')
evaluate(model, ace05_dataset, TEST)
print('Evaluation on the KBP test set')
evaluate(model, kbp_dataset, TEST)
best_dev_f1 = dev_f1
# Prepare the optimizer and the scheduler
num_train_docs = len(dataset.examples[TRAIN])
num_epoch_steps = math.ceil(num_train_docs / configs['batch_size'])
num_train_steps = int(num_epoch_steps * configs['epochs'])
num_warmup_steps = int(num_train_steps * 0.1)
optimizer = model.get_optimizer(num_warmup_steps, num_train_steps)
print('Prepared the optimizer and the scheduler', flush=True)
# Start training
accumulated_loss = RunningAverage()
iters, batch_loss = 0, 0
for i in range(configs['epochs']):
print('Starting epoch {}'.format(i+1), flush=True)
train_indices = list(range(num_train_docs))
random.shuffle(train_indices)
for train_idx in train_indices:
iters += 1
tensorized_example = dataset.tensorized_examples[TRAIN][train_idx]
iter_loss = model(*tensorized_example)[0]
iter_loss /= configs['batch_size']
iter_loss.backward()
batch_loss += iter_loss.data.item()
if iters % configs['batch_size'] == 0:
accumulated_loss.update(batch_loss)
torch.nn.utils.clip_grad_norm_(model.parameters(), configs['max_grad_norm'])
optimizer.step()
optimizer.zero_grad()
batch_loss = 0
# Report loss
if iters % configs['report_frequency'] == 0:
print('{} Average Loss = {}'.format(iters, accumulated_loss()), flush=True)
accumulated_loss = RunningAverage()
# Evaluation after each epoch
with torch.no_grad():
print('Evaluation on the (aggregated) dev set')
dev_f1 = evaluate(model, dataset, DEV)
print('Evaluation on the (aggregated) test set')
evaluate(model, dataset, TEST)
# Individual Test Set
print('Evaluation on the Ontonote test set')
evaluate(model, ontonote_dataset, TEST)
print('Evaluation on the ACE05 test set')
evaluate(model, ace05_dataset, TEST)
print('Evaluation on the KBP test set')
evaluate(model, kbp_dataset, TEST)
# Save model if it has better F1 score
if dev_f1 > best_dev_f1:
best_dev_f1 = dev_f1
# Save the model
save_path = os.path.join(configs['saved_path'], 'model.pt')
torch.save({'model_state_dict': model.state_dict()}, save_path)
print('Saved the model', flush=True)
if __name__ == '__main__':
# Parse argument
parser = ArgumentParser()
parser.add_argument('-c', '--config_name', default='spanbert_large')
args = parser.parse_args()
# Start training
train(args.config_name)