Skip to content

Commit

Permalink
Fix an incorrect usage of is_local argument.
Browse files Browse the repository at this point in the history
  • Loading branch information
xinghai-sun committed Aug 14, 2017
1 parent dd92a02 commit 5a7c92d
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 7 deletions.
2 changes: 1 addition & 1 deletion deep_speech_2/cloud/pcloud_submit.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ MEAN_STD_FILE="../mean_std.npz"
CLOUD_DATA_DIR="/pfs/dlnel/home/sunxinghai@baidu.com/deepspeech2/data"
CLOUD_MODEL_DIR="/pfs/dlnel/home/sunxinghai@baidu.com/deepspeech2/model"
# Configure cloud resources
NUM_CPU=12
NUM_CPU=8
NUM_GPU=8
NUM_NODE=1
MEMORY="10Gi"
Expand Down
6 changes: 5 additions & 1 deletion deep_speech_2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def train(self,
gradient_clipping,
num_passes,
output_model_dir,
is_local=True,
num_iterations_print=100):
"""Train the model.
Expand All @@ -65,6 +66,8 @@ def train(self,
:param num_iterations_print: Number of training iterations for printing
a training loss.
:type rnn_iteratons_print: int
:param is_local: Set to False if running with pserver with multi-nodes.
:type is_local: bool
:param output_model_dir: Directory for saving the model (every pass).
:type output_model_dir: basestring
"""
Expand All @@ -79,7 +82,8 @@ def train(self,
trainer = paddle.trainer.SGD(
cost=self._loss,
parameters=self._parameters,
update_equation=optimizer)
update_equation=optimizer,
is_local=is_local)

# create event handler
def event_handler(event):
Expand Down
8 changes: 3 additions & 5 deletions deep_speech_2/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,15 +179,13 @@ def train():
gradient_clipping=400,
num_passes=args.num_passes,
num_iterations_print=args.num_iterations_print,
output_model_dir=args.output_model_dir)
output_model_dir=args.output_model_dir,
is_local=args.is_local)


def main():
utils.print_arguments(args)
paddle.init(
use_gpu=args.use_gpu,
trainer_count=args.trainer_count,
is_local=args.is_local)
paddle.init(use_gpu=args.use_gpu, trainer_count=args.trainer_count)
train()


Expand Down

0 comments on commit 5a7c92d

Please sign in to comment.