From 22de327cd2e927cfc09930143d56f3fe314f2b27 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=AA=91=E9=A9=AC=E5=B0=8F=E7=8C=AB?= <1435130236@qq.com> Date: Wed, 11 Jan 2023 16:06:27 +0800 Subject: [PATCH] [Bug fixes] update attribute map handler (#4421) * update config * update legacy config * add comment for change * trigger cla --- paddlenlp/transformers/configuration_utils.py | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/paddlenlp/transformers/configuration_utils.py b/paddlenlp/transformers/configuration_utils.py index fde4640c33ca..4f7f31d8068e 100644 --- a/paddlenlp/transformers/configuration_utils.py +++ b/paddlenlp/transformers/configuration_utils.py @@ -195,8 +195,12 @@ def convert_to_legacy_config(attribute_map: Dict[str, str], config: Dict[str, An args.append(init_arg) config["init_args"] = args + # TODO(wj-Mcat): to improve compatibility for: old local config and new PretrainedConfig, eg: + # { "init_args": [], "init_class": "", "num_classes": 12 } for standard_field, paddle_field in attribute_map.items(): - config[paddle_field] = config.pop(standard_field, None) or config.pop(paddle_field, None) + value = config.pop(standard_field, None) or config.pop(paddle_field, None) + if value is not None: + config[paddle_field] = value return config @@ -729,16 +733,6 @@ def from_pretrained( config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) - # do standard config map: there are some old-school pretrained-config not refactored. - config_dict = convert_to_legacy_config(cls.attribute_map, config_dict) - - config_dict = flatten_model_config(config_dict) - if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: - logger.warning( - f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " - f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." - ) - return cls.from_dict(config_dict, **kwargs) @classmethod @@ -859,12 +853,18 @@ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PretrainedConfig": [`PretrainedConfig`]: The configuration object instantiated from those parameters. """ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False) - # Those arguments may be passed along for our internal telemetry. - # We remove them so they don't appear in `return_unused_kwargs`. - # convert local config to legacy config + # do standard config map: there are some old-school pretrained-config not refactored. config_dict = convert_to_legacy_config(cls.attribute_map, config_dict) + config_dict = flatten_model_config(config_dict) + + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + config = cls(**config_dict) if hasattr(config, "pruned_heads"):