From eda6ceb567f6d1151b03a2c720306da91708a6a5 Mon Sep 17 00:00:00 2001 From: SevenSamon <1273520759@qq.com> Date: Tue, 27 Aug 2024 03:24:16 +0000 Subject: [PATCH 01/11] fix_chatglmv2_8k --- paddlenlp/transformers/chatglm_v2/modeling.py | 437 +++++++++++++++--- 1 file changed, 383 insertions(+), 54 deletions(-) diff --git a/paddlenlp/transformers/chatglm_v2/modeling.py b/paddlenlp/transformers/chatglm_v2/modeling.py index bbfb6e52f481..4fecef593922 100644 --- a/paddlenlp/transformers/chatglm_v2/modeling.py +++ b/paddlenlp/transformers/chatglm_v2/modeling.py @@ -13,18 +13,22 @@ # limitations under the License. import math +from functools import partial from typing import Any, Dict, List, Optional, Tuple import paddle import paddle.nn as nn import paddle.nn.functional as F +import paddle.tensor as tensor +from paddle.distributed import fleet +from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker from paddle.distributed.fleet.utils import recompute from paddle.utils import map_structure from paddlenlp.transformers.long_sequence_strategies import LongSequenceStrategies from ...utils.converter import StateDictNameMapping, init_name_mappings -from .. import PretrainedModel, register_base_model +from .. import PretrainedModel, linear_utils, register_base_model from ..model_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithPast, @@ -32,6 +36,24 @@ ) from .configuration import CHATGLM_V2_PRETRAINED_RESOURCE_FILES_MAP, ChatGLMv2Config +try: + from paddle.distributed.fleet.utils.sequence_parallel_utils import ( + GatherOp, + ScatterOp, + mark_as_sequence_parallel_parameter, + ) +except: + pass + +try: + from paddle.nn.functional.flash_attention import flash_attention +except: + flash_attention = None +try: + from paddle.incubate.nn.layer.fused_dropout_add import FusedDropoutAdd +except: + FusedDropoutAdd = None + __all__ = [ "ChatGLMv2Model", "ChatGLMv2PretrainedModel", @@ -39,6 +61,27 @@ ] +def parallel_matmul(lm_output, logit_weights, parallel_output): + hcg = fleet.get_hybrid_communicate_group() + model_parallel_group = hcg.get_model_parallel_group() + world_size = hcg.get_model_parallel_world_size() + + if world_size > 1: + # _c_identity is backwards is reduce + input_parallel = paddle.distributed.collective._c_identity(lm_output, group=model_parallel_group) + + logits = paddle.matmul(input_parallel, logit_weights, transpose_y=False) + + if parallel_output: + return logits + + # _c_concat has not grad backwards + return paddle.distributed.collective._c_concat(logits, group=model_parallel_group) + else: + logits = paddle.matmul(lm_output, logit_weights, transpose_y=True) + return logits + + class RotaryEmbedding(nn.Layer): def __init__(self, dim, original_impl=False): super().__init__() @@ -97,7 +140,7 @@ def apply_rotary_pos_emb(x: paddle.Tensor, rope_cache: paddle.Tensor) -> paddle. class RMSNorm(nn.Layer): - def __init__(self, hidden_size, epsilon=None): + def __init__(self, hidden_size, config: ChatGLMv2Config, epsilon=None): super().__init__() self.hidden_size = hidden_size self.weight = paddle.create_parameter( @@ -107,6 +150,9 @@ def __init__(self, hidden_size, epsilon=None): ) self.epsilon = 1e-5 if epsilon is None else epsilon + if config.sequence_parallel: + mark_as_sequence_parallel_parameter(self.weight) + def forward(self, hidden_states): input_dtype = hidden_states.dtype variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True) @@ -131,14 +177,19 @@ def __init__(self, config: ChatGLMv2Config, layer_number): self.num_attention_heads_per_partition = config.num_attention_heads self.hidden_size_per_partition = config.kv_channels * self.num_attention_heads_per_partition self.hidden_size_per_attention_head = self.hidden_size_per_partition // self.num_attention_heads_per_partition - + self.tensor_parallel_degree = config.tensor_parallel_degree + if self.tensor_parallel_degree > 1: + assert ( + self.hidden_size_per_partition % self.tensor_parallel_degree == 0 + ), "hidden_size_per_partition % tensor_parallel_degree must be zero." + self.hidden_size_per_partition = self.hidden_size_per_partition // self.tensor_parallel_degree coeff = None self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) if self.apply_query_key_layer_scaling: coeff = self.layer_number self.norm_factor *= coeff self.coeff = coeff - + self.config = config self.attention_dropout = nn.Dropout(config.attention_dropout) def forward(self, query_layer, key_layer, value_layer, attention_mask): @@ -198,6 +249,10 @@ def forward(self, query_layer, key_layer, value_layer, attention_mask): new_context_shape = context_layer.shape[:-2] + [self.hidden_size_per_partition] context_layer = context_layer.reshape(new_context_shape) + if self.config.sequence_parallel: + sq, b, hp = context_layer.shape + context_layer = context_layer.reshape([sq * b, hp]) + return context_layer @@ -221,33 +276,86 @@ def __init__(self, config: ChatGLMv2Config, layer_number, device=None): self.num_multi_query_groups_per_partition = config.multi_query_group_num self.multi_query_group_num = config.multi_query_group_num self.num_attention_heads_per_partition = config.num_attention_heads + self.config = config + # Recompute defaults to False and is controlled by Trainer + self.enable_recompute = False + self.tensor_parallel_degree = config.tensor_parallel_degree + self.sequence_parallel = config.sequence_parallel + self.use_flash_attention = config.use_flash_attention if flash_attention else False - self.query_key_value = nn.Linear( - config.hidden_size, - config.hidden_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num, - bias_attr=config.add_bias_linear or config.add_qkv_bias, + if config.sequence_parallel: + ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear + RowParallelLinear = linear_utils.RowSequenceParallelLinear + else: + ColumnParallelLinear = linear_utils.ColumnParallelLinear + RowParallelLinear = linear_utils.RowParallelLinear + + if config.tensor_parallel_degree > 1: + self.query_key_value = ColumnParallelLinear( + config.hidden_size, + config.hidden_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num, + has_bias=config.add_bias_linear or config.add_qkv_bias, + gather_output=False, + ) + self.dense = RowParallelLinear( + config.hidden_size, config.hidden_size, input_is_parallel=True, has_bias=config.add_bias_linear + ) + self.num_attention_heads_per_partition = config.num_attention_heads // config.tensor_parallel_degree + assert ( + self.num_multi_query_groups_per_partition % self.tensor_parallel_degree == 0 + ), "`multi_query_group_num` % `tensor_parallel_degree` must equal to `0`" + self.num_multi_query_groups_per_partition = ( + self.num_multi_query_groups_per_partition // self.tensor_parallel_degree + ) + else: + self.query_key_value = nn.Linear( + config.hidden_size, + config.hidden_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num, + bias_attr=config.add_bias_linear or config.add_qkv_bias, + ) + # Output. + self.dense = nn.Linear(config.hidden_size, config.hidden_size, bias_attr=config.add_bias_linear) + + def _flash_attention(self, q, k, v, attention_mask=None, output_attentions=False): + out, weights = flash_attention( + query=q, + key=k, + value=v, + dropout=self.config.attention_dropout, + causal=q.shape[0] != 1, + return_softmax=output_attentions, ) - # Output. - self.dense = nn.Linear(config.hidden_size, config.hidden_size, bias_attr=config.add_bias_linear) + # [bs, seq_len, num_head, head_dim] -> [bs, seq_len, num_head * head_dim] + out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]]) + return (out, weights) if output_attentions else out - def forward(self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True): - seq_length, batch_size, hidden_size = hidden_states.shape - mixed_x_layer = self.query_key_value(hidden_states) + def _core_attention(self, q, k, v, attention_mask=None, output_attentions=False): + outputs = self.core_attention(q, k, v, attention_mask) + return outputs + def forward( + self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True, output_attentions=False + ): + seq_length = self.config.seq_length + mixed_x_layer = self.query_key_value(hidden_states) (query_layer, key_layer, value_layer) = mixed_x_layer.split( [ self.num_attention_heads_per_partition * self.hidden_size_per_attention_head, - self.hidden_size_per_attention_head * self.multi_query_group_num, - self.hidden_size_per_attention_head * self.multi_query_group_num, + self.hidden_size_per_attention_head * self.num_multi_query_groups_per_partition, + self.hidden_size_per_attention_head * self.num_multi_query_groups_per_partition, ], axis=-1, ) query_layer = query_layer.reshape( - [seq_length, batch_size, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head] + [seq_length, -1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head] + ) + key_layer = key_layer.reshape( + [seq_length, -1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head] + ) + value_layer = value_layer.reshape( + [seq_length, -1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head] ) - key_layer = key_layer.reshape([seq_length, batch_size, -1, self.hidden_size_per_attention_head]) - value_layer = value_layer.reshape([seq_length, batch_size, -1, self.hidden_size_per_attention_head]) # apply relative positional encoding (rotary embedding) if rotary_pos_emb is not None: @@ -278,13 +386,30 @@ def forward(self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, # ================================== # core attention computation # ================================== - - context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) - + if self.use_flash_attention: + attention_fuc = self._flash_attention + else: + attention_fuc = self._core_attention + has_gradient = ( + (not query_layer.stop_gradient) or (not key_layer.stop_gradient) or (not value_layer.stop_gradient) + ) + if self.enable_recompute and self.config.recompute_granularity == "core_attn" and has_gradient: + context_layer = recompute( + attention_fuc, + query_layer, + key_layer, + value_layer, + attention_mask, + output_attentions, + use_reentrant=False, + ) + else: + context_layer = attention_fuc( + query_layer, key_layer, value_layer, attention_mask=attention_mask, output_attentions=output_attentions + ) # ================= # Output. [seq_length, b, h] # ================= - output = self.dense(context_layer) return output, kv_cache @@ -302,14 +427,28 @@ def __init__(self, config: ChatGLMv2Config): self.add_bias = config.add_bias_linear - # Project to 4h due to swiglu doubling the output width, see https://arxiv.org/pdf/2002.05202.pdf - self.dense_h_to_4h = nn.Linear(config.hidden_size, config.ffn_hidden_size * 2, bias_attr=self.add_bias) - # Project back to h. - self.dense_4h_to_h = nn.Linear( - config.ffn_hidden_size, - config.hidden_size, - bias_attr=self.add_bias, - ) + if config.sequence_parallel: + ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear + RowParallelLinear = linear_utils.RowSequenceParallelLinear + else: + ColumnParallelLinear = linear_utils.ColumnParallelLinear + RowParallelLinear = linear_utils.RowParallelLinear + if config.tensor_parallel_degree > 1: + self.dense_h_to_4h = ColumnParallelLinear( + config.hidden_size, config.ffn_hidden_size * 2, has_bias=self.add_bias, gather_output=False + ) + self.dense_4h_to_h = RowParallelLinear( + config.ffn_hidden_size, config.hidden_size, input_is_parallel=True, has_bias=self.add_bias + ) + else: + # Project to 4h due to swiglu doubling the output width, see https://arxiv.org/pdf/2002.05202.pdf + self.dense_h_to_4h = nn.Linear(config.hidden_size, config.ffn_hidden_size * 2, bias_attr=self.add_bias) + # Project back to h. + self.dense_4h_to_h = nn.Linear( + config.ffn_hidden_size, + config.hidden_size, + bias_attr=self.add_bias, + ) def forward(self, hidden_states): # [s, b, 4hp] @@ -336,22 +475,27 @@ def __init__(self, config: ChatGLMv2Config, layer_number): super(GLMBlock, self).__init__() self.layer_number = layer_number self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm + # Recompute defaults to False and is controlled by Trainer + self.enable_recompute = False self.fp32_residual_connection = config.fp32_residual_connection LayerNormFunc = RMSNorm if config.rmsnorm else nn.LayerNorm # Layernorm on the input data. - self.input_layernorm = LayerNormFunc(config.hidden_size, epsilon=config.layernorm_epsilon) + self.input_layernorm = LayerNormFunc(config.hidden_size, epsilon=config.layernorm_epsilon, config=config) # Self attention. self.self_attention = SelfAttention(config, layer_number) self.hidden_dropout = config.hidden_dropout # Layernorm on the attention output - self.post_attention_layernorm = LayerNormFunc(config.hidden_size, epsilon=config.layernorm_epsilon) + self.post_attention_layernorm = LayerNormFunc( + config.hidden_size, epsilon=config.layernorm_epsilon, config=config + ) # MLP self.mlp = MLP(config) + self.config = config def forward( self, @@ -366,10 +510,21 @@ def forward( # Layer norm at the beginning of the transformer layer. layernorm_output = self.input_layernorm(hidden_states) + has_gradient = not layernorm_output.stop_gradient # Self attention. - attention_output, kv_cache = self.self_attention( - layernorm_output, attention_mask, rotary_pos_emb, kv_cache=kv_cache, use_cache=use_cache - ) + if self.enable_recompute and has_gradient and self.config.recompute_granularity == "full_attn": + attention_output, kv_cache = recompute( + self.self_attention, + layernorm_output, + attention_mask, + rotary_pos_emb, + kv_cache=kv_cache, + use_cache=use_cache, + ) + else: + attention_output, kv_cache = self.self_attention( + layernorm_output, attention_mask, rotary_pos_emb, kv_cache=kv_cache, use_cache=use_cache + ) # Residual connection. if self.apply_residual_connection_post_layernorm: @@ -403,7 +558,10 @@ class GLMTransformer(nn.Layer): def __init__(self, config: ChatGLMv2Config): super(GLMTransformer, self).__init__() self.config = config + # Recompute defaults to False and is controlled by Trainer self.enable_recompute = False + self.no_recompute_layers = config.no_recompute_layers if config.no_recompute_layers is not None else [] + self.recompute_granularity = config.recompute_granularity self.fp32_residual_connection = config.fp32_residual_connection self.post_layer_norm = config.post_layer_norm @@ -419,7 +577,7 @@ def build_layer(layer_number): if self.post_layer_norm: LayerNormFunc = RMSNorm if config.rmsnorm else nn.LayerNorm # Final layer norm before output. - self.final_layernorm = LayerNormFunc(config.hidden_size, epsilon=config.layernorm_epsilon) + self.final_layernorm = LayerNormFunc(config.hidden_size, epsilon=config.layernorm_epsilon, config=config) def _get_layer(self, layer_number): return self.layers[layer_number] @@ -476,7 +634,12 @@ def forward( layer = self._get_layer(index) - if self.enable_recompute and not hidden_states.stop_gradient: + if ( + self.enable_recompute + and not hidden_states.stop_gradient + and index not in self.no_recompute_layers + and self.recompute_granularity == "full" + ): hidden_states, kv_cache = self.recompute_training( layer, hidden_states, @@ -546,6 +709,10 @@ def get_masks(self, input_ids, past_key_values, padding_mask=None): return casual_mask + def init_weights(self, layer): + """Initialization hook""" + return None + def get_position_ids(self, input_ids): batch_size, seq_length = input_ids.shape position_ids = paddle.arange(seq_length, dtype="int64").unsqueeze(0).tile([batch_size, 1]) @@ -610,14 +777,81 @@ def _get_name_mappings(cls, config: ChatGLMv2Config) -> List[StateDictNameMappin ] ) - for mapping in mappings: - mapping[0] = "transformer." + mapping[0] - if len(mapping) > 1 and mapping[1] is not None: - mapping[1] = "chatglm_v2." + mapping[1] + # for mapping in mappings: + # mapping[0] = "transformer." + mapping[0] + # if len(mapping) > 1 and mapping[1] is not None: + # mapping[1] = "chatglm_v2." + mapping[1] init_name_mappings(mappings) return [StateDictNameMapping(*mapping) for mapping in mappings] + @classmethod + def _get_tensor_parallel_mappings(cls, config, is_split=True): + + from paddlenlp.transformers.conversion_utils import split_or_merge_func + + def split_qkv_weights(tensor_parallel_degree, tensor_parallel_rank, hidden_size, tensor): + split_query_size = hidden_size // tensor_parallel_degree + split_kv_size = (tensor.shape[-1] - hidden_size) // (2 * tensor_parallel_degree) + + query = tensor[..., :hidden_size] + key = tensor[..., hidden_size : hidden_size + split_kv_size * tensor_parallel_degree] + value = tensor[..., tensor.shape[-1] - split_kv_size * tensor_parallel_degree :] + + key_part = key[..., tensor_parallel_rank * split_kv_size : (tensor_parallel_rank + 1) * split_kv_size] + value_part = value[..., tensor_parallel_rank * split_kv_size : (tensor_parallel_rank + 1) * split_kv_size] + query_part = query[ + ..., tensor_parallel_rank * split_query_size : (tensor_parallel_rank + 1) * split_query_size + ] + + return paddle.concat([query_part, key_part, value_part], axis=-1) + + fn = split_or_merge_func( + is_split=is_split, + tensor_parallel_degree=config.tensor_parallel_degree, + tensor_parallel_rank=config.tensor_parallel_rank, + num_attention_heads=config.num_attention_heads, + ) + + def split_mlp_weights(tensor_parallel_degree, tensor_parallel_rank, tensor): + split_size = tensor.shape[-1] // tensor_parallel_degree // 2 + ffn_fc = tensor[..., : tensor.shape[-1] // 2] + gate = tensor[..., tensor.shape[-1] // 2 :] + ffn_fc_part = ffn_fc[..., tensor_parallel_rank * split_size : (tensor_parallel_rank + 1) * split_size] + gate_part = gate[..., tensor_parallel_rank * split_size : (tensor_parallel_rank + 1) * split_size] + return paddle.concat([ffn_fc_part, gate_part], axis=-1) + + def get_tensor_parallel_split_mappings(num_hidden_layers): + final_actions = {} + base_actions = { + # Column Linear + "output_layer.weight": partial(fn, is_column=True), + "encoder.layers.0.mlp.dense_h_to_4h.weight": partial( + split_mlp_weights, config.tensor_parallel_degree, config.tensor_parallel_rank + ), + "encoder.layers.0.self_attention.query_key_value.bias": partial( + split_qkv_weights, config.tensor_parallel_degree, config.tensor_parallel_rank, config.hidden_size + ), + "encoder.layers.0.self_attention.query_key_value.weight": partial( + split_qkv_weights, config.tensor_parallel_degree, config.tensor_parallel_rank, config.hidden_size + ), + # Row Linear + "embedding.word_embeddings.weight": partial(fn, is_column=False), + "encoder.layers.0.self_attention.dense.weight": partial(fn, is_column=False), + "encoder.layers.0.mlp.dense_4h_to_h.weight": partial(fn, is_column=False), + } + for key, action in base_actions.items(): + if "layers.0." in key: + for i in range(num_hidden_layers): + final_actions[key.replace("layers.0.", f"layers.{i}.")] = action + final_actions[key] = action + + return final_actions + + mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers) + + return mappings + class Embedding(nn.Layer): """Language model embeddings.""" @@ -626,7 +860,12 @@ def __init__(self, config: ChatGLMv2Config): super(Embedding, self).__init__() self.hidden_size = config.hidden_size - self.word_embeddings = nn.Embedding(config.padded_vocab_size, self.hidden_size) + if config.tensor_parallel_degree > 1: + self.word_embeddings = fleet.meta_parallel.VocabParallelEmbedding( + config.padded_vocab_size, self.hidden_size + ) + else: + self.word_embeddings = nn.Embedding(config.padded_vocab_size, self.hidden_size) self.fp32_residual_connection = config.fp32_residual_connection def forward(self, input_ids): @@ -646,6 +885,7 @@ class ChatGLMv2Model(ChatGLMv2PretrainedModel): def __init__(self, config: ChatGLMv2Config, empty_init=True): super().__init__(config) self.embedding = Embedding(config) + self.config = config # Rotary positional embeddings self.max_sequence_length = config.max_sequence_length @@ -662,7 +902,13 @@ def __init__(self, config: ChatGLMv2Config, empty_init=True): else: self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2) self.encoder = GLMTransformer(config) - self.output_layer = nn.Linear(config.hidden_size, config.padded_vocab_size, bias_attr=False) + if config.tensor_parallel_degree > 1: + self.output_layer = nn.Linear( + config.hidden_size, config.padded_vocab_size // config.tensor_parallel_degree, bias_attr=False + ) + else: + self.output_layer = nn.Linear(config.hidden_size, config.padded_vocab_size, bias_attr=False) + self.apply(self.init_weights) def get_input_embeddings(self): return self.embedding.word_embeddings @@ -692,6 +938,12 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embedding(input_ids) + if self.config.sequence_parallel: + seq_length, batch_size, hidden_size = inputs_embeds.shape + inputs_embeds = paddle.reshape_(inputs_embeds, [batch_size * seq_length, hidden_size]) + # [seq_len * bs / n, num_head * head_dim] (n is mp parallelism) + inputs_embeds = ScatterOp.apply(inputs_embeds) + full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) # Rotary positional embeddings @@ -730,11 +982,87 @@ def forward( ) +class ChatGLMv2PretrainingCriterion(nn.Layer): + """ + Criterion for ChatGLMv2. It calculates the final loss. + """ + + def __init__(self, config): + super(ChatGLMv2PretrainingCriterion, self).__init__() + self.config = config + if config.tensor_parallel_degree > 1 and config.tensor_parallel_output: + self.loss_func = fleet.meta_parallel.ParallelCrossEntropy() + else: + self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none") + + def forward(self, prediction_scores, masked_lm_labels, loss_mask=None): + """ + Args: + prediction_scores(Tensor): + The logits of masked token prediction. Its data type should be float32 and + its shape is [batch_size, sequence_length, vocab_size]. + masked_lm_labels(Tensor): + The labels of the masked language modeling, the dimensionality of `masked_lm_labels` + is equal to `prediction_scores`. Its data type should be int64 and + its shape is [batch_size, sequence_length, 1]. + loss_mask(Tensor): + Mask used for calculating the loss of the masked language modeling to avoid + calculating some unwanted tokens. + Its data type should be float32 and its shape is [batch_size, sequence_length, 1]. + + Returns: + Tensor: The pretraining loss. Its data type should be float32 and its shape is [1]. + + """ + with paddle.amp.auto_cast(False): + reshaped_logits = prediction_scores.reshape([-1, prediction_scores.shape[-1]]).astype("float32") + reshaped_labels = masked_lm_labels.reshape([-1]) + loss = self.loss_func(reshaped_logits, reshaped_labels) + loss = paddle.sum(loss.reshape([-1]).cast(paddle.float32) * loss_mask.reshape([-1]).cast(paddle.float32)) + loss = loss / loss_mask.sum() + return loss + + +class Chatglmv2LMHead(nn.Layer): + def __init__(self, config: ChatGLMv2Config, embedding_weights=None): + super(Chatglmv2LMHead, self).__init__() + if embedding_weights is not None: + self.decoder_weight = embedding_weights + else: + if config.tensor_parallel_degree > 1: + vocab_size = config.vocab_size // config.tensor_parallel_degree + else: + vocab_size = config.vocab_size + + if vocab_size != config.vocab_size: + with get_rng_state_tracker().rng_state(): + self.decoder_weight = self.create_parameter( + shape=[vocab_size, config.hidden_size], + dtype=paddle.get_default_dtype(), + ) + else: + self.decoder_weight = self.create_parameter( + shape=[vocab_size, config.hidden_size], dtype=paddle.get_default_dtype() + ) + self.config = config + + def forward(self, hidden_states, return_last_logit): + if return_last_logit: + hidden_states = hidden_states[-1:] + if self.config.sequence_parallel: + hidden_states = GatherOp.apply(hidden_states) + hidden_states = paddle.reshape_(hidden_states, [self.config.seq_length, -1, self.config.hidden_size]) + logits = parallel_matmul(hidden_states, self.decoder_weight, self.config.tensor_parallel_output) + return logits.transpose([1, 0, 2]) + + class ChatGLMv2ForCausalLM(ChatGLMv2PretrainedModel): def __init__(self, config: ChatGLMv2Config): super().__init__(config) self.max_sequence_length = config.max_sequence_length self.chatglm_v2 = ChatGLMv2Model(config) + self.criterion = ChatGLMv2PretrainingCriterion(config) + self.config = config def reorder_cache(self, cache: paddle.Tensor, beam_idx): cache = map_structure(lambda x: paddle.index_select(x, beam_idx, axis=1), cache) @@ -826,23 +1154,24 @@ def forward( hidden_states = transformer_outputs[0] + if self.config.sequence_parallel: + hidden_states = GatherOp.apply(hidden_states) + seq_length = self.config.seq_length + hidden_states = hidden_states.reshape([seq_length, -1, self.config.hidden_size]) if return_last_logit: hidden_states = hidden_states[-1:] - lm_logits = self.chatglm_v2.output_layer(hidden_states) + if self.config.tensor_parallel_degree > 1: + lm_logits = parallel_matmul( + hidden_states, self.chatglm_v2.output_layer.weight, self.config.tensor_parallel_output + ) + else: + lm_logits = self.chatglm_v2.output_layer(hidden_states) lm_logits = lm_logits.transpose([1, 0, 2]) - + # shape = [batch_size, seq_length, vocab_size] loss = None if labels is not None: - reshaped_logits = lm_logits.reshape([-1, lm_logits.shape[-1]]).astype("float32") - reshaped_labels = labels.reshape([-1]) - - loss_fn = nn.CrossEntropyLoss(reduction="none") - loss_mask = (labels != -100).astype("float32") - loss = loss_fn(reshaped_logits, reshaped_labels) - loss = paddle.sum(loss.reshape([-1]).cast(paddle.float32) * loss_mask.reshape([-1]).cast(paddle.float32)) - loss = loss / loss_mask.sum() - + loss = self.criterion(lm_logits, labels, loss_mask) lm_logits = lm_logits.astype(hidden_states.dtype) loss = loss.astype(hidden_states.dtype) From d46530e4b963d8031dea8ca3f851aea5ba52e4b5 Mon Sep 17 00:00:00 2001 From: SevenSamon <1273520759@qq.com> Date: Tue, 27 Aug 2024 06:34:11 +0000 Subject: [PATCH 02/11] update modeling and modeling_pp --- paddlenlp/transformers/chatglm_v2/modeling.py | 37 ++- .../transformers/chatglm_v2/modeling_pp.py | 267 ++++++++++++++++++ 2 files changed, 291 insertions(+), 13 deletions(-) create mode 100644 paddlenlp/transformers/chatglm_v2/modeling_pp.py diff --git a/paddlenlp/transformers/chatglm_v2/modeling.py b/paddlenlp/transformers/chatglm_v2/modeling.py index 4fecef593922..e44a3260313a 100644 --- a/paddlenlp/transformers/chatglm_v2/modeling.py +++ b/paddlenlp/transformers/chatglm_v2/modeling.py @@ -277,6 +277,7 @@ def __init__(self, config: ChatGLMv2Config, layer_number, device=None): self.multi_query_group_num = config.multi_query_group_num self.num_attention_heads_per_partition = config.num_attention_heads self.config = config + self.seq_length = config.seq_length # Recompute defaults to False and is controlled by Trainer self.enable_recompute = False self.tensor_parallel_degree = config.tensor_parallel_degree @@ -336,7 +337,7 @@ def _core_attention(self, q, k, v, attention_mask=None, output_attentions=False) def forward( self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True, output_attentions=False ): - seq_length = self.config.seq_length + # seq_length, batch_size = self.config.seq_length, hidden_states.shape[0]//self.config.seq_length mixed_x_layer = self.query_key_value(hidden_states) (query_layer, key_layer, value_layer) = mixed_x_layer.split( [ @@ -346,16 +347,26 @@ def forward( ], axis=-1, ) - - query_layer = query_layer.reshape( - [seq_length, -1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head] - ) - key_layer = key_layer.reshape( - [seq_length, -1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head] - ) - value_layer = value_layer.reshape( - [seq_length, -1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head] - ) + if self.sequence_parallel: + query_layer = query_layer.reshape( + [self.seq_length, -1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head] + ) + key_layer = key_layer.reshape( + [self.seq_length, -1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head] + ) + value_layer = value_layer.reshape( + [self.seq_length, -1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head] + ) + else: + query_layer = query_layer.reshape( + [0, 0, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head] + ) + key_layer = key_layer.reshape( + [0, 0, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head] + ) + value_layer = value_layer.reshape( + [0, 0, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head] + ) # apply relative positional encoding (rotary embedding) if rotary_pos_emb is not None: @@ -390,6 +401,7 @@ def forward( attention_fuc = self._flash_attention else: attention_fuc = self._core_attention + has_gradient = ( (not query_layer.stop_gradient) or (not key_layer.stop_gradient) or (not value_layer.stop_gradient) ) @@ -477,7 +489,7 @@ def __init__(self, config: ChatGLMv2Config, layer_number): self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm # Recompute defaults to False and is controlled by Trainer self.enable_recompute = False - + self.config = config self.fp32_residual_connection = config.fp32_residual_connection LayerNormFunc = RMSNorm if config.rmsnorm else nn.LayerNorm @@ -495,7 +507,6 @@ def __init__(self, config: ChatGLMv2Config, layer_number): # MLP self.mlp = MLP(config) - self.config = config def forward( self, diff --git a/paddlenlp/transformers/chatglm_v2/modeling_pp.py b/paddlenlp/transformers/chatglm_v2/modeling_pp.py new file mode 100644 index 000000000000..8b867aab3101 --- /dev/null +++ b/paddlenlp/transformers/chatglm_v2/modeling_pp.py @@ -0,0 +1,267 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle +import paddle.distributed.fleet as fleet +from paddle.distributed.fleet.meta_parallel import ( + LayerDesc, + PipelineLayer, + SharedLayerDesc, +) + +try: + from paddle.distributed.fleet.utils.sequence_parallel_utils import ScatterOp +except: + pass + +from paddlenlp.transformers.model_utils import PipelinePretrainedModel + +from .modeling import ( + ChatGLMv2Config, + Chatglmv2LMHead, + ChatGLMv2PretrainedModel, + ChatGLMv2PretrainingCriterion, + Embedding, + GLMBlock, + RMSNorm, +) + +__all__ = ["ChatGLMv2ForCausalLMPipe"] + + +def get_hcg(): + return fleet.get_hybrid_communicate_group() + + +def get_attr(layer, name): + if getattr(layer, name, None) is not None: + return getattr(layer, name, None) + else: + return get_attr(layer._layer, name) + + +def parse_args(args): + if isinstance(args, tuple): + if len(args) == 3: + hidden_states, attention_mask, position_ids = args + elif len(args) == 2: + hidden_states, attention_mask = args + position_ids = None + elif len(args) == 1: + hidden_states = args + attention_mask, position_ids = None, None + else: + hidden_states = args + attention_mask, position_ids = None, None + + if position_ids is not None: + position_ids.stop_gradient = True + + if attention_mask is not None: + attention_mask.stop_gradient = True + + return hidden_states, attention_mask, position_ids + + +def return_args(hidden_states, attention_mask=None, position_ids=None): + ret = (hidden_states,) + + if attention_mask is not None: + ret += (attention_mask.clone(),) + if position_ids is not None: + ret += (position_ids.clone(),) + if len(ret) == 1: + ret = ret[0] + + return ret + + +def forward_impl(self, seq_len: int, n_elem: int, base: int = 10000): + """Enhanced Transformer with Rotary Position Embedding. + Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ + transformers/rope/__init__.py. MIT License: + https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. + """ + # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ + theta = 1.0 / (base ** (paddle.arange(0, n_elem, 2, dtype="float32") / n_elem)) + + # Create position indexes `[0, 1, ..., seq_len - 1]` + seq_idx = paddle.arange(0, seq_len, dtype=theta.dtype) + + # Calculate the product of position index and $\theta_i$ + idx_theta = paddle.outer(seq_idx, theta).astype(self.default_dtype) + + cache = paddle.stack([paddle.cos(idx_theta), paddle.sin(idx_theta)], axis=-1) + + # this is to mimic the behaviour of complex32, else we will get different results + if self.default_dtype in (paddle.float16, paddle.bfloat16, paddle.int8): + cache = cache.astype(self.default_dtype) + # cache = cache.bfloat16() if dtype == paddle.bfloat16 else cache.astype("float16") + return cache + + +class EmbeddingPipe(Embedding): + """Extends Embedding to forward attention_mask through the pipeline.""" + + @property + def embedding_weight(self): + return get_attr(self.word_embeddings, "weight") + + def forward(self, args): + input_ids, attention_mask, position_ids = parse_args(args) + input_ids.stop_gradient = True + inputs_embeds = super().forward(input_ids=input_ids, position_ids=position_ids) + batch_size, seq_length = input_ids.shape + + if self.config.sequence_parallel: + seq_length, batch_size, hidden_size = inputs_embeds.shape + inputs_embeds = paddle.reshape_(inputs_embeds, [batch_size * seq_length, hidden_size]) + # [seq_len * bs / n, num_head * head_dim] (n is mp parallelism) + inputs_embeds = ScatterOp.apply(inputs_embeds) + + if attention_mask is None: + attention_mask = paddle.ones((batch_size, 1, seq_length, seq_length), dtype="bool") + if len(attention_mask.shape) == 2: + # from Tokenizer + attention_mask = ( + attention_mask.unsqueeze(axis=[1, 2]).expand([batch_size, 1, seq_length, seq_length]).astype("bool") + ) + elif len(attention_mask.shape) == 3: + # [batch_size,tgt_length, src_length] -> [batch_size, 1, tgt_length, src_length] + attention_mask = attention_mask.unsqueeze(1).astype("bool") + elif len(attention_mask.shape) == 4: + attention_mask = attention_mask.astype("bool") + + causal_mask = paddle.tril(paddle.ones([batch_size, 1, seq_length, seq_length])).astype("bool") + attention_mask = attention_mask & causal_mask + + return return_args(inputs_embeds, attention_mask, position_ids) + + +class GLMBlockPipe(GLMBlock): + """Extends GLMBlock to forward attention_mask through the pipeline.""" + + def forward(self, args): + hiden_states, attention_mask, rotary_pos_emb, kv_cache, use_cache = parse_args(args) + hiden_states, kv_cache = super().forward(hiden_states, attention_mask, rotary_pos_emb, kv_cache, use_cache) + return return_args(hiden_states, kv_cache) + + +class RMSNormPipe(RMSNorm): + def forward(self, args): + hidden_states = parse_args(args) + hidden_states = super().forward(hidden_states) + return hidden_states + + +class Chatglmv2LMHeadPipe(Chatglmv2LMHead): + def __init__(self, config): + super(Chatglmv2LMHeadPipe, self).__init__(config) + + @property + def embedding_weight(self): + return get_attr(self, "weight") + + +class ChatGLMv2ForCausalLMPipe(PipelinePretrainedModel, PipelineLayer): + """ChatGLMv2ForPretraining adapted for pipeline parallelism. + + The largest change is flattening the ChatGLMv2Model class so we can express it as a + sequence of layers including embedding, transformer layers, and output. + """ + + config_class = ChatGLMv2Config + + get_masks = ChatGLMv2PretrainedModel.get_masks + _get_tensor_parallel_mappings = ChatGLMv2PretrainedModel._get_tensor_parallel_mappings + init_weights = ChatGLMv2PretrainedModel.init_weights + get_position_ids = ChatGLMv2PretrainedModel.get_position_ids + _get_name_mappings = ChatGLMv2PretrainedModel._get_name_mappings + + # NO base_model_prefix !!!! + + def __init__( + self, + config, + pp_recompute_interval=1, + ): + self.config = config + + virtual_pp_degree = getattr(self.config, "virtual_pp_degree", 1) + + hcg = get_hcg() + tensor_parallel_degree = max(hcg.get_model_parallel_world_size(), 1) + tensor_parallel_rank = max(hcg.get_model_parallel_rank(), 0) + + config.tensor_parallel_degree = tensor_parallel_degree + config.tensor_parallel_rank = tensor_parallel_rank + + # Rotary positional embeddings + # self.max_sequence_length = config.max_sequence_length + # rotary_dim = ( + # config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels + # ) + # rotary_pos_emb = forward_impl(self.max_sequence_length, rotary_dim // 2) + # if position_ids is not None: + # rotary_pos_emb = rotary_pos_emb[position_ids] + # else: + # rotary_pos_emb = rotary_pos_emb[None, :seq_length] + + # rotary_pos_emb = rotary_pos_emb.transpose([1, 0, 2, 3]) + + self.add_sequential_layer( + SharedLayerDesc( + "chatglmv2_shared_weight", EmbeddingPipe, shared_weight_attr="embedding_weight", config=config + ), + "embedding", + ) + for i in range(config.num_hidden_layers): + self.add_sequential_layer( + LayerDesc(GLMBlockPipe, config=config), + f"chatglmv2.decoder.layers.{i}", + ) + + self.add_sequential_layer(LayerDesc(RMSNormPipe, config=config), "encoder.final_layernorm") + self.add_sequential_layer( + SharedLayerDesc( + "chatglmv2_shared_weight", Chatglmv2LMHeadPipe, shared_weight_attr="embedding_weight", config=config + ), + "embedding.word_embeddings", + ) + + recompute_interval = 0 + # if self.config.recompute and recompute_granularity == "full": + # assert pp_recompute_interval <= config.num_hidden_layers // ( + # virtual_pp_degree * get_hcg().topology().get_dim_size("pipe") + # ), "pp recompute interval should smaller than num layers of each pp chunk" + # recompute_interval = pp_recompute_interval + + seg_method = "layer:GLMBlock" + if config.num_hidden_layers % get_hcg().topology().get_dim_size("pipe") != 0: + seg_method = "uniform" + + PipelineLayer.__init__( + self, + layers=self.get_sequential_layers(), + loss_fn=ChatGLMv2PretrainingCriterion(config), + topology=get_hcg().topology(), + seg_method=seg_method, + recompute_interval=recompute_interval, + recompute_ctx={ + "mp_group": get_hcg().get_model_parallel_group(), + "offload": False, + "partition": False, + }, + num_virtual_pipeline_stages=virtual_pp_degree, + ) + self.apply(self._init_weights) From b43b650fdbafb354086243af780f1df3568d89fa Mon Sep 17 00:00:00 2001 From: SevenSamon <1273520759@qq.com> Date: Tue, 27 Aug 2024 09:27:50 +0000 Subject: [PATCH 03/11] update_pp_for_chatglmv2 --- paddlenlp/transformers/chatglm_v2/__init__.py | 4 + .../transformers/chatglm_v2/modeling_pp.py | 90 ++++++++++++------- 2 files changed, 63 insertions(+), 31 deletions(-) diff --git a/paddlenlp/transformers/chatglm_v2/__init__.py b/paddlenlp/transformers/chatglm_v2/__init__.py index 775d34cf85f8..83b91e71d58f 100644 --- a/paddlenlp/transformers/chatglm_v2/__init__.py +++ b/paddlenlp/transformers/chatglm_v2/__init__.py @@ -11,3 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from .configuration import * +from .modeling import * +from .modeling_pp import * +from .tokenizer import * diff --git a/paddlenlp/transformers/chatglm_v2/modeling_pp.py b/paddlenlp/transformers/chatglm_v2/modeling_pp.py index 8b867aab3101..8e2f02048f0a 100644 --- a/paddlenlp/transformers/chatglm_v2/modeling_pp.py +++ b/paddlenlp/transformers/chatglm_v2/modeling_pp.py @@ -52,17 +52,29 @@ def get_attr(layer, name): def parse_args(args): if isinstance(args, tuple): - if len(args) == 3: + if len(args) == 6: + hidden_states, attention_mask, position_ids, rotary_pos_emb, kv_cache, use_cache = args + elif len(args) == 5: + hidden_states, attention_mask, position_ids, rotary_pos_emb, kv_cache = args + use_cache = None + elif len(args) == 4: + hidden_states, attention_mask, position_ids, rotary_pos_emb = args + kv_cache = None + use_cache = None + elif len(args) == 3: hidden_states, attention_mask, position_ids = args + rotary_pos_emb = None + kv_cache = None + use_cache = None elif len(args) == 2: hidden_states, attention_mask = args position_ids = None - elif len(args) == 1: - hidden_states = args - attention_mask, position_ids = None, None + rotary_pos_emb = None + kv_cache = None + use_cache = None else: hidden_states = args - attention_mask, position_ids = None, None + attention_mask, position_ids, rotary_pos_emb, kv_cache, use_cache = None, None, None, None, None if position_ids is not None: position_ids.stop_gradient = True @@ -70,16 +82,34 @@ def parse_args(args): if attention_mask is not None: attention_mask.stop_gradient = True - return hidden_states, attention_mask, position_ids + if rotary_pos_emb is not None: + rotary_pos_emb.stop_gradient = True + if kv_cache is not None: + kv_cache.stop_gradient = True -def return_args(hidden_states, attention_mask=None, position_ids=None): + if use_cache is not None: + use_cache.stop_gradient = True + + return hidden_states, attention_mask, position_ids, rotary_pos_emb, kv_cache, use_cache + + +def return_args( + hidden_states, attention_mask=None, position_ids=None, rotary_pos_emb=None, kv_cache=None, use_cache=None +): ret = (hidden_states,) if attention_mask is not None: ret += (attention_mask.clone(),) if position_ids is not None: ret += (position_ids.clone(),) + if rotary_pos_emb is not None: + ret += (rotary_pos_emb.clone(),) + if kv_cache is not None: + ret += (kv_cache.clone(),) + if use_cache is not None: + ret += (use_cache.clone(),) + if len(ret) == 1: ret = ret[0] @@ -118,7 +148,7 @@ def embedding_weight(self): return get_attr(self.word_embeddings, "weight") def forward(self, args): - input_ids, attention_mask, position_ids = parse_args(args) + input_ids, attention_mask, position_ids, rotary_pos_emb, kv_cache, use_cache = parse_args(args) input_ids.stop_gradient = True inputs_embeds = super().forward(input_ids=input_ids, position_ids=position_ids) batch_size, seq_length = input_ids.shape @@ -145,16 +175,31 @@ def forward(self, args): causal_mask = paddle.tril(paddle.ones([batch_size, 1, seq_length, seq_length])).astype("bool") attention_mask = attention_mask & causal_mask - return return_args(inputs_embeds, attention_mask, position_ids) + # Rotary positional embeddings + self.max_sequence_length = self.config.max_sequence_length + rotary_dim = ( + self.config.hidden_size // self.config.num_attention_heads + if self.config.kv_channels is None + else self.config.kv_channels + ) + rotary_pos_emb = forward_impl(self.max_sequence_length, rotary_dim // 2) + if position_ids is not None: + rotary_pos_emb = rotary_pos_emb[position_ids] + else: + rotary_pos_emb = rotary_pos_emb[None, :seq_length] + + rotary_pos_emb = rotary_pos_emb.transpose([1, 0, 2, 3]) + + return return_args(inputs_embeds, attention_mask, position_ids, rotary_pos_emb, kv_cache, use_cache) class GLMBlockPipe(GLMBlock): """Extends GLMBlock to forward attention_mask through the pipeline.""" def forward(self, args): - hiden_states, attention_mask, rotary_pos_emb, kv_cache, use_cache = parse_args(args) - hiden_states, kv_cache = super().forward(hiden_states, attention_mask, rotary_pos_emb, kv_cache, use_cache) - return return_args(hiden_states, kv_cache) + hidden_states, attention_mask, position_ids, rotary_pos_emb, kv_cache, use_cache = parse_args(args) + hidden_states, kv_cache = super().forward(hidden_states, attention_mask, rotary_pos_emb, kv_cache, use_cache) + return return_args(hidden_states, attention_mask, position_ids, rotary_pos_emb, kv_cache, use_cache) class RMSNormPipe(RMSNorm): @@ -170,7 +215,7 @@ def __init__(self, config): @property def embedding_weight(self): - return get_attr(self, "weight") + return get_attr(self, "decoder_weight") class ChatGLMv2ForCausalLMPipe(PipelinePretrainedModel, PipelineLayer): @@ -190,11 +235,7 @@ class ChatGLMv2ForCausalLMPipe(PipelinePretrainedModel, PipelineLayer): # NO base_model_prefix !!!! - def __init__( - self, - config, - pp_recompute_interval=1, - ): + def __init__(self, config): self.config = config virtual_pp_degree = getattr(self.config, "virtual_pp_degree", 1) @@ -206,19 +247,6 @@ def __init__( config.tensor_parallel_degree = tensor_parallel_degree config.tensor_parallel_rank = tensor_parallel_rank - # Rotary positional embeddings - # self.max_sequence_length = config.max_sequence_length - # rotary_dim = ( - # config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels - # ) - # rotary_pos_emb = forward_impl(self.max_sequence_length, rotary_dim // 2) - # if position_ids is not None: - # rotary_pos_emb = rotary_pos_emb[position_ids] - # else: - # rotary_pos_emb = rotary_pos_emb[None, :seq_length] - - # rotary_pos_emb = rotary_pos_emb.transpose([1, 0, 2, 3]) - self.add_sequential_layer( SharedLayerDesc( "chatglmv2_shared_weight", EmbeddingPipe, shared_weight_attr="embedding_weight", config=config From 75b0d31ecb63589b7c0a875b310f40debb00bf08 Mon Sep 17 00:00:00 2001 From: SevenSamon <1273520759@qq.com> Date: Tue, 27 Aug 2024 12:22:16 +0000 Subject: [PATCH 04/11] update_pp --- paddlenlp/transformers/__init__.py | 1 + paddlenlp/transformers/chatglm_v2/modeling.py | 10 ++++----- .../transformers/chatglm_v2/modeling_pp.py | 21 +++++++++++++------ 3 files changed, 21 insertions(+), 11 deletions(-) diff --git a/paddlenlp/transformers/__init__.py b/paddlenlp/transformers/__init__.py index deddfb976d7d..56fc0a2e3afa 100644 --- a/paddlenlp/transformers/__init__.py +++ b/paddlenlp/transformers/__init__.py @@ -265,6 +265,7 @@ from .chatglm.tokenizer import * from .chatglm_v2.configuration import * from .chatglm_v2.modeling import * +from .chatglm_v2.modeling_pp import * from .chatglm_v2.tokenizer import * from .speecht5.configuration import * from .speecht5.modeling import * diff --git a/paddlenlp/transformers/chatglm_v2/modeling.py b/paddlenlp/transformers/chatglm_v2/modeling.py index e44a3260313a..b34f4c96a4ee 100644 --- a/paddlenlp/transformers/chatglm_v2/modeling.py +++ b/paddlenlp/transformers/chatglm_v2/modeling.py @@ -869,7 +869,7 @@ class Embedding(nn.Layer): def __init__(self, config: ChatGLMv2Config): super(Embedding, self).__init__() - + self.config = config self.hidden_size = config.hidden_size if config.tensor_parallel_degree > 1: self.word_embeddings = fleet.meta_parallel.VocabParallelEmbedding( @@ -1006,7 +1006,7 @@ def __init__(self, config): else: self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none") - def forward(self, prediction_scores, masked_lm_labels, loss_mask=None): + def forward(self, prediction_scores, masked_lm_labels): """ Args: prediction_scores(Tensor): @@ -1026,6 +1026,7 @@ def forward(self, prediction_scores, masked_lm_labels, loss_mask=None): """ with paddle.amp.auto_cast(False): + loss_mask = (masked_lm_labels != -100).astype("float32") reshaped_logits = prediction_scores.reshape([-1, prediction_scores.shape[-1]]).astype("float32") reshaped_labels = masked_lm_labels.reshape([-1]) loss = self.loss_func(reshaped_logits, reshaped_labels) @@ -1057,7 +1058,7 @@ def __init__(self, config: ChatGLMv2Config, embedding_weights=None): ) self.config = config - def forward(self, hidden_states, return_last_logit): + def forward(self, hidden_states, return_last_logit=False): if return_last_logit: hidden_states = hidden_states[-1:] if self.config.sequence_parallel: @@ -1181,8 +1182,7 @@ def forward( # shape = [batch_size, seq_length, vocab_size] loss = None if labels is not None: - loss_mask = (labels != -100).astype("float32") - loss = self.criterion(lm_logits, labels, loss_mask) + loss = self.criterion(lm_logits, labels) lm_logits = lm_logits.astype(hidden_states.dtype) loss = loss.astype(hidden_states.dtype) diff --git a/paddlenlp/transformers/chatglm_v2/modeling_pp.py b/paddlenlp/transformers/chatglm_v2/modeling_pp.py index 8e2f02048f0a..86a74869a272 100644 --- a/paddlenlp/transformers/chatglm_v2/modeling_pp.py +++ b/paddlenlp/transformers/chatglm_v2/modeling_pp.py @@ -143,6 +143,10 @@ def forward_impl(self, seq_len: int, n_elem: int, base: int = 10000): class EmbeddingPipe(Embedding): """Extends Embedding to forward attention_mask through the pipeline.""" + def __init__(self, config: ChatGLMv2Config): + super().__init__(config) + self.default_dtype = paddle.get_default_dtype() + @property def embedding_weight(self): return get_attr(self.word_embeddings, "weight") @@ -150,7 +154,7 @@ def embedding_weight(self): def forward(self, args): input_ids, attention_mask, position_ids, rotary_pos_emb, kv_cache, use_cache = parse_args(args) input_ids.stop_gradient = True - inputs_embeds = super().forward(input_ids=input_ids, position_ids=position_ids) + inputs_embeds = super().forward(input_ids=input_ids) batch_size, seq_length = input_ids.shape if self.config.sequence_parallel: @@ -174,7 +178,9 @@ def forward(self, args): causal_mask = paddle.tril(paddle.ones([batch_size, 1, seq_length, seq_length])).astype("bool") attention_mask = attention_mask & causal_mask - + zero = paddle.zeros(attention_mask.shape, dtype=inputs_embeds.dtype) + neg_inf = paddle.full_like(attention_mask, paddle.finfo(inputs_embeds.dtype).min, dtype=inputs_embeds.dtype) + attention_mask = paddle.where(attention_mask, zero, neg_inf) # Rotary positional embeddings self.max_sequence_length = self.config.max_sequence_length rotary_dim = ( @@ -182,7 +188,7 @@ def forward(self, args): if self.config.kv_channels is None else self.config.kv_channels ) - rotary_pos_emb = forward_impl(self.max_sequence_length, rotary_dim // 2) + rotary_pos_emb = forward_impl(self, self.max_sequence_length, rotary_dim // 2) if position_ids is not None: rotary_pos_emb = rotary_pos_emb[position_ids] else: @@ -204,7 +210,7 @@ def forward(self, args): class RMSNormPipe(RMSNorm): def forward(self, args): - hidden_states = parse_args(args) + hidden_states, attention_mask, position_ids, rotary_pos_emb, kv_cache, use_cache = parse_args(args) hidden_states = super().forward(hidden_states) return hidden_states @@ -255,11 +261,14 @@ def __init__(self, config): ) for i in range(config.num_hidden_layers): self.add_sequential_layer( - LayerDesc(GLMBlockPipe, config=config), + LayerDesc(GLMBlockPipe, config=config, layer_number=i), f"chatglmv2.decoder.layers.{i}", ) - self.add_sequential_layer(LayerDesc(RMSNormPipe, config=config), "encoder.final_layernorm") + self.add_sequential_layer( + LayerDesc(RMSNormPipe, hidden_size=config.hidden_size, config=config, epsilon=config.layernorm_epsilon), + "encoder.final_layernorm", + ) self.add_sequential_layer( SharedLayerDesc( "chatglmv2_shared_weight", Chatglmv2LMHeadPipe, shared_weight_attr="embedding_weight", config=config From a1af3f39a1a89565026e963d047f0e251d8c5585 Mon Sep 17 00:00:00 2001 From: SevenSamon <1273520759@qq.com> Date: Wed, 28 Aug 2024 02:01:10 +0000 Subject: [PATCH 05/11] delete_pp --- paddlenlp/transformers/__init__.py | 1 - paddlenlp/transformers/chatglm_v2/__init__.py | 4 - .../transformers/chatglm_v2/modeling_pp.py | 304 ------------------ 3 files changed, 309 deletions(-) delete mode 100644 paddlenlp/transformers/chatglm_v2/modeling_pp.py diff --git a/paddlenlp/transformers/__init__.py b/paddlenlp/transformers/__init__.py index 56fc0a2e3afa..deddfb976d7d 100644 --- a/paddlenlp/transformers/__init__.py +++ b/paddlenlp/transformers/__init__.py @@ -265,7 +265,6 @@ from .chatglm.tokenizer import * from .chatglm_v2.configuration import * from .chatglm_v2.modeling import * -from .chatglm_v2.modeling_pp import * from .chatglm_v2.tokenizer import * from .speecht5.configuration import * from .speecht5.modeling import * diff --git a/paddlenlp/transformers/chatglm_v2/__init__.py b/paddlenlp/transformers/chatglm_v2/__init__.py index 83b91e71d58f..775d34cf85f8 100644 --- a/paddlenlp/transformers/chatglm_v2/__init__.py +++ b/paddlenlp/transformers/chatglm_v2/__init__.py @@ -11,7 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from .configuration import * -from .modeling import * -from .modeling_pp import * -from .tokenizer import * diff --git a/paddlenlp/transformers/chatglm_v2/modeling_pp.py b/paddlenlp/transformers/chatglm_v2/modeling_pp.py deleted file mode 100644 index 86a74869a272..000000000000 --- a/paddlenlp/transformers/chatglm_v2/modeling_pp.py +++ /dev/null @@ -1,304 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import paddle -import paddle.distributed.fleet as fleet -from paddle.distributed.fleet.meta_parallel import ( - LayerDesc, - PipelineLayer, - SharedLayerDesc, -) - -try: - from paddle.distributed.fleet.utils.sequence_parallel_utils import ScatterOp -except: - pass - -from paddlenlp.transformers.model_utils import PipelinePretrainedModel - -from .modeling import ( - ChatGLMv2Config, - Chatglmv2LMHead, - ChatGLMv2PretrainedModel, - ChatGLMv2PretrainingCriterion, - Embedding, - GLMBlock, - RMSNorm, -) - -__all__ = ["ChatGLMv2ForCausalLMPipe"] - - -def get_hcg(): - return fleet.get_hybrid_communicate_group() - - -def get_attr(layer, name): - if getattr(layer, name, None) is not None: - return getattr(layer, name, None) - else: - return get_attr(layer._layer, name) - - -def parse_args(args): - if isinstance(args, tuple): - if len(args) == 6: - hidden_states, attention_mask, position_ids, rotary_pos_emb, kv_cache, use_cache = args - elif len(args) == 5: - hidden_states, attention_mask, position_ids, rotary_pos_emb, kv_cache = args - use_cache = None - elif len(args) == 4: - hidden_states, attention_mask, position_ids, rotary_pos_emb = args - kv_cache = None - use_cache = None - elif len(args) == 3: - hidden_states, attention_mask, position_ids = args - rotary_pos_emb = None - kv_cache = None - use_cache = None - elif len(args) == 2: - hidden_states, attention_mask = args - position_ids = None - rotary_pos_emb = None - kv_cache = None - use_cache = None - else: - hidden_states = args - attention_mask, position_ids, rotary_pos_emb, kv_cache, use_cache = None, None, None, None, None - - if position_ids is not None: - position_ids.stop_gradient = True - - if attention_mask is not None: - attention_mask.stop_gradient = True - - if rotary_pos_emb is not None: - rotary_pos_emb.stop_gradient = True - - if kv_cache is not None: - kv_cache.stop_gradient = True - - if use_cache is not None: - use_cache.stop_gradient = True - - return hidden_states, attention_mask, position_ids, rotary_pos_emb, kv_cache, use_cache - - -def return_args( - hidden_states, attention_mask=None, position_ids=None, rotary_pos_emb=None, kv_cache=None, use_cache=None -): - ret = (hidden_states,) - - if attention_mask is not None: - ret += (attention_mask.clone(),) - if position_ids is not None: - ret += (position_ids.clone(),) - if rotary_pos_emb is not None: - ret += (rotary_pos_emb.clone(),) - if kv_cache is not None: - ret += (kv_cache.clone(),) - if use_cache is not None: - ret += (use_cache.clone(),) - - if len(ret) == 1: - ret = ret[0] - - return ret - - -def forward_impl(self, seq_len: int, n_elem: int, base: int = 10000): - """Enhanced Transformer with Rotary Position Embedding. - Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ - transformers/rope/__init__.py. MIT License: - https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. - """ - # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ - theta = 1.0 / (base ** (paddle.arange(0, n_elem, 2, dtype="float32") / n_elem)) - - # Create position indexes `[0, 1, ..., seq_len - 1]` - seq_idx = paddle.arange(0, seq_len, dtype=theta.dtype) - - # Calculate the product of position index and $\theta_i$ - idx_theta = paddle.outer(seq_idx, theta).astype(self.default_dtype) - - cache = paddle.stack([paddle.cos(idx_theta), paddle.sin(idx_theta)], axis=-1) - - # this is to mimic the behaviour of complex32, else we will get different results - if self.default_dtype in (paddle.float16, paddle.bfloat16, paddle.int8): - cache = cache.astype(self.default_dtype) - # cache = cache.bfloat16() if dtype == paddle.bfloat16 else cache.astype("float16") - return cache - - -class EmbeddingPipe(Embedding): - """Extends Embedding to forward attention_mask through the pipeline.""" - - def __init__(self, config: ChatGLMv2Config): - super().__init__(config) - self.default_dtype = paddle.get_default_dtype() - - @property - def embedding_weight(self): - return get_attr(self.word_embeddings, "weight") - - def forward(self, args): - input_ids, attention_mask, position_ids, rotary_pos_emb, kv_cache, use_cache = parse_args(args) - input_ids.stop_gradient = True - inputs_embeds = super().forward(input_ids=input_ids) - batch_size, seq_length = input_ids.shape - - if self.config.sequence_parallel: - seq_length, batch_size, hidden_size = inputs_embeds.shape - inputs_embeds = paddle.reshape_(inputs_embeds, [batch_size * seq_length, hidden_size]) - # [seq_len * bs / n, num_head * head_dim] (n is mp parallelism) - inputs_embeds = ScatterOp.apply(inputs_embeds) - - if attention_mask is None: - attention_mask = paddle.ones((batch_size, 1, seq_length, seq_length), dtype="bool") - if len(attention_mask.shape) == 2: - # from Tokenizer - attention_mask = ( - attention_mask.unsqueeze(axis=[1, 2]).expand([batch_size, 1, seq_length, seq_length]).astype("bool") - ) - elif len(attention_mask.shape) == 3: - # [batch_size,tgt_length, src_length] -> [batch_size, 1, tgt_length, src_length] - attention_mask = attention_mask.unsqueeze(1).astype("bool") - elif len(attention_mask.shape) == 4: - attention_mask = attention_mask.astype("bool") - - causal_mask = paddle.tril(paddle.ones([batch_size, 1, seq_length, seq_length])).astype("bool") - attention_mask = attention_mask & causal_mask - zero = paddle.zeros(attention_mask.shape, dtype=inputs_embeds.dtype) - neg_inf = paddle.full_like(attention_mask, paddle.finfo(inputs_embeds.dtype).min, dtype=inputs_embeds.dtype) - attention_mask = paddle.where(attention_mask, zero, neg_inf) - # Rotary positional embeddings - self.max_sequence_length = self.config.max_sequence_length - rotary_dim = ( - self.config.hidden_size // self.config.num_attention_heads - if self.config.kv_channels is None - else self.config.kv_channels - ) - rotary_pos_emb = forward_impl(self, self.max_sequence_length, rotary_dim // 2) - if position_ids is not None: - rotary_pos_emb = rotary_pos_emb[position_ids] - else: - rotary_pos_emb = rotary_pos_emb[None, :seq_length] - - rotary_pos_emb = rotary_pos_emb.transpose([1, 0, 2, 3]) - - return return_args(inputs_embeds, attention_mask, position_ids, rotary_pos_emb, kv_cache, use_cache) - - -class GLMBlockPipe(GLMBlock): - """Extends GLMBlock to forward attention_mask through the pipeline.""" - - def forward(self, args): - hidden_states, attention_mask, position_ids, rotary_pos_emb, kv_cache, use_cache = parse_args(args) - hidden_states, kv_cache = super().forward(hidden_states, attention_mask, rotary_pos_emb, kv_cache, use_cache) - return return_args(hidden_states, attention_mask, position_ids, rotary_pos_emb, kv_cache, use_cache) - - -class RMSNormPipe(RMSNorm): - def forward(self, args): - hidden_states, attention_mask, position_ids, rotary_pos_emb, kv_cache, use_cache = parse_args(args) - hidden_states = super().forward(hidden_states) - return hidden_states - - -class Chatglmv2LMHeadPipe(Chatglmv2LMHead): - def __init__(self, config): - super(Chatglmv2LMHeadPipe, self).__init__(config) - - @property - def embedding_weight(self): - return get_attr(self, "decoder_weight") - - -class ChatGLMv2ForCausalLMPipe(PipelinePretrainedModel, PipelineLayer): - """ChatGLMv2ForPretraining adapted for pipeline parallelism. - - The largest change is flattening the ChatGLMv2Model class so we can express it as a - sequence of layers including embedding, transformer layers, and output. - """ - - config_class = ChatGLMv2Config - - get_masks = ChatGLMv2PretrainedModel.get_masks - _get_tensor_parallel_mappings = ChatGLMv2PretrainedModel._get_tensor_parallel_mappings - init_weights = ChatGLMv2PretrainedModel.init_weights - get_position_ids = ChatGLMv2PretrainedModel.get_position_ids - _get_name_mappings = ChatGLMv2PretrainedModel._get_name_mappings - - # NO base_model_prefix !!!! - - def __init__(self, config): - self.config = config - - virtual_pp_degree = getattr(self.config, "virtual_pp_degree", 1) - - hcg = get_hcg() - tensor_parallel_degree = max(hcg.get_model_parallel_world_size(), 1) - tensor_parallel_rank = max(hcg.get_model_parallel_rank(), 0) - - config.tensor_parallel_degree = tensor_parallel_degree - config.tensor_parallel_rank = tensor_parallel_rank - - self.add_sequential_layer( - SharedLayerDesc( - "chatglmv2_shared_weight", EmbeddingPipe, shared_weight_attr="embedding_weight", config=config - ), - "embedding", - ) - for i in range(config.num_hidden_layers): - self.add_sequential_layer( - LayerDesc(GLMBlockPipe, config=config, layer_number=i), - f"chatglmv2.decoder.layers.{i}", - ) - - self.add_sequential_layer( - LayerDesc(RMSNormPipe, hidden_size=config.hidden_size, config=config, epsilon=config.layernorm_epsilon), - "encoder.final_layernorm", - ) - self.add_sequential_layer( - SharedLayerDesc( - "chatglmv2_shared_weight", Chatglmv2LMHeadPipe, shared_weight_attr="embedding_weight", config=config - ), - "embedding.word_embeddings", - ) - - recompute_interval = 0 - # if self.config.recompute and recompute_granularity == "full": - # assert pp_recompute_interval <= config.num_hidden_layers // ( - # virtual_pp_degree * get_hcg().topology().get_dim_size("pipe") - # ), "pp recompute interval should smaller than num layers of each pp chunk" - # recompute_interval = pp_recompute_interval - - seg_method = "layer:GLMBlock" - if config.num_hidden_layers % get_hcg().topology().get_dim_size("pipe") != 0: - seg_method = "uniform" - - PipelineLayer.__init__( - self, - layers=self.get_sequential_layers(), - loss_fn=ChatGLMv2PretrainingCriterion(config), - topology=get_hcg().topology(), - seg_method=seg_method, - recompute_interval=recompute_interval, - recompute_ctx={ - "mp_group": get_hcg().get_model_parallel_group(), - "offload": False, - "partition": False, - }, - num_virtual_pipeline_stages=virtual_pp_degree, - ) - self.apply(self._init_weights) From f08edfd6114a2e23c9dcefe1f810e3b2d1dba613 Mon Sep 17 00:00:00 2001 From: SevenSamon <1273520759@qq.com> Date: Wed, 28 Aug 2024 02:52:19 +0000 Subject: [PATCH 06/11] fix_flash_attention and seed_guard --- paddlenlp/transformers/chatglm_v2/modeling.py | 53 +++++++++++++++---- 1 file changed, 42 insertions(+), 11 deletions(-) diff --git a/paddlenlp/transformers/chatglm_v2/modeling.py b/paddlenlp/transformers/chatglm_v2/modeling.py index b34f4c96a4ee..7169f8579204 100644 --- a/paddlenlp/transformers/chatglm_v2/modeling.py +++ b/paddlenlp/transformers/chatglm_v2/modeling.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import math from functools import partial from typing import Any, Dict, List, Optional, Tuple @@ -61,6 +62,20 @@ ] +def seed_guard_context(name=None): + if ( + not isinstance(paddle.base.framework._current_expected_place(), paddle.core.CPUPlace) + and name in get_rng_state_tracker().states_ + ): + # todo fix it + # ValueError: Length of gpu state list should be equal to the gpu device count + # /usr/local/lib/python3.10/dist-packages/paddle/incubate/framework/random.py:119: ValueError + # return contextlib.nullcontext() + return get_rng_state_tracker().rng_state(name) + else: + return contextlib.nullcontext() + + def parallel_matmul(lm_output, logit_weights, parallel_output): hcg = fleet.get_hybrid_communicate_group() model_parallel_group = hcg.get_model_parallel_group() @@ -227,7 +242,8 @@ def forward(self, query_layer, key_layer, value_layer, attention_mask): # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.attention_dropout(attention_probs) + with seed_guard_context("local_seed"): + attention_probs = self.attention_dropout(attention_probs) # [batch_size, num_heads, query_length, key_length] # value_layer -> context layer. @@ -318,14 +334,25 @@ def __init__(self, config: ChatGLMv2Config, layer_number, device=None): self.dense = nn.Linear(config.hidden_size, config.hidden_size, bias_attr=config.add_bias_linear) def _flash_attention(self, q, k, v, attention_mask=None, output_attentions=False): - out, weights = flash_attention( - query=q, - key=k, - value=v, - dropout=self.config.attention_dropout, - causal=q.shape[0] != 1, - return_softmax=output_attentions, - ) + """ + q: [seq_len, bs, num_head, head_dim] + k: [seq_len, bs, num_head, head_dim] + v: [seq_len, bs, num_head, head_dim] + """ + q = q.transpose([1, 0, 2, 3]) + k = k.transpose([1, 0, 2, 3]) + v = v.transpose([1, 0, 2, 3]) + + with seed_guard_context("local_seed"): + out, weights = flash_attention( + query=q, + key=k, + value=v, + dropout=self.config.attention_dropout, + causal=q.shape[0] != 1, + return_softmax=output_attentions, + trainging=self.training, + ) # [bs, seq_len, num_head, head_dim] -> [bs, seq_len, num_head * head_dim] out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]]) return (out, weights) if output_attentions else out @@ -543,7 +570,10 @@ def forward( else: residual = hidden_states - layernorm_input = F.dropout(attention_output, p=self.hidden_dropout, training=self.training) + current_seed = "local_seed" if self.config.sequence_parallel else "global_seed" + + with seed_guard_context(current_seed): + layernorm_input = F.dropout(attention_output, p=self.hidden_dropout, training=self.training) layernorm_input = residual + layernorm_input # Layer norm post the self attention. @@ -558,7 +588,8 @@ def forward( else: residual = layernorm_input - output = F.dropout(mlp_output, p=self.hidden_dropout, training=self.training) + with seed_guard_context(current_seed): + output = F.dropout(mlp_output, p=self.hidden_dropout, training=self.training) output = residual + output return output, kv_cache From 4272b02c9c8f43a8d8e927755086ab4b668e9074 Mon Sep 17 00:00:00 2001 From: SevenSamon <1273520759@qq.com> Date: Wed, 28 Aug 2024 07:16:29 +0000 Subject: [PATCH 07/11] update_flashattention --- paddlenlp/transformers/chatglm_v2/modeling.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/paddlenlp/transformers/chatglm_v2/modeling.py b/paddlenlp/transformers/chatglm_v2/modeling.py index 7169f8579204..3a8ad0863c1c 100644 --- a/paddlenlp/transformers/chatglm_v2/modeling.py +++ b/paddlenlp/transformers/chatglm_v2/modeling.py @@ -351,10 +351,17 @@ def _flash_attention(self, q, k, v, attention_mask=None, output_attentions=False dropout=self.config.attention_dropout, causal=q.shape[0] != 1, return_softmax=output_attentions, - trainging=self.training, + training=self.training, ) # [bs, seq_len, num_head, head_dim] -> [bs, seq_len, num_head * head_dim] out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]]) + # [bs, seq_len, num_head * head_dim]-> [seq_len, bs, num_head * head_dim] + out = out.transpose([1, 0, 2]) + + if self.config.sequence_parallel: + sq, bs, hp = out.shape + out = out.reshape([sq * bs, hp]) + return (out, weights) if output_attentions else out def _core_attention(self, q, k, v, attention_mask=None, output_attentions=False): From c6d8f156dbee639521a770e5bc1c7a836b640fa7 Mon Sep 17 00:00:00 2001 From: SevenSamon <1273520759@qq.com> Date: Wed, 28 Aug 2024 09:43:27 +0000 Subject: [PATCH 08/11] delete_flash_attention --- .../transformers/chatglm_v2/modeling.py | 2 +- paddlenlp/transformers/chatglm_v2/modeling.py | 42 +------------------ 2 files changed, 2 insertions(+), 42 deletions(-) diff --git a/paddlenlp/experimental/transformers/chatglm_v2/modeling.py b/paddlenlp/experimental/transformers/chatglm_v2/modeling.py index c7e9762fb801..74d5fecf6123 100644 --- a/paddlenlp/experimental/transformers/chatglm_v2/modeling.py +++ b/paddlenlp/experimental/transformers/chatglm_v2/modeling.py @@ -185,7 +185,7 @@ def __init__(self, config: ChatGLMv2Config, empty_init=True): if self.post_layer_norm: LayerNormFunc = RMSNorm if config.rmsnorm else nn.LayerNorm # Final layer norm before output. - self.final_layernorm = LayerNormFunc(config.hidden_size, epsilon=config.layernorm_epsilon) + self.final_layernorm = LayerNormFunc(config.hidden_size, epsilon=config.layernorm_epsilon, config=config) def get_input_embeddings(self): return self.embedding.word_embeddings diff --git a/paddlenlp/transformers/chatglm_v2/modeling.py b/paddlenlp/transformers/chatglm_v2/modeling.py index 3a8ad0863c1c..cbab8dfaa8b1 100644 --- a/paddlenlp/transformers/chatglm_v2/modeling.py +++ b/paddlenlp/transformers/chatglm_v2/modeling.py @@ -20,7 +20,6 @@ import paddle import paddle.nn as nn import paddle.nn.functional as F -import paddle.tensor as tensor from paddle.distributed import fleet from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker from paddle.distributed.fleet.utils import recompute @@ -46,10 +45,6 @@ except: pass -try: - from paddle.nn.functional.flash_attention import flash_attention -except: - flash_attention = None try: from paddle.incubate.nn.layer.fused_dropout_add import FusedDropoutAdd except: @@ -298,7 +293,6 @@ def __init__(self, config: ChatGLMv2Config, layer_number, device=None): self.enable_recompute = False self.tensor_parallel_degree = config.tensor_parallel_degree self.sequence_parallel = config.sequence_parallel - self.use_flash_attention = config.use_flash_attention if flash_attention else False if config.sequence_parallel: ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear @@ -333,37 +327,6 @@ def __init__(self, config: ChatGLMv2Config, layer_number, device=None): # Output. self.dense = nn.Linear(config.hidden_size, config.hidden_size, bias_attr=config.add_bias_linear) - def _flash_attention(self, q, k, v, attention_mask=None, output_attentions=False): - """ - q: [seq_len, bs, num_head, head_dim] - k: [seq_len, bs, num_head, head_dim] - v: [seq_len, bs, num_head, head_dim] - """ - q = q.transpose([1, 0, 2, 3]) - k = k.transpose([1, 0, 2, 3]) - v = v.transpose([1, 0, 2, 3]) - - with seed_guard_context("local_seed"): - out, weights = flash_attention( - query=q, - key=k, - value=v, - dropout=self.config.attention_dropout, - causal=q.shape[0] != 1, - return_softmax=output_attentions, - training=self.training, - ) - # [bs, seq_len, num_head, head_dim] -> [bs, seq_len, num_head * head_dim] - out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]]) - # [bs, seq_len, num_head * head_dim]-> [seq_len, bs, num_head * head_dim] - out = out.transpose([1, 0, 2]) - - if self.config.sequence_parallel: - sq, bs, hp = out.shape - out = out.reshape([sq * bs, hp]) - - return (out, weights) if output_attentions else out - def _core_attention(self, q, k, v, attention_mask=None, output_attentions=False): outputs = self.core_attention(q, k, v, attention_mask) return outputs @@ -431,10 +394,7 @@ def forward( # ================================== # core attention computation # ================================== - if self.use_flash_attention: - attention_fuc = self._flash_attention - else: - attention_fuc = self._core_attention + attention_fuc = self._core_attention has_gradient = ( (not query_layer.stop_gradient) or (not key_layer.stop_gradient) or (not value_layer.stop_gradient) From 0eb86ef056265859f64dee110bfb5ad3d1e57ce6 Mon Sep 17 00:00:00 2001 From: SevenSamon <1273520759@qq.com> Date: Thu, 29 Aug 2024 03:21:16 +0000 Subject: [PATCH 09/11] add_pp_for_chatglmv2 --- paddlenlp/transformers/__init__.py | 1 + paddlenlp/transformers/chatglm_v2/__init__.py | 3 + paddlenlp/transformers/chatglm_v2/modeling.py | 6 +- .../transformers/chatglm_v2/modeling_pp.py | 292 ++++++++++++++++++ 4 files changed, 299 insertions(+), 3 deletions(-) create mode 100644 paddlenlp/transformers/chatglm_v2/modeling_pp.py diff --git a/paddlenlp/transformers/__init__.py b/paddlenlp/transformers/__init__.py index deddfb976d7d..56fc0a2e3afa 100644 --- a/paddlenlp/transformers/__init__.py +++ b/paddlenlp/transformers/__init__.py @@ -265,6 +265,7 @@ from .chatglm.tokenizer import * from .chatglm_v2.configuration import * from .chatglm_v2.modeling import * +from .chatglm_v2.modeling_pp import * from .chatglm_v2.tokenizer import * from .speecht5.configuration import * from .speecht5.modeling import * diff --git a/paddlenlp/transformers/chatglm_v2/__init__.py b/paddlenlp/transformers/chatglm_v2/__init__.py index 775d34cf85f8..2bbcc4a933b3 100644 --- a/paddlenlp/transformers/chatglm_v2/__init__.py +++ b/paddlenlp/transformers/chatglm_v2/__init__.py @@ -11,3 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from .modeling import * +from .modeling_pp import * +from .tokenizer import * diff --git a/paddlenlp/transformers/chatglm_v2/modeling.py b/paddlenlp/transformers/chatglm_v2/modeling.py index cbab8dfaa8b1..20136afcc5c5 100644 --- a/paddlenlp/transformers/chatglm_v2/modeling.py +++ b/paddlenlp/transformers/chatglm_v2/modeling.py @@ -88,7 +88,7 @@ def parallel_matmul(lm_output, logit_weights, parallel_output): # _c_concat has not grad backwards return paddle.distributed.collective._c_concat(logits, group=model_parallel_group) else: - logits = paddle.matmul(lm_output, logit_weights, transpose_y=True) + logits = paddle.matmul(lm_output, logit_weights, transpose_y=False) return logits @@ -1047,12 +1047,12 @@ def __init__(self, config: ChatGLMv2Config, embedding_weights=None): if vocab_size != config.vocab_size: with get_rng_state_tracker().rng_state(): self.decoder_weight = self.create_parameter( - shape=[vocab_size, config.hidden_size], + shape=[config.hidden_size, vocab_size], dtype=paddle.get_default_dtype(), ) else: self.decoder_weight = self.create_parameter( - shape=[vocab_size, config.hidden_size], dtype=paddle.get_default_dtype() + shape=[config.hidden_size, vocab_size], dtype=paddle.get_default_dtype() ) self.config = config diff --git a/paddlenlp/transformers/chatglm_v2/modeling_pp.py b/paddlenlp/transformers/chatglm_v2/modeling_pp.py new file mode 100644 index 000000000000..83e232a43789 --- /dev/null +++ b/paddlenlp/transformers/chatglm_v2/modeling_pp.py @@ -0,0 +1,292 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle +import paddle.distributed.fleet as fleet +from paddle.distributed.fleet.meta_parallel import LayerDesc, PipelineLayer + +try: + from paddle.distributed.fleet.utils.sequence_parallel_utils import ScatterOp +except: + pass + +from paddlenlp.transformers.model_utils import PipelinePretrainedModel + +from .modeling import ( + ChatGLMv2Config, + Chatglmv2LMHead, + ChatGLMv2PretrainedModel, + ChatGLMv2PretrainingCriterion, + Embedding, + GLMBlock, + RMSNorm, +) + +__all__ = ["ChatGLMv2ForCausalLMPipe"] + + +def get_hcg(): + return fleet.get_hybrid_communicate_group() + + +def get_attr(layer, name): + if getattr(layer, name, None) is not None: + return getattr(layer, name, None) + else: + return get_attr(layer._layer, name) + + +def parse_args(args): + if isinstance(args, tuple): + if len(args) == 6: + hidden_states, attention_mask, position_ids, rotary_pos_emb, kv_cache, use_cache = args + elif len(args) == 5: + hidden_states, attention_mask, position_ids, rotary_pos_emb, kv_cache = args + use_cache = None + elif len(args) == 4: + hidden_states, attention_mask, position_ids, rotary_pos_emb = args + kv_cache = None + use_cache = None + elif len(args) == 3: + hidden_states, attention_mask, position_ids = args + rotary_pos_emb = None + kv_cache = None + use_cache = None + elif len(args) == 2: + hidden_states, attention_mask = args + position_ids = None + rotary_pos_emb = None + kv_cache = None + use_cache = None + else: + hidden_states = args + attention_mask, position_ids, rotary_pos_emb, kv_cache, use_cache = None, None, None, None, None + + if position_ids is not None: + position_ids.stop_gradient = True + + if attention_mask is not None: + attention_mask.stop_gradient = True + + if rotary_pos_emb is not None: + rotary_pos_emb.stop_gradient = True + + if kv_cache is not None: + kv_cache.stop_gradient = True + + if use_cache is not None: + use_cache.stop_gradient = True + + return hidden_states, attention_mask, position_ids, rotary_pos_emb, kv_cache, use_cache + + +def return_args( + hidden_states, attention_mask=None, position_ids=None, rotary_pos_emb=None, kv_cache=None, use_cache=None +): + ret = (hidden_states,) + + if attention_mask is not None: + ret += (attention_mask.clone(),) + if position_ids is not None: + ret += (position_ids.clone(),) + if rotary_pos_emb is not None: + ret += (rotary_pos_emb.clone(),) + if kv_cache is not None: + ret += (kv_cache.clone(),) + if use_cache is not None: + ret += (use_cache.clone(),) + + if len(ret) == 1: + ret = ret[0] + + return ret + + +def forward_impl(self, seq_len: int, n_elem: int, base: int = 10000): + """Enhanced Transformer with Rotary Position Embedding. + Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ + transformers/rope/__init__.py. MIT License: + https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. + """ + # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ + theta = 1.0 / (base ** (paddle.arange(0, n_elem, 2, dtype="float32") / n_elem)) + + # Create position indexes `[0, 1, ..., seq_len - 1]` + seq_idx = paddle.arange(0, seq_len, dtype=theta.dtype) + + # Calculate the product of position index and $\theta_i$ + idx_theta = paddle.outer(seq_idx, theta).astype(self.default_dtype) + + cache = paddle.stack([paddle.cos(idx_theta), paddle.sin(idx_theta)], axis=-1) + + # this is to mimic the behaviour of complex32, else we will get different results + if self.default_dtype in (paddle.float16, paddle.bfloat16, paddle.int8): + cache = cache.astype(self.default_dtype) + # cache = cache.bfloat16() if dtype == paddle.bfloat16 else cache.astype("float16") + return cache + + +class EmbeddingPipe(Embedding): + """Extends Embedding to forward attention_mask through the pipeline.""" + + def __init__(self, config: ChatGLMv2Config): + super().__init__(config) + self.default_dtype = paddle.get_default_dtype() + + @property + def embedding_weight(self): + return get_attr(self.word_embeddings, "weight") + + def forward(self, args): + input_ids, attention_mask, position_ids, rotary_pos_emb, kv_cache, use_cache = parse_args(args) + input_ids.stop_gradient = True + inputs_embeds = super().forward(input_ids=input_ids) + batch_size, seq_length = input_ids.shape + + if self.config.sequence_parallel: + seq_length, batch_size, hidden_size = inputs_embeds.shape + inputs_embeds = paddle.reshape_(inputs_embeds, [batch_size * seq_length, hidden_size]) + # [seq_len * bs / n, num_head * head_dim] (n is mp parallelism) + inputs_embeds = ScatterOp.apply(inputs_embeds) + + if attention_mask is None: + attention_mask = paddle.ones((batch_size, 1, seq_length, seq_length), dtype="bool") + if len(attention_mask.shape) == 2: + # from Tokenizer + attention_mask = ( + attention_mask.unsqueeze(axis=[1, 2]).expand([batch_size, 1, seq_length, seq_length]).astype("bool") + ) + elif len(attention_mask.shape) == 3: + # [batch_size,tgt_length, src_length] -> [batch_size, 1, tgt_length, src_length] + attention_mask = attention_mask.unsqueeze(1).astype("bool") + elif len(attention_mask.shape) == 4: + attention_mask = attention_mask.astype("bool") + + causal_mask = paddle.tril(paddle.ones([batch_size, 1, seq_length, seq_length])).astype("bool") + attention_mask = attention_mask & causal_mask + zero = paddle.zeros(attention_mask.shape, dtype=inputs_embeds.dtype) + neg_inf = paddle.full_like(attention_mask, paddle.finfo(inputs_embeds.dtype).min, dtype=inputs_embeds.dtype) + attention_mask = paddle.where(attention_mask, zero, neg_inf) + # Rotary positional embeddings + self.max_sequence_length = self.config.max_sequence_length + rotary_dim = ( + self.config.hidden_size // self.config.num_attention_heads + if self.config.kv_channels is None + else self.config.kv_channels + ) + rotary_pos_emb = forward_impl(self, self.max_sequence_length, rotary_dim // 2) + if position_ids is not None: + rotary_pos_emb = rotary_pos_emb[position_ids] + else: + rotary_pos_emb = rotary_pos_emb[None, :seq_length] + + rotary_pos_emb = rotary_pos_emb.transpose([1, 0, 2, 3]) + + return return_args(inputs_embeds, attention_mask, position_ids, rotary_pos_emb, kv_cache, use_cache) + + +class GLMBlockPipe(GLMBlock): + """Extends GLMBlock to forward attention_mask through the pipeline.""" + + def forward(self, args): + hidden_states, attention_mask, position_ids, rotary_pos_emb, kv_cache, use_cache = parse_args(args) + hidden_states, kv_cache = super().forward(hidden_states, attention_mask, rotary_pos_emb, kv_cache, use_cache) + return return_args(hidden_states, attention_mask, position_ids, rotary_pos_emb, kv_cache, use_cache) + + +class RMSNormPipe(RMSNorm): + def forward(self, args): + hidden_states, attention_mask, position_ids, rotary_pos_emb, kv_cache, use_cache = parse_args(args) + hidden_states = super().forward(hidden_states) + return hidden_states + + +class Chatglmv2LMHeadPipe(Chatglmv2LMHead): + def __init__(self, config): + super(Chatglmv2LMHeadPipe, self).__init__(config) + + +class ChatGLMv2ForCausalLMPipe(PipelinePretrainedModel, PipelineLayer): + """ChatGLMv2ForPretraining adapted for pipeline parallelism. + + The largest change is flattening the ChatGLMv2Model class so we can express it as a + sequence of layers including embedding, transformer layers, and output. + """ + + config_class = ChatGLMv2Config + + get_masks = ChatGLMv2PretrainedModel.get_masks + _get_tensor_parallel_mappings = ChatGLMv2PretrainedModel._get_tensor_parallel_mappings + init_weights = ChatGLMv2PretrainedModel.init_weights + get_position_ids = ChatGLMv2PretrainedModel.get_position_ids + _get_name_mappings = ChatGLMv2PretrainedModel._get_name_mappings + + # NO base_model_prefix !!!! + + def __init__(self, config): + self.config = config + + virtual_pp_degree = getattr(self.config, "virtual_pp_degree", 1) + + hcg = get_hcg() + tensor_parallel_degree = max(hcg.get_model_parallel_world_size(), 1) + tensor_parallel_rank = max(hcg.get_model_parallel_rank(), 0) + + config.tensor_parallel_degree = tensor_parallel_degree + config.tensor_parallel_rank = tensor_parallel_rank + + self.add_sequential_layer( + LayerDesc(EmbeddingPipe, config=config), + "embedding", + ) + for i in range(config.num_hidden_layers): + self.add_sequential_layer( + LayerDesc(GLMBlockPipe, config=config, layer_number=i), + f"encoder.layers.{i}", + ) + + self.add_sequential_layer( + LayerDesc(RMSNormPipe, hidden_size=config.hidden_size, config=config, epsilon=config.layernorm_epsilon), + "encoder.final_layernorm", + ) + self.add_sequential_layer( + LayerDesc(Chatglmv2LMHeadPipe, config=config), + "output_layer", + ) + + recompute_interval = 0 + # if self.config.recompute and recompute_granularity == "full": + # assert pp_recompute_interval <= config.num_hidden_layers // ( + # virtual_pp_degree * get_hcg().topology().get_dim_size("pipe") + # ), "pp recompute interval should smaller than num layers of each pp chunk" + # recompute_interval = pp_recompute_interval + + seg_method = "layer:GLMBlock" + if config.num_hidden_layers % get_hcg().topology().get_dim_size("pipe") != 0: + seg_method = "uniform" + + PipelineLayer.__init__( + self, + layers=self.get_sequential_layers(), + loss_fn=ChatGLMv2PretrainingCriterion(config), + topology=get_hcg().topology(), + seg_method=seg_method, + recompute_interval=recompute_interval, + recompute_ctx={ + "mp_group": get_hcg().get_model_parallel_group(), + "offload": False, + "partition": False, + }, + num_virtual_pipeline_stages=virtual_pp_degree, + ) + self.apply(self._init_weights) From 96d4cb533de2f85ee9a9eedf0676d36776f597dc Mon Sep 17 00:00:00 2001 From: SevenSamon <1273520759@qq.com> Date: Fri, 20 Sep 2024 07:01:20 +0000 Subject: [PATCH 10/11] update --- .../transformers/chatglm_v2/modeling.py | 5 ++-- paddlenlp/transformers/chatglm_v2/modeling.py | 29 +++++++++---------- .../transformers/chatglm_v2/modeling_pp.py | 2 +- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/paddlenlp/experimental/transformers/chatglm_v2/modeling.py b/paddlenlp/experimental/transformers/chatglm_v2/modeling.py index 74d5fecf6123..1651cab30c43 100644 --- a/paddlenlp/experimental/transformers/chatglm_v2/modeling.py +++ b/paddlenlp/experimental/transformers/chatglm_v2/modeling.py @@ -31,6 +31,7 @@ from paddlenlp.transformers import ChatGLMv2Config, ChatGLMv2PretrainedModel from paddlenlp.transformers.chatglm_v2.modeling import ( Embedding, + LayerNorm, RMSNorm, RotaryEmbedding, ) @@ -183,9 +184,9 @@ def __init__(self, config: ChatGLMv2Config, empty_init=True): self.post_layer_norm = config.post_layer_norm if self.post_layer_norm: - LayerNormFunc = RMSNorm if config.rmsnorm else nn.LayerNorm + LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm # Final layer norm before output. - self.final_layernorm = LayerNormFunc(config.hidden_size, epsilon=config.layernorm_epsilon, config=config) + self.final_layernorm = LayerNormFunc(config) def get_input_embeddings(self): return self.embedding.word_embeddings diff --git a/paddlenlp/transformers/chatglm_v2/modeling.py b/paddlenlp/transformers/chatglm_v2/modeling.py index 20136afcc5c5..e460e0962b73 100644 --- a/paddlenlp/transformers/chatglm_v2/modeling.py +++ b/paddlenlp/transformers/chatglm_v2/modeling.py @@ -149,16 +149,22 @@ def apply_rotary_pos_emb(x: paddle.Tensor, rope_cache: paddle.Tensor) -> paddle. return paddle.concat((x_out2, x_pass), axis=-1) +class LayerNorm(nn.LayerNorm): + def __init__(self, config): + self.config = config + super().__init__(config.hidde_size, epsilon=config.layernorm_epsilon) + + class RMSNorm(nn.Layer): - def __init__(self, hidden_size, config: ChatGLMv2Config, epsilon=None): + def __init__(self, config: ChatGLMv2Config): super().__init__() - self.hidden_size = hidden_size + self.hidden_size = config.hidden_size self.weight = paddle.create_parameter( shape=[self.hidden_size], dtype=paddle.get_default_dtype(), default_initializer=nn.initializer.Constant(1.0), ) - self.epsilon = 1e-5 if epsilon is None else epsilon + self.epsilon = 1e-5 if config.layernorm_epsilon is None else config.layernorm_epsilon if config.sequence_parallel: mark_as_sequence_parallel_parameter(self.weight) @@ -486,18 +492,16 @@ def __init__(self, config: ChatGLMv2Config, layer_number): self.config = config self.fp32_residual_connection = config.fp32_residual_connection - LayerNormFunc = RMSNorm if config.rmsnorm else nn.LayerNorm + LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm # Layernorm on the input data. - self.input_layernorm = LayerNormFunc(config.hidden_size, epsilon=config.layernorm_epsilon, config=config) + self.input_layernorm = LayerNormFunc(config) # Self attention. self.self_attention = SelfAttention(config, layer_number) self.hidden_dropout = config.hidden_dropout # Layernorm on the attention output - self.post_attention_layernorm = LayerNormFunc( - config.hidden_size, epsilon=config.layernorm_epsilon, config=config - ) + self.post_attention_layernorm = LayerNormFunc(config) # MLP self.mlp = MLP(config) @@ -584,9 +588,9 @@ def build_layer(layer_number): self.layers = nn.LayerList([build_layer(i + 1) for i in range(self.num_hidden_layers)]) if self.post_layer_norm: - LayerNormFunc = RMSNorm if config.rmsnorm else nn.LayerNorm + LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm # Final layer norm before output. - self.final_layernorm = LayerNormFunc(config.hidden_size, epsilon=config.layernorm_epsilon, config=config) + self.final_layernorm = LayerNormFunc(config) def _get_layer(self, layer_number): return self.layers[layer_number] @@ -786,11 +790,6 @@ def _get_name_mappings(cls, config: ChatGLMv2Config) -> List[StateDictNameMappin ] ) - # for mapping in mappings: - # mapping[0] = "transformer." + mapping[0] - # if len(mapping) > 1 and mapping[1] is not None: - # mapping[1] = "chatglm_v2." + mapping[1] - init_name_mappings(mappings) return [StateDictNameMapping(*mapping) for mapping in mappings] diff --git a/paddlenlp/transformers/chatglm_v2/modeling_pp.py b/paddlenlp/transformers/chatglm_v2/modeling_pp.py index 83e232a43789..07dc182732af 100644 --- a/paddlenlp/transformers/chatglm_v2/modeling_pp.py +++ b/paddlenlp/transformers/chatglm_v2/modeling_pp.py @@ -256,7 +256,7 @@ def __init__(self, config): ) self.add_sequential_layer( - LayerDesc(RMSNormPipe, hidden_size=config.hidden_size, config=config, epsilon=config.layernorm_epsilon), + LayerDesc(RMSNormPipe, config=config), "encoder.final_layernorm", ) self.add_sequential_layer( From 8fc4c8a19d1f5bbb07f6920cb572bf179086c8a6 Mon Sep 17 00:00:00 2001 From: SevenSamon <1273520759@qq.com> Date: Fri, 20 Sep 2024 07:43:18 +0000 Subject: [PATCH 11/11] update --- paddlenlp/transformers/chatglm_v2/modeling.py | 1 - 1 file changed, 1 deletion(-) diff --git a/paddlenlp/transformers/chatglm_v2/modeling.py b/paddlenlp/transformers/chatglm_v2/modeling.py index e2cbfce08478..08e32e4e55c5 100644 --- a/paddlenlp/transformers/chatglm_v2/modeling.py +++ b/paddlenlp/transformers/chatglm_v2/modeling.py @@ -802,7 +802,6 @@ def _get_tensor_parallel_mappings(cls, config, is_split=True): from paddlenlp.transformers.conversion_utils import split_or_merge_func - def split_or_merge_qkv_weights(tensor_parallel_degree, tensor_parallel_rank, hidden_size, is_split, tensor): if is_split: return split_qkv_weights(tensor_parallel_degree, tensor_parallel_rank, hidden_size, tensor)