Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Models]Add pp for chatglmv2 #9043

Open
wants to merge 13 commits into
base: develop
Choose a base branch
from
5 changes: 3 additions & 2 deletions paddlenlp/experimental/transformers/chatglm_v2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from paddlenlp.transformers import ChatGLMv2Config, ChatGLMv2PretrainedModel
from paddlenlp.transformers.chatglm_v2.modeling import (
Embedding,
LayerNorm,
RMSNorm,
RotaryEmbedding,
)
Expand Down Expand Up @@ -183,9 +184,9 @@

self.post_layer_norm = config.post_layer_norm
if self.post_layer_norm:
LayerNormFunc = RMSNorm if config.rmsnorm else nn.LayerNorm
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm

Check warning on line 187 in paddlenlp/experimental/transformers/chatglm_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/chatglm_v2/modeling.py#L187

Added line #L187 was not covered by tests
# Final layer norm before output.
self.final_layernorm = LayerNormFunc(config.hidden_size, epsilon=config.layernorm_epsilon, config=config)
self.final_layernorm = LayerNormFunc(config)

Check warning on line 189 in paddlenlp/experimental/transformers/chatglm_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/chatglm_v2/modeling.py#L189

Added line #L189 was not covered by tests

def get_input_embeddings(self):
return self.embedding.word_embeddings
Expand Down
1 change: 1 addition & 0 deletions paddlenlp/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@
from .chatglm.tokenizer import *
from .chatglm_v2.configuration import *
from .chatglm_v2.modeling import *
from .chatglm_v2.modeling_pp import *
from .chatglm_v2.tokenizer import *
from .speecht5.configuration import *
from .speecht5.modeling import *
Expand Down
3 changes: 3 additions & 0 deletions paddlenlp/transformers/chatglm_v2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .modeling import *
from .modeling_pp import *
from .tokenizer import *
27 changes: 17 additions & 10 deletions paddlenlp/transformers/chatglm_v2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,16 +150,25 @@
return paddle.concat((x_out2, x_pass), axis=-1)


class LayerNorm(nn.LayerNorm):
def __init__(self, config):
self.config = config
super().__init__(config.hidde_size, epsilon=config.layernorm_epsilon)

Check warning on line 156 in paddlenlp/transformers/chatglm_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/chatglm_v2/modeling.py#L155-L156

Added lines #L155 - L156 were not covered by tests


class RMSNorm(nn.Layer):
def __init__(self, hidden_size, config: ChatGLMv2Config, epsilon=None):
def __init__(self, config: ChatGLMv2Config):
super().__init__()
self.hidden_size = hidden_size
self.hidden_size = config.hidden_size
self.weight = paddle.create_parameter(
shape=[self.hidden_size],
dtype=paddle.get_default_dtype(),
default_initializer=nn.initializer.Constant(1.0),
)
self.epsilon = 1e-5 if epsilon is None else epsilon
self.epsilon = 1e-5 if config.layernorm_epsilon is None else config.layernorm_epsilon

if config.sequence_parallel:
mark_as_sequence_parallel_parameter(self.weight)

Check warning on line 171 in paddlenlp/transformers/chatglm_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/chatglm_v2/modeling.py#L171

Added line #L171 was not covered by tests

if config.sequence_parallel:
mark_as_sequence_parallel_parameter(self.weight)
Expand Down Expand Up @@ -487,18 +496,16 @@
self.config = config
self.fp32_residual_connection = config.fp32_residual_connection

LayerNormFunc = RMSNorm if config.rmsnorm else nn.LayerNorm
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
# Layernorm on the input data.
self.input_layernorm = LayerNormFunc(config.hidden_size, epsilon=config.layernorm_epsilon, config=config)
self.input_layernorm = LayerNormFunc(config)

# Self attention.
self.self_attention = SelfAttention(config, layer_number)
self.hidden_dropout = config.hidden_dropout

# Layernorm on the attention output
self.post_attention_layernorm = LayerNormFunc(
config.hidden_size, epsilon=config.layernorm_epsilon, config=config
)
self.post_attention_layernorm = LayerNormFunc(config)

# MLP
self.mlp = MLP(config)
Expand Down Expand Up @@ -585,9 +592,9 @@
self.layers = nn.LayerList([build_layer(i + 1) for i in range(self.num_hidden_layers)])

if self.post_layer_norm:
LayerNormFunc = RMSNorm if config.rmsnorm else nn.LayerNorm
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
# Final layer norm before output.
self.final_layernorm = LayerNormFunc(config.hidden_size, epsilon=config.layernorm_epsilon, config=config)
self.final_layernorm = LayerNormFunc(config)

def _get_layer(self, layer_number):
return self.layers[layer_number]
Expand Down
292 changes: 292 additions & 0 deletions paddlenlp/transformers/chatglm_v2/modeling_pp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,292 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.distributed.fleet as fleet
from paddle.distributed.fleet.meta_parallel import LayerDesc, PipelineLayer

try:
from paddle.distributed.fleet.utils.sequence_parallel_utils import ScatterOp
except:
pass

Check warning on line 21 in paddlenlp/transformers/chatglm_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/chatglm_v2/modeling_pp.py#L20-L21

Added lines #L20 - L21 were not covered by tests

from paddlenlp.transformers.model_utils import PipelinePretrainedModel

from .modeling import (
ChatGLMv2Config,
Chatglmv2LMHead,
ChatGLMv2PretrainedModel,
ChatGLMv2PretrainingCriterion,
Embedding,
GLMBlock,
RMSNorm,
)

__all__ = ["ChatGLMv2ForCausalLMPipe"]


def get_hcg():
return fleet.get_hybrid_communicate_group()

Check warning on line 39 in paddlenlp/transformers/chatglm_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/chatglm_v2/modeling_pp.py#L39

Added line #L39 was not covered by tests


def get_attr(layer, name):
if getattr(layer, name, None) is not None:
return getattr(layer, name, None)

Check warning on line 44 in paddlenlp/transformers/chatglm_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/chatglm_v2/modeling_pp.py#L43-L44

Added lines #L43 - L44 were not covered by tests
else:
return get_attr(layer._layer, name)

Check warning on line 46 in paddlenlp/transformers/chatglm_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/chatglm_v2/modeling_pp.py#L46

Added line #L46 was not covered by tests


def parse_args(args):
if isinstance(args, tuple):
if len(args) == 6:
hidden_states, attention_mask, position_ids, rotary_pos_emb, kv_cache, use_cache = args
elif len(args) == 5:
hidden_states, attention_mask, position_ids, rotary_pos_emb, kv_cache = args
use_cache = None
elif len(args) == 4:
hidden_states, attention_mask, position_ids, rotary_pos_emb = args
kv_cache = None
use_cache = None
elif len(args) == 3:
hidden_states, attention_mask, position_ids = args
rotary_pos_emb = None
kv_cache = None
use_cache = None
elif len(args) == 2:
hidden_states, attention_mask = args
position_ids = None
rotary_pos_emb = None
kv_cache = None
use_cache = None

Check warning on line 70 in paddlenlp/transformers/chatglm_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/chatglm_v2/modeling_pp.py#L50-L70

Added lines #L50 - L70 were not covered by tests
else:
hidden_states = args
attention_mask, position_ids, rotary_pos_emb, kv_cache, use_cache = None, None, None, None, None

Check warning on line 73 in paddlenlp/transformers/chatglm_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/chatglm_v2/modeling_pp.py#L72-L73

Added lines #L72 - L73 were not covered by tests

if position_ids is not None:
position_ids.stop_gradient = True

Check warning on line 76 in paddlenlp/transformers/chatglm_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/chatglm_v2/modeling_pp.py#L75-L76

Added lines #L75 - L76 were not covered by tests

if attention_mask is not None:
attention_mask.stop_gradient = True

Check warning on line 79 in paddlenlp/transformers/chatglm_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/chatglm_v2/modeling_pp.py#L78-L79

Added lines #L78 - L79 were not covered by tests

if rotary_pos_emb is not None:
rotary_pos_emb.stop_gradient = True

Check warning on line 82 in paddlenlp/transformers/chatglm_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/chatglm_v2/modeling_pp.py#L81-L82

Added lines #L81 - L82 were not covered by tests

if kv_cache is not None:
kv_cache.stop_gradient = True

Check warning on line 85 in paddlenlp/transformers/chatglm_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/chatglm_v2/modeling_pp.py#L84-L85

Added lines #L84 - L85 were not covered by tests

if use_cache is not None:
use_cache.stop_gradient = True

Check warning on line 88 in paddlenlp/transformers/chatglm_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/chatglm_v2/modeling_pp.py#L87-L88

Added lines #L87 - L88 were not covered by tests

return hidden_states, attention_mask, position_ids, rotary_pos_emb, kv_cache, use_cache

Check warning on line 90 in paddlenlp/transformers/chatglm_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/chatglm_v2/modeling_pp.py#L90

Added line #L90 was not covered by tests


def return_args(
hidden_states, attention_mask=None, position_ids=None, rotary_pos_emb=None, kv_cache=None, use_cache=None
):
ret = (hidden_states,)

Check warning on line 96 in paddlenlp/transformers/chatglm_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/chatglm_v2/modeling_pp.py#L96

Added line #L96 was not covered by tests

if attention_mask is not None:
ret += (attention_mask.clone(),)
if position_ids is not None:
ret += (position_ids.clone(),)
if rotary_pos_emb is not None:
ret += (rotary_pos_emb.clone(),)
if kv_cache is not None:
ret += (kv_cache.clone(),)
if use_cache is not None:
ret += (use_cache.clone(),)

Check warning on line 107 in paddlenlp/transformers/chatglm_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/chatglm_v2/modeling_pp.py#L98-L107

Added lines #L98 - L107 were not covered by tests

if len(ret) == 1:
ret = ret[0]

Check warning on line 110 in paddlenlp/transformers/chatglm_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/chatglm_v2/modeling_pp.py#L109-L110

Added lines #L109 - L110 were not covered by tests

return ret

Check warning on line 112 in paddlenlp/transformers/chatglm_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/chatglm_v2/modeling_pp.py#L112

Added line #L112 was not covered by tests


def forward_impl(self, seq_len: int, n_elem: int, base: int = 10000):
"""Enhanced Transformer with Rotary Position Embedding.
Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
transformers/rope/__init__.py. MIT License:
https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
"""
# $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
theta = 1.0 / (base ** (paddle.arange(0, n_elem, 2, dtype="float32") / n_elem))

Check warning on line 122 in paddlenlp/transformers/chatglm_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/chatglm_v2/modeling_pp.py#L122

Added line #L122 was not covered by tests

# Create position indexes `[0, 1, ..., seq_len - 1]`
seq_idx = paddle.arange(0, seq_len, dtype=theta.dtype)

Check warning on line 125 in paddlenlp/transformers/chatglm_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/chatglm_v2/modeling_pp.py#L125

Added line #L125 was not covered by tests

# Calculate the product of position index and $\theta_i$
idx_theta = paddle.outer(seq_idx, theta).astype(self.default_dtype)

Check warning on line 128 in paddlenlp/transformers/chatglm_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/chatglm_v2/modeling_pp.py#L128

Added line #L128 was not covered by tests

cache = paddle.stack([paddle.cos(idx_theta), paddle.sin(idx_theta)], axis=-1)

Check warning on line 130 in paddlenlp/transformers/chatglm_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/chatglm_v2/modeling_pp.py#L130

Added line #L130 was not covered by tests

# this is to mimic the behaviour of complex32, else we will get different results
if self.default_dtype in (paddle.float16, paddle.bfloat16, paddle.int8):
cache = cache.astype(self.default_dtype)

Check warning on line 134 in paddlenlp/transformers/chatglm_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/chatglm_v2/modeling_pp.py#L133-L134

Added lines #L133 - L134 were not covered by tests
# cache = cache.bfloat16() if dtype == paddle.bfloat16 else cache.astype("float16")
return cache

Check warning on line 136 in paddlenlp/transformers/chatglm_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/chatglm_v2/modeling_pp.py#L136

Added line #L136 was not covered by tests


class EmbeddingPipe(Embedding):
"""Extends Embedding to forward attention_mask through the pipeline."""

def __init__(self, config: ChatGLMv2Config):
super().__init__(config)
self.default_dtype = paddle.get_default_dtype()

Check warning on line 144 in paddlenlp/transformers/chatglm_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/chatglm_v2/modeling_pp.py#L143-L144

Added lines #L143 - L144 were not covered by tests

@property
def embedding_weight(self):
return get_attr(self.word_embeddings, "weight")

Check warning on line 148 in paddlenlp/transformers/chatglm_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/chatglm_v2/modeling_pp.py#L148

Added line #L148 was not covered by tests

def forward(self, args):
input_ids, attention_mask, position_ids, rotary_pos_emb, kv_cache, use_cache = parse_args(args)
input_ids.stop_gradient = True
inputs_embeds = super().forward(input_ids=input_ids)
batch_size, seq_length = input_ids.shape

Check warning on line 154 in paddlenlp/transformers/chatglm_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/chatglm_v2/modeling_pp.py#L151-L154

Added lines #L151 - L154 were not covered by tests

if self.config.sequence_parallel:
seq_length, batch_size, hidden_size = inputs_embeds.shape
inputs_embeds = paddle.reshape_(inputs_embeds, [batch_size * seq_length, hidden_size])

Check warning on line 158 in paddlenlp/transformers/chatglm_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/chatglm_v2/modeling_pp.py#L156-L158

Added lines #L156 - L158 were not covered by tests
# [seq_len * bs / n, num_head * head_dim] (n is mp parallelism)
inputs_embeds = ScatterOp.apply(inputs_embeds)

Check warning on line 160 in paddlenlp/transformers/chatglm_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/chatglm_v2/modeling_pp.py#L160

Added line #L160 was not covered by tests

if attention_mask is None:
attention_mask = paddle.ones((batch_size, 1, seq_length, seq_length), dtype="bool")
if len(attention_mask.shape) == 2:

Check warning on line 164 in paddlenlp/transformers/chatglm_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/chatglm_v2/modeling_pp.py#L162-L164

Added lines #L162 - L164 were not covered by tests
# from Tokenizer
attention_mask = (

Check warning on line 166 in paddlenlp/transformers/chatglm_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/chatglm_v2/modeling_pp.py#L166

Added line #L166 was not covered by tests
attention_mask.unsqueeze(axis=[1, 2]).expand([batch_size, 1, seq_length, seq_length]).astype("bool")
)
elif len(attention_mask.shape) == 3:

Check warning on line 169 in paddlenlp/transformers/chatglm_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/chatglm_v2/modeling_pp.py#L169

Added line #L169 was not covered by tests
# [batch_size,tgt_length, src_length] -> [batch_size, 1, tgt_length, src_length]
attention_mask = attention_mask.unsqueeze(1).astype("bool")
elif len(attention_mask.shape) == 4:
attention_mask = attention_mask.astype("bool")

Check warning on line 173 in paddlenlp/transformers/chatglm_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/chatglm_v2/modeling_pp.py#L171-L173

Added lines #L171 - L173 were not covered by tests

causal_mask = paddle.tril(paddle.ones([batch_size, 1, seq_length, seq_length])).astype("bool")
attention_mask = attention_mask & causal_mask
zero = paddle.zeros(attention_mask.shape, dtype=inputs_embeds.dtype)
neg_inf = paddle.full_like(attention_mask, paddle.finfo(inputs_embeds.dtype).min, dtype=inputs_embeds.dtype)
attention_mask = paddle.where(attention_mask, zero, neg_inf)

Check warning on line 179 in paddlenlp/transformers/chatglm_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/chatglm_v2/modeling_pp.py#L175-L179

Added lines #L175 - L179 were not covered by tests
# Rotary positional embeddings
self.max_sequence_length = self.config.max_sequence_length
rotary_dim = (

Check warning on line 182 in paddlenlp/transformers/chatglm_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/chatglm_v2/modeling_pp.py#L181-L182

Added lines #L181 - L182 were not covered by tests
self.config.hidden_size // self.config.num_attention_heads
if self.config.kv_channels is None
else self.config.kv_channels
)
rotary_pos_emb = forward_impl(self, self.max_sequence_length, rotary_dim // 2)
if position_ids is not None:
rotary_pos_emb = rotary_pos_emb[position_ids]

Check warning on line 189 in paddlenlp/transformers/chatglm_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/chatglm_v2/modeling_pp.py#L187-L189

Added lines #L187 - L189 were not covered by tests
else:
rotary_pos_emb = rotary_pos_emb[None, :seq_length]

Check warning on line 191 in paddlenlp/transformers/chatglm_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/chatglm_v2/modeling_pp.py#L191

Added line #L191 was not covered by tests

rotary_pos_emb = rotary_pos_emb.transpose([1, 0, 2, 3])

Check warning on line 193 in paddlenlp/transformers/chatglm_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/chatglm_v2/modeling_pp.py#L193

Added line #L193 was not covered by tests

return return_args(inputs_embeds, attention_mask, position_ids, rotary_pos_emb, kv_cache, use_cache)

Check warning on line 195 in paddlenlp/transformers/chatglm_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/chatglm_v2/modeling_pp.py#L195

Added line #L195 was not covered by tests


class GLMBlockPipe(GLMBlock):
"""Extends GLMBlock to forward attention_mask through the pipeline."""

def forward(self, args):
hidden_states, attention_mask, position_ids, rotary_pos_emb, kv_cache, use_cache = parse_args(args)
hidden_states, kv_cache = super().forward(hidden_states, attention_mask, rotary_pos_emb, kv_cache, use_cache)
return return_args(hidden_states, attention_mask, position_ids, rotary_pos_emb, kv_cache, use_cache)

Check warning on line 204 in paddlenlp/transformers/chatglm_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/chatglm_v2/modeling_pp.py#L202-L204

Added lines #L202 - L204 were not covered by tests


class RMSNormPipe(RMSNorm):
def forward(self, args):
hidden_states, attention_mask, position_ids, rotary_pos_emb, kv_cache, use_cache = parse_args(args)
hidden_states = super().forward(hidden_states)
return hidden_states

Check warning on line 211 in paddlenlp/transformers/chatglm_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/chatglm_v2/modeling_pp.py#L209-L211

Added lines #L209 - L211 were not covered by tests


class Chatglmv2LMHeadPipe(Chatglmv2LMHead):
def __init__(self, config):
super(Chatglmv2LMHeadPipe, self).__init__(config)

Check warning on line 216 in paddlenlp/transformers/chatglm_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/chatglm_v2/modeling_pp.py#L216

Added line #L216 was not covered by tests


class ChatGLMv2ForCausalLMPipe(PipelinePretrainedModel, PipelineLayer):
"""ChatGLMv2ForPretraining adapted for pipeline parallelism.

The largest change is flattening the ChatGLMv2Model class so we can express it as a
sequence of layers including embedding, transformer layers, and output.
"""

config_class = ChatGLMv2Config

get_masks = ChatGLMv2PretrainedModel.get_masks
_get_tensor_parallel_mappings = ChatGLMv2PretrainedModel._get_tensor_parallel_mappings
init_weights = ChatGLMv2PretrainedModel.init_weights
get_position_ids = ChatGLMv2PretrainedModel.get_position_ids
_get_name_mappings = ChatGLMv2PretrainedModel._get_name_mappings

# NO base_model_prefix !!!!

def __init__(self, config):
self.config = config

Check warning on line 237 in paddlenlp/transformers/chatglm_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/chatglm_v2/modeling_pp.py#L237

Added line #L237 was not covered by tests

virtual_pp_degree = getattr(self.config, "virtual_pp_degree", 1)

Check warning on line 239 in paddlenlp/transformers/chatglm_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/chatglm_v2/modeling_pp.py#L239

Added line #L239 was not covered by tests

hcg = get_hcg()
tensor_parallel_degree = max(hcg.get_model_parallel_world_size(), 1)
tensor_parallel_rank = max(hcg.get_model_parallel_rank(), 0)

Check warning on line 243 in paddlenlp/transformers/chatglm_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/chatglm_v2/modeling_pp.py#L241-L243

Added lines #L241 - L243 were not covered by tests

config.tensor_parallel_degree = tensor_parallel_degree
config.tensor_parallel_rank = tensor_parallel_rank

Check warning on line 246 in paddlenlp/transformers/chatglm_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/chatglm_v2/modeling_pp.py#L245-L246

Added lines #L245 - L246 were not covered by tests

self.add_sequential_layer(

Check warning on line 248 in paddlenlp/transformers/chatglm_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/chatglm_v2/modeling_pp.py#L248

Added line #L248 was not covered by tests
LayerDesc(EmbeddingPipe, config=config),
"embedding",
)
for i in range(config.num_hidden_layers):
self.add_sequential_layer(

Check warning on line 253 in paddlenlp/transformers/chatglm_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/chatglm_v2/modeling_pp.py#L252-L253

Added lines #L252 - L253 were not covered by tests
LayerDesc(GLMBlockPipe, config=config, layer_number=i),
f"encoder.layers.{i}",
)

self.add_sequential_layer(

Check warning on line 258 in paddlenlp/transformers/chatglm_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/chatglm_v2/modeling_pp.py#L258

Added line #L258 was not covered by tests
LayerDesc(RMSNormPipe, config=config),
"encoder.final_layernorm",
)
self.add_sequential_layer(

Check warning on line 262 in paddlenlp/transformers/chatglm_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/chatglm_v2/modeling_pp.py#L262

Added line #L262 was not covered by tests
LayerDesc(Chatglmv2LMHeadPipe, config=config),
"output_layer",
)

recompute_interval = 0

Check warning on line 267 in paddlenlp/transformers/chatglm_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/chatglm_v2/modeling_pp.py#L267

Added line #L267 was not covered by tests
# if self.config.recompute and recompute_granularity == "full":
# assert pp_recompute_interval <= config.num_hidden_layers // (
# virtual_pp_degree * get_hcg().topology().get_dim_size("pipe")
# ), "pp recompute interval should smaller than num layers of each pp chunk"
# recompute_interval = pp_recompute_interval

seg_method = "layer:GLMBlock"
if config.num_hidden_layers % get_hcg().topology().get_dim_size("pipe") != 0:
seg_method = "uniform"

Check warning on line 276 in paddlenlp/transformers/chatglm_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/chatglm_v2/modeling_pp.py#L274-L276

Added lines #L274 - L276 were not covered by tests

PipelineLayer.__init__(

Check warning on line 278 in paddlenlp/transformers/chatglm_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/chatglm_v2/modeling_pp.py#L278

Added line #L278 was not covered by tests
self,
layers=self.get_sequential_layers(),
loss_fn=ChatGLMv2PretrainingCriterion(config),
topology=get_hcg().topology(),
seg_method=seg_method,
recompute_interval=recompute_interval,
recompute_ctx={
"mp_group": get_hcg().get_model_parallel_group(),
"offload": False,
"partition": False,
},
num_virtual_pipeline_stages=virtual_pp_degree,
)
self.apply(self._init_weights)

Check warning on line 292 in paddlenlp/transformers/chatglm_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/chatglm_v2/modeling_pp.py#L292

Added line #L292 was not covered by tests
Loading