-
Notifications
You must be signed in to change notification settings - Fork 3
/
infill.py
113 lines (99 loc) · 4.14 KB
/
infill.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# Imports
import os.path
import pickle
import ilm.tokenize_util
from transformers import GPT2LMHeadModel
from ilm.infer import infill_with_ilm
import gdown
# Variables
MODEL_DIR = 'model/'
MASK_CLS = 'ilm.mask.hierarchical.MaskHierarchical'
result = []
tokenizer = ilm.tokenize_util.Tokenizer.GPT2
datamodel = 'model/pytorch_model.bin'
model_location = "https://drive.google.com/uc?id=1-12EFaKNBYD1vlfeZcKnV5PaSqeHNTHX"
if os.path.isfile(datamodel):
('Model was already downloaded.')
else:
gdown.download(model_location, datamodel)
# Create context
context = 'The sun is shining. _ All the children want to swim.'
class INFILL:
def infilling_sentence(self, context: str):
result.clear()
with open(os.path.join(MODEL_DIR, 'additional_ids_to_tokens.pkl'), 'rb') as f:
additional_ids_to_tokens = pickle.load(f)
additional_tokens_to_ids = {v: k for k, v in additional_ids_to_tokens.items()}
try:
ilm.tokenize_util.update_tokenizer(additional_ids_to_tokens, tokenizer)
except ValueError:
print('Already updated')
# Load model
device = 'cpu'
model = GPT2LMHeadModel.from_pretrained(MODEL_DIR)
model.eval()
_ = model.to(device)
context_ids = ilm.tokenize_util.encode(context, tokenizer)
_blank_id = ilm.tokenize_util.encode(' _', tokenizer)[0]
# Infilling type: One of sentence, document, mixture, paragraph, ngram, or word
context_ids[context_ids.index(_blank_id)] = additional_tokens_to_ids['<|infill_sentence|>']
generated = infill_with_ilm(
model,
additional_tokens_to_ids,
context_ids,
num_infills=5)
for g in generated:
result.append(str(ilm.tokenize_util.decode(g, tokenizer)))
return result
def infilling_word(self, context: str):
result.clear()
with open(os.path.join(MODEL_DIR, 'additional_ids_to_tokens.pkl'), 'rb') as f:
additional_ids_to_tokens = pickle.load(f)
additional_tokens_to_ids = {v: k for k, v in additional_ids_to_tokens.items()}
try:
ilm.tokenize_util.update_tokenizer(additional_ids_to_tokens, tokenizer)
except ValueError:
print('Already updated')
# Load model
device = 'cpu'
model = GPT2LMHeadModel.from_pretrained(MODEL_DIR)
model.eval()
_ = model.to(device)
context_ids = ilm.tokenize_util.encode(context, tokenizer)
_blank_id = ilm.tokenize_util.encode(' _', tokenizer)[0]
# Infilling type: One of sentence, document, mixture, paragraph, ngram, or word
context_ids[context_ids.index(_blank_id)] = additional_tokens_to_ids['<|infill_word|>']
generated = infill_with_ilm(
model,
additional_tokens_to_ids,
context_ids,
num_infills=5)
for g in generated:
result.append(str(ilm.tokenize_util.decode(g, tokenizer)))
return result
def infilling_ngram(self, context: str):
result.clear()
with open(os.path.join(MODEL_DIR, 'additional_ids_to_tokens.pkl'), 'rb') as f:
additional_ids_to_tokens = pickle.load(f)
additional_tokens_to_ids = {v: k for k, v in additional_ids_to_tokens.items()}
try:
ilm.tokenize_util.update_tokenizer(additional_ids_to_tokens, tokenizer)
except ValueError:
print('Already updated')
# Load model
device = 'cpu'
model = GPT2LMHeadModel.from_pretrained(MODEL_DIR)
model.eval()
_ = model.to(device)
context_ids = ilm.tokenize_util.encode(context, tokenizer)
_blank_id = ilm.tokenize_util.encode(' _', tokenizer)[0]
# Infilling type: One of sentence, document, mixture, paragraph, ngram, or word
context_ids[context_ids.index(_blank_id)] = additional_tokens_to_ids['<|infill_ngram|>']
generated = infill_with_ilm(
model,
additional_tokens_to_ids,
context_ids,
num_infills=5)
for g in generated:
result.append(str(ilm.tokenize_util.decode(g, tokenizer)))
return result