-
Notifications
You must be signed in to change notification settings - Fork 0
/
mnist_memory_support.py
106 lines (80 loc) · 3.39 KB
/
mnist_memory_support.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from src.preprocess import *
from src.clustering_tree import *
import matplotlib.pyplot as plt
from ot.bregman import barycenter_sinkhorn
from multiprocessing import Pool
import os
import re
import argparse
def peak_memory():
"""
Return the peak memory that the process used (MB).
"""
pid = os.getpid()
with open(f'/proc/{pid}/status') as f:
# extract "VmHWM: 862168 kB"
for line in f:
if not line.startswith('VmHWM:'):
continue
return int(re.search('[0-9]+', line)[0]) / 1000.
raise ValueError('Not Found')
def calc_tree_bar(func_args):
tree = func_args[0]
a_list = pickle.load(open("exp/tmp/" + args.method + "_a_list.pk", "rb"))
#a_list = args[1]
barycenter = tree.barycenter(a_list)
mem_size = peak_memory()
print(f"Peak Memory : {mem_size} MB")
return mem_size
def calc_tree_naive_bar(func_args):
tree = func_args[0]
#a_list = args[1]
a_list = pickle.load(open("exp/tmp/" + args.method + "_a_list.pk", "rb"))
barycenter = tree.barycenter_naive(a_list)
mem_size = peak_memory()
print(f"Peak Memory : {mem_size} MB")
return mem_size
def calc_ibp_bar(func_args):
#M = args[0]
#a_list = args[1]
M = pickle.load(open("exp/tmp/" + args.method + "_M.pk", "rb"))
a_list = pickle.load(open("exp/tmp/" + args.method + "_A.pk", "rb"))
barycenter = barycenter_sinkhorn(a_list, M, 0.01)
mem_size = peak_memory()
print(f"Peak Memory : {mem_size} MB")
return mem_size
if __name__ == "__main__":
n_pixels = [int(28*math.sqrt(1/2)), 28, int(28*math.sqrt(2)), int(28*math.sqrt(3)), int(28*math.sqrt(4)), int(28*math.sqrt(5)), int(28*math.sqrt(6))]
n_classes = [i for i in range(10)]
parser = argparse.ArgumentParser()
parser.add_argument('method', type=str, help="FastPSD, PSD, or IBP")
args = parser.parse_args()
process_list = [Pool(1) for _ in range(len(n_pixels)*len(n_classes))]
mem_list = {}
for n_pixel in n_pixels:
mem_list[str(n_pixel)] = []
n_process = 0
for i in range(len(n_pixels)):
for n_class in n_classes:
if args.method == "IBP":
mnist = MNIST(n_pixels=int(n_pixels[i]))
a_list = np.array(mnist.get_data(n_class)[:1000]).T
M = mnist.compute_M()
pickle.dump(M, open("exp/tmp/" + args.method + "_M.pk", "wb"))
pickle.dump(a_list, open("exp/tmp/" + args.method + "_A.pk", "wb"))
result = process_list[n_process].map(calc_ibp_bar, [(None,),])
else:
mnist = MNIST(n_pixels=int(n_pixels[i]))
a_list = mnist.get_data(n_class)[:1000]
support = mnist.get_support()
tree = ClusteringTree(support)
pickle.dump(a_list, open("exp/tmp/" + args.method + "_a_list.pk", "wb"))
if args.method == "FastPSD":
result = process_list[n_process].map(calc_tree_bar, [(tree,),])
elif args.method == "PSD":
result = process_list[n_process].map(calc_tree_naive_bar, [(tree,),])
process_list[n_process].close()
mem_list[str(n_pixels[i])].append(result[0])
print(mem_list)
n_process += 1
pickle.dump(mem_list, open("exp/mnist_memory_support_" + args.method + ".pk", "wb"))