Skip to content

Commit

Permalink
[lazy] fix lazy cls init (#5720)
Browse files Browse the repository at this point in the history
* fix

* fix

* fix

* fix

* fix

* remove kernel intall

* rebase

revert

fix

* fix

* fix
  • Loading branch information
flybird11111 committed May 17, 2024
1 parent 2011b13 commit 9d83c6d
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .github/workflows/build_on_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ jobs:
- name: Install Colossal-AI
run: |
BUILD_EXT=1 pip install -v -e .
pip install -v -e .
pip install -r requirements/requirements-test.txt
- name: Store Colossal-AI Cache
Expand Down
23 changes: 23 additions & 0 deletions colossalai/lazy/pretrained.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import os
from typing import Callable, Optional, Union

Expand Down Expand Up @@ -74,6 +75,24 @@ def new_from_pretrained(
subfolder = kwargs.pop("subfolder", "")
commit_hash = kwargs.pop("_commit_hash", None)
variant = kwargs.pop("variant", None)

kwargs.pop("state_dict", None)
kwargs.pop("from_tf", False)
kwargs.pop("from_flax", False)
kwargs.pop("output_loading_info", False)
kwargs.pop("trust_remote_code", None)
kwargs.pop("low_cpu_mem_usage", None)
kwargs.pop("device_map", None)
kwargs.pop("max_memory", None)
kwargs.pop("offload_folder", None)
kwargs.pop("offload_state_dict", False)
kwargs.pop("load_in_8bit", False)
kwargs.pop("load_in_4bit", False)
kwargs.pop("quantization_config", None)
kwargs.pop("adapter_kwargs", {})
kwargs.pop("adapter_name", "default")
kwargs.pop("use_flash_attention_2", False)

use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)

if len(kwargs) > 0:
Expand Down Expand Up @@ -108,6 +127,10 @@ def new_from_pretrained(
**kwargs,
)
else:
config = copy.deepcopy(config)
kwarg_attn_imp = kwargs.pop("attn_implementation", None)
if kwarg_attn_imp is not None and config._attn_implementation != kwarg_attn_imp:
config._attn_implementation = kwarg_attn_imp
model_kwargs = kwargs

if commit_hash is None:
Expand Down

0 comments on commit 9d83c6d

Please sign in to comment.