Skip to content

Commit

Permalink
IssueFix: Traning new gesture got broken due to my earlier commit. It…
Browse files Browse the repository at this point in the history
…s fixed with this commit.
  • Loading branch information
asingh33 committed Nov 21, 2019
1 parent 0fb93e9 commit c119198
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 17 deletions.
33 changes: 17 additions & 16 deletions gestureCNN.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def modlistdir(path, pattern = None):


# Load CNN model
def loadCNN():
def loadCNN(bTraining = False):
global get_output
model = Sequential()

Expand Down Expand Up @@ -173,21 +173,22 @@ def loadCNN():
# Model conig details
model.get_config()


#List all the weight files available in current directory
WeightFileName = modlistdir('.','.hdf5')
if len(WeightFileName) == 0:
print('Error: No pretrained weight file found. Please either train the model or download one from the https://github.com/asingh33/CNNGestureRecognizer')
return 0
else:
print('Found these weight files - {}'.format(WeightFileName))
#Load pretrained weights
w = int(input("Which weight file to load (enter the INDEX of it, which starts from 0): "))
fname = WeightFileName[int(w)]
print("loading ", fname)
model.load_weights(fname)

layer = model.layers[11]
if not bTraining :
#List all the weight files available in current directory
WeightFileName = modlistdir('.','.hdf5')
if len(WeightFileName) == 0:
print('Error: No pretrained weight file found. Please either train the model or download one from the https://github.com/asingh33/CNNGestureRecognizer')
return 0
else:
print('Found these weight files - {}'.format(WeightFileName))
#Load pretrained weights
w = int(input("Which weight file to load (enter the INDEX of it, which starts from 0): "))
fname = WeightFileName[int(w)]
print("loading ", fname)
model.load_weights(fname)

# refer the last layer here
layer = model.layers[-1]
get_output = K.function([model.layers[0].input, K.learning_phase()], [layer.output,])


Expand Down
2 changes: 1 addition & 1 deletion trackgesture.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def Main():
mod = myNN.loadCNN()
break
elif ans == 2:
mod = myNN.loadCNN()
mod = myNN.loadCNN(True)
myNN.trainModel(mod)
input("Press any key to continue")
break
Expand Down

0 comments on commit c119198

Please sign in to comment.