-
Notifications
You must be signed in to change notification settings - Fork 0
/
loader.py
121 lines (95 loc) · 4.01 KB
/
loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import pandas as pd
import os
# from torchtext.data import Dataset, BucketIterator
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
import torch
#Modified from https://github.com/aladdinpersson/Machine-Learning-Collection/blob/master/ML/Pytorch/more_advanced/image_captioning/get_loader.py
class Vocabulary:
def __init__(self, freq_thres=1):
self.itos = {0: "[PAD]", 1: "[START]", 2: "[END]", 3: "[UNK]", 4: "[MASK]"}
self.stoi = {"[PAD]": 0, "[START]": 1, "[END]": 2, "[UNK]": 3, "[MASK]": 4}
self.freq_threshold = freq_thres
def __len__(self):
return len(self.itos)
#returns the numeric token value of a given string token
def get_idx(self, token):
return self.stoi[token]
#returns the alphanumeric token value given the token idx
def get_token(self, token):
return self.itos[token]
@staticmethod
def tokenizer_seq(fasta_seq):
# print(fasta_seq)
return [str(x) for x in list(fasta_seq)]
def build_vocabulary(self):
frequencies = {}
idx = len(self.itos)
for idx1, base in enumerate(list('acgut')):
self.stoi[base] = idx+idx1
self.itos[idx+idx1] = base
def numericalize(self, fasta_seq):
tokenized_seq = self.tokenizer_seq(fasta_seq.lower())
return [
self.stoi[token] if token in self.stoi else self.stoi["[UNK]"]
for token in tokenized_seq
]
class SequenceDataset(Dataset):
def __init__(self, filename, freq_threshold=5):
self.df = pd.read_csv(filename, header=None)
# Get Sequences (miRNA and Target mRNA)
# Dataset Column Positions - miRNA, mRNA, miRNA_Seq, mRNA_Seq, Relative_score
self.mirna = self.df.iloc[:, 2]
self.mrna = self.df.iloc[:, 3]
self.rel_score = self.df.iloc[:, -1]
#concatenating row-wise to create a combined vocabulary
all_seq = self.mirna[:] + self.mrna
# Initialize vocabulary and build vocab
self.vocab = Vocabulary(freq_threshold)
self.vocab.build_vocabulary(all_seq.tolist())
def __len__(self):
return len(self.df)
def numericalize_seq(self,seq):
numericalized_seq = [self.vocab.stoi["[START]"]]
numericalized_seq += self.vocab.numericalize(seq)
numericalized_seq.append(self.vocab.stoi["[END]"])
return numericalized_seq
def get_vocabulary(self):
return self.vocab.stoi
def __getitem__(self, index):
mirna, mrna, score = torch.tensor(self.numericalize_seq(self.mirna[index])), torch.tensor(self.numericalize_seq(self.mrna[index])),torch.tensor(self.rel_score[index])
# mirna, mrna, score = mirna.unsqueeze(0), mrna.unsqueeze(0), score.unsqueeze(0)
# print(mirna.size(), mrna.size())
return mirna, mrna, score
class CollateSequences:
def __init__(self, pad_idx):
self.pad_idx = pad_idx
def __call__(self, batch):
# imgs = [item[0].unsqueeze(0) for item in batch]
# imgs = torch.cat(imgs, dim=0)
# targets = [item[1] for item in batch]
# targets = pad_sequence(targets, batch_first=False, padding_value=self.pad_idx)
mirna = [item[0] for item in batch]
mrna = [item[1] for item in batch]
mirna = pad_sequence(mirna, batch_first=True, padding_value=self.pad_idx)
mrna = pad_sequence(mrna, batch_first=True, padding_value=self.pad_idx)
return mirna, mrna, [item[2] for item in batch]
# Returns a ready Loader and the Dataset Class for the Sequence
def get_loader(
seq_csv,
batch_size=5,
num_workers=8,
shuffle=True,
pin_memory=True
):
dataset = SequenceDataset(filename=seq_csv)
pad_idx = dataset.vocab.stoi["<PAD>"]
loader = DataLoader(
dataset=dataset,
batch_size=batch_size,
num_workers=num_workers,
shuffle=shuffle,
pin_memory=pin_memory,
collate_fn=CollateSequences(pad_idx=pad_idx)
)
return loader, dataset