-
Notifications
You must be signed in to change notification settings - Fork 7
/
lvu_dataset.py
85 lines (68 loc) · 2.96 KB
/
lvu_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch
import os
import numpy as np
import random
import pandas as pd
from torch.utils.data import Dataset, DataLoader
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.device_count() > 0:
torch.cuda.manual_seed_all(seed)
set_seed(1112)
DATA_ROOT = '/playpen-storage/mmiemon/lvu/data'
duration_data = pd.read_csv('/playpen-storage/mmiemon/lvu/data/CMD/metadata/durations.csv').set_index('videoid')
class CustomDataset(Dataset):
def __init__(self, args, split):
self.args = args
self.split = split
self.videos = []
self.labels = []
self.starts = []
csv_file = f'{DATA_ROOT}/lvu_1.0/{args.long_term_task}/{split}.csv'
with open(csv_file, 'r') as f:
f.readline()
for line in f:
video_id = line.split()[-2].strip()
if not os.path.exists(f'{DATA_ROOT}/vit_features_spatial/{video_id}.npy'):
print("Features not found for video : ", video_id)
continue
if args.long_term_task == 'view_count':
label = float(np.log(float(line.split()[0])))
# make zero-mean
label -= 11.76425435683139
elif args.long_term_task == 'like_ratio':
items = line.split()
like, dislike = float(items[0]), float(items[1])
label = like / (like + dislike) * 10.0
# make zero-mean
label -= 9.138220535629456
else:
label = int(line.split()[0])
duration = duration_data.loc[video_id]['duration']
self.videos.append(video_id)
self.starts.append(0)
self.labels.append(label)
for start in range(1, duration-args.l_secs+1):
self.videos.append(video_id)
self.starts.append(start)
self.labels.append(label)
print('Total videos in ', split, len(set(self.videos)))
print('Total spans ', split, len(self.videos))
def __len__(self):
if self.split == 'train':
return len(set(self.videos))
else:
return len(self.videos)
def __getitem__(self, idx):
if self.split == 'train':
idx = random.randint(0, len(self.videos)-1)
if self.args.feature_type == 'vit_spatial':
video_features = np.load(f'{DATA_ROOT}/vit_features_spatial/{self.videos[idx]}.npy')
x = np.zeros((self.args.l_secs, 197, 1024))
for i in range(self.starts[idx], min(self.starts[idx] + self.args.l_secs, video_features.shape[0])):
x[i - self.starts[idx]] = video_features[i]
x = np.reshape(x,(x.shape[0]* x.shape[1], 1024))
#you can add support for other types of features smilarly
return self.videos[idx], x, self.labels[idx]