-
Notifications
You must be signed in to change notification settings - Fork 6
/
beam_search.py
executable file
·121 lines (98 loc) · 3.55 KB
/
beam_search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# beam search implementation in PyTorch."""
#
#
# hyp1#-hyp1---hyp1 -hyp1
# \ /
# hyp2 \-hyp2 /-hyp2#hyp2
# / \
# hyp3#-hyp3---hyp3 -hyp3
# ========================
#
# Takes care of beams, back pointers, and scores.
# Code borrowed from https://github.com/MaximumEntropy/Seq2Seq-PyTorch/blob/master/beam_search.py,
# who borrowed it from PyTorch OpenNMT example
# https://github.com/pytorch/examples/blob/master/OpenNMT/onmt/Beam.py
# :-)
import torch
class Beam(object):
"""Ordered beam of candidate outputs. Fixed length."""
def __init__(self, size, steps, cuda=False):
"""Initialize params."""
self.size = size
self.done = False
self.pad = -1
self.steps = steps
self.current_step = 0
self.tt = torch.cuda if cuda else torch
# The score for each translation on the beam.
self.scores = self.tt.FloatTensor(size).zero_()
# The backpointers at each time-step.
self.prevKs = []
# The outputs at each time-step.