diff --git a/python/paddle/distributed/checkpoint/load_state_dict.py b/python/paddle/distributed/checkpoint/load_state_dict.py index fda6b6f9174b5..4ae82398713ae 100644 --- a/python/paddle/distributed/checkpoint/load_state_dict.py +++ b/python/paddle/distributed/checkpoint/load_state_dict.py @@ -405,9 +405,9 @@ def load_state_dict( assert isinstance( state_dict, dict ), "The state_dict should be a dictionary." - state_dict = flatten_state_dict(state_dict) - if len(state_dict) > 0: - for val in state_dict.values(): + flat_state_dict, mapping = flatten_state_dict(state_dict) + if len(flat_state_dict) > 0: + for val in flat_state_dict.values(): assert isinstance( val, paddle.Tensor ), f"Only support dygraph Tensor now, but is {val}" @@ -423,7 +423,7 @@ def load_state_dict( paddle.distributed.barrier(process_group) rank_to_files = get_rank_to_files( - path, state_dict, process_group, use_dist + path, flat_state_dict, process_group, use_dist ) if len(rank_to_files) <= 0: return @@ -434,16 +434,18 @@ def load_state_dict( ) # read_items: [ReadItem(local_tensor_index, rank, cur_offsets, storage_offsets, lengths)], # slice the storage local tensor in (storage_offsets, lengths) to assign the current tensor in (cur_offsets, lengths) in rank. - read_items = get_read_items(path, state_dict, process_group, use_dist) + read_items = get_read_items( + path, flat_state_dict, process_group, use_dist + ) storage_file_to_state_dict = {} logger.debug( - f"before load, state_dict:{state_dict},\n load_infos:{load_infos},\n read_items:{read_items}" + f"before load, state_dict:{flat_state_dict},\n load_infos:{load_infos},\n read_items:{read_items}" ) state_dict_in_cpu = [] - for k, v in state_dict.items(): + for k, v in flat_state_dict.items(): if v.place.is_cpu_place(): state_dict_in_cpu.append(k) - state_dict[k] = v.cuda() + flat_state_dict[k] = v.cuda() for item in read_items: assert ( item.local_tensor_index in load_infos @@ -484,15 +486,17 @@ def load_state_dict( # The read item rank need to be assigned if item.rank == paddle.distributed.get_rank(): assert ( - item.local_tensor_index.tensor_key in state_dict - ), f"item:{item}, state_dict:{state_dict}" + item.local_tensor_index.tensor_key in flat_state_dict + ), f"item:{item}, state_dict:{flat_state_dict}" cur_local_tensor = ( - state_dict[ + flat_state_dict[ item.local_tensor_index.tensor_key ]._local_value() if use_dist - and state_dict[item.local_tensor_index.tensor_key].is_dist() - else state_dict[item.local_tensor_index.tensor_key] + and flat_state_dict[ + item.local_tensor_index.tensor_key + ].is_dist() + else flat_state_dict[item.local_tensor_index.tensor_key] ) cur_offsets = item.cur_offset cur_lengths = item.lengths @@ -513,7 +517,9 @@ def load_state_dict( else: cur_chunk_tensor = paddle.zeros( item.lengths, - dtype=state_dict[item.local_tensor_index.tensor_key].dtype, + dtype=flat_state_dict[ + item.local_tensor_index.tensor_key + ].dtype, ) if src_rank == item.rank: @@ -530,6 +536,6 @@ def load_state_dict( cur_chunk_tensor, src=src_rank, group=process_group ) - for k, v in state_dict.items(): + for k, v in flat_state_dict.items(): if k in state_dict_in_cpu: state_dict[k] = v.cpu() diff --git a/python/paddle/distributed/checkpoint/metadata.py b/python/paddle/distributed/checkpoint/metadata.py index 4eb5d559a9c0c..d1f3a3fdb66c0 100644 --- a/python/paddle/distributed/checkpoint/metadata.py +++ b/python/paddle/distributed/checkpoint/metadata.py @@ -40,3 +40,4 @@ class LocalTensorIndex: class Metadata: state_dict_metadata: Dict[str, List[LocalTensorMetadata]] = None storage_metadata: Dict[LocalTensorIndex, str] = None + flat_mapping: Dict[str, Tuple[str]] = None diff --git a/python/paddle/distributed/checkpoint/save_state_dict.py b/python/paddle/distributed/checkpoint/save_state_dict.py index b2c380c66ba2f..86047e637e360 100644 --- a/python/paddle/distributed/checkpoint/save_state_dict.py +++ b/python/paddle/distributed/checkpoint/save_state_dict.py @@ -13,7 +13,6 @@ # limitations under the License. import os -from typing import List import paddle from paddle.distributed.communication.group import is_initialized @@ -50,7 +49,7 @@ def check_file_name(file_name, process_group): def merge_state_dict_metadata(global_state_dict_metadata): assert isinstance( - global_state_dict_metadata, List + global_state_dict_metadata, list ), "The global_state_dict should be a list." out = {} for state_dict in global_state_dict_metadata: @@ -64,7 +63,7 @@ def merge_state_dict_metadata(global_state_dict_metadata): return out -def dedup_storage_metadata(global_storage_metadata): +def dedup_key_in_dict(global_storage_metadata): out = {} for storage_metadata in global_storage_metadata: for key, val in storage_metadata.items(): @@ -74,6 +73,34 @@ def dedup_storage_metadata(global_storage_metadata): return out +def dedup_tensor( + local_state_dict, local_storage_metadata, global_storage_metadata +): + """ + Dedup the replicated tensor in local state_dict. + + Args: + local_state_dict(Dict[str, paddle.Tensor]): The state_dict of current rank. + local_storage_metadata(Dict[LocalTensorIndex, str]): The storage metadata of current rank. + global_storage_metadata(Dict[LocalTensorIndex, str]): The final storage metadata of all ranks. + + Examples: + In rank0, local_state_dict:{"w1": t1_0, "w2": t2}, local_storage_metadata:{LocalTensorIndex("w1", (0,0)): "0_0.distcp", LocalTensorIndex("w2", (0,0)): "0_0.distcp"}, + in rank1, local_state_dict:{"w1": t1_1, "w2": t2}, local_storage_metadata:{LocalTensorIndex("w1", (1,0)): "1_0.distcp", LocalTensorIndex("w2", (0,0)): "1_0.distcp"}, + global_storage_metadata:{LocalTensorIndex("w1", (0,0)): "0_0.distcp", LocalTensorIndex("w1", (1,0)): "1_0.distcp", LocalTensorIndex("w2", (0, 0)): "0_0.distcp"}. + w2 is replicated in rank0 and rank1. We save it in rank0 as default thus need to remove it in other ranks. + Finally, the local_state_dict:{"w1": t1_1, "w2": t2} in rank1 update to {"w1": t1_1}. + """ + + for tensor_index, file_name in global_storage_metadata.items(): + rank = int(file_name.split(".")[0].split("_")[0]) + if ( + tensor_index in local_storage_metadata + and rank != paddle.distributed.get_rank() + ): + local_state_dict.pop(tensor_index.tensor_key) + + def save_state_dict( state_dict, path, @@ -107,9 +134,9 @@ def save_state_dict( assert isinstance( state_dict, dict ), "The state_dict should be a dictionary." - state_dict = flatten_state_dict(state_dict) - if len(state_dict) > 0: - for val in state_dict.values(): + flat_state_dict, mapping = flatten_state_dict(state_dict) + if len(flat_state_dict) > 0: + for val in flat_state_dict.values(): assert isinstance( val, paddle.Tensor ), "Only support dygraph Tensor now, support static DistributedTensor later" @@ -134,12 +161,12 @@ def save_state_dict( if use_dist: check_file_name(file_name, process_group) # the parameter_name and order in state_dict should be the same - check_state_dict(state_dict, process_group) + check_state_dict(flat_state_dict, process_group) metadata = Metadata() local_state_dict = {} local_state_dict_metadata = {} local_storage_metadata = {} - for key, val in state_dict.items(): + for key, val in flat_state_dict.items(): if isinstance(val, paddle.Tensor): # Case1: not initialized means this tensor is placed in another mesh which do not contain this rank if not val._is_initialized(): @@ -178,6 +205,7 @@ def save_state_dict( ] = file_name global_state_dict_metadata = [] global_storage_metadata = [] + global_flatten_mapping = [] if use_dist: paddle.distributed.all_gather_object( global_state_dict_metadata, @@ -187,19 +215,24 @@ def save_state_dict( paddle.distributed.all_gather_object( global_storage_metadata, local_storage_metadata, process_group ) + paddle.distributed.all_gather_object( + global_flatten_mapping, mapping, process_group + ) else: global_state_dict_metadata.append(local_state_dict_metadata) global_storage_metadata.append(local_storage_metadata) + global_flatten_mapping.append(mapping) metadata.state_dict_metadata = merge_state_dict_metadata( global_state_dict_metadata ) - metadata.storage_metadata = dedup_storage_metadata( - global_storage_metadata - ) + metadata.storage_metadata = dedup_key_in_dict(global_storage_metadata) + metadata.flat_mapping = dedup_key_in_dict(global_flatten_mapping) if coordinator_rank == paddle.distributed.get_rank(): logger.debug(f"metadata:{metadata}") paddle.save(metadata, os.path.join(path, f"{unique_id}.metadata")) logger.debug(f"local_state_dict:{local_state_dict}") - # TODO(pangengzheng): del the replicated tensor in local_state_dict, now different might save the replicated tensor + dedup_tensor( + local_state_dict, local_storage_metadata, metadata.storage_metadata + ) paddle.save(local_state_dict, os.path.join(path, file_name)) diff --git a/python/paddle/distributed/checkpoint/utils.py b/python/paddle/distributed/checkpoint/utils.py index cb0f069984c3a..d592d6ebcb97b 100644 --- a/python/paddle/distributed/checkpoint/utils.py +++ b/python/paddle/distributed/checkpoint/utils.py @@ -63,5 +63,47 @@ def compute_local_shape_and_global_offset( def flatten_state_dict(state_dict): - # TODO, {"model": {"w0": xxx}} -> {model.w0: xxx} + """ + Flatten the nested dict to a flat dict. + {"model": {"w0": xxx}} -> {model.w0: xxx} + """ + flatten_state_dict = {} + mapping = {} + + def _flatten(key, value): + if isinstance(value, dict): + for k, v in value.items(): + assert isinstance(k, str), f"The key should be str, but is {k}" + _flatten(key + (k,), v) + elif isinstance(value, paddle.Tensor): + flatten_key_str = ".".join(key) + flatten_state_dict[flatten_key_str] = value + mapping[flatten_key_str] = key + else: + raise ValueError( + f"The value should be dict or paddle.Tensor, but is {value}" + ) + + _flatten((), state_dict) + + return flatten_state_dict, mapping + + +def unflatten_state_dict(flat_state_dict, mapping): + """ + Unflatten the flat dict to a nested dict. + {model.w0: xxx} -> {"model": {"w0": xxx}} + """ + state_dict = {} + for key, value in flat_state_dict.items(): + key_tuple = mapping[key] + assert isinstance( + key_tuple, tuple + ), f"The key should be tuple, but is {key_tuple}" + tmp = state_dict + for i in range(len(key_tuple) - 1): + key = key_tuple[i] + tmp = tmp.setdefault(key, {}) + tmp[key_tuple[-1]] = value + return state_dict diff --git a/python/paddle/optimizer/optimizer.py b/python/paddle/optimizer/optimizer.py index 3a64f2095f30a..134b164409a95 100644 --- a/python/paddle/optimizer/optimizer.py +++ b/python/paddle/optimizer/optimizer.py @@ -406,35 +406,7 @@ def set_state_dict(self, state_dict): tensor.set_xpu_scale_value( state_dict.get(var_tmp.name + ".SCALE_VALUE", -1.0) ) - - model_np = np.array(tensor) - - load_para = state_dict[var_tmp.name] - - if isinstance(load_para, Variable): - load_para_np = np.array(load_para) - elif isinstance(load_para, core.eager.Tensor): - load_para_np = np.array(load_para) - elif isinstance(load_para, np.ndarray): - load_para_np = load_para - else: - raise RuntimeError( - f"State dict type {str(type(load_para))} not supprt" - ) - - assert ( - model_np.shape == load_para_np.shape - ), "Parameter shape not match, Dygraph Parameter [ {} ] need tensor with shape {} but load tensor with shape {}".format( - model_np.name, model_np.shape, load_para_np.shape - ) - - assert ( - model_np.dtype == load_para_np.dtype - ), "Parameter dtype not match, Dygraph Parameter [ {} ] need tensor with dtype {} but load tensor with dtype {}".format( - model_np.name, model_np.dtype, load_para_np.dtype - ) - - tensor.set(load_para_np, framework._current_expected_place()) + var.set_value(state_dict[var_tmp.name]) def get_opti_var_name_list(self): return self._opti_name_list diff --git a/test/auto_parallel/CMakeLists.txt b/test/auto_parallel/CMakeLists.txt index 774dc3d2023b9..a735762cce658 100644 --- a/test/auto_parallel/CMakeLists.txt +++ b/test/auto_parallel/CMakeLists.txt @@ -194,6 +194,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU) py_test_modules(test_gpt_with_prim MODULES test_gpt_with_prim) set_tests_properties(test_gpt_with_prim PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 200) + py_test_modules(test_dist_checkpoint_utils MODULES test_dist_checkpoint_utils) + set_tests_properties(test_dist_checkpoint_utils + PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 120) py_test_modules(test_semi_auto_parallel_unshard_dtensor MODULES test_semi_auto_parallel_unshard_dtensor) set_tests_properties(test_semi_auto_parallel_unshard_dtensor diff --git a/test/auto_parallel/semi_auto_parallel_checkpoint_dedup_tensor.py b/test/auto_parallel/semi_auto_parallel_checkpoint_dedup_tensor.py new file mode 100644 index 0000000000000..7f8884156aa7e --- /dev/null +++ b/test/auto_parallel/semi_auto_parallel_checkpoint_dedup_tensor.py @@ -0,0 +1,68 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import numpy as np + +import paddle +import paddle.distributed as dist + + +class TestSaveStateDict: + def __init__(self): + self._ckpt_path = os.getenv("ckpt_path") + + def test_dedup_tesnor(self): + w1 = paddle.arange(32).reshape([4, 8]) + w2 = paddle.arange(32, 36).reshape([2, 2]) + mesh = dist.ProcessMesh([0, 1]) + dist_w1 = dist.shard_tensor(w1, mesh, [dist.Replicate()]) + dist_w2 = dist.shard_tensor(w2, mesh, [dist.Shard(0)]) + state_dict = {"w1": dist_w1, "w2": dist_w2} + # w1 is replicated in rank0 and ran1, it will only save in rank0. + # Therefore, rank0 save state_dict:{"w1": dist_w1, "w2": dist_w2}, rank1 save state_dict:{"w2": dist_w2} + dist.save_state_dict(state_dict, self._ckpt_path) + paddle.distributed.barrier() + # check + expect_local_state_dict = {} + for k, v in state_dict.items(): + if k == "w1" and paddle.distributed.get_rank() != 0: + continue + expect_local_state_dict[k] = v._local_value() + data_file_path = os.path.join( + self._ckpt_path, f"{paddle.distributed.get_rank()}_0.distcp" + ) + metadata_file_path = os.path.join(self._ckpt_path, "0.metadata") + assert os.path.exists(data_file_path) and os.path.exists( + metadata_file_path + ) + local_state_dict = paddle.load(data_file_path) + metadata = paddle.load(metadata_file_path) + + for k, local_tensor in local_state_dict.items(): + assert k in expect_local_state_dict + expect_tensor = expect_local_state_dict[k] + np.testing.assert_equal(expect_tensor.numpy(), local_tensor.numpy()) + for tensor_index, file_name in metadata.storage_metadata.items(): + rank = int(file_name.split(".")[0].split("_")[0]) + if tensor_index.tensor_key == "w1": + assert rank == 0 + + def run_test_case(self): + self.test_dedup_tesnor() + + +if __name__ == '__main__': + TestSaveStateDict().run_test_case() diff --git a/test/auto_parallel/semi_auto_parallel_checkpoint_flatten_mapping.py b/test/auto_parallel/semi_auto_parallel_checkpoint_flatten_mapping.py new file mode 100644 index 0000000000000..c8cfdb22d8598 --- /dev/null +++ b/test/auto_parallel/semi_auto_parallel_checkpoint_flatten_mapping.py @@ -0,0 +1,74 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import paddle +import paddle.distributed as dist + + +class TestSemiautoSaveLoad: + def __init__(self): + self._ckpt_path = os.getenv("ckpt_path") + + def test_flatten_mapping(self): + if paddle.distributed.get_rank() == 0: + state_dict = { + "model": { + "a": paddle.to_tensor([1, 2]), + "b": paddle.to_tensor([3, 4]), + }, + "optimizer": { + "c": paddle.to_tensor([5, 6]), + "d": paddle.to_tensor([7, 8]), + }, + } + else: + state_dict = { + "model": { + "a": paddle.to_tensor([10, 20]), + "b": paddle.to_tensor([30, 40]), + }, + "optimizer": { + "c": paddle.to_tensor([50, 60]), + "d": paddle.to_tensor([70, 80]), + }, + } + expected_mapping = { + "model.a": ("model", "a"), + "model.b": ("model", "b"), + "optimizer.c": ("optimizer", "c"), + "optimizer.d": ("optimizer", "d"), + } + dist.save_state_dict(state_dict, self._ckpt_path) + metadata_path = os.path.join(self._ckpt_path, "0.metadata") + assert os.path.exists(metadata_path) + metadata = paddle.load(metadata_path) + assert len(metadata.flat_mapping) == len( + expected_mapping + ), f"expect {len(expected_mapping)}, but got {len(metadata.flat_mapping)}" + for key in metadata.flat_mapping: + assert ( + key in expected_mapping + ), f"expect {key} in flatten_mapping, but not found" + assert ( + metadata.flat_mapping[key] == expected_mapping[key] + ), f"expect {metadata.flat_mapping[key]} == {expected_mapping[key]}, but not equal" + + def run_test_case(self): + self.test_flatten_mapping() + + +if __name__ == '__main__': + TestSemiautoSaveLoad().run_test_case() diff --git a/test/auto_parallel/semi_auto_parallel_shard_optimizer_api.py b/test/auto_parallel/semi_auto_parallel_shard_optimizer_api.py index f4d22a16c41bd..0153d3bd21216 100644 --- a/test/auto_parallel/semi_auto_parallel_shard_optimizer_api.py +++ b/test/auto_parallel/semi_auto_parallel_shard_optimizer_api.py @@ -179,6 +179,61 @@ def test_shard_optimizer_master_params(self): assert v.is_dist() assert v.shape[-1] == v._local_shape[-1] * 2 + # save load + ckpt_state_dict = opt.state_dict() + dist.save_state_dict(ckpt_state_dict, self._ckpt_path) + paddle.distributed.barrier() + expected_local_state_dict = {} + expected_local_state_dict.setdefault("master_weights", {}) + need_load_state_dict = {} + need_load_state_dict.setdefault("master_weights", {}) + for k, v in ckpt_state_dict.items(): + if k == "LR_Scheduler": + continue + elif k == "master_weights": + assert isinstance(v, dict), v + for mk, mv in v.items(): + expected_local_state_dict[k][mk] = mv._local_value().clone() + need_load_state_dict[k][mk] = paddle.zeros_like(mv) + else: + expected_local_state_dict[k] = v._local_value().clone() + need_load_state_dict[k] = paddle.zeros_like(v) + opt.set_state_dict(need_load_state_dict) + after_set_state_dict = opt.state_dict() + for k, v in after_set_state_dict.items(): + if k == "master_weights": + assert isinstance(v, dict), v + for mk, mv in v.items(): + assert ( + mv.numpy().sum() == 0.0 + ), f"state_dict {k} in master_weights is not zero" + assert ( + need_load_state_dict[k][mk].numpy().sum() == 0.0 + ), f"state_dict {k} in master_weights is not zero" + else: + assert v.numpy().sum() == 0.0, f"state_dict {k} is not zero" + assert k in need_load_state_dict, f"state_dict {k} is not found" + assert ( + need_load_state_dict[k].numpy().sum() == 0.0 + ), f"state_dict {k} is not zero" + dist.load_state_dict(need_load_state_dict, self._ckpt_path) + opt.set_state_dict(need_load_state_dict) + new_state_dict = opt.state_dict() + assert "master_weights" in new_state_dict, new_state_dict + for k, v in new_state_dict.items(): + assert k in expected_local_state_dict + if k == "master_weights": + for mk, mv in v.items(): + np.testing.assert_equal( + mv._local_value().numpy(), + expected_local_state_dict[k][mk].numpy(), + ) + else: + np.testing.assert_equal( + v._local_value().numpy(), + expected_local_state_dict[k].numpy(), + ) + def test_shard_optimizer_params_group(self): paddle.seed(self._seed) linear = paddle.nn.Linear(10, 10) diff --git a/test/auto_parallel/test_dist_checkpoint_utils.py b/test/auto_parallel/test_dist_checkpoint_utils.py new file mode 100644 index 0000000000000..5a51f73f0fa56 --- /dev/null +++ b/test/auto_parallel/test_dist_checkpoint_utils.py @@ -0,0 +1,105 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import unittest + +import collective.test_communication_api_base as test_base +import numpy as np + +import paddle +from paddle.distributed.checkpoint.utils import ( + flatten_state_dict, + unflatten_state_dict, +) + + +class TestDistCheckpointUtils(test_base.CommunicationTestDistBase): + def setUp(self): + super().setUp(num_of_devices=2, timeout=120, nnode=1) + self._default_envs = {} + self._changeable_envs = {"backend": ["gpu"]} + + def test_flatten_mapping(self): + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + ckpt_path_tmp = tempfile.TemporaryDirectory() + ckpt_path = ckpt_path_tmp.name + envs["ckpt_path"] = ckpt_path + self.run_test_case( + "semi_auto_parallel_checkpoint_flatten_mapping.py", + user_defined_envs=envs, + ) + ckpt_path_tmp.cleanup() + + def test_dedup_tensor(self): + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + ckpt_path_tmp = tempfile.TemporaryDirectory() + ckpt_path = ckpt_path_tmp.name + envs["ckpt_path"] = ckpt_path + self.run_test_case( + "semi_auto_parallel_checkpoint_dedup_tensor.py", + user_defined_envs=envs, + ) + ckpt_path_tmp.cleanup() + + def test_flatten_state_dict(self): + state_dict = { + "model": { + "a.0": paddle.to_tensor([1, 2]), + "b": paddle.to_tensor([3, 4]), + }, + "optimizer": { + "c": paddle.to_tensor([5, 6]), + "d.2": paddle.to_tensor([7, 8]), + }, + } + expected_flat_state_dict = { + "model.a.0": paddle.to_tensor([1, 2]), + "model.b": paddle.to_tensor([3, 4]), + "optimizer.c": paddle.to_tensor([5, 6]), + "optimizer.d.2": paddle.to_tensor([7, 8]), + } + flat_state_dict, mapping = flatten_state_dict(state_dict) + self.assertTrue(len(expected_flat_state_dict) == len(flat_state_dict)) + for k, v in flat_state_dict.items(): + self.assertTrue(isinstance(v, paddle.Tensor)) + self.assertTrue(k in expected_flat_state_dict) + np.testing.assert_equal( + v.numpy(), expected_flat_state_dict[k].numpy() + ) + recover_state_dict = unflatten_state_dict(flat_state_dict, mapping) + + def check_state_dict(d1, d2): + self.assertTrue(len(d1) == len(d2)) + self.assertTrue(type(d1) == type(d2)) + if isinstance(d1, dict): + for k in d1: + self.assertTrue(k in d2) + check_state_dict(d1[k], d2[k]) + elif isinstance(d1, paddle.Tensor): + np.testing.assert_equal(d1.numpy(), d2.numpy()) + else: + raise ValueError(f"Invalid type of state_dict:{d1} != {d2}") + + check_state_dict(recover_state_dict, state_dict) + + +if __name__ == "__main__": + unittest.main()