Skip to content

Commit

Permalink
[Bug fixes] update attribute map handler (#4421)
Browse files Browse the repository at this point in the history
* update config

* update legacy config

* add comment for change

* trigger cla
  • Loading branch information
wj-Mcat committed Jan 11, 2023
1 parent ed1f5ac commit 22de327
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions paddlenlp/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,12 @@ def convert_to_legacy_config(attribute_map: Dict[str, str], config: Dict[str, An
args.append(init_arg)
config["init_args"] = args

# TODO(wj-Mcat): to improve compatibility for: old local config and new PretrainedConfig, eg:
# { "init_args": [], "init_class": "", "num_classes": 12 }
for standard_field, paddle_field in attribute_map.items():
config[paddle_field] = config.pop(standard_field, None) or config.pop(paddle_field, None)
value = config.pop(standard_field, None) or config.pop(paddle_field, None)
if value is not None:
config[paddle_field] = value
return config


Expand Down Expand Up @@ -729,16 +733,6 @@ def from_pretrained(

config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)

# do standard config map: there are some old-school pretrained-config not refactored.
config_dict = convert_to_legacy_config(cls.attribute_map, config_dict)

config_dict = flatten_model_config(config_dict)
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)

return cls.from_dict(config_dict, **kwargs)

@classmethod
Expand Down Expand Up @@ -859,12 +853,18 @@ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PretrainedConfig":
[`PretrainedConfig`]: The configuration object instantiated from those parameters.
"""
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
# Those arguments may be passed along for our internal telemetry.
# We remove them so they don't appear in `return_unused_kwargs`.

# convert local config to legacy config
# do standard config map: there are some old-school pretrained-config not refactored.
config_dict = convert_to_legacy_config(cls.attribute_map, config_dict)

config_dict = flatten_model_config(config_dict)

if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)

config = cls(**config_dict)

if hasattr(config, "pruned_heads"):
Expand Down

0 comments on commit 22de327

Please sign in to comment.