Skip to content

Commit

Permalink
[shardformer] fix import (#5788)
Browse files Browse the repository at this point in the history
  • Loading branch information
ver217 committed Jun 6, 2024
1 parent 5ead00f commit 73e88a5
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
6 changes: 4 additions & 2 deletions colossalai/shardformer/modeling/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.cache_utils import Cache
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
Expand All @@ -17,8 +21,6 @@
LlamaForCausalLM,
LlamaForSequenceClassification,
LlamaModel,
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
apply_rotary_pos_emb,
repeat_kv,
)
Expand Down
6 changes: 4 additions & 2 deletions colossalai/shardformer/modeling/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@
)

try:
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from transformers.models.qwen2.modeling_qwen2 import (
Qwen2Attention,
Qwen2ForCausalLM,
Qwen2ForSequenceClassification,
Qwen2Model,
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
apply_rotary_pos_emb,
repeat_kv,
)
Expand Down

0 comments on commit 73e88a5

Please sign in to comment.