diff --git a/paddlemix/triton_ops/triton_ops.py b/paddlemix/triton_ops/triton_ops.py index 3ade229c3..f7b7bf7f9 100644 --- a/paddlemix/triton_ops/triton_ops.py +++ b/paddlemix/triton_ops/triton_ops.py @@ -633,7 +633,6 @@ def weight_only_int8(x, qweight, scales, bias=None, bool_trans_w=True): return out -########################### adaptive layer norm ############################### fused_adaLN_scale_residual_template = ( """ @@ -1317,26 +1316,27 @@ def rms_norm(x, weight=None, bias=None, epsilon=1e-05): fused_rotary_emb_template = ( """ std::vector ${op_name}_func( - const paddle::Tensor &x, + const paddle::Tensor &q, + const paddle::Tensor &k, const paddle::Tensor &q_norm_weight, const paddle::Tensor &q_norm_bias, const paddle::Tensor &k_norm_weight, const paddle::Tensor &k_norm_bias, const paddle::Tensor &freqs_cis, float epsilon) { - int BSZ = x.dims()[0]; - int SEQ_LEN = x.dims()[1]; + int BSZ = q.dims()[0]; + int SEQ_LEN = q.dims()[1]; + int DIM = q.dims()[2]; int HEAD_DIM = freqs_cis.dims()[2]; - int DIM = q_norm_weight.dims()[0]; + int NUM_HEAD = DIM / HEAD_DIM; int M = BSZ * SEQ_LEN; - int DIM_concat = x.dims()[2]; - auto q_out = paddle::empty({BSZ, SEQ_LEN, NUM_HEAD, HEAD_DIM}, x.dtype(), x.place()); - auto k_out = paddle::empty({BSZ, SEQ_LEN, NUM_HEAD, HEAD_DIM}, x.dtype(), x.place()); - auto v_out = paddle::empty({BSZ, SEQ_LEN, NUM_HEAD, HEAD_DIM}, x.dtype(), x.place()); + auto q_out = paddle::empty({BSZ, SEQ_LEN, NUM_HEAD, HEAD_DIM}, q.dtype(), q.place()); + auto k_out = paddle::empty({BSZ, SEQ_LEN, NUM_HEAD, HEAD_DIM}, k.dtype(), k.place()); - auto x_ptr = get_tensor_ptr(x); + auto q_ptr = get_tensor_ptr(q); + auto k_ptr = get_tensor_ptr(k); auto q_norm_weight_ptr = get_tensor_ptr(q_norm_weight); auto q_norm_bias_ptr = get_tensor_ptr(q_norm_bias); auto k_norm_weight_ptr = get_tensor_ptr(k_norm_weight); @@ -1344,13 +1344,13 @@ def rms_norm(x, weight=None, bias=None, epsilon=1e-05): auto freqs_cis_ptr = get_tensor_ptr(freqs_cis); auto q_out_ptr = get_tensor_ptr(q_out); auto k_out_ptr = get_tensor_ptr(k_out); - auto v_out_ptr = get_tensor_ptr(v_out); + const paddle::Tensor &x = q; auto run_stream = q_out.stream(); """ + tune_and_invoke_part + """ - return {q_out, k_out, v_out}; + return {q_out, k_out}; } std::vector> ${op_name}_InferShape( @@ -1359,23 +1359,24 @@ def rms_norm(x, weight=None, bias=None, epsilon=1e-05): const std::vector& C_shape, const std::vector& D_shape, const std::vector& E_shape, - const std::vector& F_shape) { + const std::vector& F_shape, + const std::vector& G_shape) { int BSZ = A_shape[0]; int SEQ_LEN = A_shape[1]; - int HEAD_DIM = F_shape[2]; - int DIM = B_shape[0]; + int DIM = A_shape[2]; + int HEAD_DIM = G_shape[2]; int NUM_HEAD = DIM / HEAD_DIM; std::vector res_shape = {BSZ, SEQ_LEN, NUM_HEAD, HEAD_DIM}; - return {res_shape, res_shape, res_shape}; + return {res_shape, res_shape}; } std::vector ${op_name}_InferDtype(const paddle::DataType& A_dtype) { - return {A_dtype, A_dtype, A_dtype}; + return {A_dtype, A_dtype}; } PD_BUILD_OP(${op_name}) - .Inputs({"x", "q_norm_weight", "q_norm_bias", "k_norm_weight", "k_norm_bias", "freqs_cis"}) - .Outputs({"q_out", "k_out", "v_out"}) + .Inputs({"q", "k", "q_norm_weight", "q_norm_bias", "k_norm_weight", "k_norm_bias", "freqs_cis"}) + .Outputs({"q_out", "k_out"}) .SetKernelFn(PD_KERNEL(${op_name}_func)) .Attrs({"epsilon: float"}) .SetInferDtypeFn(PD_INFER_DTYPE(${op_name}_InferDtype)) @@ -1389,29 +1390,28 @@ def rms_norm(x, weight=None, bias=None, epsilon=1e-05): key=["M"], ) def fused_rotary_emb_kernel( - x_ptr, # [BSZ, SEQ_LEN, DIM_concat] - q_out_ptr, - k_out_ptr, # [BSZ, SEQ_LEN, NUM_HEAD, HEAD_DIM, 2] - v_out_ptr, # [BSZ, SEQ_LEN, NUM_HEAD, HEAD_DIM] - q_norm_weight_ptr, + q_ptr, # [BSZ, SEQ_LEN, DIM] + k_ptr, + q_out_ptr, # [BSZ, SEQ_LEN, NUM_HEAD, HEAD_DIM] + k_out_ptr, + q_norm_weight_ptr, # [DIM] q_norm_bias_ptr, k_norm_weight_ptr, - k_norm_bias_ptr, # [DIM] - freqs_cis_ptr, # [1, seq_len, 1, head_dim, 2] + k_norm_bias_ptr, + freqs_cis_ptr, # [SEQ_LEN, 1, HEAD_DIM, 2] epsilon, SEQ_LEN, M, DIM, - DIM_concat, DIM_npo2: tl.constexpr, ): row = tl.program_id(axis=0) - x_ptr += row * DIM_concat + q_ptr += row * DIM + k_ptr += row * DIM offs = tl.arange(0, DIM_npo2) masks = offs < DIM - q_eles = tl.load(x_ptr + offs, mask=masks, other=0.0).to(tl.float32) - k_eles = tl.load(x_ptr + DIM + offs, mask=masks, other=0.0).to(tl.float32) - v_eles = tl.load(x_ptr + 2 * DIM + offs, mask=masks, other=0.0) + q_eles = tl.load(q_ptr + offs, mask=masks, other=0.0).to(tl.float32) + k_eles = tl.load(k_ptr + offs, mask=masks, other=0.0).to(tl.float32) # qk layernorm q_mean = tl.sum(q_eles, axis=0) / DIM @@ -1451,49 +1451,116 @@ def fused_rotary_emb_kernel( k_resi_hat = tl.reshape(k_resi_hat, (DIM_npo2, 2)) k_res = tl.sum(k_resi_hat * freqs_cis, axis=1) - out_offs = row * DIM + offs - tl.store(q_out_ptr + out_offs, q_res, mask=masks) - tl.store(k_out_ptr + out_offs, k_res, mask=masks) - tl.store(v_out_ptr + out_offs, v_eles, mask=masks) + tl.store(q_out_ptr + row * DIM + offs, q_res, mask=masks) + tl.store(k_out_ptr + row * DIM + offs, k_res, mask=masks) def fused_rotary_emb( - x, - q_norm_weight, + q, # [BSZ, SEQ_LEN, DIM] + k, + q_norm_weight, # [DIM] q_norm_bias, k_norm_weight, k_norm_bias, freqs_cis, epsilon=1e-5, ): - assert x.is_contiguous() + """ + Examples: + + import os + os.environ["CUDA_VISIBLE_DEVICES"] = "7" + import paddle + + def apply_rotary_emb(xq, xk, freqs_cis): + xq_ = paddle.as_complex(xq.cast("float32").reshape([*tuple(xq.shape)[:-1], -1, 2])) + xk_ = paddle.as_complex(xk.cast("float32").reshape([*tuple(xk.shape)[:-1], -1, 2])) + shape = [(d if i == 1 or i == xq_.ndim - 1 else 1) for i, d in enumerate(tuple(xq_.shape))] + freqs_cis = freqs_cis.reshape([*shape]) + xq_out = paddle.as_real(xq_ * freqs_cis).flatten(start_axis=3) + xk_out = paddle.as_real(xk_ * freqs_cis).flatten(start_axis=3) + return xq_out.cast(xq.dtype), xk_out.cast(xk.dtype) + + def get_freqs_cis(dim: int, end: int): + freqs = 1.0 / 10000.0 ** (paddle.arange(start=0, end=dim, step=2)[: dim // 2].cast("float32") / dim) + t = paddle.arange(end=end).cast("float32") + # [SEQ_LEN, HEAD_DIM//2] + freqs = paddle.outer(t, freqs).cast("float32") + # [SEQ_LEN, HEAD_DIM//2] + freqs_cis_ref = paddle.complex( + paddle.ones_like(freqs) * paddle.cos(freqs), paddle.ones_like(freqs) * paddle.sin(freqs) + ) + freqs_cis = paddle.stack([ + paddle.cos(freqs), + -paddle.sin(freqs), + paddle.sin(freqs), + paddle.cos(freqs)], axis=-1) + # [SEQ_LEN, HEAD_DIM, 2] + freqs_cis = freqs_cis.reshape([freqs_cis.shape[0], -1, 2]).unsqueeze(1) + return freqs_cis, freqs_cis_ref + + BSZ = 2 + SEQ_LEN = 64 + NUM_HEAD = 16 + HEAD_DIM = 72 + DIM = NUM_HEAD * HEAD_DIM + epsilon = 1e-5 + dtype_= "float16" + q = paddle.rand([BSZ, SEQ_LEN, DIM], dtype=dtype_) + k = paddle.rand([BSZ, SEQ_LEN, DIM], dtype=dtype_) + q_norm_weight = paddle.rand([DIM], dtype=dtype_) + k_norm_weight = paddle.rand([DIM], dtype=dtype_) + q_norm_bias = paddle.rand([DIM], dtype=dtype_) + k_norm_bias = paddle.rand([DIM], dtype=dtype_) + + freqs_cis, freqs_cis_ref = get_freqs_cis(HEAD_DIM, SEQ_LEN) + freqs_cis = paddle.expand(freqs_cis, [-1, NUM_HEAD, -1, -1]) + + # paddle + q_ln = paddle.nn.functional.layer_norm(q, [DIM], weight=q_norm_weight, bias=q_norm_bias, epsilon=epsilon) + k_ln = paddle.nn.functional.layer_norm(k, [DIM], weight=k_norm_weight, bias=k_norm_bias, epsilon=epsilon) + q_ln = q_ln.reshape([BSZ, SEQ_LEN, NUM_HEAD, HEAD_DIM]) + k_ln = k_ln.reshape([BSZ, SEQ_LEN, NUM_HEAD, HEAD_DIM]) + q_out_baseline, k_out_baseline= apply_rotary_emb(q_ln, k_ln, freqs_cis_ref) + q_out_baseline = q_out_baseline.reshape([BSZ, SEQ_LEN, NUM_HEAD, HEAD_DIM]) + k_out_baseline = k_out_baseline.reshape([BSZ, SEQ_LEN, NUM_HEAD, HEAD_DIM]) + + # triton + import paddlemix + q_out, k_out = paddlemix.triton_ops.fused_rotary_emb(q, k, q_norm_weight, q_norm_bias, k_norm_weight, k_norm_bias, freqs_cis, epsilon) + + print("Q allclose: ", paddle.allclose(q_out_baseline, q_out, rtol=1e-03, atol=1e-02).numpy()) + print("K allclose: ", paddle.allclose(k_out_baseline, k_out, rtol=1e-03, atol=1e-02).numpy()) + """ + assert q.shape == k.shape, "q and k should have the same shape" + assert len(q.shape) == 3, "q should be [BSZ, SEQ_LEN, DIM]" assert q_norm_weight is not None, "q_norm_weight should not be none" assert q_norm_bias is not None, "q_norm_bias should not be none" assert k_norm_weight is not None, "k_norm_weight should not be none" assert k_norm_bias is not None, "k_norm_bias should not be none" - DIM = q_norm_weight.shape[0] + assert ( + q_norm_weight.shape == q_norm_bias.shape == k_norm_weight.shape == k_norm_bias.shape + ), "q_norm_weight, q_norm_bias, k_norm_weight, k_norm_bias should have the same shape" + assert q.shape[-1] == q_norm_weight.shape[0], "q_norm_weight should be [DIM]" + + BSZ, SEQ_LEN, DIM = q.shape HEAD_DIM = freqs_cis.shape[-2] assert (DIM % HEAD_DIM) == 0, "dim should be divisible by head_dim" - DIM_concat = x.shape[-1] - assert (DIM * 3) == DIM_concat, "not support GQA, qkv num_head should be equal" - BSZ = x.shape[0] - SEQ_LEN = x.shape[1] NUM_HEAD = DIM // HEAD_DIM M = BSZ * SEQ_LEN DIM_npo2 = triton.next_power_of_2(DIM) - dtype_ = x.dtype + dtype_ = q.dtype # q_out_tensor = paddle.empty([BSZ, SEQ_LEN, NUM_HEAD, HEAD_DIM], dtype=dtype_) # k_out_tensor = paddle.empty([BSZ, SEQ_LEN, NUM_HEAD, HEAD_DIM], dtype=dtype_) - # v_out_tensor = paddle.empty([BSZ, SEQ_LEN, NUM_HEAD, HEAD_DIM], dtype=dtype_) # fused_rotary_emb_kernel[(M,)]( - # input_tensor, q_out_tensor, k_out_tensor, v_out_tensor, - # q_norm_weight, q_norm_bias, k_norm_weight, k_norm_bias, freqs_cis, epsilon, - # SEQ_LEN, M, DIM, DIM_concat, + # q, k, q_out_tensor, k_out_tensor, q_norm_weight, q_norm_bias, + # k_norm_weight, k_norm_bias, freqs_cis, epsilon, + # SEQ_LEN, M, DIM, # DIM_npo2, num_warps=4, # ) - # return q_out_tensor, k_out_tensor, v_out_tensor + # return q_out_tensor, k_out_tensor op_name = "triton_fused_rotary_emb" op_name += get_dtype_str(dtype_) @@ -1511,13 +1578,12 @@ def fused_rotary_emb( empty_dtype = dtype_ if dtype_ != paddle.bfloat16 else paddle.float16 q_out_tensor = paddle.empty([BSZ, SEQ_LEN, NUM_HEAD, HEAD_DIM], dtype=empty_dtype).astype(dtype_) k_out_tensor = paddle.empty([BSZ, SEQ_LEN, NUM_HEAD, HEAD_DIM], dtype=empty_dtype).astype(dtype_) - v_out_tensor = paddle.empty([BSZ, SEQ_LEN, NUM_HEAD, HEAD_DIM], dtype=empty_dtype).astype(dtype_) grid = ("M",) fused_rotary_emb_kernel[(op_name, grid, fused_rotary_emb_kernel_config)]( - x, + q, + k, q_out_tensor, k_out_tensor, - v_out_tensor, q_norm_weight, q_norm_bias, k_norm_weight, @@ -1527,7 +1593,6 @@ def fused_rotary_emb( SEQ_LEN, M, DIM, - DIM_concat, DIM_npo2, ) @@ -1535,7 +1600,8 @@ def fused_rotary_emb( print(f"== we are in dynamic mode, op_name: {op_name}") outs = _C_ops._run_custom_op( op_name, - x, + q, + k, q_norm_weight, q_norm_bias, k_norm_weight, @@ -1543,12 +1609,13 @@ def fused_rotary_emb( freqs_cis, epsilon, ) - return outs[0], outs[1], outs[2] + return outs[0], outs[1] else: print(f"== we are in dynamic to static mode, op_name: {op_name}") helper = LayerHelper(op_name, **locals()) inputs = { - "x": x, + "q": q, + "k": k, "q_norm_weight": q_norm_weight, "q_norm_bias": q_norm_bias, "k_norm_weight": k_norm_weight, @@ -1557,13 +1624,12 @@ def fused_rotary_emb( } q_out = helper.create_variable_for_type_inference(dtype=dtype_) k_out = helper.create_variable_for_type_inference(dtype=dtype_) - v_out = helper.create_variable_for_type_inference(dtype=dtype_) helper.append_op( type=op_name, inputs=inputs, attrs={ "epsilon": epsilon, }, - outputs={"q_out": q_out, "k_out": k_out, "v_out": v_out}, + outputs={"q_out": q_out, "k_out": k_out}, ) - return q_out, k_out, v_out + return q_out, k_out diff --git a/ppdiffusers/examples/inference/class_conditional_image_generation-large_dit_3b.py b/ppdiffusers/examples/inference/class_conditional_image_generation-large_dit_3b.py index 93f02a955..853d2c401 100644 --- a/ppdiffusers/examples/inference/class_conditional_image_generation-large_dit_3b.py +++ b/ppdiffusers/examples/inference/class_conditional_image_generation-large_dit_3b.py @@ -12,12 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os + import paddle from paddlenlp.trainer import set_seed from ppdiffusers import DDIMScheduler, DiTPipeline -dtype = paddle.float32 +dtype = paddle.bfloat16 + +# True for inference optimizate +os.environ["INFERENCE_OPTIMIZE"] = "False" + with paddle.LazyGuard(): pipe = DiTPipeline.from_pretrained("Alpha-VLLM/Large-DiT-3B-256", paddle_dtype=dtype) pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) diff --git a/ppdiffusers/examples/inference/class_conditional_image_generation-large_dit_7b.py b/ppdiffusers/examples/inference/class_conditional_image_generation-large_dit_7b.py index 170ecfe4b..906e0aa28 100644 --- a/ppdiffusers/examples/inference/class_conditional_image_generation-large_dit_7b.py +++ b/ppdiffusers/examples/inference/class_conditional_image_generation-large_dit_7b.py @@ -12,12 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os + import paddle from paddlenlp.trainer import set_seed from ppdiffusers import DDIMScheduler, DiTPipeline -dtype = paddle.float32 +dtype = paddle.bfloat16 + +# True for inference optimizate +os.environ["INFERENCE_OPTIMIZE"] = "False" + with paddle.LazyGuard(): pipe = DiTPipeline.from_pretrained("Alpha-VLLM/Large-DiT-7B-256", paddle_dtype=dtype) pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) diff --git a/ppdiffusers/ppdiffusers/models/dit_llama.py b/ppdiffusers/ppdiffusers/models/dit_llama.py index 4c1b55f6a..cd0ebcfc1 100644 --- a/ppdiffusers/ppdiffusers/models/dit_llama.py +++ b/ppdiffusers/ppdiffusers/models/dit_llama.py @@ -13,6 +13,7 @@ # limitations under the License. import math +import os from typing import Optional import paddle @@ -23,6 +24,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from .embeddings import LabelEmbedding from .modeling_utils import ModelMixin +from .simplified_dit_llama import SimplifiedDiTLLaMA2DModel from .transformer_2d import Transformer2DModelOutput @@ -466,6 +468,13 @@ def __init__( self.final_layer = FinalLayer(dim, patch_size, self.out_channels) self.freqs_cis = self.precompute_freqs_cis(dim // num_attention_heads, 4096) + self.INFERENCE_OPTIMIZE = os.getenv("INFERENCE_OPTIMIZE") == "True" + if self.INFERENCE_OPTIMIZE: + self.simplified_dit_llama = SimplifiedDiTLLaMA2DModel( + num_layers, dim, num_attention_heads, multiple_of, mlp_ratio, norm_eps + ) + del self.layers + def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value @@ -526,9 +535,15 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): t = paddle.arange(end=end) input_0, vec2_0 = TypePromote(t, freqs) freqs = paddle.outer(input_0, vec2_0).cast("float32") - freqs_cis = paddle.complex( - paddle.ones_like(freqs) * paddle.cos(freqs), paddle.ones_like(freqs) * paddle.sin(freqs) - ) + if os.getenv("INFERENCE_OPTIMIZE") == "True": + freqs_cis = paddle.stack( + [paddle.cos(freqs), -paddle.sin(freqs), paddle.sin(freqs), paddle.cos(freqs)], axis=-1 + ) + freqs_cis = freqs_cis.reshape([freqs_cis.shape[0], -1, 2]).unsqueeze(1) + else: + freqs_cis = paddle.complex( + paddle.ones_like(freqs) * paddle.cos(freqs), paddle.ones_like(freqs) * paddle.sin(freqs) + ) return freqs_cis def forward( @@ -556,15 +571,18 @@ def forward( adaln_input = t + y # 2. Blocks - for i, layer in enumerate(self.layers): - if self.gradient_checkpointing: - x = paddle.distributed.fleet.utils.recompute(layer, x, self.freqs_cis[: x.shape[1]], adaln_input) - else: - x = layer( - x, - self.freqs_cis[: x.shape[1]], - adaln_input, - ) + if self.INFERENCE_OPTIMIZE: + x = self.simplified_dit_llama(x, self.freqs_cis[: x.shape[1]], adaln_input) + else: + for i, layer in enumerate(self.layers): + if self.gradient_checkpointing: + x = paddle.distributed.fleet.utils.recompute(layer, x, self.freqs_cis[: x.shape[1]], adaln_input) + else: + x = layer( + x, + self.freqs_cis[: x.shape[1]], + adaln_input, + ) # 3. Output hidden_states = self.final_layer(x, adaln_input) @@ -574,3 +592,29 @@ def forward( return (output,) return Transformer2DModelOutput(sample=output) + + @classmethod + def custom_modify_weight(cls, state_dict): + if os.getenv("INFERENCE_OPTIMIZE") == "True": + for i in range(32): + mappings = [ + (f"adaLN_modulations.{i}.weight", f"{i}.adaLN_modulation.1.weight"), + (f"adaLN_modulations.{i}.bias", f"{i}.adaLN_modulation.1.bias"), + (f"attention_norms.{i}.weight", f"{i}.attention_norm.weight"), + (f"wqs.{i}.weight", f"{i}.attention.wq.weight"), + (f"wks.{i}.weight", f"{i}.attention.wk.weight"), + (f"wvs.{i}.weight", f"{i}.attention.wv.weight"), + (f"wos.{i}.weight", f"{i}.attention.wo.weight"), + (f"q_norms.{i}.weight", f"{i}.attention.q_norm.weight"), + (f"q_norms.{i}.bias", f"{i}.attention.q_norm.bias"), + (f"k_norms.{i}.weight", f"{i}.attention.k_norm.weight"), + (f"k_norms.{i}.bias", f"{i}.attention.k_norm.bias"), + (f"ffn_norms.{i}.weight", f"{i}.ffn_norm.weight"), + (f"w2s.{i}.weight", f"{i}.feed_forward.w2.weight"), + ] + for to_, from_ in mappings: + state_dict["simplified_dit_llama." + to_] = paddle.assign(state_dict["layers." + from_]) + + w1 = state_dict[f"layers.{i}.feed_forward.w1.weight"] + w3 = state_dict[f"layers.{i}.feed_forward.w3.weight"] + state_dict[f"simplified_dit_llama.w13s.{i}.weight"] = paddle.concat([w1, w3], axis=1) diff --git a/ppdiffusers/ppdiffusers/models/modeling_utils.py b/ppdiffusers/ppdiffusers/models/modeling_utils.py index e4462d284..3bcb620a6 100644 --- a/ppdiffusers/ppdiffusers/models/modeling_utils.py +++ b/ppdiffusers/ppdiffusers/models/modeling_utils.py @@ -1050,6 +1050,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P return model + @classmethod + def custom_modify_weight(cls, state_dict): + pass + @classmethod def _load_pretrained_model( cls, @@ -1130,6 +1134,7 @@ def _find_mismatched_keys( error_msgs.append( f"Error size mismatch, {key_name} receives a shape {loaded_shape}, but the expected shape is {model_shape}." ) + cls.custom_modify_weight(state_dict) faster_set_state_dict(model_to_load, state_dict) missing_keys = sorted(list(set(expected_keys) - set(loaded_keys))) diff --git a/ppdiffusers/ppdiffusers/models/simplified_dit_llama.py b/ppdiffusers/ppdiffusers/models/simplified_dit_llama.py new file mode 100644 index 000000000..2912041eb --- /dev/null +++ b/ppdiffusers/ppdiffusers/models/simplified_dit_llama.py @@ -0,0 +1,176 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn.functional as F +from paddle import nn +from paddle.framework import LayerHelper, in_dynamic_mode +from paddle.nn.functional.flash_attention import flash_attention + + +class SimplifiedDiTLLaMA2DModel(nn.Layer): + def __init__(self, num_layers: int, dim: int, n_heads: int, multiple_of: int, mlp_ratio: float, norm_eps: float): + super().__init__() + self.num_layers = num_layers + self.dim = dim + self.n_heads = n_heads + self.head_dim = dim // n_heads + self.norm_eps = norm_eps + + self.adaLN_modulations = nn.LayerList([nn.Linear(min(dim, 1024), 6 * dim) for i in range(num_layers)]) + + self.attention_norms = nn.LayerList( + [nn.LayerNorm(dim, epsilon=norm_eps, bias_attr=False) for i in range(num_layers)] + ) + + self.wqs = nn.LayerList([nn.Linear(dim, n_heads * self.head_dim, bias_attr=False) for i in range(num_layers)]) + self.wks = nn.LayerList([nn.Linear(dim, n_heads * self.head_dim, bias_attr=False) for i in range(num_layers)]) + self.wvs = nn.LayerList([nn.Linear(dim, n_heads * self.head_dim, bias_attr=False) for i in range(num_layers)]) + self.wos = nn.LayerList([nn.Linear(n_heads * self.head_dim, dim, bias_attr=False) for i in range(num_layers)]) + + self.q_norms = nn.LayerList([nn.LayerNorm(n_heads * self.head_dim) for i in range(num_layers)]) + self.k_norms = nn.LayerList([nn.LayerNorm(n_heads * self.head_dim) for i in range(num_layers)]) + + self.ffn_norms = nn.LayerList( + [nn.LayerNorm(dim, epsilon=norm_eps, bias_attr=False) for i in range(num_layers)] + ) + + hidden_dim = int(dim * mlp_ratio) + hidden_dim = int(2 * hidden_dim / 3) + hidden_dim = int(multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)) + self.w13s = nn.LayerList([nn.Linear(dim, hidden_dim * 2, bias_attr=False) for i in range(num_layers)]) + self.w2s = nn.LayerList([nn.Linear(hidden_dim, dim, bias_attr=False) for i in range(num_layers)]) + + def compute_activation( + self, + ffn1_out, + bias=None, + dequant_scales=None, + shift=None, + smooth=None, + act_method="swiglu", + compute_dtype="default", + quant_scale=-1, + quant_round_type=0, + quant_max_bound=0, + quant_min_bound=0, + ): + if in_dynamic_mode(): + out = paddle._C_ops.fused_bias_act( + ffn1_out, + bias, + dequant_scales, + shift, + smooth, + act_method, + compute_dtype, + quant_scale, + quant_round_type, + quant_max_bound, + quant_min_bound, + ) + return out + + helper = LayerHelper("fused_bias_act") + out = helper.create_variable_for_type_inference(dtype=ffn1_out.dtype) + inputs = {} + inputs["x"] = ffn1_out + if bias is not None: + inputs["bias"] = bias + if dequant_scales is not None: + inputs["dequant_scales"] = dequant_scales + if shift is not None: + inputs["shift"] = shift + if smooth is not None: + inputs["smooth"] = smooth + attrs = { + "act_method": act_method, + "compute_dtype": compute_dtype, + "quant_scale": quant_scale, + "quant_round_type": quant_round_type, + "quant_max_bound": quant_max_bound, + "quant_min_bound": quant_min_bound, + } + helper.append_op( + type="fused_bias_act", + inputs=inputs, + outputs={"out": out}, + attrs=attrs, + ) + return out + + @paddle.incubate.jit.inference( + cache_static_model=False, + enable_new_ir=True, + exp_enable_use_cutlass=True, + ) + def forward(self, x, freqs_cis, adaln_input): + freqs_cis = paddle.expand(freqs_cis, [-1, self.n_heads, -1, -1]) + adaln_input = F.silu(adaln_input) + prev_gate_mlp = None + + from paddlemix.triton_ops import ( + adaptive_layer_norm, + fused_adaLN_scale_residual, + fused_rotary_emb, + ) + + for i in range(self.num_layers): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulations[i]( + adaln_input + ).chunk(6, axis=1) + # (Fused_)adaLN + if i == 0: + attn_in = adaptive_layer_norm( + x, scale_msa, shift_msa, weight=self.attention_norms[i].weight, epsilon=self.norm_eps + ) + else: + x, attn_in = fused_adaLN_scale_residual( + resi_out, + ffn_out, + prev_gate_mlp, + scale_msa, + shift_msa, + weight=self.attention_norms[i].weight, + epsilon=self.norm_eps, + ) + # Attention + xq, xk, xv = self.wqs[i](attn_in), self.wks[i](attn_in), self.wvs[i](attn_in) + xq, xk = fused_rotary_emb( + xq, + xk, + self.q_norms[i].weight, + self.q_norms[i].bias, + self.k_norms[i].weight, + self.k_norms[i].bias, + freqs_cis, + self.norm_eps, + ) + xv = xv.reshape([xv.shape[0], xv.shape[1], self.n_heads, self.head_dim]) + attn_out, _ = flash_attention(xq, xk, xv, dropout=0.0, causal=False, return_softmax=False) + attn_out = attn_out.flatten(start_axis=-2) + attn_out = self.wos[i](attn_out) + # Fused_adaLN + resi_out, adaLN_out = fused_adaLN_scale_residual( + x, attn_out, gate_msa, scale_mlp, shift_mlp, weight=self.ffn_norms[i].weight, epsilon=self.norm_eps + ) + # FFN + ffn_out = self.w13s[i](adaLN_out) + ffn_out = self.compute_activation(ffn_out) + ffn_out = self.w2s[i](ffn_out) + # + prev_gate_mlp = gate_mlp + + x = resi_out + prev_gate_mlp.unsqueeze(1) * ffn_out + return x