diff --git a/gestureCNN.py b/gestureCNN.py index ae6c9402..e5672840 100644 --- a/gestureCNN.py +++ b/gestureCNN.py @@ -142,7 +142,7 @@ def modlistdir(path, pattern = None): # Load CNN model -def loadCNN(): +def loadCNN(bTraining = False): global get_output model = Sequential() @@ -173,21 +173,22 @@ def loadCNN(): # Model conig details model.get_config() - - #List all the weight files available in current directory - WeightFileName = modlistdir('.','.hdf5') - if len(WeightFileName) == 0: - print('Error: No pretrained weight file found. Please either train the model or download one from the https://github.com/asingh33/CNNGestureRecognizer') - return 0 - else: - print('Found these weight files - {}'.format(WeightFileName)) - #Load pretrained weights - w = int(input("Which weight file to load (enter the INDEX of it, which starts from 0): ")) - fname = WeightFileName[int(w)] - print("loading ", fname) - model.load_weights(fname) - - layer = model.layers[11] + if not bTraining : + #List all the weight files available in current directory + WeightFileName = modlistdir('.','.hdf5') + if len(WeightFileName) == 0: + print('Error: No pretrained weight file found. Please either train the model or download one from the https://github.com/asingh33/CNNGestureRecognizer') + return 0 + else: + print('Found these weight files - {}'.format(WeightFileName)) + #Load pretrained weights + w = int(input("Which weight file to load (enter the INDEX of it, which starts from 0): ")) + fname = WeightFileName[int(w)] + print("loading ", fname) + model.load_weights(fname) + + # refer the last layer here + layer = model.layers[-1] get_output = K.function([model.layers[0].input, K.learning_phase()], [layer.output,]) diff --git a/trackgesture.py b/trackgesture.py index 9c87aa45..13e52f49 100644 --- a/trackgesture.py +++ b/trackgesture.py @@ -206,7 +206,7 @@ def Main(): mod = myNN.loadCNN() break elif ans == 2: - mod = myNN.loadCNN() + mod = myNN.loadCNN(True) myNN.trainModel(mod) input("Press any key to continue") break