diff --git a/README.md b/README.md index e6ed8a4..66ae61f 100644 --- a/README.md +++ b/README.md @@ -45,7 +45,7 @@ Using pretrained models not only get high performance, but fastly attach converg # References - \[1\] 'Improved Training of Wasserstein GANs' by Ishaan Gulrajani et. al, https://arxiv.org/abs/1704.00028, (https://github.com/igul222/improved_wgan_training)[code] -- \[2\] 'GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium' by Martin Heusel et. al, https://arxiv.org/abs/1704.00028 +- \[2\] 'GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium' by Martin Heusel et. al, https://arxiv.org/abs/1706.08500, (https://github.com/bioinf-jku/TTUR)[code] # Contact If you run into any problems with this code, please submit a bug report on the Github site of the project. For another inquries pleace contact with me: yaxing@cvc.uab.es diff --git a/tflib/lsun_label.py b/tflib/lsun_label.py index 0b011f5..f51d890 100644 --- a/tflib/lsun_label.py +++ b/tflib/lsun_label.py @@ -16,64 +16,67 @@ 'restaurant':9} def make_generator(path, n_files, batch_size,image_size, IW = False, pharse='train'): - epoch_count = [1] - image_list_main = listdir(path) - image_list = [] - for sub_class in image_list_main: - # pdb.set_trace() - sub_class_path =path + '/'+ sub_class + '/'+ pharse - sub_class_image = listdir(sub_class_path) - image_list.extend([sub_class_path + '/' + i for i in sub_class_image]) + epoch_count = [1] + image_list_main = listdir(path) + image_list = [] + for sub_class in image_list_main: + # pdb.set_trace() + sub_class_path =path + '/'+ sub_class + '/'+ pharse + sub_class_image = listdir(sub_class_path) + image_list.extend([sub_class_path + '/' + i for i in sub_class_image]) - def get_epoch(): - images = np.zeros((batch_size, 3, 64, 64), dtype='int32') - labels = np.zeros((batch_size,), dtype='int32') - files = range(len(image_list)) - random_state = np.random.RandomState(epoch_count[0]) - random_state.shuffle(files) - epoch_count[0] += 1 - for n, i in enumerate(files): - #image = scipy.misc.imread("{}/{}.png".format(path, str(i+1).zfill(len(str(n_files))))) - image = scipy.misc.imread("{}".format(image_list[i])) - label = Label[image_list[i].split('/')[2]] - image = scipy.misc.imresize(image,(image_size,image_size)) - images[n % batch_size] = image.transpose(2,0,1) - labels[n % batch_size] = label - if n > 0 and n % batch_size == 0: - yield (images,labels) - def get_epoch_from_end(): - images = np.zeros((batch_size, 3, 64, 64), dtype='int32') - files = range(n_files) - random_state = np.random.RandomState(epoch_count[0]) - random_state.shuffle(files) - epoch_count[0] += 1 - for n, i in enumerate(files): - #image = scipy.misc.imread("{}/{}.png".format(path, str(i+1).zfill(len(str(n_files))))) + def get_epoch(): + images = np.zeros((batch_size, 3, 64, 64), dtype='int32') + labels = np.zeros((batch_size,), dtype='int32') + files = range(len(image_list)) + random_state = np.random.RandomState(epoch_count[0]) + random_state.shuffle(files) + epoch_count[0] += 1 + for n, i in enumerate(files): + #image = scipy.misc.imread("{}/{}.png".format(path, str(i+1).zfill(len(str(n_files))))) + image = scipy.misc.imread("{}".format(image_list[i])) + label = Label[image_list[i].split('/')[2]] + image = scipy.misc.imresize(image,(image_size,image_size)) + images[n % batch_size] = image.transpose(2,0,1) + labels[n % batch_size] = label + if n > 0 and n % batch_size == 0: + yield (images,labels) + ''' + def get_epoch_from_end(): + images = np.zeros((batch_size, 3, 64, 64), dtype='int32') + files = range(n_files) + random_state = np.random.RandomState(epoch_count[0]) + random_state.shuffle(files) + epoch_count[0] += 1 + for n, i in enumerate(files): + #image = scipy.misc.imread("{}/{}.png".format(path, str(i+1).zfill(len(str(n_files))))) - image = scipy.misc.imread("{}".format(path + image_list[-i-1])) + image = scipy.misc.imread("{}".format(path + image_list[-i-1])) - image = scipy.misc.imresize(image,(image_size,image_size)) - images[n % batch_size] = image.transpose(2,0,1) - if n > 0 and n % batch_size == 0: - yield (images,labels) - return get_epoch_from_end if IW else get_epoch + image = scipy.misc.imresize(image,(image_size,image_size)) + images[n % batch_size] = image.transpose(2,0,1) + #if n > 0 and n % batch_size == 0: + # yield (images,labels) + ''' + return get_epoch def load_from_end(batch_size, data_dir='/home/ishaan/data/imagenet64',image_size = 64, NUM_TRAIN = 7000): - return ( - make_generator(data_dir+'/train/', NUM_TRAIN, batch_size,image_size, IW =True), - make_generator(data_dir+'/val/', 10000, batch_size,image_size, IW =True) - ) + return ( + make_generator(data_dir+'/train/', NUM_TRAIN, batch_size,image_size, IW =True), + make_generator(data_dir+'/val/', 10000, batch_size,image_size, IW =True) + ) def load(batch_size, data_dir='/home/ishaan/data/imagenet64',image_size = 64, NUM_TRAIN = 7000): - return ( - make_generator(data_dir, NUM_TRAIN, batch_size,image_size, pharse='train'), - make_generator(data_dir, 10000, batch_size,image_size, pharse='val') - ) - + return ( + make_generator(data_dir, NUM_TRAIN, batch_size,image_size, pharse='train'), + make_generator(data_dir, 10000, batch_size,image_size, pharse='val') + ) +''' if __name__ == '__main__': - train_gen, valid_gen = load(64) + train_gen, valid_gen = load(64) + t0 = time.time() + for i, batch in enumerate(train_gen(), start=1): + print "{}\t{}".format(str(time.time() - t0), batch[0][0,0,0,0]) + if i == 1000: + break t0 = time.time() - for i, batch in enumerate(train_gen(), start=1): - print "{}\t{}".format(str(time.time() - t0), batch[0][0,0,0,0]) - if i == 1000: - break - t0 = time.time() +''' \ No newline at end of file