diff --git a/paddlenlp/transformers/configuration_utils.py b/paddlenlp/transformers/configuration_utils.py index fde4640c33ca..4f7f31d8068e 100644 --- a/paddlenlp/transformers/configuration_utils.py +++ b/paddlenlp/transformers/configuration_utils.py @@ -195,8 +195,12 @@ def convert_to_legacy_config(attribute_map: Dict[str, str], config: Dict[str, An args.append(init_arg) config["init_args"] = args + # TODO(wj-Mcat): to improve compatibility for: old local config and new PretrainedConfig, eg: + # { "init_args": [], "init_class": "", "num_classes": 12 } for standard_field, paddle_field in attribute_map.items(): - config[paddle_field] = config.pop(standard_field, None) or config.pop(paddle_field, None) + value = config.pop(standard_field, None) or config.pop(paddle_field, None) + if value is not None: + config[paddle_field] = value return config @@ -729,16 +733,6 @@ def from_pretrained( config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) - # do standard config map: there are some old-school pretrained-config not refactored. - config_dict = convert_to_legacy_config(cls.attribute_map, config_dict) - - config_dict = flatten_model_config(config_dict) - if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: - logger.warning( - f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " - f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." - ) - return cls.from_dict(config_dict, **kwargs) @classmethod @@ -859,12 +853,18 @@ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PretrainedConfig": [`PretrainedConfig`]: The configuration object instantiated from those parameters. """ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False) - # Those arguments may be passed along for our internal telemetry. - # We remove them so they don't appear in `return_unused_kwargs`. - # convert local config to legacy config + # do standard config map: there are some old-school pretrained-config not refactored. config_dict = convert_to_legacy_config(cls.attribute_map, config_dict) + config_dict = flatten_model_config(config_dict) + + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + config = cls(**config_dict) if hasattr(config, "pruned_heads"):