-
Notifications
You must be signed in to change notification settings - Fork 83
/
config.py
59 lines (44 loc) · 2.3 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import argparse
import os
import logging
import logging.handlers
# DEBUG < INFO < WARNING < ERROR < CRITICAL
def get_logger(filename):
logger = logging.getLogger('logger')
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter('[%(levelname)s | %(filename)s:%(lineno)s] %(asctime)s: %(message)s')
if not os.path.isdir('log'):
os.mkdir('log')
file_handler = logging.FileHandler('./log/' + filename + '.log')
stream_handler = logging.StreamHandler()
file_handler.setFormatter(formatter)
stream_handler.setFormatter(formatter)
logger.addHandler(file_handler)
logger.addHandler(stream_handler)
return logger
def get_args():
parser = argparse.ArgumentParser('parameters')
parser.add_argument('--dataset', type=str, default='CIFAR10', help='CIFAR10, CIFAR100, MNIST')
parser.add_argument('--model-name', type=str, default='ResNet26', help='ResNet26, ResNet38, ResNet50')
parser.add_argument('--img-size', type=int, default=32)
parser.add_argument('--batch-size', type=int, default=25)
parser.add_argument('--num-workers', type=int, default=0)
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--lr', type=float, default=1e-1)
parser.add_argument('--momentum', type=float, default=0.9)
parser.add_argument('--weight-decay', type=float, default=1e-4)
parser.add_argument('--print-interval', type=int, default=100)
parser.add_argument('--cuda', type=bool, default=True)
parser.add_argument('--pretrained-model', type=bool, default=False)
parser.add_argument('--stem', type=bool, default=False, help='attention stem: True, conv: False')
parser.add_argument('--distributed', type=bool, default=False)
parser.add_argument('--gpu-devices', type=int, nargs='+', default=None)
parser.add_argument('--gpu', type=int, default=None)
parser.add_argument('--rank', type=int, default=0, help='current process number')
parser.add_argument('--world-size', type=int, default=1, help='Total number of processes to be used (number of gpus)')
parser.add_argument('--dist-backend', type=str, default='nccl')
parser.add_argument('--dist-url', default='tcp://127.0.0.1:3456', type=str)
args = parser.parse_args()
logger = get_logger('train')
logger.info(vars(args))
return args, logger