-
Notifications
You must be signed in to change notification settings - Fork 0
/
sample-decoder-test.py
43 lines (34 loc) · 2.16 KB
/
sample-decoder-test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch
import torch.nn as nn
from models import *
def testAttentionDecoder():
# device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"
embedding_size = 300
hidden_size = 1024
output_size = 10_000 # num of words
max_len = 100
decoder_input_sample = torch.tensor([[3]], device=device) # replace in class with dictionary in the above method
encoder_output_states, encoder_hidden = testEncoder()
# attention_decoder_model = AttentionDecoder(embedding_size=embedding_size,
# hidden_size=hidden_size,output_size=output_size,
# dropout_probability=0.2,max_len=max_len
# ).to(device)
attention_weights_model = AttnDecoderRNN(hidden_size=hidden_size, output_size=output_size, dropout_p=0.2,
max_length=max_len, device=device)
# decoder_hidden_input = attention_decoder_model.initHidden()
decoder_hidden_input = attention_weights_model.initHidden()
print(decoder_hidden_input.device)
#
# attention_decoder_model_outputs = attention_decoder_model(input_x=decoder_input_sample,
# encoder_output_states=encoder_output_states,
# decoder_prev_hidden = decoder_hidden_input)
decoder_hidden = attention_weights_model.initHidden()
encoder_hidden_size = 100
sample_encoder_outputs = torch.zeros(max_len, encoder_hidden_size, device=device)
print("SAMPLE ENCODER OUTPUTS SIZE : ", sample_encoder_outputs.size())
print("ACTUAL ENCODER OUTPUTS SIZE : ", encoder_output_states.size())
attention_weights_hidden = attention_weights_model(input_x=decoder_input_sample, hidden_state=decoder_hidden,
encoder_outputs=sample_encoder_outputs)
# print("ATTENTION DECODER MODEL OUTPUTS SHAPE : ",attention_decoder_model_outputs.shape)
print("ATTENTION DECODER MODEL OUTPUTS SHAPE : ", attention_weights_hidden.shape)