-
Notifications
You must be signed in to change notification settings - Fork 24
/
benchmark_script.py
55 lines (32 loc) · 1.38 KB
/
benchmark_script.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from easy_tpp.config_factory import Config
from easy_tpp.runner import Runner
def main():
""" Run the models on one dataset - take taxi dataset for example """
# Run RMTPP
config = Config.build_from_yaml_file('configs/experiment_config.yaml', experiment_id='RMTPP_train')
model_runner = Runner.build_from_config(config)
model_runner.run()
# Run NHP
config = Config.build_from_yaml_file('configs/experiment_config.yaml', experiment_id='NHP_train')
model_runner = Runner.build_from_config(config)
model_runner.run()
# Run SAHP
config = Config.build_from_yaml_file('configs/experiment_config.yaml', experiment_id='SAHP_train')
model_runner = Runner.build_from_config(config)
model_runner.run()
# Run THP
config = Config.build_from_yaml_file('configs/experiment_config.yaml', experiment_id='THP_train')
model_runner = Runner.build_from_config(config)
model_runner.run()
#Run AttNHP
#converge slow
config = Config.build_from_yaml_file('configs/experiment_config.yaml', experiment_id='AttNHP_train')
model_runner = Runner.build_from_config(config)
model_runner.run()
# Run ODETPP
config = Config.build_from_yaml_file('configs/experiment_config.yaml', experiment_id='ODETPP_train')
model_runner = Runner.build_from_config(config)
model_runner.run()
return
if __name__ == '__main__':
main()