-
Notifications
You must be signed in to change notification settings - Fork 1
/
balanced_set.py
108 lines (100 loc) · 3.76 KB
/
balanced_set.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from parse_movies_example import load_all_movies
from random import shuffle
import numpy as np
import string
FIELDS = ['title', 'year', 'identifier', 'episode', 'summary']
NUM_MOVIES_PER_DECADE = 6000
def balanced_set(filename):
"""
Generator to yield movies from a text file generated by
gen_balanced_set. The format of this text file differs from
plot.list.gz, but the movies yielded by this generator are the same
as those from the load_all_movies generator from parse_movies_example
"""
for line in open(filename, 'r'):
#current_movie = line
fields = line.split("||")
current_movie = {"title": fields[0],
"year": int(fields[1]),
'identifier': fields[2],
'episode': fields[3],
"summary": fields[4]}
yield current_movie
def split_balanced(infile, trainfile, testfile):
movies = balanced_set(infile)
decades = [1930 + 10*i for i in range(0,9)]
counts = {}
for i in decades:
counts[i] = NUM_MOVIES_PER_DECADE/2
train = open(trainfile, 'w')
test = open(testfile, 'w')
for movie in movies:
if counts[movie['year']] != 0:
f = test
counts[movie['year']] -= 1
else:
f = train
try:
f.write('||'.join([movie['title'],
str(movie['year']),
movie['identifier'],
movie['episode'],
' '.join(movie['summary'].split())]))
except:
print movie
f.write('\n')
def gen_balanced_set(infile, outfile):
"""
Generate a new 'balanced' data set from the full dataset and
write it to a textfile. The movies are shuffled before writing, so
this will output a different file every time. I don't call this often;
generally, it is sufficient to call this once and use the above
generator on the resulting file. As such, this routine isn't all
that efficient.
"""
movies = [movie for movie in load_all_movies(infile)]
shuffle(movies)
decades = [1930 + 10*i for i in range(0,9)]
counts = {}
for i in decades:
counts[i] = 6000
f = open(outfile, 'w')
for movie in movies:
if counts[movie['year']] != 0:
f.write('||'.join([movie['title'],
str(movie['year']),
movie['identifier'],
movie['episode'],
' '.join(movie['summary'].split())]))
f.write('\n')
counts[movie['year']]-=1
def list_all_words(filename, outfile):
movies = load_all_movies(filename)
all_words = set()
for m in movies:
summary = m['summary']
for word in clean_str(summary).split():
all_words.add(word)
with open(outfile, 'w') as f:
for word in all_words:
f.write(word + '\n')
def clean_str(instr, punc_to_whitespace=False):
"""
Helper to return string with punctuation and capital letters removed.
"""
if punc_to_whitespace:
table = string.maketrans(string.punctuation,
' '*len(string.punctuation))
return instr.lower().translate(table)
return instr.lower().translate(None, string.punctuation)
if __name__=='__main__':
gen_balanced_set('plot.list.gz','balanced.txt')
movies = balanced_set('balanced.txt')
counts = [0 for d in range(0, 9)]
to_index = lambda year: (year-1930)/10
to_year = lambda index: 10*index + 1930
for m in movies:
counts[ (m['year']-1930)/10 ]+=1
print 'total: ' + str(sum(counts))
for i, count in enumerate(counts):
print str(to_year(i)) + ': ' + str(count)