Skip to content

Commit

Permalink
[PSCore]Fix test fleet base 2 (#38588)
Browse files Browse the repository at this point in the history
  • Loading branch information
zmxdream committed Dec 30, 2021
1 parent 15cbf81 commit 04496d8
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 46 deletions.
8 changes: 8 additions & 0 deletions python/paddle/distributed/fleet/base/fleet_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,8 @@ def init_server(self, *args, **kwargs):
"""
self._runtime_handle._init_server(*args, **kwargs)

@is_non_distributed_check
@inited_runtime_handler
def load_model(self, path, mode):
"""
load fleet model from path
Expand Down Expand Up @@ -699,6 +701,8 @@ def stop_worker(self):
"""
self._runtime_handle._stop_worker()

@is_non_distributed_check
@inited_runtime_handler
def save(self, dirname, feed=[], fetch=[], **configs):
inference = True

Expand Down Expand Up @@ -742,6 +746,8 @@ def save(self, dirname, feed=[], fetch=[], **configs):
self._runtime_handle._save_persistables(
executor, dirname, main_program=None, mode=increment_mode)

@is_non_distributed_check
@inited_runtime_handler
def save_inference_model(self,
executor,
dirname,
Expand Down Expand Up @@ -777,6 +783,8 @@ def save_inference_model(self,
executor, dirname, feeded_var_names, target_vars, main_program,
export_for_deployment, mode)

@is_non_distributed_check
@inited_runtime_handler
def save_persistables(self, executor, dirname, main_program=None, mode=0):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,6 @@ def test_gradient_merge_optimizer(self):
self.assertEqual(sends, 0)
self.assertEqual(sgds, 0)

fleet.init_worker()
time.sleep(8)
fleet.stop_worker()


if __name__ == "__main__":
unittest.main()
45 changes: 3 additions & 42 deletions python/paddle/fluid/tests/unittests/test_fleet_base_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ class TestFleetBase(unittest.TestCase):
def setUp(self):
os.environ["POD_IP"] = "127.0.0.1"
os.environ["PADDLE_PORT"] = "36000"
os.environ["PADDLE_TRAINERS_NUM"] = "2"
os.environ["PADDLE_PSERVERS_IP_PORT_LIST"] = \
"127.0.0.1:36001,127.0.0.2:36001"
os.environ["PADDLE_TRAINERS_NUM"] = "1"
#os.environ["PADDLE_PSERVERS_IP_PORT_LIST"] = \
# "127.0.0.1:36001,127.0.0.2:36001"

def test_ps_minimize(self):
import paddle
Expand Down Expand Up @@ -78,45 +78,6 @@ def test_ps_minimize(self):
fleet.load_model(path="/tmp", mode=0)
fleet.load_model(path="/tmp", mode=1)

self.assertRaises(
Exception,
fleet.save_inference_model,
dirname='/tmp/',
feeded_var_names=['x', 'y'],
target_vars=[avg_cost],
executor="exe")

self.assertRaises(
Exception,
fleet.save_inference_model,
dirname='/tmp/',
feeded_var_names=['x', 'y'],
target_vars=[avg_cost],
executor=exe,
main_program=compiled_prog)

self.assertRaises(
Exception,
fleet.save_inference_model,
dirname='afs:/tmp/',
feeded_var_names=['x', 'y'],
target_vars=[avg_cost],
executor=exe,
main_program=compiled_prog)

self.assertRaises(
Exception, fleet.save_persistables, executor=pe, dirname='/tmp/')

self.assertRaises(
Exception, fleet.save_persistables, executor="exe", dirname='/tmp/')

self.assertRaises(
Exception,
fleet.save_persistables,
executor=exe,
dirname='/tmp/',
main_program=compiled_prog)


if __name__ == "__main__":
unittest.main()

0 comments on commit 04496d8

Please sign in to comment.