From b93c57a7e2bf69fde66b42d53acca7027df7bc36 Mon Sep 17 00:00:00 2001 From: YKTian-x2b <2084984251@qq.com> Date: Fri, 24 May 2024 03:03:28 +0000 Subject: [PATCH 01/21] DiT FFN fineGrained --- ppdiffusers/ppdiffusers/models/dit_llama.py | 68 ++++++++++++++++++++- 1 file changed, 67 insertions(+), 1 deletion(-) diff --git a/ppdiffusers/ppdiffusers/models/dit_llama.py b/ppdiffusers/ppdiffusers/models/dit_llama.py index 4c1b55f6a..16b8534a0 100644 --- a/ppdiffusers/ppdiffusers/models/dit_llama.py +++ b/ppdiffusers/ppdiffusers/models/dit_llama.py @@ -290,6 +290,72 @@ def forward(self, x): return output +class FeedForward_kai(nn.Layer): + def __init__(self, dim, hidden_dim, multiple_of=256, ffn_dim_multiplier=None): + """ + Initialize the FeedForward module. + + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension of the feedforward layer. + multiple_of (int): Value to ensure hidden dimension is a multiple + of this value. + ffn_dim_multiplier (float, optional): Custom multiplier for hidden + dimension. Defaults to None. + + Attributes: + w1 (nn.Linear): Linear transformation for the first + layer. + w2 (nn.Linear): Linear transformation for the second layer. + w3 (nn.Linear): Linear transformation for the third + layer. + + """ + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = int(multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)) + + self.w1 = nn.Linear(dim, hidden_dim, bias_attr=False) + self.w2 = nn.Linear(hidden_dim, dim, bias_attr=False) + self.w3 = nn.Linear(dim, hidden_dim, bias_attr=False) + + self.first_run = True + self.concat_weight = None + + def compute_activation(self, ffn1_out): + origin_batch_size = ffn1_out.shape[0] + origin_seq_len = ffn1_out.shape[1] + ffn1_out = ffn1_out.reshape([origin_batch_size*origin_seq_len, ffn1_out.shape[-1]]) + res = paddle._C_ops.fused_bias_act( + ffn1_out, + None, + None, + None, + None, + "swiglu", + "default", + -1, + 0, + 0, + 0 + ) + return res.reshape([origin_batch_size, origin_seq_len, res.shape[-1]]) + + def forward(self, x): + if self.first_run: + self.first_run = False + self.concat_weight = paddle.concat([self.w1.weight, self.w3.weight], axis=-1) + del self.w1.weight + del self.w3.weight + + ffn1_out = paddle.matmul(x, self.concat_weight) + ffn1_out = self.compute_activation(ffn1_out) + ffn2_out = paddle.matmul(ffn1_out, self.w2.weight) + return ffn2_out + + class TransformerBlock(nn.Layer): def __init__( self, @@ -339,7 +405,7 @@ def __init__( self.head_dim = dim // n_heads self.attention = Attention(dim, n_heads, n_kv_heads, qk_norm, fused_attn) mlp_hidden_dim = int(dim * mlp_ratio) - self.feed_forward = FeedForward( + self.feed_forward = FeedForward_kai( dim=dim, hidden_dim=mlp_hidden_dim, multiple_of=multiple_of, ffn_dim_multiplier=ffn_dim_multiplier ) self.layer_id = layer_id From bca34842fe2472aadfab6459e5b65b29aa8845eb Mon Sep 17 00:00:00 2001 From: YKTian-x2b <2084984251@qq.com> Date: Fri, 24 May 2024 03:13:10 +0000 Subject: [PATCH 02/21] DiT FFN fineGrained --- ppdiffusers/ppdiffusers/models/dit_llama.py | 40 +-------------------- 1 file changed, 1 insertion(+), 39 deletions(-) diff --git a/ppdiffusers/ppdiffusers/models/dit_llama.py b/ppdiffusers/ppdiffusers/models/dit_llama.py index 16b8534a0..615aab025 100644 --- a/ppdiffusers/ppdiffusers/models/dit_llama.py +++ b/ppdiffusers/ppdiffusers/models/dit_llama.py @@ -253,44 +253,6 @@ def forward(self, x, freqs_cis): class FeedForward(nn.Layer): - def __init__(self, dim, hidden_dim, multiple_of=256, ffn_dim_multiplier=None): - """ - Initialize the FeedForward module. - - Args: - dim (int): Input dimension. - hidden_dim (int): Hidden dimension of the feedforward layer. - multiple_of (int): Value to ensure hidden dimension is a multiple - of this value. - ffn_dim_multiplier (float, optional): Custom multiplier for hidden - dimension. Defaults to None. - - Attributes: - w1 (nn.Linear): Linear transformation for the first - layer. - w2 (nn.Linear): Linear transformation for the second layer. - w3 (nn.Linear): Linear transformation for the third - layer. - - """ - super().__init__() - hidden_dim = int(2 * hidden_dim / 3) - if ffn_dim_multiplier is not None: - hidden_dim = int(ffn_dim_multiplier * hidden_dim) - hidden_dim = int(multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)) - - self.w1 = nn.Linear(dim, hidden_dim, bias_attr=False) - self.w2 = nn.Linear(hidden_dim, dim, bias_attr=False) - self.w3 = nn.Linear(dim, hidden_dim, bias_attr=False) - - def forward(self, x): - xw1 = F.silu(self.w1(x)) - xw3 = self.w3(x) - output = self.w2(xw1 * xw3) - return output - - -class FeedForward_kai(nn.Layer): def __init__(self, dim, hidden_dim, multiple_of=256, ffn_dim_multiplier=None): """ Initialize the FeedForward module. @@ -405,7 +367,7 @@ def __init__( self.head_dim = dim // n_heads self.attention = Attention(dim, n_heads, n_kv_heads, qk_norm, fused_attn) mlp_hidden_dim = int(dim * mlp_ratio) - self.feed_forward = FeedForward_kai( + self.feed_forward = FeedForward( dim=dim, hidden_dim=mlp_hidden_dim, multiple_of=multiple_of, ffn_dim_multiplier=ffn_dim_multiplier ) self.layer_id = layer_id From 6771b36c575f23fe40a0a41c1f5d2918e87734dc Mon Sep 17 00:00:00 2001 From: YKTian-x2b <2084984251@qq.com> Date: Fri, 31 May 2024 02:04:46 +0000 Subject: [PATCH 03/21] clear fine_grained_FFN --- ppdiffusers/ppdiffusers/models/dit_llama.py | 110 ++++++++++++------ .../ppdiffusers/models/modeling_utils.py | 5 + 2 files changed, 82 insertions(+), 33 deletions(-) diff --git a/ppdiffusers/ppdiffusers/models/dit_llama.py b/ppdiffusers/ppdiffusers/models/dit_llama.py index 615aab025..4f986cd47 100644 --- a/ppdiffusers/ppdiffusers/models/dit_llama.py +++ b/ppdiffusers/ppdiffusers/models/dit_llama.py @@ -252,6 +252,43 @@ def forward(self, x, freqs_cis): return self.wo(output) +class FeedForward_kai(nn.Layer): + def __init__(self, dim, hidden_dim, multiple_of=256, ffn_dim_multiplier=None): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = int(multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)) + + self.w13 = nn.Linear(dim, hidden_dim * 2, bias_attr=False) + self.w2 = nn.Linear(hidden_dim, dim, bias_attr=False) + + def compute_activation(self, ffn1_out): + origin_batch_size = ffn1_out.shape[0] + origin_seq_len = ffn1_out.shape[1] + ffn1_out = ffn1_out.reshape([origin_batch_size*origin_seq_len, ffn1_out.shape[-1]]) + res = paddle._C_ops.fused_bias_act( + ffn1_out, + None, + None, + None, + None, + "swiglu", + "default", + -1, + 0, + 0, + 0 + ) + return res.reshape([origin_batch_size, origin_seq_len, res.shape[-1]]) + + def forward(self, x): + ffn1_out = self.w13(x) + ffn1_out = self.compute_activation(ffn1_out) + ffn2_out = self.w2(ffn1_out) + return ffn2_out + + class FeedForward(nn.Layer): def __init__(self, dim, hidden_dim, multiple_of=256, ffn_dim_multiplier=None): """ @@ -283,39 +320,11 @@ def __init__(self, dim, hidden_dim, multiple_of=256, ffn_dim_multiplier=None): self.w2 = nn.Linear(hidden_dim, dim, bias_attr=False) self.w3 = nn.Linear(dim, hidden_dim, bias_attr=False) - self.first_run = True - self.concat_weight = None - - def compute_activation(self, ffn1_out): - origin_batch_size = ffn1_out.shape[0] - origin_seq_len = ffn1_out.shape[1] - ffn1_out = ffn1_out.reshape([origin_batch_size*origin_seq_len, ffn1_out.shape[-1]]) - res = paddle._C_ops.fused_bias_act( - ffn1_out, - None, - None, - None, - None, - "swiglu", - "default", - -1, - 0, - 0, - 0 - ) - return res.reshape([origin_batch_size, origin_seq_len, res.shape[-1]]) - def forward(self, x): - if self.first_run: - self.first_run = False - self.concat_weight = paddle.concat([self.w1.weight, self.w3.weight], axis=-1) - del self.w1.weight - del self.w3.weight - - ffn1_out = paddle.matmul(x, self.concat_weight) - ffn1_out = self.compute_activation(ffn1_out) - ffn2_out = paddle.matmul(ffn1_out, self.w2.weight) - return ffn2_out + xw1 = F.silu(self.w1(x)) + xw3 = self.w3(x) + output = self.w2(xw1 * xw3) + return output class TransformerBlock(nn.Layer): @@ -367,7 +376,7 @@ def __init__( self.head_dim = dim // n_heads self.attention = Attention(dim, n_heads, n_kv_heads, qk_norm, fused_attn) mlp_hidden_dim = int(dim * mlp_ratio) - self.feed_forward = FeedForward( + self.feed_forward = FeedForward_kai( dim=dim, hidden_dim=mlp_hidden_dim, multiple_of=multiple_of, ffn_dim_multiplier=ffn_dim_multiplier ) self.layer_id = layer_id @@ -602,3 +611,38 @@ def forward( return (output,) return Transformer2DModelOutput(sample=output) + + @classmethod + def custom_modify_weight(cls, state_dict): + # print("kai==================================") + # print(state_dict.keys()) + import re + w1_pattern = r"layers\.(\d+)\.feed_forward\.w1.weight$" + w3_pattern = r"layers\.(\d+)\.feed_forward\.w3.weight$" + keys_to_add = [] + w1_keys_to_del = [] + w3_keys_to_del = [] + for key in state_dict.keys(): + if re.match(w1_pattern, key): + w1_keys_to_del.append(key) + w3_match = re.match(w3_pattern, key) + if w3_match: + w13_key ='layers.' + w3_match.group(1) + '.feed_forward.w13.weight' + keys_to_add.append(w13_key) + w3_keys_to_del.append(key) + + assert len(keys_to_add) == len(w1_keys_to_del) == len(w3_keys_to_del) + + for ii in range(len(keys_to_add)): + w13_key = keys_to_add[ii] + w1_key = w1_keys_to_del[ii] + w3_key = w3_keys_to_del[ii] + state_dict[w13_key] = paddle.concat([state_dict[w1_key], state_dict[w3_key]], axis=1) + state_dict.pop(w3_key) + state_dict.pop(w1_key) + + # print(state_dict.keys()) + # exit() + + + diff --git a/ppdiffusers/ppdiffusers/models/modeling_utils.py b/ppdiffusers/ppdiffusers/models/modeling_utils.py index e4462d284..3bcb620a6 100644 --- a/ppdiffusers/ppdiffusers/models/modeling_utils.py +++ b/ppdiffusers/ppdiffusers/models/modeling_utils.py @@ -1050,6 +1050,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P return model + @classmethod + def custom_modify_weight(cls, state_dict): + pass + @classmethod def _load_pretrained_model( cls, @@ -1130,6 +1134,7 @@ def _find_mismatched_keys( error_msgs.append( f"Error size mismatch, {key_name} receives a shape {loaded_shape}, but the expected shape is {model_shape}." ) + cls.custom_modify_weight(state_dict) faster_set_state_dict(model_to_load, state_dict) missing_keys = sorted(list(set(expected_keys) - set(loaded_keys))) From 668f1aca045e07063acefab64c192dc79e5884b2 Mon Sep 17 00:00:00 2001 From: YKTian-x2b <2084984251@qq.com> Date: Fri, 7 Jun 2024 07:04:35 +0000 Subject: [PATCH 04/21] decorator + fineGrained_qkv_ffn + triton_adaLN_fusedAdaLN --- ppdiffusers/ppdiffusers/models/dit_llama.py | 178 ++++++++++++++------ 1 file changed, 127 insertions(+), 51 deletions(-) diff --git a/ppdiffusers/ppdiffusers/models/dit_llama.py b/ppdiffusers/ppdiffusers/models/dit_llama.py index 4f986cd47..5b0a89399 100644 --- a/ppdiffusers/ppdiffusers/models/dit_llama.py +++ b/ppdiffusers/ppdiffusers/models/dit_llama.py @@ -25,6 +25,8 @@ from .modeling_utils import ModelMixin from .transformer_2d import Transformer2DModelOutput +from paddle.framework import LayerHelper, in_dynamic_mode +from paddle.incubate.tt import adaptive_layer_norm, fused_adaLN_scale_residual def TypePromote(x, y): TYPE_PROMOTE_DICT = { @@ -120,9 +122,7 @@ def __init__(self, dim, n_heads, n_kv_heads, qk_norm=True, fused_attn=True): self.n_rep = self.n_local_heads // self.n_local_kv_heads self.head_dim = dim // n_heads - self.wq = nn.Linear(dim, n_heads * self.head_dim, bias_attr=False) - self.wk = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias_attr=False) - self.wv = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias_attr=False) + self.qkv = nn.Linear(dim, (n_heads + 2 * self.n_kv_heads) * self.head_dim, bias_attr=False) self.wo = nn.Linear(n_heads * self.head_dim, dim, bias_attr=False) if qk_norm: @@ -184,13 +184,13 @@ def apply_rotary_emb(xq, xk, freqs_cis): Tuple[paddle.Tensor, paddle.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. """ - with paddle.amp.auto_cast(enable=False): - xq_ = paddle.as_complex(xq.cast("float32").reshape([*tuple(xq.shape)[:-1], -1, 2])) - xk_ = paddle.as_complex(xk.cast("float32").reshape([*tuple(xk.shape)[:-1], -1, 2])) - freqs_cis = Attention.reshape_for_broadcast(freqs_cis, xq_) - xq_out = paddle.as_real(xq_ * freqs_cis).flatten(start_axis=3) - xk_out = paddle.as_real(xk_ * freqs_cis).flatten(start_axis=3) - return xq_out.cast(xq.dtype), xk_out.cast(xk.dtype) + # with paddle.amp.auto_cast(enable=False): + xq_ = paddle.as_complex(xq.cast("float32").reshape([*tuple(xq.shape)[:-1], -1, 2])) + xk_ = paddle.as_complex(xk.cast("float32").reshape([*tuple(xk.shape)[:-1], -1, 2])) + freqs_cis = Attention.reshape_for_broadcast(freqs_cis, xq_) + xq_out = paddle.as_real(xq_ * freqs_cis).flatten(start_axis=3) + xk_out = paddle.as_real(xk_ * freqs_cis).flatten(start_axis=3) + return xq_out.cast(xq.dtype), xk_out.cast(xk.dtype) def forward(self, x, freqs_cis): """ @@ -205,7 +205,10 @@ def forward(self, x, freqs_cis): """ bsz, seqlen, _ = tuple(x.shape) - xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + + qkv_out = self.qkv(x) + xq, xk, xv = paddle.split(qkv_out, 3, axis=-1) + dtype = xq.dtype xq = self.q_norm(xq) @@ -263,24 +266,57 @@ def __init__(self, dim, hidden_dim, multiple_of=256, ffn_dim_multiplier=None): self.w13 = nn.Linear(dim, hidden_dim * 2, bias_attr=False) self.w2 = nn.Linear(hidden_dim, dim, bias_attr=False) - def compute_activation(self, ffn1_out): + def compute_activation(self, + ffn1_out, + bias=None, + dequant_scales=None, + shift=None, + smooth=None, + act_method="swiglu", + compute_dtype="default", + quant_scale=-1, + quant_round_type=0, + quant_max_bound=0, + quant_min_bound=0): origin_batch_size = ffn1_out.shape[0] origin_seq_len = ffn1_out.shape[1] ffn1_out = ffn1_out.reshape([origin_batch_size*origin_seq_len, ffn1_out.shape[-1]]) - res = paddle._C_ops.fused_bias_act( - ffn1_out, - None, - None, - None, - None, - "swiglu", - "default", - -1, - 0, - 0, - 0 + if in_dynamic_mode(): + out = paddle._C_ops.fused_bias_act( + ffn1_out, + bias, + dequant_scales, + shift, + smooth, + act_method, + compute_dtype, + quant_scale, + quant_round_type, + quant_max_bound, + quant_min_bound + ) + return out.reshape([origin_batch_size, origin_seq_len, out.shape[-1]]) + + helper = LayerHelper("fused_bias_act") + out = helper.create_variable_for_type_inference(dtype=ffn1_out.dtype) + inputs = {} + inputs["x"] = ffn1_out + attrs = { + "act_method": act_method, + "compute_dtype": compute_dtype, + "quant_scale": quant_scale, + "quant_round_type": quant_round_type, + "quant_max_bound": quant_max_bound, + "quant_min_bound": quant_min_bound, + } + helper.append_op( + type="fused_bias_act", + inputs=inputs, + outputs={"out": out}, + attrs=attrs, ) - return res.reshape([origin_batch_size, origin_seq_len, res.shape[-1]]) + return out.reshape([origin_batch_size, origin_seq_len, out.shape[-1]]) + def forward(self, x): ffn1_out = self.w13(x) @@ -387,6 +423,7 @@ def __init__( nn.Silu(), nn.Linear(min(dim, 1024), 6 * dim), ) + self.norm_eps = norm_eps def forward(self, x, freqs_cis, adaln_input=None): """ @@ -407,10 +444,12 @@ def forward(self, x, freqs_cis, adaln_input=None): shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk( 6, axis=1 ) - h = x + gate_msa.unsqueeze(1) * self.attention( - modulate(self.attention_norm(x), shift_msa, scale_msa), freqs_cis + attention_out = self.attention( + adaptive_layer_norm(x, scale_msa, shift_msa, weight=self.attention_norm.weight, epsilon=self.norm_eps), freqs_cis ) - out = h + gate_mlp.unsqueeze(1) * self.feed_forward(modulate(self.ffn_norm(h), shift_mlp, scale_mlp)) + residual_out, adaLN_out = fused_adaLN_scale_residual(x, attention_out, gate_msa, scale_mlp, shift_mlp, + weight=self.ffn_norm.weight, epsilon=self.norm_eps) + out = residual_out + gate_mlp.unsqueeze(1) * self.feed_forward(adaLN_out) else: h = x + self.attention(self.attention_norm(x), freqs_cis) out = h + self.feed_forward(self.ffn_norm(h)) @@ -498,6 +537,7 @@ def __init__( for idx in range(num_layers) ] ) + # del self.layers # 3. Define output layers self.final_layer = FinalLayer(dim, patch_size, self.out_channels) @@ -568,6 +608,21 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): ) return freqs_cis + @paddle.incubate.layers.inference(with_trt=False, + cache_static_model=True, + collect_shape=False) + def transformer_blocks(self, x, adaln_input): + for i, layer in enumerate(self.layers): + if self.gradient_checkpointing and False: + x = paddle.distributed.fleet.utils.recompute(layer, x, self.freqs_cis[: x.shape[1]], adaln_input) + else: + x = layer( + x, + self.freqs_cis[: x.shape[1]], + adaln_input, + ) + return x + def forward( self, hidden_states: paddle.Tensor, @@ -593,15 +648,16 @@ def forward( adaln_input = t + y # 2. Blocks - for i, layer in enumerate(self.layers): - if self.gradient_checkpointing: - x = paddle.distributed.fleet.utils.recompute(layer, x, self.freqs_cis[: x.shape[1]], adaln_input) - else: - x = layer( - x, - self.freqs_cis[: x.shape[1]], - adaln_input, - ) + x = self.transformer_blocks(x, adaln_input) + # for i, layer in enumerate(self.layers): + # if self.gradient_checkpointing: + # x = paddle.distributed.fleet.utils.recompute(layer, x, self.freqs_cis[: x.shape[1]], adaln_input) + # else: + # x = layer( + # x, + # self.freqs_cis[: x.shape[1]], + # adaln_input, + # ) # 3. Output hidden_states = self.final_layer(x, adaln_input) @@ -614,35 +670,55 @@ def forward( @classmethod def custom_modify_weight(cls, state_dict): - # print("kai==================================") - # print(state_dict.keys()) import re - w1_pattern = r"layers\.(\d+)\.feed_forward\.w1.weight$" - w3_pattern = r"layers\.(\d+)\.feed_forward\.w3.weight$" - keys_to_add = [] + w1_pattern = r"layers\.(\d+)\.feed_forward\.w1\.weight$" + w3_pattern = r"layers\.(\d+)\.feed_forward\.w3\.weight$" + wq_pattern = r"layers\.(\d+)\.attention\.wq\.weight$" + wk_pattern = r"layers\.(\d+)\.attention\.wk\.weight$" + wv_pattern = r"layers\.(\d+)\.attention\.wv\.weight$" + + w13_keys_to_add = [] w1_keys_to_del = [] w3_keys_to_del = [] + qkv_keys_to_add = [] + wq_keys_to_del = [] + wk_keys_to_del = [] + wv_keys_to_del = [] + for key in state_dict.keys(): if re.match(w1_pattern, key): w1_keys_to_del.append(key) w3_match = re.match(w3_pattern, key) if w3_match: w13_key ='layers.' + w3_match.group(1) + '.feed_forward.w13.weight' - keys_to_add.append(w13_key) + w13_keys_to_add.append(w13_key) w3_keys_to_del.append(key) + if re.match(wq_pattern, key): + wq_keys_to_del.append(key) + if re.match(wk_pattern, key): + wk_keys_to_del.append(key) + wv_match = re.match(wv_pattern, key) + if(wv_match): + qkv_key = 'layers.' + wv_match.group(1) + '.attention.qkv.weight' + qkv_keys_to_add.append(qkv_key) + wv_keys_to_del.append(key) - assert len(keys_to_add) == len(w1_keys_to_del) == len(w3_keys_to_del) + assert len(w13_keys_to_add) == len(w1_keys_to_del) == len(w3_keys_to_del) \ + == len(qkv_keys_to_add) == len(wq_keys_to_del) == len(wk_keys_to_del) == len(wv_keys_to_del) - for ii in range(len(keys_to_add)): - w13_key = keys_to_add[ii] + for ii in range(len(w13_keys_to_add)): + w13_key = w13_keys_to_add[ii] w1_key = w1_keys_to_del[ii] w3_key = w3_keys_to_del[ii] state_dict[w13_key] = paddle.concat([state_dict[w1_key], state_dict[w3_key]], axis=1) state_dict.pop(w3_key) state_dict.pop(w1_key) - - # print(state_dict.keys()) - # exit() - - + wq_key = wq_keys_to_del[ii] + wk_key = wk_keys_to_del[ii] + wv_key = wv_keys_to_del[ii] + qkv_key = qkv_keys_to_add[ii] + state_dict[qkv_key] = paddle.concat([state_dict[wq_key], state_dict[wk_key], state_dict[wv_key]], 1) + state_dict.pop(wq_key) + state_dict.pop(wk_key) + state_dict.pop(wv_key) From cb8bacb804390b22a181e697a073179fdd4cbc31 Mon Sep 17 00:00:00 2001 From: YKTian-x2b <2084984251@qq.com> Date: Wed, 12 Jun 2024 02:55:44 +0000 Subject: [PATCH 05/21] clear up pr ing... --- ppdiffusers/ppdiffusers/models/dit_llama.py | 181 +++++++------------- 1 file changed, 61 insertions(+), 120 deletions(-) diff --git a/ppdiffusers/ppdiffusers/models/dit_llama.py b/ppdiffusers/ppdiffusers/models/dit_llama.py index 5b0a89399..575ad1e22 100644 --- a/ppdiffusers/ppdiffusers/models/dit_llama.py +++ b/ppdiffusers/ppdiffusers/models/dit_llama.py @@ -107,9 +107,7 @@ def __init__(self, dim, n_heads, n_kv_heads, qk_norm=True, fused_attn=True): n_local_kv_heads (int): Number of local key and value heads. n_rep (int): Number of repetitions for local heads. head_dim (int): Dimension size of each attention head. - wq (nn.Linear): Linear transformation for queries. - wk (nn.Linear): Linear transformation for keys. - wv (nn.Linear): Linear transformation for values. + qkv (nn.Linear): Linear transformation for queries, keys and values. wo (nn.Linear): Linear transformation for output. cache_k (paddle.Tensor): Cached keys for attention. cache_v (paddle.Tensor): Cached values for attention. @@ -255,8 +253,24 @@ def forward(self, x, freqs_cis): return self.wo(output) -class FeedForward_kai(nn.Layer): +class FeedForward(nn.Layer): def __init__(self, dim, hidden_dim, multiple_of=256, ffn_dim_multiplier=None): + """ + Initialize the FeedForward module. + + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension of the feedforward layer. + multiple_of (int): Value to ensure hidden dimension is a multiple + of this value. + ffn_dim_multiplier (float, optional): Custom multiplier for hidden + dimension. Defaults to None. + + Attributes: + w13 (nn.Linear): Linear transformation for the first layer and the third layer. + w2 (nn.Linear): Linear transformation for the second layer. + + """ super().__init__() hidden_dim = int(2 * hidden_dim / 3) if ffn_dim_multiplier is not None: @@ -265,7 +279,7 @@ def __init__(self, dim, hidden_dim, multiple_of=256, ffn_dim_multiplier=None): self.w13 = nn.Linear(dim, hidden_dim * 2, bias_attr=False) self.w2 = nn.Linear(hidden_dim, dim, bias_attr=False) - + def compute_activation(self, ffn1_out, bias=None, @@ -325,44 +339,6 @@ def forward(self, x): return ffn2_out -class FeedForward(nn.Layer): - def __init__(self, dim, hidden_dim, multiple_of=256, ffn_dim_multiplier=None): - """ - Initialize the FeedForward module. - - Args: - dim (int): Input dimension. - hidden_dim (int): Hidden dimension of the feedforward layer. - multiple_of (int): Value to ensure hidden dimension is a multiple - of this value. - ffn_dim_multiplier (float, optional): Custom multiplier for hidden - dimension. Defaults to None. - - Attributes: - w1 (nn.Linear): Linear transformation for the first - layer. - w2 (nn.Linear): Linear transformation for the second layer. - w3 (nn.Linear): Linear transformation for the third - layer. - - """ - super().__init__() - hidden_dim = int(2 * hidden_dim / 3) - if ffn_dim_multiplier is not None: - hidden_dim = int(ffn_dim_multiplier * hidden_dim) - hidden_dim = int(multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)) - - self.w1 = nn.Linear(dim, hidden_dim, bias_attr=False) - self.w2 = nn.Linear(hidden_dim, dim, bias_attr=False) - self.w3 = nn.Linear(dim, hidden_dim, bias_attr=False) - - def forward(self, x): - xw1 = F.silu(self.w1(x)) - xw3 = self.w3(x) - output = self.w2(xw1 * xw3) - return output - - class TransformerBlock(nn.Layer): def __init__( self, @@ -412,7 +388,7 @@ def __init__( self.head_dim = dim // n_heads self.attention = Attention(dim, n_heads, n_kv_heads, qk_norm, fused_attn) mlp_hidden_dim = int(dim * mlp_ratio) - self.feed_forward = FeedForward_kai( + self.feed_forward = FeedForward( dim=dim, hidden_dim=mlp_hidden_dim, multiple_of=multiple_of, ffn_dim_multiplier=ffn_dim_multiplier ) self.layer_id = layer_id @@ -608,20 +584,20 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): ) return freqs_cis - @paddle.incubate.layers.inference(with_trt=False, - cache_static_model=True, - collect_shape=False) - def transformer_blocks(self, x, adaln_input): - for i, layer in enumerate(self.layers): - if self.gradient_checkpointing and False: - x = paddle.distributed.fleet.utils.recompute(layer, x, self.freqs_cis[: x.shape[1]], adaln_input) - else: - x = layer( - x, - self.freqs_cis[: x.shape[1]], - adaln_input, - ) - return x + # @paddle.incubate.layers.inference(with_trt=False, + # cache_static_model=True, + # collect_shape=False) + # def transformer_blocks(self, x, adaln_input): + # for i, layer in enumerate(self.layers): + # if self.gradient_checkpointing and False: + # x = paddle.distributed.fleet.utils.recompute(layer, x, self.freqs_cis[: x.shape[1]], adaln_input) + # else: + # x = layer( + # x, + # self.freqs_cis[: x.shape[1]], + # adaln_input, + # ) + # return x def forward( self, @@ -648,16 +624,16 @@ def forward( adaln_input = t + y # 2. Blocks - x = self.transformer_blocks(x, adaln_input) - # for i, layer in enumerate(self.layers): - # if self.gradient_checkpointing: - # x = paddle.distributed.fleet.utils.recompute(layer, x, self.freqs_cis[: x.shape[1]], adaln_input) - # else: - # x = layer( - # x, - # self.freqs_cis[: x.shape[1]], - # adaln_input, - # ) + # x = self.transformer_blocks(x, adaln_input) + for i, layer in enumerate(self.layers): + if self.gradient_checkpointing: + x = paddle.distributed.fleet.utils.recompute(layer, x, self.freqs_cis[: x.shape[1]], adaln_input) + else: + x = layer( + x, + self.freqs_cis[: x.shape[1]], + adaln_input, + ) # 3. Output hidden_states = self.final_layer(x, adaln_input) @@ -670,55 +646,20 @@ def forward( @classmethod def custom_modify_weight(cls, state_dict): - import re - w1_pattern = r"layers\.(\d+)\.feed_forward\.w1\.weight$" - w3_pattern = r"layers\.(\d+)\.feed_forward\.w3\.weight$" - wq_pattern = r"layers\.(\d+)\.attention\.wq\.weight$" - wk_pattern = r"layers\.(\d+)\.attention\.wk\.weight$" - wv_pattern = r"layers\.(\d+)\.attention\.wv\.weight$" - - w13_keys_to_add = [] - w1_keys_to_del = [] - w3_keys_to_del = [] - qkv_keys_to_add = [] - wq_keys_to_del = [] - wk_keys_to_del = [] - wv_keys_to_del = [] - - for key in state_dict.keys(): - if re.match(w1_pattern, key): - w1_keys_to_del.append(key) - w3_match = re.match(w3_pattern, key) - if w3_match: - w13_key ='layers.' + w3_match.group(1) + '.feed_forward.w13.weight' - w13_keys_to_add.append(w13_key) - w3_keys_to_del.append(key) - if re.match(wq_pattern, key): - wq_keys_to_del.append(key) - if re.match(wk_pattern, key): - wk_keys_to_del.append(key) - wv_match = re.match(wv_pattern, key) - if(wv_match): - qkv_key = 'layers.' + wv_match.group(1) + '.attention.qkv.weight' - qkv_keys_to_add.append(qkv_key) - wv_keys_to_del.append(key) - - assert len(w13_keys_to_add) == len(w1_keys_to_del) == len(w3_keys_to_del) \ - == len(qkv_keys_to_add) == len(wq_keys_to_del) == len(wk_keys_to_del) == len(wv_keys_to_del) - - for ii in range(len(w13_keys_to_add)): - w13_key = w13_keys_to_add[ii] - w1_key = w1_keys_to_del[ii] - w3_key = w3_keys_to_del[ii] - state_dict[w13_key] = paddle.concat([state_dict[w1_key], state_dict[w3_key]], axis=1) - state_dict.pop(w3_key) - state_dict.pop(w1_key) - - wq_key = wq_keys_to_del[ii] - wk_key = wk_keys_to_del[ii] - wv_key = wv_keys_to_del[ii] - qkv_key = qkv_keys_to_add[ii] - state_dict[qkv_key] = paddle.concat([state_dict[wq_key], state_dict[wk_key], state_dict[wv_key]], 1) - state_dict.pop(wq_key) - state_dict.pop(wk_key) - state_dict.pop(wv_key) + for key in list(state_dict.keys()): + if 'feed_forward.w1.weight' in key: + w1 = state_dict.pop(key) + w3_key = key.replace('w1', 'w3') + w3 = state_dict.pop(w3_key) + w13 = paddle.concat([w1, w3], axis=1) + state_dict[key.replace('w1', 'w13')] = w13 + if 'attention.wq.weight' in key or 'attention.wk.weight' in key or 'attention.wv.weight' in key: + part = key.split('.')[-2] + layer_id = key.split('.')[1] + qkv_key = f'layers.{layer_id}.attention.qkv.weight' + if part == 'wq' and qkv_key not in state_dict: + state_dict[qkv_key] = state_dict.pop(key) + elif part in ('wk', 'wv'): + qkv = state_dict.get(qkv_key) + if qkv is not None: + state_dict[qkv_key] = paddle.concat([qkv, state_dict.pop(key)], axis=1) From bdefd3b5b4bad9469faca60e9c5d95bb2828bc90 Mon Sep 17 00:00:00 2001 From: YKTian-x2b <2084984251@qq.com> Date: Wed, 12 Jun 2024 10:05:51 +0000 Subject: [PATCH 06/21] Optional acceleration --- ...nditional_image_generation-large_dit_7b.py | 10 +- ppdiffusers/ppdiffusers/models/dit_llama.py | 194 +++++++++++------- 2 files changed, 129 insertions(+), 75 deletions(-) diff --git a/ppdiffusers/examples/inference/class_conditional_image_generation-large_dit_7b.py b/ppdiffusers/examples/inference/class_conditional_image_generation-large_dit_7b.py index 170ecfe4b..fa9b40365 100644 --- a/ppdiffusers/examples/inference/class_conditional_image_generation-large_dit_7b.py +++ b/ppdiffusers/examples/inference/class_conditional_image_generation-large_dit_7b.py @@ -11,13 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import os +os.environ["CUDA_VISIBLE_DEVICES"] = "4" import paddle from paddlenlp.trainer import set_seed from ppdiffusers import DDIMScheduler, DiTPipeline -dtype = paddle.float32 +dtype = paddle.float16 + +# To speed up this code, call zkk and let him run for you, +# then you will get a speed increase of almost 100%. +os.environ['callZKK']= "True" + with paddle.LazyGuard(): pipe = DiTPipeline.from_pretrained("Alpha-VLLM/Large-DiT-7B-256", paddle_dtype=dtype) pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) diff --git a/ppdiffusers/ppdiffusers/models/dit_llama.py b/ppdiffusers/ppdiffusers/models/dit_llama.py index 575ad1e22..44bd1e32a 100644 --- a/ppdiffusers/ppdiffusers/models/dit_llama.py +++ b/ppdiffusers/ppdiffusers/models/dit_llama.py @@ -27,6 +27,7 @@ from paddle.framework import LayerHelper, in_dynamic_mode from paddle.incubate.tt import adaptive_layer_norm, fused_adaLN_scale_residual +import os def TypePromote(x, y): TYPE_PROMOTE_DICT = { @@ -92,7 +93,7 @@ def forward(self, t): class Attention(nn.Layer): - def __init__(self, dim, n_heads, n_kv_heads, qk_norm=True, fused_attn=True): + def __init__(self, dim, n_heads, n_kv_heads, qk_norm=True, fused_attn=True, callZKK=False): """ Initialize the Attention module. @@ -107,6 +108,9 @@ def __init__(self, dim, n_heads, n_kv_heads, qk_norm=True, fused_attn=True): n_local_kv_heads (int): Number of local key and value heads. n_rep (int): Number of repetitions for local heads. head_dim (int): Dimension size of each attention head. + wq (nn.Linear): Linear transformation for queries. + wk (nn.Linear): Linear transformation for keys. + wv (nn.Linear): Linear transformation for values. qkv (nn.Linear): Linear transformation for queries, keys and values. wo (nn.Linear): Linear transformation for output. cache_k (paddle.Tensor): Cached keys for attention. @@ -120,7 +124,14 @@ def __init__(self, dim, n_heads, n_kv_heads, qk_norm=True, fused_attn=True): self.n_rep = self.n_local_heads // self.n_local_kv_heads self.head_dim = dim // n_heads - self.qkv = nn.Linear(dim, (n_heads + 2 * self.n_kv_heads) * self.head_dim, bias_attr=False) + self.callZKK = callZKK + if not callZKK: + self.wq = nn.Linear(dim, n_heads * self.head_dim, bias_attr=False) + self.wk = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias_attr=False) + self.wv = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias_attr=False) + else: + self.qkv = nn.Linear(dim, (n_heads + 2 * self.n_kv_heads) * self.head_dim, bias_attr=False) + self.wo = nn.Linear(n_heads * self.head_dim, dim, bias_attr=False) if qk_norm: @@ -182,13 +193,21 @@ def apply_rotary_emb(xq, xk, freqs_cis): Tuple[paddle.Tensor, paddle.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. """ - # with paddle.amp.auto_cast(enable=False): - xq_ = paddle.as_complex(xq.cast("float32").reshape([*tuple(xq.shape)[:-1], -1, 2])) - xk_ = paddle.as_complex(xk.cast("float32").reshape([*tuple(xk.shape)[:-1], -1, 2])) - freqs_cis = Attention.reshape_for_broadcast(freqs_cis, xq_) - xq_out = paddle.as_real(xq_ * freqs_cis).flatten(start_axis=3) - xk_out = paddle.as_real(xk_ * freqs_cis).flatten(start_axis=3) - return xq_out.cast(xq.dtype), xk_out.cast(xk.dtype) + if not os.getenv('callZKK'): + with paddle.amp.auto_cast(enable=False): + xq_ = paddle.as_complex(xq.cast("float32").reshape([*tuple(xq.shape)[:-1], -1, 2])) + xk_ = paddle.as_complex(xk.cast("float32").reshape([*tuple(xk.shape)[:-1], -1, 2])) + freqs_cis = Attention.reshape_for_broadcast(freqs_cis, xq_) + xq_out = paddle.as_real(xq_ * freqs_cis).flatten(start_axis=3) + xk_out = paddle.as_real(xk_ * freqs_cis).flatten(start_axis=3) + return xq_out.cast(xq.dtype), xk_out.cast(xk.dtype) + else: + xq_ = paddle.as_complex(xq.cast("float32").reshape([*tuple(xq.shape)[:-1], -1, 2])) + xk_ = paddle.as_complex(xk.cast("float32").reshape([*tuple(xk.shape)[:-1], -1, 2])) + freqs_cis = Attention.reshape_for_broadcast(freqs_cis, xq_) + xq_out = paddle.as_real(xq_ * freqs_cis).flatten(start_axis=3) + xk_out = paddle.as_real(xk_ * freqs_cis).flatten(start_axis=3) + return xq_out.cast(xq.dtype), xk_out.cast(xk.dtype) def forward(self, x, freqs_cis): """ @@ -204,8 +223,11 @@ def forward(self, x, freqs_cis): """ bsz, seqlen, _ = tuple(x.shape) - qkv_out = self.qkv(x) - xq, xk, xv = paddle.split(qkv_out, 3, axis=-1) + if not self.callZKK: + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + else: + qkv_out = self.qkv(x) + xq, xk, xv = paddle.split(qkv_out, 3, axis=-1) dtype = xq.dtype @@ -254,7 +276,7 @@ def forward(self, x, freqs_cis): class FeedForward(nn.Layer): - def __init__(self, dim, hidden_dim, multiple_of=256, ffn_dim_multiplier=None): + def __init__(self, dim, hidden_dim, multiple_of=256, ffn_dim_multiplier=None, callZKK=False): """ Initialize the FeedForward module. @@ -267,9 +289,11 @@ def __init__(self, dim, hidden_dim, multiple_of=256, ffn_dim_multiplier=None): dimension. Defaults to None. Attributes: - w13 (nn.Linear): Linear transformation for the first layer and the third layer. - w2 (nn.Linear): Linear transformation for the second layer. - + w1 (nn.Linear): Linear transformation for the first layer. + w2 (nn.Linear): Linear transformation for the second layer. + w3 (nn.Linear): Linear transformation for the third layer. + w13 (nn.Linear): Linear transformation for the first and the third layer. + """ super().__init__() hidden_dim = int(2 * hidden_dim / 3) @@ -277,8 +301,13 @@ def __init__(self, dim, hidden_dim, multiple_of=256, ffn_dim_multiplier=None): hidden_dim = int(ffn_dim_multiplier * hidden_dim) hidden_dim = int(multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)) - self.w13 = nn.Linear(dim, hidden_dim * 2, bias_attr=False) self.w2 = nn.Linear(hidden_dim, dim, bias_attr=False) + self.callZKK = callZKK + if not callZKK: + self.w1 = nn.Linear(dim, hidden_dim, bias_attr=False) + self.w3 = nn.Linear(dim, hidden_dim, bias_attr=False) + else: + self.w13 = nn.Linear(dim, hidden_dim * 2, bias_attr=False) def compute_activation(self, ffn1_out, @@ -331,12 +360,17 @@ def compute_activation(self, ) return out.reshape([origin_batch_size, origin_seq_len, out.shape[-1]]) - def forward(self, x): - ffn1_out = self.w13(x) - ffn1_out = self.compute_activation(ffn1_out) - ffn2_out = self.w2(ffn1_out) - return ffn2_out + if not self.callZKK: + xw1 = F.silu(self.w1(x)) + xw3 = self.w3(x) + output = self.w2(xw1 * xw3) + return output + else: + ffn1_out = self.w13(x) + ffn1_out = self.compute_activation(ffn1_out) + ffn2_out = self.w2(ffn1_out) + return ffn2_out class TransformerBlock(nn.Layer): @@ -352,6 +386,7 @@ def __init__( norm_eps: float, qk_norm: bool, fused_attn: bool, + callZKK=False, ) -> None: """ Initialize a TransformerBlock. @@ -386,10 +421,11 @@ def __init__( super().__init__() self.dim = dim self.head_dim = dim // n_heads - self.attention = Attention(dim, n_heads, n_kv_heads, qk_norm, fused_attn) + self.attention = Attention(dim, n_heads, n_kv_heads, qk_norm, fused_attn, callZKK=callZKK) mlp_hidden_dim = int(dim * mlp_ratio) self.feed_forward = FeedForward( - dim=dim, hidden_dim=mlp_hidden_dim, multiple_of=multiple_of, ffn_dim_multiplier=ffn_dim_multiplier + dim=dim, hidden_dim=mlp_hidden_dim, multiple_of=multiple_of, + ffn_dim_multiplier=ffn_dim_multiplier, callZKK=callZKK ) self.layer_id = layer_id self.attention_norm = nn.LayerNorm(dim, epsilon=norm_eps, bias_attr=False) @@ -400,6 +436,7 @@ def __init__( nn.Linear(min(dim, 1024), 6 * dim), ) self.norm_eps = norm_eps + self.callZKK = callZKK def forward(self, x, freqs_cis, adaln_input=None): """ @@ -420,12 +457,17 @@ def forward(self, x, freqs_cis, adaln_input=None): shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk( 6, axis=1 ) - attention_out = self.attention( - adaptive_layer_norm(x, scale_msa, shift_msa, weight=self.attention_norm.weight, epsilon=self.norm_eps), freqs_cis - ) - residual_out, adaLN_out = fused_adaLN_scale_residual(x, attention_out, gate_msa, scale_mlp, shift_mlp, - weight=self.ffn_norm.weight, epsilon=self.norm_eps) - out = residual_out + gate_mlp.unsqueeze(1) * self.feed_forward(adaLN_out) + if not self.callZKK: + h = x + gate_msa.unsqueeze(1) * self.attention( + modulate(self.attention_norm(x), shift_msa, scale_msa), freqs_cis + ) + out = h + gate_mlp.unsqueeze(1) * self.feed_forward(modulate(self.ffn_norm(h), shift_mlp, scale_mlp)) + else: + attention_out = self.attention(adaptive_layer_norm(x, scale_msa, shift_msa, + weight=self.attention_norm.weight, epsilon=self.norm_eps), freqs_cis) + residual_out, adaLN_out = fused_adaLN_scale_residual(x, attention_out, gate_msa, scale_mlp, shift_mlp, + weight=self.ffn_norm.weight, epsilon=self.norm_eps) + out = residual_out + gate_mlp.unsqueeze(1) * self.feed_forward(adaLN_out) else: h = x + self.attention(self.attention_norm(x), freqs_cis) out = h + self.feed_forward(self.ffn_norm(h)) @@ -487,7 +529,6 @@ def __init__( self.num_classes = num_classes self.learn_sigma = learn_sigma self.qk_norm = qk_norm - self.gradient_checkpointing = True self.fused_attn = True @@ -495,6 +536,7 @@ def __init__( self.t_embedder = TimestepEmbedder(min(dim, 1024)) self.y_embedder = LabelEmbedding(num_classes, min(dim, 1024), class_dropout_prob) + self.callZKK = True if os.getenv('callZKK') else False # 2. Define transformers blocks self.layers = nn.LayerList( [ @@ -509,11 +551,13 @@ def __init__( norm_eps=norm_eps, qk_norm=qk_norm, fused_attn=self.fused_attn, + callZKK=self.callZKK, ) for idx in range(num_layers) ] ) - # del self.layers + + del self.layers # 3. Define output layers self.final_layer = FinalLayer(dim, patch_size, self.out_channels) @@ -584,20 +628,20 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): ) return freqs_cis - # @paddle.incubate.layers.inference(with_trt=False, - # cache_static_model=True, - # collect_shape=False) - # def transformer_blocks(self, x, adaln_input): - # for i, layer in enumerate(self.layers): - # if self.gradient_checkpointing and False: - # x = paddle.distributed.fleet.utils.recompute(layer, x, self.freqs_cis[: x.shape[1]], adaln_input) - # else: - # x = layer( - # x, - # self.freqs_cis[: x.shape[1]], - # adaln_input, - # ) - # return x + @paddle.incubate.layers.inference(with_trt=False, + cache_static_model=True, + collect_shape=False) + def transformer_blocks(self, x, adaln_input): + for i, layer in enumerate(self.layers): + if self.gradient_checkpointing and False: + x = paddle.distributed.fleet.utils.recompute(layer, x, self.freqs_cis[: x.shape[1]], adaln_input) + else: + x = layer( + x, + self.freqs_cis[: x.shape[1]], + adaln_input, + ) + return x def forward( self, @@ -624,17 +668,19 @@ def forward( adaln_input = t + y # 2. Blocks - # x = self.transformer_blocks(x, adaln_input) - for i, layer in enumerate(self.layers): - if self.gradient_checkpointing: - x = paddle.distributed.fleet.utils.recompute(layer, x, self.freqs_cis[: x.shape[1]], adaln_input) - else: - x = layer( - x, - self.freqs_cis[: x.shape[1]], - adaln_input, - ) - + if not self.callZKK: + for i, layer in enumerate(self.layers): + if self.gradient_checkpointing: + x = paddle.distributed.fleet.utils.recompute(layer, x, self.freqs_cis[: x.shape[1]], adaln_input) + else: + x = layer( + x, + self.freqs_cis[: x.shape[1]], + adaln_input, + ) + else: + x = self.transformer_blocks(x, adaln_input) + # 3. Output hidden_states = self.final_layer(x, adaln_input) output = self.unpatchify(hidden_states) @@ -646,20 +692,22 @@ def forward( @classmethod def custom_modify_weight(cls, state_dict): - for key in list(state_dict.keys()): - if 'feed_forward.w1.weight' in key: - w1 = state_dict.pop(key) - w3_key = key.replace('w1', 'w3') - w3 = state_dict.pop(w3_key) - w13 = paddle.concat([w1, w3], axis=1) - state_dict[key.replace('w1', 'w13')] = w13 - if 'attention.wq.weight' in key or 'attention.wk.weight' in key or 'attention.wv.weight' in key: - part = key.split('.')[-2] - layer_id = key.split('.')[1] - qkv_key = f'layers.{layer_id}.attention.qkv.weight' - if part == 'wq' and qkv_key not in state_dict: - state_dict[qkv_key] = state_dict.pop(key) - elif part in ('wk', 'wv'): - qkv = state_dict.get(qkv_key) - if qkv is not None: - state_dict[qkv_key] = paddle.concat([qkv, state_dict.pop(key)], axis=1) + # If you're not invited to zkk, you won't get any performance optimizations. + if os.getenv('callZKK'): + for key in list(state_dict.keys()): + if 'feed_forward.w1.weight' in key: + w1 = state_dict.pop(key) + w3_key = key.replace('w1', 'w3') + w3 = state_dict.pop(w3_key) + w13 = paddle.concat([w1, w3], axis=1) + state_dict[key.replace('w1', 'w13')] = w13 + if 'attention.wq.weight' in key or 'attention.wk.weight' in key or 'attention.wv.weight' in key: + part = key.split('.')[-2] + layer_id = key.split('.')[1] + qkv_key = f'layers.{layer_id}.attention.qkv.weight' + if part == 'wq' and qkv_key not in state_dict: + state_dict[qkv_key] = state_dict.pop(key) + elif part in ('wk', 'wv'): + qkv = state_dict.get(qkv_key) + if qkv is not None: + state_dict[qkv_key] = paddle.concat([qkv, state_dict.pop(key)], axis=1) From 1cc8ca47bbdc1acd1862844ab3d97faa343c4bfb Mon Sep 17 00:00:00 2001 From: YKTian-x2b <2084984251@qq.com> Date: Mon, 22 Jul 2024 09:08:34 +0000 Subject: [PATCH 07/21] no reshape --- .../deploy/controlnet/infer_dygraph.py | 1 + .../class_conditional_image_generation-dit.py | 6 +- ...nditional_image_generation-large_dit_3b.py | 31 +++++- ...nditional_image_generation-large_dit_7b.py | 5 +- ppdiffusers/ppdiffusers/models/attention.py | 18 +++- .../ppdiffusers/models/attention_processor.py | 2 +- ppdiffusers/ppdiffusers/models/dit_llama.py | 13 +-- .../ppdiffusers/models/normalization.py | 2 +- .../ppdiffusers/models/transformer_2d.py | 99 +++++++++++++------ .../ppdiffusers/models/unet_2d_condition.py | 47 +++++++-- .../ppdiffusers/pipelines/dit/pipeline_dit.py | 28 ++++++ 11 files changed, 190 insertions(+), 62 deletions(-) diff --git a/ppdiffusers/deploy/controlnet/infer_dygraph.py b/ppdiffusers/deploy/controlnet/infer_dygraph.py index 1c8f22b75..f9672679e 100644 --- a/ppdiffusers/deploy/controlnet/infer_dygraph.py +++ b/ppdiffusers/deploy/controlnet/infer_dygraph.py @@ -168,6 +168,7 @@ def main(args): ), ) pipe.set_progress_bar_config(disable=False) + breakpoint() pipe.change_scheduler(args.scheduler) parse_prompt_type = args.parse_prompt_type diff --git a/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py b/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py index 71d73ec0d..3985d42a4 100644 --- a/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py +++ b/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py @@ -11,13 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import os +os.environ["CUDA_VISIBLE_DEVICES"] = "6" import paddle from paddlenlp.trainer import set_seed from ppdiffusers import DDIMScheduler, DiTPipeline -dtype = paddle.float32 +dtype = paddle.float16 pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256", paddle_dtype=dtype) pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) set_seed(42) @@ -25,6 +26,5 @@ words = ["golden retriever"] # class_ids [207] class_ids = pipe.get_label_ids(words) - image = pipe(class_labels=class_ids, num_inference_steps=25).images[0] image.save("class_conditional_image_generation-dit-result.png") diff --git a/ppdiffusers/examples/inference/class_conditional_image_generation-large_dit_3b.py b/ppdiffusers/examples/inference/class_conditional_image_generation-large_dit_3b.py index 93f02a955..b7f552688 100644 --- a/ppdiffusers/examples/inference/class_conditional_image_generation-large_dit_3b.py +++ b/ppdiffusers/examples/inference/class_conditional_image_generation-large_dit_3b.py @@ -11,13 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import os +os.environ["CUDA_VISIBLE_DEVICES"] = "6" import paddle from paddlenlp.trainer import set_seed +import datetime from ppdiffusers import DDIMScheduler, DiTPipeline -dtype = paddle.float32 +dtype = paddle.bfloat16 + +os.environ['callZKK']= "True" + with paddle.LazyGuard(): pipe = DiTPipeline.from_pretrained("Alpha-VLLM/Large-DiT-3B-256", paddle_dtype=dtype) pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) @@ -26,5 +31,25 @@ words = ["golden retriever"] # class_ids [207] class_ids = pipe.get_label_ids(words) -image = pipe(class_labels=class_ids, num_inference_steps=25).images[0] + +# for kkk in range(3): +# image = pipe(class_labels=class_ids, num_inference_steps=25).images[0] + + +paddle.device.cuda.synchronize(0) +starttime = datetime.datetime.now() + +for kk in range(1): + image = pipe(class_labels=class_ids, num_inference_steps=25).images[0] + +paddle.device.cuda.synchronize(0) +endtime = datetime.datetime.now() +duringtime = endtime-starttime +time_ms = duringtime.seconds * 1000 + duringtime.microseconds / 1000.0 +msg = "total_time_cost: " + str(time_ms/5) + "ms\n\n" +print(msg) +with open("/tyk/PaddleMIX/ppdiffusers/examples/inference/kai/res/time_3B_722.txt", "a") as time_file: + time_file.write(msg) + + image.save("class_conditional_image_generation-large_dit_3b-result.png") diff --git a/ppdiffusers/examples/inference/class_conditional_image_generation-large_dit_7b.py b/ppdiffusers/examples/inference/class_conditional_image_generation-large_dit_7b.py index fa9b40365..9be81375b 100644 --- a/ppdiffusers/examples/inference/class_conditional_image_generation-large_dit_7b.py +++ b/ppdiffusers/examples/inference/class_conditional_image_generation-large_dit_7b.py @@ -11,14 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os -os.environ["CUDA_VISIBLE_DEVICES"] = "4" + import paddle from paddlenlp.trainer import set_seed from ppdiffusers import DDIMScheduler, DiTPipeline -dtype = paddle.float16 +dtype = paddle.bfloat16 # To speed up this code, call zkk and let him run for you, # then you will get a speed increase of almost 100%. diff --git a/ppdiffusers/ppdiffusers/models/attention.py b/ppdiffusers/ppdiffusers/models/attention.py index 6612b790f..8abf7cb10 100644 --- a/ppdiffusers/ppdiffusers/models/attention.py +++ b/ppdiffusers/ppdiffusers/models/attention.py @@ -24,6 +24,7 @@ from .lora import LoRACompatibleLinear from .normalization import AdaLayerNorm, AdaLayerNormZero +from paddle.incubate.tt import adaptive_layer_norm, rms_norm def _chunked_feed_forward( ff: nn.Layer, hidden_states: paddle.Tensor, chunk_dim: int, chunk_size: int, lora_scale: Optional[float] = None @@ -152,6 +153,8 @@ def __init__( super().__init__() self.only_cross_attention = only_cross_attention + self.norm_eps = norm_eps + self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" self.use_ada_layer_norm_single = norm_type == "ada_norm_single" @@ -337,12 +340,17 @@ def forward( ) hidden_states = attn_output + hidden_states - # 4. Feed-forward - if not self.use_ada_layer_norm_single: - norm_hidden_states = self.norm3(hidden_states) + # 4. Feed-forwards + # hidden_states_kai = paddle.clone(hidden_states) + + # if not self.use_ada_layer_norm_single: + # norm_hidden_states = self.norm3(hidden_states) + + # if self.use_ada_layer_norm_zero: + # norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + norm_hidden_states = adaptive_layer_norm(hidden_states, scale_mlp, shift_mlp, epsilon=self.norm_eps) - if self.use_ada_layer_norm_zero: - norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] if self.use_ada_layer_norm_single: norm_hidden_states = self.norm2(hidden_states) diff --git a/ppdiffusers/ppdiffusers/models/attention_processor.py b/ppdiffusers/ppdiffusers/models/attention_processor.py index b8c82dafa..011c915c1 100644 --- a/ppdiffusers/ppdiffusers/models/attention_processor.py +++ b/ppdiffusers/ppdiffusers/models/attention_processor.py @@ -981,7 +981,7 @@ def __call__( return hidden_states - +# kai: this class XFormersAttnProcessor: r""" Processor for implementing memory efficient attention using xFormers. diff --git a/ppdiffusers/ppdiffusers/models/dit_llama.py b/ppdiffusers/ppdiffusers/models/dit_llama.py index 44bd1e32a..9cb49b343 100644 --- a/ppdiffusers/ppdiffusers/models/dit_llama.py +++ b/ppdiffusers/ppdiffusers/models/dit_llama.py @@ -321,9 +321,6 @@ def compute_activation(self, quant_round_type=0, quant_max_bound=0, quant_min_bound=0): - origin_batch_size = ffn1_out.shape[0] - origin_seq_len = ffn1_out.shape[1] - ffn1_out = ffn1_out.reshape([origin_batch_size*origin_seq_len, ffn1_out.shape[-1]]) if in_dynamic_mode(): out = paddle._C_ops.fused_bias_act( ffn1_out, @@ -338,7 +335,7 @@ def compute_activation(self, quant_max_bound, quant_min_bound ) - return out.reshape([origin_batch_size, origin_seq_len, out.shape[-1]]) + return out helper = LayerHelper("fused_bias_act") out = helper.create_variable_for_type_inference(dtype=ffn1_out.dtype) @@ -358,7 +355,7 @@ def compute_activation(self, outputs={"out": out}, attrs=attrs, ) - return out.reshape([origin_batch_size, origin_seq_len, out.shape[-1]]) + return out def forward(self, x): if not self.callZKK: @@ -557,7 +554,7 @@ def __init__( ] ) - del self.layers + # del self.layers # 3. Define output layers self.final_layer = FinalLayer(dim, patch_size, self.out_channels) @@ -628,8 +625,8 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): ) return freqs_cis - @paddle.incubate.layers.inference(with_trt=False, - cache_static_model=True, + @paddle.jit.to_static(backend="inference", with_trt=False, + cache_static_model=False, collect_shape=False) def transformer_blocks(self, x, adaln_input): for i, layer in enumerate(self.layers): diff --git a/ppdiffusers/ppdiffusers/models/normalization.py b/ppdiffusers/ppdiffusers/models/normalization.py index 15e111ab0..8e2c42113 100644 --- a/ppdiffusers/ppdiffusers/models/normalization.py +++ b/ppdiffusers/ppdiffusers/models/normalization.py @@ -45,7 +45,7 @@ def forward(self, x: paddle.Tensor, timestep: paddle.Tensor) -> paddle.Tensor: x = self.norm(x) * (1 + scale) + shift return x - +# this class AdaLayerNormZero(nn.Layer): r""" Norm layer adaptive layer norm zero (adaLN-Zero). diff --git a/ppdiffusers/ppdiffusers/models/transformer_2d.py b/ppdiffusers/ppdiffusers/models/transformer_2d.py index 3a1084d03..2a92a8310 100644 --- a/ppdiffusers/ppdiffusers/models/transformer_2d.py +++ b/ppdiffusers/ppdiffusers/models/transformer_2d.py @@ -239,7 +239,7 @@ def __init__( # 5. PixArt-Alpha blocks. self.adaln_single = None self.use_additional_conditions = False - if norm_type == "ada_norm_single": + if norm_type == "ada_norm_single": # kai: ada_norm_zero self.use_additional_conditions = self.config.sample_size == 128 # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use # additional conditions until we find better name @@ -251,10 +251,35 @@ def __init__( self.gradient_checkpointing = False + # del self.transformer_blocks + def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value + @paddle.jit.to_static(backend="inference", with_trt=False, + cache_static_model=False, + collect_shape=False) + def deco_transformer_blocks(self, hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + timestep, + cross_attention_kwargs, + class_labels): + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + return hidden_states + + def forward( self, hidden_states: paddle.Tensor, @@ -385,40 +410,48 @@ def forward( encoder_hidden_states = self.caption_projection(encoder_hidden_states) encoder_hidden_states = encoder_hidden_states.reshape([batch_size, -1, hidden_states.shape[-1]]) - for block in self.transformer_blocks: - if self.gradient_checkpointing and not hidden_states.stop_gradient and not use_old_recompute(): - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs = {} if recompute_use_reentrant() else {"use_reentrant": False} - hidden_states = recompute( - create_custom_forward(block), - hidden_states, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - timestep, - cross_attention_kwargs, - class_labels, - **ckpt_kwargs, - ) - else: - hidden_states = block( - hidden_states, + hidden_states = self.deco_transformer_blocks(hidden_states, attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, timestep=timestep, cross_attention_kwargs=cross_attention_kwargs, - class_labels=class_labels, - ) + class_labels=class_labels) + r""" + # for block in self.transformer_blocks: + # if self.gradient_checkpointing and not hidden_states.stop_gradient and not use_old_recompute(): + # def create_custom_forward(module, return_dict=None): + # def custom_forward(*inputs): + # if return_dict is not None: + # return module(*inputs, return_dict=return_dict) + # else: + # return module(*inputs) + + # return custom_forward + + # ckpt_kwargs = {} if recompute_use_reentrant() else {"use_reentrant": False} + # hidden_states = recompute( + # create_custom_forward(block), + # hidden_states, + # attention_mask, + # encoder_hidden_states, + # encoder_attention_mask, + # timestep, + # cross_attention_kwargs, + # class_labels, + # **ckpt_kwargs, + # ) + # else: + # hidden_states = block( + # hidden_states, + # attention_mask=attention_mask, + # encoder_hidden_states=encoder_hidden_states, + # encoder_attention_mask=encoder_attention_mask, + # timestep=timestep, + # cross_attention_kwargs=cross_attention_kwargs, + # class_labels=class_labels, + # ) + """ # 3. Output if self.is_input_continuous: @@ -482,3 +515,9 @@ def custom_forward(*inputs): return (output,) return Transformer2DModelOutput(sample=output) + + + @classmethod + def custom_modify_weight(cls, state_dict): + print("state_dict", state_dict.keys()) + exit() \ No newline at end of file diff --git a/ppdiffusers/ppdiffusers/models/unet_2d_condition.py b/ppdiffusers/ppdiffusers/models/unet_2d_condition.py index 65d95bcaa..9e5fd5e8e 100644 --- a/ppdiffusers/ppdiffusers/models/unet_2d_condition.py +++ b/ppdiffusers/ppdiffusers/models/unet_2d_condition.py @@ -629,6 +629,7 @@ def __init__( self.position_net = PositionNet( positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type ) + #del self.down_blocks @property def attn_processors(self) -> Dict[str, AttentionProcessor]: @@ -905,6 +906,43 @@ def forward( logger.info("Forward upsample size to force interpolation output size.") forward_upsample_size = True + timestep = timestep.reshape([1]) + + self.tmp_down_intrablock_additional_residuals = down_intrablock_additional_residuals + self.tmp_cross_attention_kwargs = cross_attention_kwargs + self.tmp_forward_upsample_size = forward_upsample_size + self.tmp_upsample_size = upsample_size + + sample = self.run_down_mid_up_blocks(sample, encoder_hidden_states, timestep, timestep_cond, + attention_mask, + encoder_attention_mask, + mid_block_additional_residual, + down_block_additional_residuals) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (sample,) + + return UNet2DConditionOutput(sample=sample) + + @paddle.jit.to_static(backend="paddle_inference", cache_static_model=False, + with_trt=True, + switch_ir_optim=True, + precision_mode="float16", + switch_ir_debug=False) + def run_down_mid_up_blocks(self, sample, encoder_hidden_states, timestep, timestep_cond, + attention_mask, + encoder_attention_mask, + mid_block_additional_residual, down_block_additional_residuals): + + down_intrablock_additional_residuals = self.tmp_down_intrablock_additional_residuals + cross_attention_kwargs = self.tmp_cross_attention_kwargs + forward_upsample_size = self.tmp_forward_upsample_size + upsample_size = self.tmp_upsample_size + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension # expects mask of shape: # [batch, key_tokens] @@ -1197,11 +1235,4 @@ def forward( if self.data_format == "NHWC": sample = sample.transpose([0, 3, 1, 2]) - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - - if not return_dict: - return (sample,) - - return UNet2DConditionOutput(sample=sample) + return sample diff --git a/ppdiffusers/ppdiffusers/pipelines/dit/pipeline_dit.py b/ppdiffusers/ppdiffusers/pipelines/dit/pipeline_dit.py index 8654bbc2e..e7f47e8b1 100644 --- a/ppdiffusers/ppdiffusers/pipelines/dit/pipeline_dit.py +++ b/ppdiffusers/ppdiffusers/pipelines/dit/pipeline_dit.py @@ -27,6 +27,7 @@ from ...utils.paddle_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +import datetime class DiTPipeline(DiffusionPipeline): r""" @@ -89,6 +90,18 @@ def get_label_ids(self, label: Union[str, List[str]]) -> List[int]: return [self.labels[l] for l in label] + def warm_up_transformer(self, latent_model_input, class_labels_input): + timesteps = 0 + half = latent_model_input[: len(latent_model_input) // 2] + latent_model_input = paddle.concat([half, half], axis=0) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, timesteps) + + timesteps = paddle.to_tensor([timesteps], dtype=paddle.int64) + timesteps = timesteps.expand([latent_model_input.shape[0],]) + self.transformer( + latent_model_input, timestep=timesteps, class_labels=class_labels_input + ) + @paddle.no_grad() def __call__( self, @@ -168,6 +181,12 @@ def __call__( # set step values self.scheduler.set_timesteps(num_inference_steps) + + # self.warm_up_transformer(latent_model_input, class_labels_input) + # print("\n--------------------- warm up end\n") + # paddle.device.cuda.synchronize(0) + # starttime = datetime.datetime.now() + for t in self.progress_bar(self.scheduler.timesteps): if guidance_scale > 1: half = latent_model_input[: len(latent_model_input) // 2] @@ -216,6 +235,15 @@ def __call__( # compute previous image: x_t -> x_t-1 latent_model_input = self.scheduler.step(model_output, t, latent_model_input).prev_sample + # paddle.device.cuda.synchronize(0) + # endtime = datetime.datetime.now() + # duringtime = endtime-starttime + # time_ms = duringtime.seconds * 1000 + duringtime.microseconds / 1000.0 + # msg = "total_time_cost: " + str(time_ms) + "ms\n\n" + # print(msg) + # with open("/tyk/PaddleMIX/ppdiffusers/examples/inference/kai/res/time_719.txt", "a") as time_file: + # time_file.write(msg) + if guidance_scale > 1: latents, _ = latent_model_input.chunk(2, axis=0) else: From 3251f134777d44f3385ac67bdd367c5c247a39c8 Mon Sep 17 00:00:00 2001 From: YKTian-x2b <2084984251@qq.com> Date: Mon, 22 Jul 2024 09:10:25 +0000 Subject: [PATCH 08/21] Revert "no reshape" This reverts commit 1cc8ca47bbdc1acd1862844ab3d97faa343c4bfb. --- .../deploy/controlnet/infer_dygraph.py | 1 - .../class_conditional_image_generation-dit.py | 6 +- ...nditional_image_generation-large_dit_3b.py | 31 +----- ...nditional_image_generation-large_dit_7b.py | 5 +- ppdiffusers/ppdiffusers/models/attention.py | 18 +--- .../ppdiffusers/models/attention_processor.py | 2 +- ppdiffusers/ppdiffusers/models/dit_llama.py | 13 ++- .../ppdiffusers/models/normalization.py | 2 +- .../ppdiffusers/models/transformer_2d.py | 99 ++++++------------- .../ppdiffusers/models/unet_2d_condition.py | 47 ++------- .../ppdiffusers/pipelines/dit/pipeline_dit.py | 28 ------ 11 files changed, 62 insertions(+), 190 deletions(-) diff --git a/ppdiffusers/deploy/controlnet/infer_dygraph.py b/ppdiffusers/deploy/controlnet/infer_dygraph.py index f9672679e..1c8f22b75 100644 --- a/ppdiffusers/deploy/controlnet/infer_dygraph.py +++ b/ppdiffusers/deploy/controlnet/infer_dygraph.py @@ -168,7 +168,6 @@ def main(args): ), ) pipe.set_progress_bar_config(disable=False) - breakpoint() pipe.change_scheduler(args.scheduler) parse_prompt_type = args.parse_prompt_type diff --git a/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py b/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py index 3985d42a4..71d73ec0d 100644 --- a/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py +++ b/ppdiffusers/examples/inference/class_conditional_image_generation-dit.py @@ -11,14 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os -os.environ["CUDA_VISIBLE_DEVICES"] = "6" + import paddle from paddlenlp.trainer import set_seed from ppdiffusers import DDIMScheduler, DiTPipeline -dtype = paddle.float16 +dtype = paddle.float32 pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256", paddle_dtype=dtype) pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) set_seed(42) @@ -26,5 +25,6 @@ words = ["golden retriever"] # class_ids [207] class_ids = pipe.get_label_ids(words) + image = pipe(class_labels=class_ids, num_inference_steps=25).images[0] image.save("class_conditional_image_generation-dit-result.png") diff --git a/ppdiffusers/examples/inference/class_conditional_image_generation-large_dit_3b.py b/ppdiffusers/examples/inference/class_conditional_image_generation-large_dit_3b.py index b7f552688..93f02a955 100644 --- a/ppdiffusers/examples/inference/class_conditional_image_generation-large_dit_3b.py +++ b/ppdiffusers/examples/inference/class_conditional_image_generation-large_dit_3b.py @@ -11,18 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os -os.environ["CUDA_VISIBLE_DEVICES"] = "6" + import paddle from paddlenlp.trainer import set_seed -import datetime from ppdiffusers import DDIMScheduler, DiTPipeline -dtype = paddle.bfloat16 - -os.environ['callZKK']= "True" - +dtype = paddle.float32 with paddle.LazyGuard(): pipe = DiTPipeline.from_pretrained("Alpha-VLLM/Large-DiT-3B-256", paddle_dtype=dtype) pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) @@ -31,25 +26,5 @@ words = ["golden retriever"] # class_ids [207] class_ids = pipe.get_label_ids(words) - -# for kkk in range(3): -# image = pipe(class_labels=class_ids, num_inference_steps=25).images[0] - - -paddle.device.cuda.synchronize(0) -starttime = datetime.datetime.now() - -for kk in range(1): - image = pipe(class_labels=class_ids, num_inference_steps=25).images[0] - -paddle.device.cuda.synchronize(0) -endtime = datetime.datetime.now() -duringtime = endtime-starttime -time_ms = duringtime.seconds * 1000 + duringtime.microseconds / 1000.0 -msg = "total_time_cost: " + str(time_ms/5) + "ms\n\n" -print(msg) -with open("/tyk/PaddleMIX/ppdiffusers/examples/inference/kai/res/time_3B_722.txt", "a") as time_file: - time_file.write(msg) - - +image = pipe(class_labels=class_ids, num_inference_steps=25).images[0] image.save("class_conditional_image_generation-large_dit_3b-result.png") diff --git a/ppdiffusers/examples/inference/class_conditional_image_generation-large_dit_7b.py b/ppdiffusers/examples/inference/class_conditional_image_generation-large_dit_7b.py index 9be81375b..fa9b40365 100644 --- a/ppdiffusers/examples/inference/class_conditional_image_generation-large_dit_7b.py +++ b/ppdiffusers/examples/inference/class_conditional_image_generation-large_dit_7b.py @@ -11,13 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import os +os.environ["CUDA_VISIBLE_DEVICES"] = "4" import paddle from paddlenlp.trainer import set_seed from ppdiffusers import DDIMScheduler, DiTPipeline -dtype = paddle.bfloat16 +dtype = paddle.float16 # To speed up this code, call zkk and let him run for you, # then you will get a speed increase of almost 100%. diff --git a/ppdiffusers/ppdiffusers/models/attention.py b/ppdiffusers/ppdiffusers/models/attention.py index 8abf7cb10..6612b790f 100644 --- a/ppdiffusers/ppdiffusers/models/attention.py +++ b/ppdiffusers/ppdiffusers/models/attention.py @@ -24,7 +24,6 @@ from .lora import LoRACompatibleLinear from .normalization import AdaLayerNorm, AdaLayerNormZero -from paddle.incubate.tt import adaptive_layer_norm, rms_norm def _chunked_feed_forward( ff: nn.Layer, hidden_states: paddle.Tensor, chunk_dim: int, chunk_size: int, lora_scale: Optional[float] = None @@ -153,8 +152,6 @@ def __init__( super().__init__() self.only_cross_attention = only_cross_attention - self.norm_eps = norm_eps - self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" self.use_ada_layer_norm_single = norm_type == "ada_norm_single" @@ -340,17 +337,12 @@ def forward( ) hidden_states = attn_output + hidden_states - # 4. Feed-forwards - # hidden_states_kai = paddle.clone(hidden_states) - - # if not self.use_ada_layer_norm_single: - # norm_hidden_states = self.norm3(hidden_states) - - # if self.use_ada_layer_norm_zero: - # norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] - - norm_hidden_states = adaptive_layer_norm(hidden_states, scale_mlp, shift_mlp, epsilon=self.norm_eps) + # 4. Feed-forward + if not self.use_ada_layer_norm_single: + norm_hidden_states = self.norm3(hidden_states) + if self.use_ada_layer_norm_zero: + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] if self.use_ada_layer_norm_single: norm_hidden_states = self.norm2(hidden_states) diff --git a/ppdiffusers/ppdiffusers/models/attention_processor.py b/ppdiffusers/ppdiffusers/models/attention_processor.py index 011c915c1..b8c82dafa 100644 --- a/ppdiffusers/ppdiffusers/models/attention_processor.py +++ b/ppdiffusers/ppdiffusers/models/attention_processor.py @@ -981,7 +981,7 @@ def __call__( return hidden_states -# kai: this + class XFormersAttnProcessor: r""" Processor for implementing memory efficient attention using xFormers. diff --git a/ppdiffusers/ppdiffusers/models/dit_llama.py b/ppdiffusers/ppdiffusers/models/dit_llama.py index 9cb49b343..44bd1e32a 100644 --- a/ppdiffusers/ppdiffusers/models/dit_llama.py +++ b/ppdiffusers/ppdiffusers/models/dit_llama.py @@ -321,6 +321,9 @@ def compute_activation(self, quant_round_type=0, quant_max_bound=0, quant_min_bound=0): + origin_batch_size = ffn1_out.shape[0] + origin_seq_len = ffn1_out.shape[1] + ffn1_out = ffn1_out.reshape([origin_batch_size*origin_seq_len, ffn1_out.shape[-1]]) if in_dynamic_mode(): out = paddle._C_ops.fused_bias_act( ffn1_out, @@ -335,7 +338,7 @@ def compute_activation(self, quant_max_bound, quant_min_bound ) - return out + return out.reshape([origin_batch_size, origin_seq_len, out.shape[-1]]) helper = LayerHelper("fused_bias_act") out = helper.create_variable_for_type_inference(dtype=ffn1_out.dtype) @@ -355,7 +358,7 @@ def compute_activation(self, outputs={"out": out}, attrs=attrs, ) - return out + return out.reshape([origin_batch_size, origin_seq_len, out.shape[-1]]) def forward(self, x): if not self.callZKK: @@ -554,7 +557,7 @@ def __init__( ] ) - # del self.layers + del self.layers # 3. Define output layers self.final_layer = FinalLayer(dim, patch_size, self.out_channels) @@ -625,8 +628,8 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): ) return freqs_cis - @paddle.jit.to_static(backend="inference", with_trt=False, - cache_static_model=False, + @paddle.incubate.layers.inference(with_trt=False, + cache_static_model=True, collect_shape=False) def transformer_blocks(self, x, adaln_input): for i, layer in enumerate(self.layers): diff --git a/ppdiffusers/ppdiffusers/models/normalization.py b/ppdiffusers/ppdiffusers/models/normalization.py index 8e2c42113..15e111ab0 100644 --- a/ppdiffusers/ppdiffusers/models/normalization.py +++ b/ppdiffusers/ppdiffusers/models/normalization.py @@ -45,7 +45,7 @@ def forward(self, x: paddle.Tensor, timestep: paddle.Tensor) -> paddle.Tensor: x = self.norm(x) * (1 + scale) + shift return x -# this + class AdaLayerNormZero(nn.Layer): r""" Norm layer adaptive layer norm zero (adaLN-Zero). diff --git a/ppdiffusers/ppdiffusers/models/transformer_2d.py b/ppdiffusers/ppdiffusers/models/transformer_2d.py index 2a92a8310..3a1084d03 100644 --- a/ppdiffusers/ppdiffusers/models/transformer_2d.py +++ b/ppdiffusers/ppdiffusers/models/transformer_2d.py @@ -239,7 +239,7 @@ def __init__( # 5. PixArt-Alpha blocks. self.adaln_single = None self.use_additional_conditions = False - if norm_type == "ada_norm_single": # kai: ada_norm_zero + if norm_type == "ada_norm_single": self.use_additional_conditions = self.config.sample_size == 128 # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use # additional conditions until we find better name @@ -251,35 +251,10 @@ def __init__( self.gradient_checkpointing = False - # del self.transformer_blocks - def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value - @paddle.jit.to_static(backend="inference", with_trt=False, - cache_static_model=False, - collect_shape=False) - def deco_transformer_blocks(self, hidden_states, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - timestep, - cross_attention_kwargs, - class_labels): - for block in self.transformer_blocks: - hidden_states = block( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - timestep=timestep, - cross_attention_kwargs=cross_attention_kwargs, - class_labels=class_labels, - ) - return hidden_states - - def forward( self, hidden_states: paddle.Tensor, @@ -410,48 +385,40 @@ def forward( encoder_hidden_states = self.caption_projection(encoder_hidden_states) encoder_hidden_states = encoder_hidden_states.reshape([batch_size, -1, hidden_states.shape[-1]]) - hidden_states = self.deco_transformer_blocks(hidden_states, + for block in self.transformer_blocks: + if self.gradient_checkpointing and not hidden_states.stop_gradient and not use_old_recompute(): + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs = {} if recompute_use_reentrant() else {"use_reentrant": False} + hidden_states = recompute( + create_custom_forward(block), + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + timestep, + cross_attention_kwargs, + class_labels, + **ckpt_kwargs, + ) + else: + hidden_states = block( + hidden_states, attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, timestep=timestep, cross_attention_kwargs=cross_attention_kwargs, - class_labels=class_labels) - r""" - # for block in self.transformer_blocks: - # if self.gradient_checkpointing and not hidden_states.stop_gradient and not use_old_recompute(): - # def create_custom_forward(module, return_dict=None): - # def custom_forward(*inputs): - # if return_dict is not None: - # return module(*inputs, return_dict=return_dict) - # else: - # return module(*inputs) - - # return custom_forward - - # ckpt_kwargs = {} if recompute_use_reentrant() else {"use_reentrant": False} - # hidden_states = recompute( - # create_custom_forward(block), - # hidden_states, - # attention_mask, - # encoder_hidden_states, - # encoder_attention_mask, - # timestep, - # cross_attention_kwargs, - # class_labels, - # **ckpt_kwargs, - # ) - # else: - # hidden_states = block( - # hidden_states, - # attention_mask=attention_mask, - # encoder_hidden_states=encoder_hidden_states, - # encoder_attention_mask=encoder_attention_mask, - # timestep=timestep, - # cross_attention_kwargs=cross_attention_kwargs, - # class_labels=class_labels, - # ) - """ + class_labels=class_labels, + ) # 3. Output if self.is_input_continuous: @@ -515,9 +482,3 @@ def forward( return (output,) return Transformer2DModelOutput(sample=output) - - - @classmethod - def custom_modify_weight(cls, state_dict): - print("state_dict", state_dict.keys()) - exit() \ No newline at end of file diff --git a/ppdiffusers/ppdiffusers/models/unet_2d_condition.py b/ppdiffusers/ppdiffusers/models/unet_2d_condition.py index 9e5fd5e8e..65d95bcaa 100644 --- a/ppdiffusers/ppdiffusers/models/unet_2d_condition.py +++ b/ppdiffusers/ppdiffusers/models/unet_2d_condition.py @@ -629,7 +629,6 @@ def __init__( self.position_net = PositionNet( positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type ) - #del self.down_blocks @property def attn_processors(self) -> Dict[str, AttentionProcessor]: @@ -906,43 +905,6 @@ def forward( logger.info("Forward upsample size to force interpolation output size.") forward_upsample_size = True - timestep = timestep.reshape([1]) - - self.tmp_down_intrablock_additional_residuals = down_intrablock_additional_residuals - self.tmp_cross_attention_kwargs = cross_attention_kwargs - self.tmp_forward_upsample_size = forward_upsample_size - self.tmp_upsample_size = upsample_size - - sample = self.run_down_mid_up_blocks(sample, encoder_hidden_states, timestep, timestep_cond, - attention_mask, - encoder_attention_mask, - mid_block_additional_residual, - down_block_additional_residuals) - - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - - if not return_dict: - return (sample,) - - return UNet2DConditionOutput(sample=sample) - - @paddle.jit.to_static(backend="paddle_inference", cache_static_model=False, - with_trt=True, - switch_ir_optim=True, - precision_mode="float16", - switch_ir_debug=False) - def run_down_mid_up_blocks(self, sample, encoder_hidden_states, timestep, timestep_cond, - attention_mask, - encoder_attention_mask, - mid_block_additional_residual, down_block_additional_residuals): - - down_intrablock_additional_residuals = self.tmp_down_intrablock_additional_residuals - cross_attention_kwargs = self.tmp_cross_attention_kwargs - forward_upsample_size = self.tmp_forward_upsample_size - upsample_size = self.tmp_upsample_size - # ensure attention_mask is a bias, and give it a singleton query_tokens dimension # expects mask of shape: # [batch, key_tokens] @@ -1235,4 +1197,11 @@ def run_down_mid_up_blocks(self, sample, encoder_hidden_states, timestep, timest if self.data_format == "NHWC": sample = sample.transpose([0, 3, 1, 2]) - return sample + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (sample,) + + return UNet2DConditionOutput(sample=sample) diff --git a/ppdiffusers/ppdiffusers/pipelines/dit/pipeline_dit.py b/ppdiffusers/ppdiffusers/pipelines/dit/pipeline_dit.py index e7f47e8b1..8654bbc2e 100644 --- a/ppdiffusers/ppdiffusers/pipelines/dit/pipeline_dit.py +++ b/ppdiffusers/ppdiffusers/pipelines/dit/pipeline_dit.py @@ -27,7 +27,6 @@ from ...utils.paddle_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput -import datetime class DiTPipeline(DiffusionPipeline): r""" @@ -90,18 +89,6 @@ def get_label_ids(self, label: Union[str, List[str]]) -> List[int]: return [self.labels[l] for l in label] - def warm_up_transformer(self, latent_model_input, class_labels_input): - timesteps = 0 - half = latent_model_input[: len(latent_model_input) // 2] - latent_model_input = paddle.concat([half, half], axis=0) - latent_model_input = self.scheduler.scale_model_input(latent_model_input, timesteps) - - timesteps = paddle.to_tensor([timesteps], dtype=paddle.int64) - timesteps = timesteps.expand([latent_model_input.shape[0],]) - self.transformer( - latent_model_input, timestep=timesteps, class_labels=class_labels_input - ) - @paddle.no_grad() def __call__( self, @@ -181,12 +168,6 @@ def __call__( # set step values self.scheduler.set_timesteps(num_inference_steps) - - # self.warm_up_transformer(latent_model_input, class_labels_input) - # print("\n--------------------- warm up end\n") - # paddle.device.cuda.synchronize(0) - # starttime = datetime.datetime.now() - for t in self.progress_bar(self.scheduler.timesteps): if guidance_scale > 1: half = latent_model_input[: len(latent_model_input) // 2] @@ -235,15 +216,6 @@ def __call__( # compute previous image: x_t -> x_t-1 latent_model_input = self.scheduler.step(model_output, t, latent_model_input).prev_sample - # paddle.device.cuda.synchronize(0) - # endtime = datetime.datetime.now() - # duringtime = endtime-starttime - # time_ms = duringtime.seconds * 1000 + duringtime.microseconds / 1000.0 - # msg = "total_time_cost: " + str(time_ms) + "ms\n\n" - # print(msg) - # with open("/tyk/PaddleMIX/ppdiffusers/examples/inference/kai/res/time_719.txt", "a") as time_file: - # time_file.write(msg) - if guidance_scale > 1: latents, _ = latent_model_input.chunk(2, axis=0) else: From 55b50428d271a5ec18198763fec134ea277b7d07 Mon Sep 17 00:00:00 2001 From: YKTian-x2b <2084984251@qq.com> Date: Mon, 22 Jul 2024 09:22:55 +0000 Subject: [PATCH 09/21] no reshape --- ...nditional_image_generation-large_dit_7b.py | 5 ++- ppdiffusers/ppdiffusers/models/dit_llama.py | 35 ++++++++----------- 2 files changed, 17 insertions(+), 23 deletions(-) diff --git a/ppdiffusers/examples/inference/class_conditional_image_generation-large_dit_7b.py b/ppdiffusers/examples/inference/class_conditional_image_generation-large_dit_7b.py index fa9b40365..9be81375b 100644 --- a/ppdiffusers/examples/inference/class_conditional_image_generation-large_dit_7b.py +++ b/ppdiffusers/examples/inference/class_conditional_image_generation-large_dit_7b.py @@ -11,14 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os -os.environ["CUDA_VISIBLE_DEVICES"] = "4" + import paddle from paddlenlp.trainer import set_seed from ppdiffusers import DDIMScheduler, DiTPipeline -dtype = paddle.float16 +dtype = paddle.bfloat16 # To speed up this code, call zkk and let him run for you, # then you will get a speed increase of almost 100%. diff --git a/ppdiffusers/ppdiffusers/models/dit_llama.py b/ppdiffusers/ppdiffusers/models/dit_llama.py index 44bd1e32a..20128980b 100644 --- a/ppdiffusers/ppdiffusers/models/dit_llama.py +++ b/ppdiffusers/ppdiffusers/models/dit_llama.py @@ -321,9 +321,6 @@ def compute_activation(self, quant_round_type=0, quant_max_bound=0, quant_min_bound=0): - origin_batch_size = ffn1_out.shape[0] - origin_seq_len = ffn1_out.shape[1] - ffn1_out = ffn1_out.reshape([origin_batch_size*origin_seq_len, ffn1_out.shape[-1]]) if in_dynamic_mode(): out = paddle._C_ops.fused_bias_act( ffn1_out, @@ -338,7 +335,7 @@ def compute_activation(self, quant_max_bound, quant_min_bound ) - return out.reshape([origin_batch_size, origin_seq_len, out.shape[-1]]) + return out helper = LayerHelper("fused_bias_act") out = helper.create_variable_for_type_inference(dtype=ffn1_out.dtype) @@ -358,7 +355,7 @@ def compute_activation(self, outputs={"out": out}, attrs=attrs, ) - return out.reshape([origin_batch_size, origin_seq_len, out.shape[-1]]) + return out def forward(self, x): if not self.callZKK: @@ -431,14 +428,11 @@ def __init__( self.attention_norm = nn.LayerNorm(dim, epsilon=norm_eps, bias_attr=False) self.ffn_norm = nn.LayerNorm(dim, epsilon=norm_eps, bias_attr=False) - self.adaLN_modulation = nn.Sequential( - nn.Silu(), - nn.Linear(min(dim, 1024), 6 * dim), - ) + self.adaLN_modulation = nn.Linear(min(dim, 1024), 6 * dim) self.norm_eps = norm_eps self.callZKK = callZKK - def forward(self, x, freqs_cis, adaln_input=None): + def forward(self, x, freqs_cis, adaln_silu_input=None): """ Perform a forward pass through the TransformerBlock. @@ -453,8 +447,8 @@ def forward(self, x, freqs_cis, adaln_input=None): feedforward layers. """ - if adaln_input is not None: - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk( + if adaln_silu_input is not None: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_silu_input).chunk( 6, axis=1 ) if not self.callZKK: @@ -536,6 +530,7 @@ def __init__( self.t_embedder = TimestepEmbedder(min(dim, 1024)) self.y_embedder = LabelEmbedding(num_classes, min(dim, 1024), class_dropout_prob) + self.adaln_silu = nn.Silu() self.callZKK = True if os.getenv('callZKK') else False # 2. Define transformers blocks self.layers = nn.LayerList( @@ -628,18 +623,18 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): ) return freqs_cis - @paddle.incubate.layers.inference(with_trt=False, + @paddle.jit.to_static(backend="inference", with_trt=False, cache_static_model=True, collect_shape=False) - def transformer_blocks(self, x, adaln_input): + def transformer_blocks(self, x, adaln_silu_out): for i, layer in enumerate(self.layers): if self.gradient_checkpointing and False: - x = paddle.distributed.fleet.utils.recompute(layer, x, self.freqs_cis[: x.shape[1]], adaln_input) + x = paddle.distributed.fleet.utils.recompute(layer, x, self.freqs_cis[: x.shape[1]], adaln_silu_out) else: x = layer( x, self.freqs_cis[: x.shape[1]], - adaln_input, + adaln_silu_out, ) return x @@ -666,20 +661,20 @@ def forward( t = self.t_embedder(timestep) y = self.y_embedder(class_labels) adaln_input = t + y - + adaln_silu_out = self.adaln_silu(adaln_input) # 2. Blocks if not self.callZKK: for i, layer in enumerate(self.layers): if self.gradient_checkpointing: - x = paddle.distributed.fleet.utils.recompute(layer, x, self.freqs_cis[: x.shape[1]], adaln_input) + x = paddle.distributed.fleet.utils.recompute(layer, x, self.freqs_cis[: x.shape[1]], adaln_silu_out) else: x = layer( x, self.freqs_cis[: x.shape[1]], - adaln_input, + adaln_silu_out, ) else: - x = self.transformer_blocks(x, adaln_input) + x = self.transformer_blocks(x, adaln_silu_out) # 3. Output hidden_states = self.final_layer(x, adaln_input) From 70c6dc0c426023c5a6d80fac4ffd055f98526e97 Mon Sep 17 00:00:00 2001 From: YKTian-x2b <2084984251@qq.com> Date: Mon, 22 Jul 2024 09:27:45 +0000 Subject: [PATCH 10/21] no reshape --- ppdiffusers/ppdiffusers/models/dit_llama.py | 30 +++++++++++---------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/ppdiffusers/ppdiffusers/models/dit_llama.py b/ppdiffusers/ppdiffusers/models/dit_llama.py index 20128980b..9cb49b343 100644 --- a/ppdiffusers/ppdiffusers/models/dit_llama.py +++ b/ppdiffusers/ppdiffusers/models/dit_llama.py @@ -428,11 +428,14 @@ def __init__( self.attention_norm = nn.LayerNorm(dim, epsilon=norm_eps, bias_attr=False) self.ffn_norm = nn.LayerNorm(dim, epsilon=norm_eps, bias_attr=False) - self.adaLN_modulation = nn.Linear(min(dim, 1024), 6 * dim) + self.adaLN_modulation = nn.Sequential( + nn.Silu(), + nn.Linear(min(dim, 1024), 6 * dim), + ) self.norm_eps = norm_eps self.callZKK = callZKK - def forward(self, x, freqs_cis, adaln_silu_input=None): + def forward(self, x, freqs_cis, adaln_input=None): """ Perform a forward pass through the TransformerBlock. @@ -447,8 +450,8 @@ def forward(self, x, freqs_cis, adaln_silu_input=None): feedforward layers. """ - if adaln_silu_input is not None: - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_silu_input).chunk( + if adaln_input is not None: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk( 6, axis=1 ) if not self.callZKK: @@ -530,7 +533,6 @@ def __init__( self.t_embedder = TimestepEmbedder(min(dim, 1024)) self.y_embedder = LabelEmbedding(num_classes, min(dim, 1024), class_dropout_prob) - self.adaln_silu = nn.Silu() self.callZKK = True if os.getenv('callZKK') else False # 2. Define transformers blocks self.layers = nn.LayerList( @@ -552,7 +554,7 @@ def __init__( ] ) - del self.layers + # del self.layers # 3. Define output layers self.final_layer = FinalLayer(dim, patch_size, self.out_channels) @@ -624,17 +626,17 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): return freqs_cis @paddle.jit.to_static(backend="inference", with_trt=False, - cache_static_model=True, + cache_static_model=False, collect_shape=False) - def transformer_blocks(self, x, adaln_silu_out): + def transformer_blocks(self, x, adaln_input): for i, layer in enumerate(self.layers): if self.gradient_checkpointing and False: - x = paddle.distributed.fleet.utils.recompute(layer, x, self.freqs_cis[: x.shape[1]], adaln_silu_out) + x = paddle.distributed.fleet.utils.recompute(layer, x, self.freqs_cis[: x.shape[1]], adaln_input) else: x = layer( x, self.freqs_cis[: x.shape[1]], - adaln_silu_out, + adaln_input, ) return x @@ -661,20 +663,20 @@ def forward( t = self.t_embedder(timestep) y = self.y_embedder(class_labels) adaln_input = t + y - adaln_silu_out = self.adaln_silu(adaln_input) + # 2. Blocks if not self.callZKK: for i, layer in enumerate(self.layers): if self.gradient_checkpointing: - x = paddle.distributed.fleet.utils.recompute(layer, x, self.freqs_cis[: x.shape[1]], adaln_silu_out) + x = paddle.distributed.fleet.utils.recompute(layer, x, self.freqs_cis[: x.shape[1]], adaln_input) else: x = layer( x, self.freqs_cis[: x.shape[1]], - adaln_silu_out, + adaln_input, ) else: - x = self.transformer_blocks(x, adaln_silu_out) + x = self.transformer_blocks(x, adaln_input) # 3. Output hidden_states = self.final_layer(x, adaln_input) From 5d3d29f5cfe01b3a85ee683418d2f689f7f194d5 Mon Sep 17 00:00:00 2001 From: YKTian-x2b <2084984251@qq.com> Date: Thu, 25 Jul 2024 11:21:23 +0000 Subject: [PATCH 11/21] fuse_repo triton kernel --- ppdiffusers/ppdiffusers/models/dit_llama.py | 79 +++++++++++++-------- 1 file changed, 49 insertions(+), 30 deletions(-) diff --git a/ppdiffusers/ppdiffusers/models/dit_llama.py b/ppdiffusers/ppdiffusers/models/dit_llama.py index 9cb49b343..d88631a00 100644 --- a/ppdiffusers/ppdiffusers/models/dit_llama.py +++ b/ppdiffusers/ppdiffusers/models/dit_llama.py @@ -26,7 +26,7 @@ from .transformer_2d import Transformer2DModelOutput from paddle.framework import LayerHelper, in_dynamic_mode -from paddle.incubate.tt import adaptive_layer_norm, fused_adaLN_scale_residual +from paddle.incubate.tt import adaptive_layer_norm, fused_adaLN_scale_residual, fused_rotary_emb import os def TypePromote(x, y): @@ -193,15 +193,7 @@ def apply_rotary_emb(xq, xk, freqs_cis): Tuple[paddle.Tensor, paddle.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. """ - if not os.getenv('callZKK'): - with paddle.amp.auto_cast(enable=False): - xq_ = paddle.as_complex(xq.cast("float32").reshape([*tuple(xq.shape)[:-1], -1, 2])) - xk_ = paddle.as_complex(xk.cast("float32").reshape([*tuple(xk.shape)[:-1], -1, 2])) - freqs_cis = Attention.reshape_for_broadcast(freqs_cis, xq_) - xq_out = paddle.as_real(xq_ * freqs_cis).flatten(start_axis=3) - xk_out = paddle.as_real(xk_ * freqs_cis).flatten(start_axis=3) - return xq_out.cast(xq.dtype), xk_out.cast(xk.dtype) - else: + with paddle.amp.auto_cast(enable=False): xq_ = paddle.as_complex(xq.cast("float32").reshape([*tuple(xq.shape)[:-1], -1, 2])) xk_ = paddle.as_complex(xk.cast("float32").reshape([*tuple(xk.shape)[:-1], -1, 2])) freqs_cis = Attention.reshape_for_broadcast(freqs_cis, xq_) @@ -225,21 +217,22 @@ def forward(self, x, freqs_cis): if not self.callZKK: xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) - else: - qkv_out = self.qkv(x) - xq, xk, xv = paddle.split(qkv_out, 3, axis=-1) - - dtype = xq.dtype - - xq = self.q_norm(xq) - xk = self.k_norm(xk) + + dtype = xq.dtype + + xq = self.q_norm(xq) + xk = self.k_norm(xk) - xq = xq.reshape([bsz, seqlen, self.n_local_heads, self.head_dim]) - xk = xk.reshape([bsz, seqlen, self.n_local_kv_heads, self.head_dim]) - xv = xv.reshape([bsz, seqlen, self.n_local_kv_heads, self.head_dim]) + xq = xq.reshape([bsz, seqlen, self.n_local_heads, self.head_dim]) + xk = xk.reshape([bsz, seqlen, self.n_local_kv_heads, self.head_dim]) + xv = xv.reshape([bsz, seqlen, self.n_local_kv_heads, self.head_dim]) - xq, xk = Attention.apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) - xq, xk = xq.cast(dtype), xk.cast(dtype) + xq, xk = Attention.apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + xq, xk = xq.cast(dtype), xk.cast(dtype) + else: + qkv_out = self.qkv(x) + xq, xk, xv = fused_rotary_emb(qkv_out, self.q_norm.weight, self.q_norm.bias, + self.k_norm.weight, self.k_norm.bias, freqs_cis) if dtype in [paddle.float16, paddle.bfloat16]: output, _ = flash_attention( @@ -341,6 +334,14 @@ def compute_activation(self, out = helper.create_variable_for_type_inference(dtype=ffn1_out.dtype) inputs = {} inputs["x"] = ffn1_out + if bias is not None: + inputs["bias"] = bias + if dequant_scales is not None: + inputs["dequant_scales"] = dequant_scales + if shift is not None: + inputs["shift"] = shift + if smooth is not None: + inputs["smooth"] = smooth attrs = { "act_method": act_method, "compute_dtype": compute_dtype, @@ -619,15 +620,32 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): freqs = 1.0 / theta ** (paddle.arange(start=0, end=dim, step=2)[: dim // 2].cast("float32") / dim) t = paddle.arange(end=end) input_0, vec2_0 = TypePromote(t, freqs) + # [end, dim // 2] freqs = paddle.outer(input_0, vec2_0).cast("float32") - freqs_cis = paddle.complex( - paddle.ones_like(freqs) * paddle.cos(freqs), paddle.ones_like(freqs) * paddle.sin(freqs) - ) + + if os.getenv('callZKK'): + # [end, dim // 2, 4] + freqs_cis = paddle.stack([ + paddle.cos(freqs), + -paddle.sin(freqs), + paddle.sin(freqs), + paddle.cos(freqs)], axis=-1) + # [end, dim, 2] + freqs_cis = freqs_cis.reshape([freqs_cis.shape[0], -1, 2]).unsqueeze(1) + else: + freqs_cis = paddle.complex( + paddle.ones_like(freqs) * paddle.cos(freqs), paddle.ones_like(freqs) * paddle.sin(freqs) + ) return freqs_cis - @paddle.jit.to_static(backend="inference", with_trt=False, - cache_static_model=False, - collect_shape=False) + @paddle.jit.to_static(backend='inference', with_trt=False, + cache_static_model=False, + exp_enable_use_cutlass=True, + enable_new_ir=True, + delete_pass_lists = ["trt_prompt_tuning_embedding_eltwise_layernorm_fuse_pass", + "add_support_int8_pass", + "fc_fuse_pass", + "add_norm_fuse_pass",]) def transformer_blocks(self, x, adaln_input): for i, layer in enumerate(self.layers): if self.gradient_checkpointing and False: @@ -667,7 +685,7 @@ def forward( # 2. Blocks if not self.callZKK: for i, layer in enumerate(self.layers): - if self.gradient_checkpointing: + if self.gradient_checkpointing: x = paddle.distributed.fleet.utils.recompute(layer, x, self.freqs_cis[: x.shape[1]], adaln_input) else: x = layer( @@ -676,6 +694,7 @@ def forward( adaln_input, ) else: + self.freqs_cis = paddle.expand(self.freqs_cis, [-1, self.num_attention_heads, -1, -1]) x = self.transformer_blocks(x, adaln_input) # 3. Output From 19cbd9020ea315db128ae65186dd12e46b7df225 Mon Sep 17 00:00:00 2001 From: YKTian-x2b <2084984251@qq.com> Date: Mon, 5 Aug 2024 08:50:13 +0000 Subject: [PATCH 12/21] with horizontal_fuse_pass opt --- paddlemix/triton_ops/triton_ops.py | 1 - ppdiffusers/ppdiffusers/models/dit_llama.py | 133 +++++++----------- .../ppdiffusers/models/modeling_utils.py | 2 +- 3 files changed, 53 insertions(+), 83 deletions(-) diff --git a/paddlemix/triton_ops/triton_ops.py b/paddlemix/triton_ops/triton_ops.py index 3ade229c3..fff3a5947 100644 --- a/paddlemix/triton_ops/triton_ops.py +++ b/paddlemix/triton_ops/triton_ops.py @@ -1466,7 +1466,6 @@ def fused_rotary_emb( freqs_cis, epsilon=1e-5, ): - assert x.is_contiguous() assert q_norm_weight is not None, "q_norm_weight should not be none" assert q_norm_bias is not None, "q_norm_bias should not be none" assert k_norm_weight is not None, "k_norm_weight should not be none" diff --git a/ppdiffusers/ppdiffusers/models/dit_llama.py b/ppdiffusers/ppdiffusers/models/dit_llama.py index d88631a00..c9b4fa490 100644 --- a/ppdiffusers/ppdiffusers/models/dit_llama.py +++ b/ppdiffusers/ppdiffusers/models/dit_llama.py @@ -14,20 +14,19 @@ import math from typing import Optional +import os import paddle import paddle.nn as nn import paddle.nn.functional as F from paddle.nn.functional.flash_attention import flash_attention +from paddle.framework import LayerHelper, in_dynamic_mode from ..configuration_utils import ConfigMixin, register_to_config from .embeddings import LabelEmbedding from .modeling_utils import ModelMixin from .transformer_2d import Transformer2DModelOutput -from paddle.framework import LayerHelper, in_dynamic_mode -from paddle.incubate.tt import adaptive_layer_norm, fused_adaLN_scale_residual, fused_rotary_emb -import os def TypePromote(x, y): TYPE_PROMOTE_DICT = { @@ -93,7 +92,7 @@ def forward(self, t): class Attention(nn.Layer): - def __init__(self, dim, n_heads, n_kv_heads, qk_norm=True, fused_attn=True, callZKK=False): + def __init__(self, dim, n_heads, n_kv_heads, qk_norm=True, fused_attn=True, optimize_inference_for_ditllama=False): """ Initialize the Attention module. @@ -111,7 +110,6 @@ def __init__(self, dim, n_heads, n_kv_heads, qk_norm=True, fused_attn=True, call wq (nn.Linear): Linear transformation for queries. wk (nn.Linear): Linear transformation for keys. wv (nn.Linear): Linear transformation for values. - qkv (nn.Linear): Linear transformation for queries, keys and values. wo (nn.Linear): Linear transformation for output. cache_k (paddle.Tensor): Cached keys for attention. cache_v (paddle.Tensor): Cached values for attention. @@ -124,14 +122,10 @@ def __init__(self, dim, n_heads, n_kv_heads, qk_norm=True, fused_attn=True, call self.n_rep = self.n_local_heads // self.n_local_kv_heads self.head_dim = dim // n_heads - self.callZKK = callZKK - if not callZKK: - self.wq = nn.Linear(dim, n_heads * self.head_dim, bias_attr=False) - self.wk = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias_attr=False) - self.wv = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias_attr=False) - else: - self.qkv = nn.Linear(dim, (n_heads + 2 * self.n_kv_heads) * self.head_dim, bias_attr=False) - + self.optimize_inference_for_ditllama = optimize_inference_for_ditllama + self.wq = nn.Linear(dim, n_heads * self.head_dim, bias_attr=False) + self.wk = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias_attr=False) + self.wv = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias_attr=False) self.wo = nn.Linear(n_heads * self.head_dim, dim, bias_attr=False) if qk_norm: @@ -214,12 +208,9 @@ def forward(self, x, freqs_cis): """ bsz, seqlen, _ = tuple(x.shape) - - if not self.callZKK: - xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) - - dtype = xq.dtype - + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + dtype = xq.dtype + if not self.optimize_inference_for_ditllama: xq = self.q_norm(xq) xk = self.k_norm(xk) @@ -230,9 +221,10 @@ def forward(self, x, freqs_cis): xq, xk = Attention.apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) xq, xk = xq.cast(dtype), xk.cast(dtype) else: - qkv_out = self.qkv(x) - xq, xk, xv = fused_rotary_emb(qkv_out, self.q_norm.weight, self.q_norm.bias, - self.k_norm.weight, self.k_norm.bias, freqs_cis) + qkv_out = paddle.concat([xq, xk, xv], axis=-1) + import paddlemix + xq, xk, xv = paddlemix.triton_ops.fused_rotary_emb(qkv_out, self.q_norm.weight, self.q_norm.bias, + self.k_norm.weight, self.k_norm.bias, freqs_cis) if dtype in [paddle.float16, paddle.bfloat16]: output, _ = flash_attention( @@ -269,7 +261,7 @@ def forward(self, x, freqs_cis): class FeedForward(nn.Layer): - def __init__(self, dim, hidden_dim, multiple_of=256, ffn_dim_multiplier=None, callZKK=False): + def __init__(self, dim, hidden_dim, multiple_of=256, ffn_dim_multiplier=None, optimize_inference_for_ditllama=False): """ Initialize the FeedForward module. @@ -282,11 +274,12 @@ def __init__(self, dim, hidden_dim, multiple_of=256, ffn_dim_multiplier=None, ca dimension. Defaults to None. Attributes: - w1 (nn.Linear): Linear transformation for the first layer. + w1 (nn.Linear): Linear transformation for the first + layer. w2 (nn.Linear): Linear transformation for the second layer. - w3 (nn.Linear): Linear transformation for the third layer. - w13 (nn.Linear): Linear transformation for the first and the third layer. - + w3 (nn.Linear): Linear transformation for the third + layer. + """ super().__init__() hidden_dim = int(2 * hidden_dim / 3) @@ -295,12 +288,12 @@ def __init__(self, dim, hidden_dim, multiple_of=256, ffn_dim_multiplier=None, ca hidden_dim = int(multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)) self.w2 = nn.Linear(hidden_dim, dim, bias_attr=False) - self.callZKK = callZKK - if not callZKK: + if not optimize_inference_for_ditllama: self.w1 = nn.Linear(dim, hidden_dim, bias_attr=False) self.w3 = nn.Linear(dim, hidden_dim, bias_attr=False) else: self.w13 = nn.Linear(dim, hidden_dim * 2, bias_attr=False) + self.optimize_inference_for_ditllama = optimize_inference_for_ditllama def compute_activation(self, ffn1_out, @@ -357,9 +350,9 @@ def compute_activation(self, attrs=attrs, ) return out - + def forward(self, x): - if not self.callZKK: + if not self.optimize_inference_for_ditllama: xw1 = F.silu(self.w1(x)) xw3 = self.w3(x) output = self.w2(xw1 * xw3) @@ -368,7 +361,7 @@ def forward(self, x): ffn1_out = self.w13(x) ffn1_out = self.compute_activation(ffn1_out) ffn2_out = self.w2(ffn1_out) - return ffn2_out + return ffn2_out class TransformerBlock(nn.Layer): @@ -384,7 +377,7 @@ def __init__( norm_eps: float, qk_norm: bool, fused_attn: bool, - callZKK=False, + optimize_inference_for_ditllama=False, ) -> None: """ Initialize a TransformerBlock. @@ -419,11 +412,11 @@ def __init__( super().__init__() self.dim = dim self.head_dim = dim // n_heads - self.attention = Attention(dim, n_heads, n_kv_heads, qk_norm, fused_attn, callZKK=callZKK) + self.attention = Attention(dim, n_heads, n_kv_heads, qk_norm, fused_attn, optimize_inference_for_ditllama=optimize_inference_for_ditllama) mlp_hidden_dim = int(dim * mlp_ratio) self.feed_forward = FeedForward( - dim=dim, hidden_dim=mlp_hidden_dim, multiple_of=multiple_of, - ffn_dim_multiplier=ffn_dim_multiplier, callZKK=callZKK + dim=dim, hidden_dim=mlp_hidden_dim, multiple_of=multiple_of, ffn_dim_multiplier=ffn_dim_multiplier, + optimize_inference_for_ditllama=optimize_inference_for_ditllama ) self.layer_id = layer_id self.attention_norm = nn.LayerNorm(dim, epsilon=norm_eps, bias_attr=False) @@ -433,8 +426,8 @@ def __init__( nn.Silu(), nn.Linear(min(dim, 1024), 6 * dim), ) + self.optimize_inference_for_ditllama = optimize_inference_for_ditllama self.norm_eps = norm_eps - self.callZKK = callZKK def forward(self, x, freqs_cis, adaln_input=None): """ @@ -455,15 +448,16 @@ def forward(self, x, freqs_cis, adaln_input=None): shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk( 6, axis=1 ) - if not self.callZKK: + if not self.optimize_inference_for_ditllama: h = x + gate_msa.unsqueeze(1) * self.attention( modulate(self.attention_norm(x), shift_msa, scale_msa), freqs_cis ) out = h + gate_mlp.unsqueeze(1) * self.feed_forward(modulate(self.ffn_norm(h), shift_mlp, scale_mlp)) else: - attention_out = self.attention(adaptive_layer_norm(x, scale_msa, shift_msa, + import paddlemix + attention_out = self.attention(paddlemix.triton_ops.adaptive_layer_norm(x, scale_msa, shift_msa, weight=self.attention_norm.weight, epsilon=self.norm_eps), freqs_cis) - residual_out, adaLN_out = fused_adaLN_scale_residual(x, attention_out, gate_msa, scale_mlp, shift_mlp, + residual_out, adaLN_out = paddlemix.triton_ops.fused_adaLN_scale_residual(x, attention_out, gate_msa, scale_mlp, shift_mlp, weight=self.ffn_norm.weight, epsilon=self.norm_eps) out = residual_out + gate_mlp.unsqueeze(1) * self.feed_forward(adaLN_out) else: @@ -527,6 +521,7 @@ def __init__( self.num_classes = num_classes self.learn_sigma = learn_sigma self.qk_norm = qk_norm + self.gradient_checkpointing = True self.fused_attn = True @@ -534,7 +529,8 @@ def __init__( self.t_embedder = TimestepEmbedder(min(dim, 1024)) self.y_embedder = LabelEmbedding(num_classes, min(dim, 1024), class_dropout_prob) - self.callZKK = True if os.getenv('callZKK') else False + self.optimize_inference_for_ditllama = True if os.getenv('optimize_inference_for_ditllama') else False + # 2. Define transformers blocks self.layers = nn.LayerList( [ @@ -549,13 +545,11 @@ def __init__( norm_eps=norm_eps, qk_norm=qk_norm, fused_attn=self.fused_attn, - callZKK=self.callZKK, + optimize_inference_for_ditllama=self.optimize_inference_for_ditllama, ) for idx in range(num_layers) ] ) - - # del self.layers # 3. Define output layers self.final_layer = FinalLayer(dim, patch_size, self.out_channels) @@ -620,17 +614,13 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): freqs = 1.0 / theta ** (paddle.arange(start=0, end=dim, step=2)[: dim // 2].cast("float32") / dim) t = paddle.arange(end=end) input_0, vec2_0 = TypePromote(t, freqs) - # [end, dim // 2] freqs = paddle.outer(input_0, vec2_0).cast("float32") - - if os.getenv('callZKK'): - # [end, dim // 2, 4] + if os.getenv('optimize_inference_for_ditllama'): freqs_cis = paddle.stack([ paddle.cos(freqs), -paddle.sin(freqs), paddle.sin(freqs), paddle.cos(freqs)], axis=-1) - # [end, dim, 2] freqs_cis = freqs_cis.reshape([freqs_cis.shape[0], -1, 2]).unsqueeze(1) else: freqs_cis = paddle.complex( @@ -638,24 +628,16 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): ) return freqs_cis - @paddle.jit.to_static(backend='inference', with_trt=False, - cache_static_model=False, - exp_enable_use_cutlass=True, - enable_new_ir=True, - delete_pass_lists = ["trt_prompt_tuning_embedding_eltwise_layernorm_fuse_pass", - "add_support_int8_pass", - "fc_fuse_pass", - "add_norm_fuse_pass",]) + @paddle.incubate.jit.inference(cache_static_model=False, + enable_new_ir=True, + exp_enable_use_cutlass=True,) def transformer_blocks(self, x, adaln_input): for i, layer in enumerate(self.layers): - if self.gradient_checkpointing and False: - x = paddle.distributed.fleet.utils.recompute(layer, x, self.freqs_cis[: x.shape[1]], adaln_input) - else: - x = layer( - x, - self.freqs_cis[: x.shape[1]], - adaln_input, - ) + x = layer( + x, + self.freqs_cis[: x.shape[1]], + adaln_input, + ) return x def forward( @@ -683,9 +665,9 @@ def forward( adaln_input = t + y # 2. Blocks - if not self.callZKK: + if not self.optimize_inference_for_ditllama: for i, layer in enumerate(self.layers): - if self.gradient_checkpointing: + if self.gradient_checkpointing: x = paddle.distributed.fleet.utils.recompute(layer, x, self.freqs_cis[: x.shape[1]], adaln_input) else: x = layer( @@ -696,7 +678,7 @@ def forward( else: self.freqs_cis = paddle.expand(self.freqs_cis, [-1, self.num_attention_heads, -1, -1]) x = self.transformer_blocks(x, adaln_input) - + # 3. Output hidden_states = self.final_layer(x, adaln_input) output = self.unpatchify(hidden_states) @@ -705,11 +687,10 @@ def forward( return (output,) return Transformer2DModelOutput(sample=output) - + @classmethod def custom_modify_weight(cls, state_dict): - # If you're not invited to zkk, you won't get any performance optimizations. - if os.getenv('callZKK'): + if os.getenv('optimize_inference_for_ditllama'): for key in list(state_dict.keys()): if 'feed_forward.w1.weight' in key: w1 = state_dict.pop(key) @@ -717,13 +698,3 @@ def custom_modify_weight(cls, state_dict): w3 = state_dict.pop(w3_key) w13 = paddle.concat([w1, w3], axis=1) state_dict[key.replace('w1', 'w13')] = w13 - if 'attention.wq.weight' in key or 'attention.wk.weight' in key or 'attention.wv.weight' in key: - part = key.split('.')[-2] - layer_id = key.split('.')[1] - qkv_key = f'layers.{layer_id}.attention.qkv.weight' - if part == 'wq' and qkv_key not in state_dict: - state_dict[qkv_key] = state_dict.pop(key) - elif part in ('wk', 'wv'): - qkv = state_dict.get(qkv_key) - if qkv is not None: - state_dict[qkv_key] = paddle.concat([qkv, state_dict.pop(key)], axis=1) diff --git a/ppdiffusers/ppdiffusers/models/modeling_utils.py b/ppdiffusers/ppdiffusers/models/modeling_utils.py index 3bcb620a6..ea9b74253 100644 --- a/ppdiffusers/ppdiffusers/models/modeling_utils.py +++ b/ppdiffusers/ppdiffusers/models/modeling_utils.py @@ -1053,7 +1053,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P @classmethod def custom_modify_weight(cls, state_dict): pass - + @classmethod def _load_pretrained_model( cls, From c000d4ca486207caa175ab93aa57fb24d2bce19c Mon Sep 17 00:00:00 2001 From: YKTian-x2b <2084984251@qq.com> Date: Mon, 5 Aug 2024 08:55:01 +0000 Subject: [PATCH 13/21] env --- .../class_conditional_image_generation-large_dit_7b.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/ppdiffusers/examples/inference/class_conditional_image_generation-large_dit_7b.py b/ppdiffusers/examples/inference/class_conditional_image_generation-large_dit_7b.py index 9be81375b..5559be8f5 100644 --- a/ppdiffusers/examples/inference/class_conditional_image_generation-large_dit_7b.py +++ b/ppdiffusers/examples/inference/class_conditional_image_generation-large_dit_7b.py @@ -19,9 +19,8 @@ dtype = paddle.bfloat16 -# To speed up this code, call zkk and let him run for you, -# then you will get a speed increase of almost 100%. -os.environ['callZKK']= "True" +# If you want to turn off optimization, comment this code +os.environ['optimize_inference_for_ditllama']= "True" with paddle.LazyGuard(): pipe = DiTPipeline.from_pretrained("Alpha-VLLM/Large-DiT-7B-256", paddle_dtype=dtype) From 9f04a1ca9f38bf2730badab90fac7bd1d47b06eb Mon Sep 17 00:00:00 2001 From: YKTian-x2b <2084984251@qq.com> Date: Wed, 7 Aug 2024 09:55:34 +0000 Subject: [PATCH 14/21] ReNet --- ...nditional_image_generation-large_dit_3b.py | 6 +- ...nditional_image_generation-large_dit_7b.py | 2 +- ppdiffusers/ppdiffusers/models/dit_llama.py | 204 ++++++------------ .../ppdiffusers/models/modeling_utils.py | 2 +- 4 files changed, 73 insertions(+), 141 deletions(-) diff --git a/ppdiffusers/examples/inference/class_conditional_image_generation-large_dit_3b.py b/ppdiffusers/examples/inference/class_conditional_image_generation-large_dit_3b.py index 93f02a955..926f19fd4 100644 --- a/ppdiffusers/examples/inference/class_conditional_image_generation-large_dit_3b.py +++ b/ppdiffusers/examples/inference/class_conditional_image_generation-large_dit_3b.py @@ -17,7 +17,11 @@ from ppdiffusers import DDIMScheduler, DiTPipeline -dtype = paddle.float32 +dtype = paddle.bfloat16 + +# If you want to turn off optimization, comment this code +os.environ['Inference_Optimize'] = "True" + with paddle.LazyGuard(): pipe = DiTPipeline.from_pretrained("Alpha-VLLM/Large-DiT-3B-256", paddle_dtype=dtype) pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) diff --git a/ppdiffusers/examples/inference/class_conditional_image_generation-large_dit_7b.py b/ppdiffusers/examples/inference/class_conditional_image_generation-large_dit_7b.py index 5559be8f5..3c2edd7d8 100644 --- a/ppdiffusers/examples/inference/class_conditional_image_generation-large_dit_7b.py +++ b/ppdiffusers/examples/inference/class_conditional_image_generation-large_dit_7b.py @@ -20,7 +20,7 @@ dtype = paddle.bfloat16 # If you want to turn off optimization, comment this code -os.environ['optimize_inference_for_ditllama']= "True" +os.environ['Inference_Optimize'] = "True" with paddle.LazyGuard(): pipe = DiTPipeline.from_pretrained("Alpha-VLLM/Large-DiT-7B-256", paddle_dtype=dtype) diff --git a/ppdiffusers/ppdiffusers/models/dit_llama.py b/ppdiffusers/ppdiffusers/models/dit_llama.py index c9b4fa490..d1bf3c5d2 100644 --- a/ppdiffusers/ppdiffusers/models/dit_llama.py +++ b/ppdiffusers/ppdiffusers/models/dit_llama.py @@ -14,19 +14,19 @@ import math from typing import Optional -import os import paddle import paddle.nn as nn import paddle.nn.functional as F from paddle.nn.functional.flash_attention import flash_attention -from paddle.framework import LayerHelper, in_dynamic_mode from ..configuration_utils import ConfigMixin, register_to_config from .embeddings import LabelEmbedding from .modeling_utils import ModelMixin from .transformer_2d import Transformer2DModelOutput +from .simplified_dit_llama import SimplifiedDiTLLaMA2DModel +import os def TypePromote(x, y): TYPE_PROMOTE_DICT = { @@ -92,7 +92,7 @@ def forward(self, t): class Attention(nn.Layer): - def __init__(self, dim, n_heads, n_kv_heads, qk_norm=True, fused_attn=True, optimize_inference_for_ditllama=False): + def __init__(self, dim, n_heads, n_kv_heads, qk_norm=True, fused_attn=True): """ Initialize the Attention module. @@ -122,7 +122,6 @@ def __init__(self, dim, n_heads, n_kv_heads, qk_norm=True, fused_attn=True, opti self.n_rep = self.n_local_heads // self.n_local_kv_heads self.head_dim = dim // n_heads - self.optimize_inference_for_ditllama = optimize_inference_for_ditllama self.wq = nn.Linear(dim, n_heads * self.head_dim, bias_attr=False) self.wk = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias_attr=False) self.wv = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias_attr=False) @@ -210,21 +209,16 @@ def forward(self, x, freqs_cis): bsz, seqlen, _ = tuple(x.shape) xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) dtype = xq.dtype - if not self.optimize_inference_for_ditllama: - xq = self.q_norm(xq) - xk = self.k_norm(xk) - xq = xq.reshape([bsz, seqlen, self.n_local_heads, self.head_dim]) - xk = xk.reshape([bsz, seqlen, self.n_local_kv_heads, self.head_dim]) - xv = xv.reshape([bsz, seqlen, self.n_local_kv_heads, self.head_dim]) + xq = self.q_norm(xq) + xk = self.k_norm(xk) - xq, xk = Attention.apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) - xq, xk = xq.cast(dtype), xk.cast(dtype) - else: - qkv_out = paddle.concat([xq, xk, xv], axis=-1) - import paddlemix - xq, xk, xv = paddlemix.triton_ops.fused_rotary_emb(qkv_out, self.q_norm.weight, self.q_norm.bias, - self.k_norm.weight, self.k_norm.bias, freqs_cis) + xq = xq.reshape([bsz, seqlen, self.n_local_heads, self.head_dim]) + xk = xk.reshape([bsz, seqlen, self.n_local_kv_heads, self.head_dim]) + xv = xv.reshape([bsz, seqlen, self.n_local_kv_heads, self.head_dim]) + + xq, xk = Attention.apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + xq, xk = xq.cast(dtype), xk.cast(dtype) if dtype in [paddle.float16, paddle.bfloat16]: output, _ = flash_attention( @@ -261,7 +255,7 @@ def forward(self, x, freqs_cis): class FeedForward(nn.Layer): - def __init__(self, dim, hidden_dim, multiple_of=256, ffn_dim_multiplier=None, optimize_inference_for_ditllama=False): + def __init__(self, dim, hidden_dim, multiple_of=256, ffn_dim_multiplier=None): """ Initialize the FeedForward module. @@ -287,81 +281,15 @@ def __init__(self, dim, hidden_dim, multiple_of=256, ffn_dim_multiplier=None, op hidden_dim = int(ffn_dim_multiplier * hidden_dim) hidden_dim = int(multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)) + self.w1 = nn.Linear(dim, hidden_dim, bias_attr=False) self.w2 = nn.Linear(hidden_dim, dim, bias_attr=False) - if not optimize_inference_for_ditllama: - self.w1 = nn.Linear(dim, hidden_dim, bias_attr=False) - self.w3 = nn.Linear(dim, hidden_dim, bias_attr=False) - else: - self.w13 = nn.Linear(dim, hidden_dim * 2, bias_attr=False) - self.optimize_inference_for_ditllama = optimize_inference_for_ditllama - - def compute_activation(self, - ffn1_out, - bias=None, - dequant_scales=None, - shift=None, - smooth=None, - act_method="swiglu", - compute_dtype="default", - quant_scale=-1, - quant_round_type=0, - quant_max_bound=0, - quant_min_bound=0): - if in_dynamic_mode(): - out = paddle._C_ops.fused_bias_act( - ffn1_out, - bias, - dequant_scales, - shift, - smooth, - act_method, - compute_dtype, - quant_scale, - quant_round_type, - quant_max_bound, - quant_min_bound - ) - return out - - helper = LayerHelper("fused_bias_act") - out = helper.create_variable_for_type_inference(dtype=ffn1_out.dtype) - inputs = {} - inputs["x"] = ffn1_out - if bias is not None: - inputs["bias"] = bias - if dequant_scales is not None: - inputs["dequant_scales"] = dequant_scales - if shift is not None: - inputs["shift"] = shift - if smooth is not None: - inputs["smooth"] = smooth - attrs = { - "act_method": act_method, - "compute_dtype": compute_dtype, - "quant_scale": quant_scale, - "quant_round_type": quant_round_type, - "quant_max_bound": quant_max_bound, - "quant_min_bound": quant_min_bound, - } - helper.append_op( - type="fused_bias_act", - inputs=inputs, - outputs={"out": out}, - attrs=attrs, - ) - return out - + self.w3 = nn.Linear(dim, hidden_dim, bias_attr=False) + def forward(self, x): - if not self.optimize_inference_for_ditllama: - xw1 = F.silu(self.w1(x)) - xw3 = self.w3(x) - output = self.w2(xw1 * xw3) - return output - else: - ffn1_out = self.w13(x) - ffn1_out = self.compute_activation(ffn1_out) - ffn2_out = self.w2(ffn1_out) - return ffn2_out + xw1 = F.silu(self.w1(x)) + xw3 = self.w3(x) + output = self.w2(xw1 * xw3) + return output class TransformerBlock(nn.Layer): @@ -377,7 +305,6 @@ def __init__( norm_eps: float, qk_norm: bool, fused_attn: bool, - optimize_inference_for_ditllama=False, ) -> None: """ Initialize a TransformerBlock. @@ -412,11 +339,10 @@ def __init__( super().__init__() self.dim = dim self.head_dim = dim // n_heads - self.attention = Attention(dim, n_heads, n_kv_heads, qk_norm, fused_attn, optimize_inference_for_ditllama=optimize_inference_for_ditllama) + self.attention = Attention(dim, n_heads, n_kv_heads, qk_norm, fused_attn) mlp_hidden_dim = int(dim * mlp_ratio) self.feed_forward = FeedForward( - dim=dim, hidden_dim=mlp_hidden_dim, multiple_of=multiple_of, ffn_dim_multiplier=ffn_dim_multiplier, - optimize_inference_for_ditllama=optimize_inference_for_ditllama + dim=dim, hidden_dim=mlp_hidden_dim, multiple_of=multiple_of, ffn_dim_multiplier=ffn_dim_multiplier ) self.layer_id = layer_id self.attention_norm = nn.LayerNorm(dim, epsilon=norm_eps, bias_attr=False) @@ -426,8 +352,6 @@ def __init__( nn.Silu(), nn.Linear(min(dim, 1024), 6 * dim), ) - self.optimize_inference_for_ditllama = optimize_inference_for_ditllama - self.norm_eps = norm_eps def forward(self, x, freqs_cis, adaln_input=None): """ @@ -448,18 +372,10 @@ def forward(self, x, freqs_cis, adaln_input=None): shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk( 6, axis=1 ) - if not self.optimize_inference_for_ditllama: - h = x + gate_msa.unsqueeze(1) * self.attention( - modulate(self.attention_norm(x), shift_msa, scale_msa), freqs_cis - ) - out = h + gate_mlp.unsqueeze(1) * self.feed_forward(modulate(self.ffn_norm(h), shift_mlp, scale_mlp)) - else: - import paddlemix - attention_out = self.attention(paddlemix.triton_ops.adaptive_layer_norm(x, scale_msa, shift_msa, - weight=self.attention_norm.weight, epsilon=self.norm_eps), freqs_cis) - residual_out, adaLN_out = paddlemix.triton_ops.fused_adaLN_scale_residual(x, attention_out, gate_msa, scale_mlp, shift_mlp, - weight=self.ffn_norm.weight, epsilon=self.norm_eps) - out = residual_out + gate_mlp.unsqueeze(1) * self.feed_forward(adaLN_out) + h = x + gate_msa.unsqueeze(1) * self.attention( + modulate(self.attention_norm(x), shift_msa, scale_msa), freqs_cis + ) + out = h + gate_mlp.unsqueeze(1) * self.feed_forward(modulate(self.ffn_norm(h), shift_mlp, scale_mlp)) else: h = x + self.attention(self.attention_norm(x), freqs_cis) out = h + self.feed_forward(self.ffn_norm(h)) @@ -529,8 +445,6 @@ def __init__( self.t_embedder = TimestepEmbedder(min(dim, 1024)) self.y_embedder = LabelEmbedding(num_classes, min(dim, 1024), class_dropout_prob) - self.optimize_inference_for_ditllama = True if os.getenv('optimize_inference_for_ditllama') else False - # 2. Define transformers blocks self.layers = nn.LayerList( [ @@ -545,7 +459,6 @@ def __init__( norm_eps=norm_eps, qk_norm=qk_norm, fused_attn=self.fused_attn, - optimize_inference_for_ditllama=self.optimize_inference_for_ditllama, ) for idx in range(num_layers) ] @@ -554,6 +467,13 @@ def __init__( # 3. Define output layers self.final_layer = FinalLayer(dim, patch_size, self.out_channels) self.freqs_cis = self.precompute_freqs_cis(dim // num_attention_heads, 4096) + + self.Inference_Optimize = bool(os.getenv('Inference_Optimize')) + if self.Inference_Optimize: + self.simplified_dit_llama = SimplifiedDiTLLaMA2DModel(num_layers, dim, num_attention_heads, + multiple_of, mlp_ratio, norm_eps) + del self.layers + def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): @@ -615,7 +535,7 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): t = paddle.arange(end=end) input_0, vec2_0 = TypePromote(t, freqs) freqs = paddle.outer(input_0, vec2_0).cast("float32") - if os.getenv('optimize_inference_for_ditllama'): + if bool(os.getenv('Inference_Optimize')): freqs_cis = paddle.stack([ paddle.cos(freqs), -paddle.sin(freqs), @@ -628,18 +548,6 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): ) return freqs_cis - @paddle.incubate.jit.inference(cache_static_model=False, - enable_new_ir=True, - exp_enable_use_cutlass=True,) - def transformer_blocks(self, x, adaln_input): - for i, layer in enumerate(self.layers): - x = layer( - x, - self.freqs_cis[: x.shape[1]], - adaln_input, - ) - return x - def forward( self, hidden_states: paddle.Tensor, @@ -665,7 +573,9 @@ def forward( adaln_input = t + y # 2. Blocks - if not self.optimize_inference_for_ditllama: + if self.Inference_Optimize: + x = self.simplified_dit_llama(x, self.freqs_cis[: x.shape[1]], adaln_input) + else: for i, layer in enumerate(self.layers): if self.gradient_checkpointing: x = paddle.distributed.fleet.utils.recompute(layer, x, self.freqs_cis[: x.shape[1]], adaln_input) @@ -675,9 +585,6 @@ def forward( self.freqs_cis[: x.shape[1]], adaln_input, ) - else: - self.freqs_cis = paddle.expand(self.freqs_cis, [-1, self.num_attention_heads, -1, -1]) - x = self.transformer_blocks(x, adaln_input) # 3. Output hidden_states = self.final_layer(x, adaln_input) @@ -687,14 +594,35 @@ def forward( return (output,) return Transformer2DModelOutput(sample=output) - + @classmethod def custom_modify_weight(cls, state_dict): - if os.getenv('optimize_inference_for_ditllama'): - for key in list(state_dict.keys()): - if 'feed_forward.w1.weight' in key: - w1 = state_dict.pop(key) - w3_key = key.replace('w1', 'w3') - w3 = state_dict.pop(w3_key) - w13 = paddle.concat([w1, w3], axis=1) - state_dict[key.replace('w1', 'w13')] = w13 + if os.getenv('Inference_Optimize'): + map_from_my_dit = {} + for i in range(32): + map_from_my_dit[f'simplified_dit_llama.adaLN_modulations.{i}.weight'] = f'layers.{i}.adaLN_modulation.1.weight' + map_from_my_dit[f'simplified_dit_llama.adaLN_modulations.{i}.bias'] = f'layers.{i}.adaLN_modulation.1.bias' + + map_from_my_dit[f'simplified_dit_llama.attention_norms.{i}.weight'] = f'layers.{i}.attention_norm.weight' + + map_from_my_dit[f'simplified_dit_llama.wqs.{i}.weight'] = f'layers.{i}.attention.wq.weight' + map_from_my_dit[f'simplified_dit_llama.wks.{i}.weight'] = f'layers.{i}.attention.wk.weight' + map_from_my_dit[f'simplified_dit_llama.wvs.{i}.weight'] = f'layers.{i}.attention.wv.weight' + map_from_my_dit[f'simplified_dit_llama.wos.{i}.weight'] = f'layers.{i}.attention.wo.weight' + + map_from_my_dit[f'simplified_dit_llama.q_norms.{i}.weight'] = f'layers.{i}.attention.q_norm.weight' + map_from_my_dit[f'simplified_dit_llama.q_norms.{i}.bias'] = f'layers.{i}.attention.q_norm.bias' + map_from_my_dit[f'simplified_dit_llama.k_norms.{i}.weight'] = f'layers.{i}.attention.k_norm.weight' + map_from_my_dit[f'simplified_dit_llama.k_norms.{i}.bias'] = f'layers.{i}.attention.k_norm.bias' + + map_from_my_dit[f'simplified_dit_llama.ffn_norms.{i}.weight'] = f'layers.{i}.ffn_norm.weight' + map_from_my_dit[f'simplified_dit_llama.w2s.{i}.weight'] = f'layers.{i}.feed_forward.w2.weight' + + + for key in map_from_my_dit.keys(): + state_dict[key] = paddle.assign(state_dict[map_from_my_dit[key]]) + + for i in range(32): + w1 = state_dict[f'layers.{i}.feed_forward.w1.weight'] + w3 = state_dict[f'layers.{i}.feed_forward.w3.weight'] + state_dict[f'simplified_dit_llama.w13s.{i}.weight'] = paddle.concat([w1, w3], axis=1) diff --git a/ppdiffusers/ppdiffusers/models/modeling_utils.py b/ppdiffusers/ppdiffusers/models/modeling_utils.py index ea9b74253..3bcb620a6 100644 --- a/ppdiffusers/ppdiffusers/models/modeling_utils.py +++ b/ppdiffusers/ppdiffusers/models/modeling_utils.py @@ -1053,7 +1053,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P @classmethod def custom_modify_weight(cls, state_dict): pass - + @classmethod def _load_pretrained_model( cls, From f2966f7ee444c110b43669d4b8cca87fe3edb2ca Mon Sep 17 00:00:00 2001 From: YKTian-x2b <2084984251@qq.com> Date: Wed, 7 Aug 2024 09:56:54 +0000 Subject: [PATCH 15/21] new net --- .../models/simplified_dit_llama.py | 127 ++++++++++++++++++ 1 file changed, 127 insertions(+) create mode 100644 ppdiffusers/ppdiffusers/models/simplified_dit_llama.py diff --git a/ppdiffusers/ppdiffusers/models/simplified_dit_llama.py b/ppdiffusers/ppdiffusers/models/simplified_dit_llama.py new file mode 100644 index 000000000..908715f2d --- /dev/null +++ b/ppdiffusers/ppdiffusers/models/simplified_dit_llama.py @@ -0,0 +1,127 @@ +import paddle +from paddle import nn +import paddle.nn.functional as F +from paddle.nn.functional.flash_attention import flash_attention +from paddle.framework import LayerHelper, in_dynamic_mode + +class SimplifiedDiTLLaMA2DModel(nn.Layer): + def __init__(self, + num_layers: int, + dim: int, + n_heads: int, + multiple_of: int, + mlp_ratio: float, + norm_eps: float): + super().__init__() + self.num_layers = num_layers + self.dim = dim + self.n_heads = n_heads + self.head_dim = dim // n_heads + self.norm_eps = norm_eps + + self.adaLN_modulations = nn.LayerList([nn.Linear(min(dim, 1024), 6 * dim) for i in range(num_layers)]) + + self.attention_norms = nn.LayerList([nn.LayerNorm(dim, epsilon=norm_eps, bias_attr=False) for i in range(num_layers)]) + + self.wqs = nn.LayerList([nn.Linear(dim, n_heads * self.head_dim, bias_attr=False) for i in range(num_layers)]) + self.wks = nn.LayerList([nn.Linear(dim, n_heads * self.head_dim, bias_attr=False) for i in range(num_layers)]) + self.wvs = nn.LayerList([nn.Linear(dim, n_heads * self.head_dim, bias_attr=False) for i in range(num_layers)]) + self.wos = nn.LayerList([nn.Linear(n_heads * self.head_dim, dim, bias_attr=False) for i in range(num_layers)]) + + self.q_norms = nn.LayerList([nn.LayerNorm(n_heads * self.head_dim) for i in range(num_layers)]) + self.k_norms = nn.LayerList([nn.LayerNorm(n_heads * self.head_dim) for i in range(num_layers)]) + + self.ffn_norms = nn.LayerList([nn.LayerNorm(dim, epsilon=norm_eps, bias_attr=False) for i in range(num_layers)]) + + hidden_dim = int(dim * mlp_ratio) + hidden_dim = int(2 * hidden_dim / 3) + hidden_dim = int(multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)) + self.w13s = nn.LayerList([nn.Linear(dim, hidden_dim * 2, bias_attr=False) for i in range(num_layers)]) + self.w2s = nn.LayerList([nn.Linear(hidden_dim, dim, bias_attr=False) for i in range(num_layers)]) + + def compute_activation(self, + ffn1_out, + bias=None, + dequant_scales=None, + shift=None, + smooth=None, + act_method="swiglu", + compute_dtype="default", + quant_scale=-1, + quant_round_type=0, + quant_max_bound=0, + quant_min_bound=0): + if in_dynamic_mode(): + out = paddle._C_ops.fused_bias_act( + ffn1_out, + bias, + dequant_scales, + shift, + smooth, + act_method, + compute_dtype, + quant_scale, + quant_round_type, + quant_max_bound, + quant_min_bound + ) + return out + + helper = LayerHelper("fused_bias_act") + out = helper.create_variable_for_type_inference(dtype=ffn1_out.dtype) + inputs = {} + inputs["x"] = ffn1_out + if bias is not None: + inputs["bias"] = bias + if dequant_scales is not None: + inputs["dequant_scales"] = dequant_scales + if shift is not None: + inputs["shift"] = shift + if smooth is not None: + inputs["smooth"] = smooth + attrs = { + "act_method": act_method, + "compute_dtype": compute_dtype, + "quant_scale": quant_scale, + "quant_round_type": quant_round_type, + "quant_max_bound": quant_max_bound, + "quant_min_bound": quant_min_bound, + } + helper.append_op( + type="fused_bias_act", + inputs=inputs, + outputs={"out": out}, + attrs=attrs, + ) + return out + + @paddle.incubate.jit.inference(cache_static_model=False, + enable_new_ir=True, + exp_enable_use_cutlass=True,) + def forward(self, x, freqs_cis, adaln_input): + freqs_cis = paddle.expand(freqs_cis, [-1, self.n_heads, -1, -1]) + adaln_input = F.silu(adaln_input) + for i in range(self.num_layers): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulations[i](adaln_input).chunk(6, axis=1) + ## AdaLN + import paddlemix + attn_in = paddlemix.triton_ops.adaptive_layer_norm(x, scale_msa, shift_msa, + weight=self.attention_norms[i].weight, epsilon=self.norm_eps) + ## Attention + xq, xk, xv = self.wqs[i](attn_in), self.wks[i](attn_in), self.wvs[i](attn_in) + qkv_out = paddle.concat([xq, xk, xv], axis=-1) + xq, xk, xv = paddlemix.triton_ops.fused_rotary_emb(qkv_out, self.q_norms[i].weight, self.q_norms[i].bias, + self.k_norms[i].weight, self.k_norms[i].bias, freqs_cis) + attn_out, _ = flash_attention(xq, xk, xv, dropout=0.0, causal=False, return_softmax=False) + attn_out = attn_out.flatten(start_axis=-2) + attn_out = self.wos[i](attn_out) + ## Fused_adaLN + resi_out, adaLN_out = paddlemix.triton_ops.fused_adaLN_scale_residual(x, attn_out, gate_msa, scale_mlp, shift_mlp, + weight=self.ffn_norms[i].weight, epsilon=self.norm_eps) + ## FFN + ffn_out = self.w13s[i](adaLN_out) + ffn_out = self.compute_activation(ffn_out) + ffn_out = self.w2s[i](ffn_out) + x = resi_out + gate_mlp.unsqueeze(1) * ffn_out + return x + \ No newline at end of file From 437cbbbc40884ef6614ef22c2f567a16c980f90f Mon Sep 17 00:00:00 2001 From: YKTian-x2b <2084984251@qq.com> Date: Wed, 7 Aug 2024 09:57:44 +0000 Subject: [PATCH 16/21] little mod --- ppdiffusers/ppdiffusers/models/simplified_dit_llama.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ppdiffusers/ppdiffusers/models/simplified_dit_llama.py b/ppdiffusers/ppdiffusers/models/simplified_dit_llama.py index 908715f2d..6eeda2504 100644 --- a/ppdiffusers/ppdiffusers/models/simplified_dit_llama.py +++ b/ppdiffusers/ppdiffusers/models/simplified_dit_llama.py @@ -124,4 +124,3 @@ def forward(self, x, freqs_cis, adaln_input): ffn_out = self.w2s[i](ffn_out) x = resi_out + gate_mlp.unsqueeze(1) * ffn_out return x - \ No newline at end of file From fb4d478bfd26e0d740dee6f9c91d50e0f46e098f Mon Sep 17 00:00:00 2001 From: YKTian-x2b <2084984251@qq.com> Date: Wed, 7 Aug 2024 11:21:04 +0000 Subject: [PATCH 17/21] pre-commit --- .../models/simplified_dit_llama.py | 184 ++++++++++-------- 1 file changed, 106 insertions(+), 78 deletions(-) diff --git a/ppdiffusers/ppdiffusers/models/simplified_dit_llama.py b/ppdiffusers/ppdiffusers/models/simplified_dit_llama.py index 6eeda2504..82d6b8aab 100644 --- a/ppdiffusers/ppdiffusers/models/simplified_dit_llama.py +++ b/ppdiffusers/ppdiffusers/models/simplified_dit_llama.py @@ -1,56 +1,71 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import paddle -from paddle import nn import paddle.nn.functional as F -from paddle.nn.functional.flash_attention import flash_attention +from paddle import nn from paddle.framework import LayerHelper, in_dynamic_mode +from paddle.nn.functional.flash_attention import flash_attention + class SimplifiedDiTLLaMA2DModel(nn.Layer): - def __init__(self, - num_layers: int, - dim: int, - n_heads: int, - multiple_of: int, - mlp_ratio: float, - norm_eps: float): - super().__init__() - self.num_layers = num_layers - self.dim = dim - self.n_heads = n_heads - self.head_dim = dim // n_heads - self.norm_eps = norm_eps - - self.adaLN_modulations = nn.LayerList([nn.Linear(min(dim, 1024), 6 * dim) for i in range(num_layers)]) - - self.attention_norms = nn.LayerList([nn.LayerNorm(dim, epsilon=norm_eps, bias_attr=False) for i in range(num_layers)]) - - self.wqs = nn.LayerList([nn.Linear(dim, n_heads * self.head_dim, bias_attr=False) for i in range(num_layers)]) - self.wks = nn.LayerList([nn.Linear(dim, n_heads * self.head_dim, bias_attr=False) for i in range(num_layers)]) - self.wvs = nn.LayerList([nn.Linear(dim, n_heads * self.head_dim, bias_attr=False) for i in range(num_layers)]) - self.wos = nn.LayerList([nn.Linear(n_heads * self.head_dim, dim, bias_attr=False) for i in range(num_layers)]) - - self.q_norms = nn.LayerList([nn.LayerNorm(n_heads * self.head_dim) for i in range(num_layers)]) - self.k_norms = nn.LayerList([nn.LayerNorm(n_heads * self.head_dim) for i in range(num_layers)]) - - self.ffn_norms = nn.LayerList([nn.LayerNorm(dim, epsilon=norm_eps, bias_attr=False) for i in range(num_layers)]) - - hidden_dim = int(dim * mlp_ratio) - hidden_dim = int(2 * hidden_dim / 3) - hidden_dim = int(multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)) - self.w13s = nn.LayerList([nn.Linear(dim, hidden_dim * 2, bias_attr=False) for i in range(num_layers)]) - self.w2s = nn.LayerList([nn.Linear(hidden_dim, dim, bias_attr=False) for i in range(num_layers)]) + def __init__(self, num_layers: int, dim: int, n_heads: int, multiple_of: int, mlp_ratio: float, norm_eps: float): + super().__init__() + self.num_layers = num_layers + self.dim = dim + self.n_heads = n_heads + self.head_dim = dim // n_heads + self.norm_eps = norm_eps + + self.adaLN_modulations = nn.LayerList([nn.Linear(min(dim, 1024), 6 * dim) for i in range(num_layers)]) + + self.attention_norms = nn.LayerList( + [nn.LayerNorm(dim, epsilon=norm_eps, bias_attr=False) for i in range(num_layers)] + ) + + self.wqs = nn.LayerList([nn.Linear(dim, n_heads * self.head_dim, bias_attr=False) for i in range(num_layers)]) + self.wks = nn.LayerList([nn.Linear(dim, n_heads * self.head_dim, bias_attr=False) for i in range(num_layers)]) + self.wvs = nn.LayerList([nn.Linear(dim, n_heads * self.head_dim, bias_attr=False) for i in range(num_layers)]) + self.wos = nn.LayerList([nn.Linear(n_heads * self.head_dim, dim, bias_attr=False) for i in range(num_layers)]) - def compute_activation(self, - ffn1_out, - bias=None, - dequant_scales=None, - shift=None, - smooth=None, - act_method="swiglu", - compute_dtype="default", - quant_scale=-1, - quant_round_type=0, - quant_max_bound=0, - quant_min_bound=0): + self.q_norms = nn.LayerList([nn.LayerNorm(n_heads * self.head_dim) for i in range(num_layers)]) + self.k_norms = nn.LayerList([nn.LayerNorm(n_heads * self.head_dim) for i in range(num_layers)]) + + self.ffn_norms = nn.LayerList( + [nn.LayerNorm(dim, epsilon=norm_eps, bias_attr=False) for i in range(num_layers)] + ) + + hidden_dim = int(dim * mlp_ratio) + hidden_dim = int(2 * hidden_dim / 3) + hidden_dim = int(multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)) + self.w13s = nn.LayerList([nn.Linear(dim, hidden_dim * 2, bias_attr=False) for i in range(num_layers)]) + self.w2s = nn.LayerList([nn.Linear(hidden_dim, dim, bias_attr=False) for i in range(num_layers)]) + + def compute_activation( + self, + ffn1_out, + bias=None, + dequant_scales=None, + shift=None, + smooth=None, + act_method="swiglu", + compute_dtype="default", + quant_scale=-1, + quant_round_type=0, + quant_max_bound=0, + quant_min_bound=0, + ): if in_dynamic_mode(): out = paddle._C_ops.fused_bias_act( ffn1_out, @@ -63,10 +78,10 @@ def compute_activation(self, quant_scale, quant_round_type, quant_max_bound, - quant_min_bound + quant_min_bound, ) return out - + helper = LayerHelper("fused_bias_act") out = helper.create_variable_for_type_inference(dtype=ffn1_out.dtype) inputs = {} @@ -95,32 +110,45 @@ def compute_activation(self, ) return out - @paddle.incubate.jit.inference(cache_static_model=False, - enable_new_ir=True, - exp_enable_use_cutlass=True,) - def forward(self, x, freqs_cis, adaln_input): - freqs_cis = paddle.expand(freqs_cis, [-1, self.n_heads, -1, -1]) - adaln_input = F.silu(adaln_input) - for i in range(self.num_layers): - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulations[i](adaln_input).chunk(6, axis=1) - ## AdaLN - import paddlemix - attn_in = paddlemix.triton_ops.adaptive_layer_norm(x, scale_msa, shift_msa, - weight=self.attention_norms[i].weight, epsilon=self.norm_eps) - ## Attention - xq, xk, xv = self.wqs[i](attn_in), self.wks[i](attn_in), self.wvs[i](attn_in) - qkv_out = paddle.concat([xq, xk, xv], axis=-1) - xq, xk, xv = paddlemix.triton_ops.fused_rotary_emb(qkv_out, self.q_norms[i].weight, self.q_norms[i].bias, - self.k_norms[i].weight, self.k_norms[i].bias, freqs_cis) - attn_out, _ = flash_attention(xq, xk, xv, dropout=0.0, causal=False, return_softmax=False) - attn_out = attn_out.flatten(start_axis=-2) - attn_out = self.wos[i](attn_out) - ## Fused_adaLN - resi_out, adaLN_out = paddlemix.triton_ops.fused_adaLN_scale_residual(x, attn_out, gate_msa, scale_mlp, shift_mlp, - weight=self.ffn_norms[i].weight, epsilon=self.norm_eps) - ## FFN - ffn_out = self.w13s[i](adaLN_out) - ffn_out = self.compute_activation(ffn_out) - ffn_out = self.w2s[i](ffn_out) - x = resi_out + gate_mlp.unsqueeze(1) * ffn_out - return x + @paddle.incubate.jit.inference( + cache_static_model=False, + enable_new_ir=True, + exp_enable_use_cutlass=True, + ) + def forward(self, x, freqs_cis, adaln_input): + freqs_cis = paddle.expand(freqs_cis, [-1, self.n_heads, -1, -1]) + adaln_input = F.silu(adaln_input) + for i in range(self.num_layers): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulations[i]( + adaln_input + ).chunk(6, axis=1) + # AdaLN + import paddlemix + + attn_in = paddlemix.triton_ops.adaptive_layer_norm( + x, scale_msa, shift_msa, weight=self.attention_norms[i].weight, epsilon=self.norm_eps + ) + # Attention + xq, xk, xv = self.wqs[i](attn_in), self.wks[i](attn_in), self.wvs[i](attn_in) + qkv_out = paddle.concat([xq, xk, xv], axis=-1) + xq, xk, xv = paddlemix.triton_ops.fused_rotary_emb( + qkv_out, + self.q_norms[i].weight, + self.q_norms[i].bias, + self.k_norms[i].weight, + self.k_norms[i].bias, + freqs_cis, + ) + attn_out, _ = flash_attention(xq, xk, xv, dropout=0.0, causal=False, return_softmax=False) + attn_out = attn_out.flatten(start_axis=-2) + attn_out = self.wos[i](attn_out) + # Fused_adaLN + resi_out, adaLN_out = paddlemix.triton_ops.fused_adaLN_scale_residual( + x, attn_out, gate_msa, scale_mlp, shift_mlp, weight=self.ffn_norms[i].weight, epsilon=self.norm_eps + ) + # FFN + ffn_out = self.w13s[i](adaLN_out) + ffn_out = self.compute_activation(ffn_out) + ffn_out = self.w2s[i](ffn_out) + x = resi_out + gate_mlp.unsqueeze(1) * ffn_out + return x From e63831300d049ee98734eb79f27dc311a6f7155f Mon Sep 17 00:00:00 2001 From: YKTian-x2b <2084984251@qq.com> Date: Thu, 8 Aug 2024 06:42:41 +0000 Subject: [PATCH 18/21] update largedit --- paddlemix/triton_ops/triton_ops.py | 183 ++++++++++++------ ...nditional_image_generation-large_dit_3b.py | 2 +- ...nditional_image_generation-large_dit_7b.py | 2 +- ppdiffusers/ppdiffusers/models/dit_llama.py | 83 ++++---- .../models/simplified_dit_llama.py | 44 +++-- 5 files changed, 203 insertions(+), 111 deletions(-) diff --git a/paddlemix/triton_ops/triton_ops.py b/paddlemix/triton_ops/triton_ops.py index fff3a5947..f7b7bf7f9 100644 --- a/paddlemix/triton_ops/triton_ops.py +++ b/paddlemix/triton_ops/triton_ops.py @@ -633,7 +633,6 @@ def weight_only_int8(x, qweight, scales, bias=None, bool_trans_w=True): return out -########################### adaptive layer norm ############################### fused_adaLN_scale_residual_template = ( """ @@ -1317,26 +1316,27 @@ def rms_norm(x, weight=None, bias=None, epsilon=1e-05): fused_rotary_emb_template = ( """ std::vector ${op_name}_func( - const paddle::Tensor &x, + const paddle::Tensor &q, + const paddle::Tensor &k, const paddle::Tensor &q_norm_weight, const paddle::Tensor &q_norm_bias, const paddle::Tensor &k_norm_weight, const paddle::Tensor &k_norm_bias, const paddle::Tensor &freqs_cis, float epsilon) { - int BSZ = x.dims()[0]; - int SEQ_LEN = x.dims()[1]; + int BSZ = q.dims()[0]; + int SEQ_LEN = q.dims()[1]; + int DIM = q.dims()[2]; int HEAD_DIM = freqs_cis.dims()[2]; - int DIM = q_norm_weight.dims()[0]; + int NUM_HEAD = DIM / HEAD_DIM; int M = BSZ * SEQ_LEN; - int DIM_concat = x.dims()[2]; - auto q_out = paddle::empty({BSZ, SEQ_LEN, NUM_HEAD, HEAD_DIM}, x.dtype(), x.place()); - auto k_out = paddle::empty({BSZ, SEQ_LEN, NUM_HEAD, HEAD_DIM}, x.dtype(), x.place()); - auto v_out = paddle::empty({BSZ, SEQ_LEN, NUM_HEAD, HEAD_DIM}, x.dtype(), x.place()); + auto q_out = paddle::empty({BSZ, SEQ_LEN, NUM_HEAD, HEAD_DIM}, q.dtype(), q.place()); + auto k_out = paddle::empty({BSZ, SEQ_LEN, NUM_HEAD, HEAD_DIM}, k.dtype(), k.place()); - auto x_ptr = get_tensor_ptr(x); + auto q_ptr = get_tensor_ptr(q); + auto k_ptr = get_tensor_ptr(k); auto q_norm_weight_ptr = get_tensor_ptr(q_norm_weight); auto q_norm_bias_ptr = get_tensor_ptr(q_norm_bias); auto k_norm_weight_ptr = get_tensor_ptr(k_norm_weight); @@ -1344,13 +1344,13 @@ def rms_norm(x, weight=None, bias=None, epsilon=1e-05): auto freqs_cis_ptr = get_tensor_ptr(freqs_cis); auto q_out_ptr = get_tensor_ptr(q_out); auto k_out_ptr = get_tensor_ptr(k_out); - auto v_out_ptr = get_tensor_ptr(v_out); + const paddle::Tensor &x = q; auto run_stream = q_out.stream(); """ + tune_and_invoke_part + """ - return {q_out, k_out, v_out}; + return {q_out, k_out}; } std::vector> ${op_name}_InferShape( @@ -1359,23 +1359,24 @@ def rms_norm(x, weight=None, bias=None, epsilon=1e-05): const std::vector& C_shape, const std::vector& D_shape, const std::vector& E_shape, - const std::vector& F_shape) { + const std::vector& F_shape, + const std::vector& G_shape) { int BSZ = A_shape[0]; int SEQ_LEN = A_shape[1]; - int HEAD_DIM = F_shape[2]; - int DIM = B_shape[0]; + int DIM = A_shape[2]; + int HEAD_DIM = G_shape[2]; int NUM_HEAD = DIM / HEAD_DIM; std::vector res_shape = {BSZ, SEQ_LEN, NUM_HEAD, HEAD_DIM}; - return {res_shape, res_shape, res_shape}; + return {res_shape, res_shape}; } std::vector ${op_name}_InferDtype(const paddle::DataType& A_dtype) { - return {A_dtype, A_dtype, A_dtype}; + return {A_dtype, A_dtype}; } PD_BUILD_OP(${op_name}) - .Inputs({"x", "q_norm_weight", "q_norm_bias", "k_norm_weight", "k_norm_bias", "freqs_cis"}) - .Outputs({"q_out", "k_out", "v_out"}) + .Inputs({"q", "k", "q_norm_weight", "q_norm_bias", "k_norm_weight", "k_norm_bias", "freqs_cis"}) + .Outputs({"q_out", "k_out"}) .SetKernelFn(PD_KERNEL(${op_name}_func)) .Attrs({"epsilon: float"}) .SetInferDtypeFn(PD_INFER_DTYPE(${op_name}_InferDtype)) @@ -1389,29 +1390,28 @@ def rms_norm(x, weight=None, bias=None, epsilon=1e-05): key=["M"], ) def fused_rotary_emb_kernel( - x_ptr, # [BSZ, SEQ_LEN, DIM_concat] - q_out_ptr, - k_out_ptr, # [BSZ, SEQ_LEN, NUM_HEAD, HEAD_DIM, 2] - v_out_ptr, # [BSZ, SEQ_LEN, NUM_HEAD, HEAD_DIM] - q_norm_weight_ptr, + q_ptr, # [BSZ, SEQ_LEN, DIM] + k_ptr, + q_out_ptr, # [BSZ, SEQ_LEN, NUM_HEAD, HEAD_DIM] + k_out_ptr, + q_norm_weight_ptr, # [DIM] q_norm_bias_ptr, k_norm_weight_ptr, - k_norm_bias_ptr, # [DIM] - freqs_cis_ptr, # [1, seq_len, 1, head_dim, 2] + k_norm_bias_ptr, + freqs_cis_ptr, # [SEQ_LEN, 1, HEAD_DIM, 2] epsilon, SEQ_LEN, M, DIM, - DIM_concat, DIM_npo2: tl.constexpr, ): row = tl.program_id(axis=0) - x_ptr += row * DIM_concat + q_ptr += row * DIM + k_ptr += row * DIM offs = tl.arange(0, DIM_npo2) masks = offs < DIM - q_eles = tl.load(x_ptr + offs, mask=masks, other=0.0).to(tl.float32) - k_eles = tl.load(x_ptr + DIM + offs, mask=masks, other=0.0).to(tl.float32) - v_eles = tl.load(x_ptr + 2 * DIM + offs, mask=masks, other=0.0) + q_eles = tl.load(q_ptr + offs, mask=masks, other=0.0).to(tl.float32) + k_eles = tl.load(k_ptr + offs, mask=masks, other=0.0).to(tl.float32) # qk layernorm q_mean = tl.sum(q_eles, axis=0) / DIM @@ -1451,48 +1451,116 @@ def fused_rotary_emb_kernel( k_resi_hat = tl.reshape(k_resi_hat, (DIM_npo2, 2)) k_res = tl.sum(k_resi_hat * freqs_cis, axis=1) - out_offs = row * DIM + offs - tl.store(q_out_ptr + out_offs, q_res, mask=masks) - tl.store(k_out_ptr + out_offs, k_res, mask=masks) - tl.store(v_out_ptr + out_offs, v_eles, mask=masks) + tl.store(q_out_ptr + row * DIM + offs, q_res, mask=masks) + tl.store(k_out_ptr + row * DIM + offs, k_res, mask=masks) def fused_rotary_emb( - x, - q_norm_weight, + q, # [BSZ, SEQ_LEN, DIM] + k, + q_norm_weight, # [DIM] q_norm_bias, k_norm_weight, k_norm_bias, freqs_cis, epsilon=1e-5, ): + """ + Examples: + + import os + os.environ["CUDA_VISIBLE_DEVICES"] = "7" + import paddle + + def apply_rotary_emb(xq, xk, freqs_cis): + xq_ = paddle.as_complex(xq.cast("float32").reshape([*tuple(xq.shape)[:-1], -1, 2])) + xk_ = paddle.as_complex(xk.cast("float32").reshape([*tuple(xk.shape)[:-1], -1, 2])) + shape = [(d if i == 1 or i == xq_.ndim - 1 else 1) for i, d in enumerate(tuple(xq_.shape))] + freqs_cis = freqs_cis.reshape([*shape]) + xq_out = paddle.as_real(xq_ * freqs_cis).flatten(start_axis=3) + xk_out = paddle.as_real(xk_ * freqs_cis).flatten(start_axis=3) + return xq_out.cast(xq.dtype), xk_out.cast(xk.dtype) + + def get_freqs_cis(dim: int, end: int): + freqs = 1.0 / 10000.0 ** (paddle.arange(start=0, end=dim, step=2)[: dim // 2].cast("float32") / dim) + t = paddle.arange(end=end).cast("float32") + # [SEQ_LEN, HEAD_DIM//2] + freqs = paddle.outer(t, freqs).cast("float32") + # [SEQ_LEN, HEAD_DIM//2] + freqs_cis_ref = paddle.complex( + paddle.ones_like(freqs) * paddle.cos(freqs), paddle.ones_like(freqs) * paddle.sin(freqs) + ) + freqs_cis = paddle.stack([ + paddle.cos(freqs), + -paddle.sin(freqs), + paddle.sin(freqs), + paddle.cos(freqs)], axis=-1) + # [SEQ_LEN, HEAD_DIM, 2] + freqs_cis = freqs_cis.reshape([freqs_cis.shape[0], -1, 2]).unsqueeze(1) + return freqs_cis, freqs_cis_ref + + BSZ = 2 + SEQ_LEN = 64 + NUM_HEAD = 16 + HEAD_DIM = 72 + DIM = NUM_HEAD * HEAD_DIM + epsilon = 1e-5 + dtype_= "float16" + q = paddle.rand([BSZ, SEQ_LEN, DIM], dtype=dtype_) + k = paddle.rand([BSZ, SEQ_LEN, DIM], dtype=dtype_) + q_norm_weight = paddle.rand([DIM], dtype=dtype_) + k_norm_weight = paddle.rand([DIM], dtype=dtype_) + q_norm_bias = paddle.rand([DIM], dtype=dtype_) + k_norm_bias = paddle.rand([DIM], dtype=dtype_) + + freqs_cis, freqs_cis_ref = get_freqs_cis(HEAD_DIM, SEQ_LEN) + freqs_cis = paddle.expand(freqs_cis, [-1, NUM_HEAD, -1, -1]) + + # paddle + q_ln = paddle.nn.functional.layer_norm(q, [DIM], weight=q_norm_weight, bias=q_norm_bias, epsilon=epsilon) + k_ln = paddle.nn.functional.layer_norm(k, [DIM], weight=k_norm_weight, bias=k_norm_bias, epsilon=epsilon) + q_ln = q_ln.reshape([BSZ, SEQ_LEN, NUM_HEAD, HEAD_DIM]) + k_ln = k_ln.reshape([BSZ, SEQ_LEN, NUM_HEAD, HEAD_DIM]) + q_out_baseline, k_out_baseline= apply_rotary_emb(q_ln, k_ln, freqs_cis_ref) + q_out_baseline = q_out_baseline.reshape([BSZ, SEQ_LEN, NUM_HEAD, HEAD_DIM]) + k_out_baseline = k_out_baseline.reshape([BSZ, SEQ_LEN, NUM_HEAD, HEAD_DIM]) + + # triton + import paddlemix + q_out, k_out = paddlemix.triton_ops.fused_rotary_emb(q, k, q_norm_weight, q_norm_bias, k_norm_weight, k_norm_bias, freqs_cis, epsilon) + + print("Q allclose: ", paddle.allclose(q_out_baseline, q_out, rtol=1e-03, atol=1e-02).numpy()) + print("K allclose: ", paddle.allclose(k_out_baseline, k_out, rtol=1e-03, atol=1e-02).numpy()) + """ + assert q.shape == k.shape, "q and k should have the same shape" + assert len(q.shape) == 3, "q should be [BSZ, SEQ_LEN, DIM]" assert q_norm_weight is not None, "q_norm_weight should not be none" assert q_norm_bias is not None, "q_norm_bias should not be none" assert k_norm_weight is not None, "k_norm_weight should not be none" assert k_norm_bias is not None, "k_norm_bias should not be none" - DIM = q_norm_weight.shape[0] + assert ( + q_norm_weight.shape == q_norm_bias.shape == k_norm_weight.shape == k_norm_bias.shape + ), "q_norm_weight, q_norm_bias, k_norm_weight, k_norm_bias should have the same shape" + assert q.shape[-1] == q_norm_weight.shape[0], "q_norm_weight should be [DIM]" + + BSZ, SEQ_LEN, DIM = q.shape HEAD_DIM = freqs_cis.shape[-2] assert (DIM % HEAD_DIM) == 0, "dim should be divisible by head_dim" - DIM_concat = x.shape[-1] - assert (DIM * 3) == DIM_concat, "not support GQA, qkv num_head should be equal" - BSZ = x.shape[0] - SEQ_LEN = x.shape[1] NUM_HEAD = DIM // HEAD_DIM M = BSZ * SEQ_LEN DIM_npo2 = triton.next_power_of_2(DIM) - dtype_ = x.dtype + dtype_ = q.dtype # q_out_tensor = paddle.empty([BSZ, SEQ_LEN, NUM_HEAD, HEAD_DIM], dtype=dtype_) # k_out_tensor = paddle.empty([BSZ, SEQ_LEN, NUM_HEAD, HEAD_DIM], dtype=dtype_) - # v_out_tensor = paddle.empty([BSZ, SEQ_LEN, NUM_HEAD, HEAD_DIM], dtype=dtype_) # fused_rotary_emb_kernel[(M,)]( - # input_tensor, q_out_tensor, k_out_tensor, v_out_tensor, - # q_norm_weight, q_norm_bias, k_norm_weight, k_norm_bias, freqs_cis, epsilon, - # SEQ_LEN, M, DIM, DIM_concat, + # q, k, q_out_tensor, k_out_tensor, q_norm_weight, q_norm_bias, + # k_norm_weight, k_norm_bias, freqs_cis, epsilon, + # SEQ_LEN, M, DIM, # DIM_npo2, num_warps=4, # ) - # return q_out_tensor, k_out_tensor, v_out_tensor + # return q_out_tensor, k_out_tensor op_name = "triton_fused_rotary_emb" op_name += get_dtype_str(dtype_) @@ -1510,13 +1578,12 @@ def fused_rotary_emb( empty_dtype = dtype_ if dtype_ != paddle.bfloat16 else paddle.float16 q_out_tensor = paddle.empty([BSZ, SEQ_LEN, NUM_HEAD, HEAD_DIM], dtype=empty_dtype).astype(dtype_) k_out_tensor = paddle.empty([BSZ, SEQ_LEN, NUM_HEAD, HEAD_DIM], dtype=empty_dtype).astype(dtype_) - v_out_tensor = paddle.empty([BSZ, SEQ_LEN, NUM_HEAD, HEAD_DIM], dtype=empty_dtype).astype(dtype_) grid = ("M",) fused_rotary_emb_kernel[(op_name, grid, fused_rotary_emb_kernel_config)]( - x, + q, + k, q_out_tensor, k_out_tensor, - v_out_tensor, q_norm_weight, q_norm_bias, k_norm_weight, @@ -1526,7 +1593,6 @@ def fused_rotary_emb( SEQ_LEN, M, DIM, - DIM_concat, DIM_npo2, ) @@ -1534,7 +1600,8 @@ def fused_rotary_emb( print(f"== we are in dynamic mode, op_name: {op_name}") outs = _C_ops._run_custom_op( op_name, - x, + q, + k, q_norm_weight, q_norm_bias, k_norm_weight, @@ -1542,12 +1609,13 @@ def fused_rotary_emb( freqs_cis, epsilon, ) - return outs[0], outs[1], outs[2] + return outs[0], outs[1] else: print(f"== we are in dynamic to static mode, op_name: {op_name}") helper = LayerHelper(op_name, **locals()) inputs = { - "x": x, + "q": q, + "k": k, "q_norm_weight": q_norm_weight, "q_norm_bias": q_norm_bias, "k_norm_weight": k_norm_weight, @@ -1556,13 +1624,12 @@ def fused_rotary_emb( } q_out = helper.create_variable_for_type_inference(dtype=dtype_) k_out = helper.create_variable_for_type_inference(dtype=dtype_) - v_out = helper.create_variable_for_type_inference(dtype=dtype_) helper.append_op( type=op_name, inputs=inputs, attrs={ "epsilon": epsilon, }, - outputs={"q_out": q_out, "k_out": k_out, "v_out": v_out}, + outputs={"q_out": q_out, "k_out": k_out}, ) - return q_out, k_out, v_out + return q_out, k_out diff --git a/ppdiffusers/examples/inference/class_conditional_image_generation-large_dit_3b.py b/ppdiffusers/examples/inference/class_conditional_image_generation-large_dit_3b.py index 926f19fd4..f19bf228d 100644 --- a/ppdiffusers/examples/inference/class_conditional_image_generation-large_dit_3b.py +++ b/ppdiffusers/examples/inference/class_conditional_image_generation-large_dit_3b.py @@ -20,7 +20,7 @@ dtype = paddle.bfloat16 # If you want to turn off optimization, comment this code -os.environ['Inference_Optimize'] = "True" +os.environ["INFOPTIMIZE"] = "True" with paddle.LazyGuard(): pipe = DiTPipeline.from_pretrained("Alpha-VLLM/Large-DiT-3B-256", paddle_dtype=dtype) diff --git a/ppdiffusers/examples/inference/class_conditional_image_generation-large_dit_7b.py b/ppdiffusers/examples/inference/class_conditional_image_generation-large_dit_7b.py index 3c2edd7d8..0fe57394e 100644 --- a/ppdiffusers/examples/inference/class_conditional_image_generation-large_dit_7b.py +++ b/ppdiffusers/examples/inference/class_conditional_image_generation-large_dit_7b.py @@ -20,7 +20,7 @@ dtype = paddle.bfloat16 # If you want to turn off optimization, comment this code -os.environ['Inference_Optimize'] = "True" +os.environ["INFOPTIMIZE"] = "True" with paddle.LazyGuard(): pipe = DiTPipeline.from_pretrained("Alpha-VLLM/Large-DiT-7B-256", paddle_dtype=dtype) diff --git a/ppdiffusers/ppdiffusers/models/dit_llama.py b/ppdiffusers/ppdiffusers/models/dit_llama.py index d1bf3c5d2..84e4ce325 100644 --- a/ppdiffusers/ppdiffusers/models/dit_llama.py +++ b/ppdiffusers/ppdiffusers/models/dit_llama.py @@ -13,6 +13,7 @@ # limitations under the License. import math +import os from typing import Optional import paddle @@ -23,10 +24,9 @@ from ..configuration_utils import ConfigMixin, register_to_config from .embeddings import LabelEmbedding from .modeling_utils import ModelMixin +from .simplified_dit_llama import SimplifiedDiTLLaMA2DModel from .transformer_2d import Transformer2DModelOutput -from .simplified_dit_llama import SimplifiedDiTLLaMA2DModel -import os def TypePromote(x, y): TYPE_PROMOTE_DICT = { @@ -467,13 +467,13 @@ def __init__( # 3. Define output layers self.final_layer = FinalLayer(dim, patch_size, self.out_channels) self.freqs_cis = self.precompute_freqs_cis(dim // num_attention_heads, 4096) - - self.Inference_Optimize = bool(os.getenv('Inference_Optimize')) - if self.Inference_Optimize: - self.simplified_dit_llama = SimplifiedDiTLLaMA2DModel(num_layers, dim, num_attention_heads, - multiple_of, mlp_ratio, norm_eps) + + self.INFOPTIMIZE = os.getenv("INFOPTIMIZE") == "True" + if self.INFOPTIMIZE: + self.simplified_dit_llama = SimplifiedDiTLLaMA2DModel( + num_layers, dim, num_attention_heads, multiple_of, mlp_ratio, norm_eps + ) del self.layers - def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): @@ -535,12 +535,10 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): t = paddle.arange(end=end) input_0, vec2_0 = TypePromote(t, freqs) freqs = paddle.outer(input_0, vec2_0).cast("float32") - if bool(os.getenv('Inference_Optimize')): - freqs_cis = paddle.stack([ - paddle.cos(freqs), - -paddle.sin(freqs), - paddle.sin(freqs), - paddle.cos(freqs)], axis=-1) + if bool(os.getenv("INFOPTIMIZE")): + freqs_cis = paddle.stack( + [paddle.cos(freqs), -paddle.sin(freqs), paddle.sin(freqs), paddle.cos(freqs)], axis=-1 + ) freqs_cis = freqs_cis.reshape([freqs_cis.shape[0], -1, 2]).unsqueeze(1) else: freqs_cis = paddle.complex( @@ -573,7 +571,7 @@ def forward( adaln_input = t + y # 2. Blocks - if self.Inference_Optimize: + if self.INFOPTIMIZE: x = self.simplified_dit_llama(x, self.freqs_cis[: x.shape[1]], adaln_input) else: for i, layer in enumerate(self.layers): @@ -597,32 +595,37 @@ def forward( @classmethod def custom_modify_weight(cls, state_dict): - if os.getenv('Inference_Optimize'): + if os.getenv("INFOPTIMIZE"): map_from_my_dit = {} for i in range(32): - map_from_my_dit[f'simplified_dit_llama.adaLN_modulations.{i}.weight'] = f'layers.{i}.adaLN_modulation.1.weight' - map_from_my_dit[f'simplified_dit_llama.adaLN_modulations.{i}.bias'] = f'layers.{i}.adaLN_modulation.1.bias' - - map_from_my_dit[f'simplified_dit_llama.attention_norms.{i}.weight'] = f'layers.{i}.attention_norm.weight' - - map_from_my_dit[f'simplified_dit_llama.wqs.{i}.weight'] = f'layers.{i}.attention.wq.weight' - map_from_my_dit[f'simplified_dit_llama.wks.{i}.weight'] = f'layers.{i}.attention.wk.weight' - map_from_my_dit[f'simplified_dit_llama.wvs.{i}.weight'] = f'layers.{i}.attention.wv.weight' - map_from_my_dit[f'simplified_dit_llama.wos.{i}.weight'] = f'layers.{i}.attention.wo.weight' - - map_from_my_dit[f'simplified_dit_llama.q_norms.{i}.weight'] = f'layers.{i}.attention.q_norm.weight' - map_from_my_dit[f'simplified_dit_llama.q_norms.{i}.bias'] = f'layers.{i}.attention.q_norm.bias' - map_from_my_dit[f'simplified_dit_llama.k_norms.{i}.weight'] = f'layers.{i}.attention.k_norm.weight' - map_from_my_dit[f'simplified_dit_llama.k_norms.{i}.bias'] = f'layers.{i}.attention.k_norm.bias' - - map_from_my_dit[f'simplified_dit_llama.ffn_norms.{i}.weight'] = f'layers.{i}.ffn_norm.weight' - map_from_my_dit[f'simplified_dit_llama.w2s.{i}.weight'] = f'layers.{i}.feed_forward.w2.weight' - - + map_from_my_dit[ + f"simplified_dit_llama.adaLN_modulations.{i}.weight" + ] = f"layers.{i}.adaLN_modulation.1.weight" + map_from_my_dit[ + f"simplified_dit_llama.adaLN_modulations.{i}.bias" + ] = f"layers.{i}.adaLN_modulation.1.bias" + + map_from_my_dit[ + f"simplified_dit_llama.attention_norms.{i}.weight" + ] = f"layers.{i}.attention_norm.weight" + + map_from_my_dit[f"simplified_dit_llama.wqs.{i}.weight"] = f"layers.{i}.attention.wq.weight" + map_from_my_dit[f"simplified_dit_llama.wks.{i}.weight"] = f"layers.{i}.attention.wk.weight" + map_from_my_dit[f"simplified_dit_llama.wvs.{i}.weight"] = f"layers.{i}.attention.wv.weight" + map_from_my_dit[f"simplified_dit_llama.wos.{i}.weight"] = f"layers.{i}.attention.wo.weight" + + map_from_my_dit[f"simplified_dit_llama.q_norms.{i}.weight"] = f"layers.{i}.attention.q_norm.weight" + map_from_my_dit[f"simplified_dit_llama.q_norms.{i}.bias"] = f"layers.{i}.attention.q_norm.bias" + map_from_my_dit[f"simplified_dit_llama.k_norms.{i}.weight"] = f"layers.{i}.attention.k_norm.weight" + map_from_my_dit[f"simplified_dit_llama.k_norms.{i}.bias"] = f"layers.{i}.attention.k_norm.bias" + + map_from_my_dit[f"simplified_dit_llama.ffn_norms.{i}.weight"] = f"layers.{i}.ffn_norm.weight" + map_from_my_dit[f"simplified_dit_llama.w2s.{i}.weight"] = f"layers.{i}.feed_forward.w2.weight" + for key in map_from_my_dit.keys(): state_dict[key] = paddle.assign(state_dict[map_from_my_dit[key]]) - - for i in range(32): - w1 = state_dict[f'layers.{i}.feed_forward.w1.weight'] - w3 = state_dict[f'layers.{i}.feed_forward.w3.weight'] - state_dict[f'simplified_dit_llama.w13s.{i}.weight'] = paddle.concat([w1, w3], axis=1) + + for i in range(32): + w1 = state_dict[f"layers.{i}.feed_forward.w1.weight"] + w3 = state_dict[f"layers.{i}.feed_forward.w3.weight"] + state_dict[f"simplified_dit_llama.w13s.{i}.weight"] = paddle.concat([w1, w3], axis=1) diff --git a/ppdiffusers/ppdiffusers/models/simplified_dit_llama.py b/ppdiffusers/ppdiffusers/models/simplified_dit_llama.py index 82d6b8aab..2912041eb 100644 --- a/ppdiffusers/ppdiffusers/models/simplified_dit_llama.py +++ b/ppdiffusers/ppdiffusers/models/simplified_dit_llama.py @@ -118,37 +118,59 @@ def compute_activation( def forward(self, x, freqs_cis, adaln_input): freqs_cis = paddle.expand(freqs_cis, [-1, self.n_heads, -1, -1]) adaln_input = F.silu(adaln_input) + prev_gate_mlp = None + + from paddlemix.triton_ops import ( + adaptive_layer_norm, + fused_adaLN_scale_residual, + fused_rotary_emb, + ) + for i in range(self.num_layers): shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulations[i]( adaln_input ).chunk(6, axis=1) - # AdaLN - import paddlemix - - attn_in = paddlemix.triton_ops.adaptive_layer_norm( - x, scale_msa, shift_msa, weight=self.attention_norms[i].weight, epsilon=self.norm_eps - ) + # (Fused_)adaLN + if i == 0: + attn_in = adaptive_layer_norm( + x, scale_msa, shift_msa, weight=self.attention_norms[i].weight, epsilon=self.norm_eps + ) + else: + x, attn_in = fused_adaLN_scale_residual( + resi_out, + ffn_out, + prev_gate_mlp, + scale_msa, + shift_msa, + weight=self.attention_norms[i].weight, + epsilon=self.norm_eps, + ) # Attention xq, xk, xv = self.wqs[i](attn_in), self.wks[i](attn_in), self.wvs[i](attn_in) - qkv_out = paddle.concat([xq, xk, xv], axis=-1) - xq, xk, xv = paddlemix.triton_ops.fused_rotary_emb( - qkv_out, + xq, xk = fused_rotary_emb( + xq, + xk, self.q_norms[i].weight, self.q_norms[i].bias, self.k_norms[i].weight, self.k_norms[i].bias, freqs_cis, + self.norm_eps, ) + xv = xv.reshape([xv.shape[0], xv.shape[1], self.n_heads, self.head_dim]) attn_out, _ = flash_attention(xq, xk, xv, dropout=0.0, causal=False, return_softmax=False) attn_out = attn_out.flatten(start_axis=-2) attn_out = self.wos[i](attn_out) # Fused_adaLN - resi_out, adaLN_out = paddlemix.triton_ops.fused_adaLN_scale_residual( + resi_out, adaLN_out = fused_adaLN_scale_residual( x, attn_out, gate_msa, scale_mlp, shift_mlp, weight=self.ffn_norms[i].weight, epsilon=self.norm_eps ) # FFN ffn_out = self.w13s[i](adaLN_out) ffn_out = self.compute_activation(ffn_out) ffn_out = self.w2s[i](ffn_out) - x = resi_out + gate_mlp.unsqueeze(1) * ffn_out + # + prev_gate_mlp = gate_mlp + + x = resi_out + prev_gate_mlp.unsqueeze(1) * ffn_out return x From c933f809a2d638ef0a0411f450baefbdc0b87d3e Mon Sep 17 00:00:00 2001 From: YKTian-x2b <2084984251@qq.com> Date: Thu, 8 Aug 2024 06:45:37 +0000 Subject: [PATCH 19/21] update largedit --- ppdiffusers/ppdiffusers/models/dit_llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ppdiffusers/ppdiffusers/models/dit_llama.py b/ppdiffusers/ppdiffusers/models/dit_llama.py index 84e4ce325..e838dd1e4 100644 --- a/ppdiffusers/ppdiffusers/models/dit_llama.py +++ b/ppdiffusers/ppdiffusers/models/dit_llama.py @@ -535,7 +535,7 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): t = paddle.arange(end=end) input_0, vec2_0 = TypePromote(t, freqs) freqs = paddle.outer(input_0, vec2_0).cast("float32") - if bool(os.getenv("INFOPTIMIZE")): + if os.getenv("INFOPTIMIZE"): freqs_cis = paddle.stack( [paddle.cos(freqs), -paddle.sin(freqs), paddle.sin(freqs), paddle.cos(freqs)], axis=-1 ) From 9903122fa316dab9caa84465be411a9e5836b4ae Mon Sep 17 00:00:00 2001 From: YKTian-x2b <2084984251@qq.com> Date: Thu, 8 Aug 2024 07:26:48 +0000 Subject: [PATCH 20/21] INFERENCE_OPTIMIZE --- .../class_conditional_image_generation-large_dit_3b.py | 6 ++++-- .../class_conditional_image_generation-large_dit_7b.py | 6 ++++-- ppdiffusers/ppdiffusers/models/dit_llama.py | 10 +++++----- 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/ppdiffusers/examples/inference/class_conditional_image_generation-large_dit_3b.py b/ppdiffusers/examples/inference/class_conditional_image_generation-large_dit_3b.py index f19bf228d..853d2c401 100644 --- a/ppdiffusers/examples/inference/class_conditional_image_generation-large_dit_3b.py +++ b/ppdiffusers/examples/inference/class_conditional_image_generation-large_dit_3b.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os + import paddle from paddlenlp.trainer import set_seed @@ -19,8 +21,8 @@ dtype = paddle.bfloat16 -# If you want to turn off optimization, comment this code -os.environ["INFOPTIMIZE"] = "True" +# True for inference optimizate +os.environ["INFERENCE_OPTIMIZE"] = "False" with paddle.LazyGuard(): pipe = DiTPipeline.from_pretrained("Alpha-VLLM/Large-DiT-3B-256", paddle_dtype=dtype) diff --git a/ppdiffusers/examples/inference/class_conditional_image_generation-large_dit_7b.py b/ppdiffusers/examples/inference/class_conditional_image_generation-large_dit_7b.py index 0fe57394e..906e0aa28 100644 --- a/ppdiffusers/examples/inference/class_conditional_image_generation-large_dit_7b.py +++ b/ppdiffusers/examples/inference/class_conditional_image_generation-large_dit_7b.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os + import paddle from paddlenlp.trainer import set_seed @@ -19,8 +21,8 @@ dtype = paddle.bfloat16 -# If you want to turn off optimization, comment this code -os.environ["INFOPTIMIZE"] = "True" +# True for inference optimizate +os.environ["INFERENCE_OPTIMIZE"] = "False" with paddle.LazyGuard(): pipe = DiTPipeline.from_pretrained("Alpha-VLLM/Large-DiT-7B-256", paddle_dtype=dtype) diff --git a/ppdiffusers/ppdiffusers/models/dit_llama.py b/ppdiffusers/ppdiffusers/models/dit_llama.py index e838dd1e4..67fc68237 100644 --- a/ppdiffusers/ppdiffusers/models/dit_llama.py +++ b/ppdiffusers/ppdiffusers/models/dit_llama.py @@ -468,8 +468,8 @@ def __init__( self.final_layer = FinalLayer(dim, patch_size, self.out_channels) self.freqs_cis = self.precompute_freqs_cis(dim // num_attention_heads, 4096) - self.INFOPTIMIZE = os.getenv("INFOPTIMIZE") == "True" - if self.INFOPTIMIZE: + self.INFERENCE_OPTIMIZE = os.getenv("INFERENCE_OPTIMIZE") == "True" + if self.INFERENCE_OPTIMIZE: self.simplified_dit_llama = SimplifiedDiTLLaMA2DModel( num_layers, dim, num_attention_heads, multiple_of, mlp_ratio, norm_eps ) @@ -535,7 +535,7 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): t = paddle.arange(end=end) input_0, vec2_0 = TypePromote(t, freqs) freqs = paddle.outer(input_0, vec2_0).cast("float32") - if os.getenv("INFOPTIMIZE"): + if os.getenv("INFERENCE_OPTIMIZE") == "True": freqs_cis = paddle.stack( [paddle.cos(freqs), -paddle.sin(freqs), paddle.sin(freqs), paddle.cos(freqs)], axis=-1 ) @@ -571,7 +571,7 @@ def forward( adaln_input = t + y # 2. Blocks - if self.INFOPTIMIZE: + if self.INFERENCE_OPTIMIZE: x = self.simplified_dit_llama(x, self.freqs_cis[: x.shape[1]], adaln_input) else: for i, layer in enumerate(self.layers): @@ -595,7 +595,7 @@ def forward( @classmethod def custom_modify_weight(cls, state_dict): - if os.getenv("INFOPTIMIZE"): + if os.getenv("INFERENCE_OPTIMIZE") == "True": map_from_my_dit = {} for i in range(32): map_from_my_dit[ From f54958a8cbd364fb9f5c0179b05a3de6b106969b Mon Sep 17 00:00:00 2001 From: YKTian-x2b <2084984251@qq.com> Date: Sat, 10 Aug 2024 07:54:25 +0000 Subject: [PATCH 21/21] new modify_weight --- ppdiffusers/ppdiffusers/models/dit_llama.py | 45 ++++++++------------- 1 file changed, 17 insertions(+), 28 deletions(-) diff --git a/ppdiffusers/ppdiffusers/models/dit_llama.py b/ppdiffusers/ppdiffusers/models/dit_llama.py index 67fc68237..cd0ebcfc1 100644 --- a/ppdiffusers/ppdiffusers/models/dit_llama.py +++ b/ppdiffusers/ppdiffusers/models/dit_llama.py @@ -596,36 +596,25 @@ def forward( @classmethod def custom_modify_weight(cls, state_dict): if os.getenv("INFERENCE_OPTIMIZE") == "True": - map_from_my_dit = {} for i in range(32): - map_from_my_dit[ - f"simplified_dit_llama.adaLN_modulations.{i}.weight" - ] = f"layers.{i}.adaLN_modulation.1.weight" - map_from_my_dit[ - f"simplified_dit_llama.adaLN_modulations.{i}.bias" - ] = f"layers.{i}.adaLN_modulation.1.bias" + mappings = [ + (f"adaLN_modulations.{i}.weight", f"{i}.adaLN_modulation.1.weight"), + (f"adaLN_modulations.{i}.bias", f"{i}.adaLN_modulation.1.bias"), + (f"attention_norms.{i}.weight", f"{i}.attention_norm.weight"), + (f"wqs.{i}.weight", f"{i}.attention.wq.weight"), + (f"wks.{i}.weight", f"{i}.attention.wk.weight"), + (f"wvs.{i}.weight", f"{i}.attention.wv.weight"), + (f"wos.{i}.weight", f"{i}.attention.wo.weight"), + (f"q_norms.{i}.weight", f"{i}.attention.q_norm.weight"), + (f"q_norms.{i}.bias", f"{i}.attention.q_norm.bias"), + (f"k_norms.{i}.weight", f"{i}.attention.k_norm.weight"), + (f"k_norms.{i}.bias", f"{i}.attention.k_norm.bias"), + (f"ffn_norms.{i}.weight", f"{i}.ffn_norm.weight"), + (f"w2s.{i}.weight", f"{i}.feed_forward.w2.weight"), + ] + for to_, from_ in mappings: + state_dict["simplified_dit_llama." + to_] = paddle.assign(state_dict["layers." + from_]) - map_from_my_dit[ - f"simplified_dit_llama.attention_norms.{i}.weight" - ] = f"layers.{i}.attention_norm.weight" - - map_from_my_dit[f"simplified_dit_llama.wqs.{i}.weight"] = f"layers.{i}.attention.wq.weight" - map_from_my_dit[f"simplified_dit_llama.wks.{i}.weight"] = f"layers.{i}.attention.wk.weight" - map_from_my_dit[f"simplified_dit_llama.wvs.{i}.weight"] = f"layers.{i}.attention.wv.weight" - map_from_my_dit[f"simplified_dit_llama.wos.{i}.weight"] = f"layers.{i}.attention.wo.weight" - - map_from_my_dit[f"simplified_dit_llama.q_norms.{i}.weight"] = f"layers.{i}.attention.q_norm.weight" - map_from_my_dit[f"simplified_dit_llama.q_norms.{i}.bias"] = f"layers.{i}.attention.q_norm.bias" - map_from_my_dit[f"simplified_dit_llama.k_norms.{i}.weight"] = f"layers.{i}.attention.k_norm.weight" - map_from_my_dit[f"simplified_dit_llama.k_norms.{i}.bias"] = f"layers.{i}.attention.k_norm.bias" - - map_from_my_dit[f"simplified_dit_llama.ffn_norms.{i}.weight"] = f"layers.{i}.ffn_norm.weight" - map_from_my_dit[f"simplified_dit_llama.w2s.{i}.weight"] = f"layers.{i}.feed_forward.w2.weight" - - for key in map_from_my_dit.keys(): - state_dict[key] = paddle.assign(state_dict[map_from_my_dit[key]]) - - for i in range(32): w1 = state_dict[f"layers.{i}.feed_forward.w1.weight"] w3 = state_dict[f"layers.{i}.feed_forward.w3.weight"] state_dict[f"simplified_dit_llama.w13s.{i}.weight"] = paddle.concat([w1, w3], axis=1)