Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DiT with decorator, triton fused_AdaLN and fineGrained #552

Open
wants to merge 27 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
b93c57a
DiT FFN fineGrained
YKTian-x2b May 24, 2024
bca3484
DiT FFN fineGrained
YKTian-x2b May 24, 2024
6771b36
clear fine_grained_FFN
YKTian-x2b May 31, 2024
b4d92a3
Merge branch 'develop' into DiT_FFN_fineGrained
westfish May 31, 2024
ae34336
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleMIX i…
YKTian-x2b Jun 7, 2024
668f1ac
decorator + fineGrained_qkv_ffn + triton_adaLN_fusedAdaLN
YKTian-x2b Jun 7, 2024
fb96011
Merge branch 'DiT_FFN_fineGrained' of https://github.com/YKTian-x2b/P…
YKTian-x2b Jun 7, 2024
cb8bacb
clear up pr ing...
YKTian-x2b Jun 12, 2024
bdefd3b
Optional acceleration
YKTian-x2b Jun 12, 2024
18b5945
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleMIX i…
YKTian-x2b Jun 19, 2024
1cc8ca4
no reshape
YKTian-x2b Jul 22, 2024
3251f13
Revert "no reshape"
YKTian-x2b Jul 22, 2024
55b5042
no reshape
YKTian-x2b Jul 22, 2024
ee3df60
Merge branch 'DiT_FFN_fineGrained' of https://github.com/YKTian-x2b/P…
YKTian-x2b Jul 22, 2024
70c6dc0
no reshape
YKTian-x2b Jul 22, 2024
5d3d29f
fuse_repo triton kernel
YKTian-x2b Jul 25, 2024
8238c08
Merge remote-tracking branch 'upstream/develop' into DiT_FFN_fineGrained
YKTian-x2b Aug 5, 2024
19cbd90
with horizontal_fuse_pass opt
YKTian-x2b Aug 5, 2024
c000d4c
env
YKTian-x2b Aug 5, 2024
9f04a1c
ReNet
YKTian-x2b Aug 7, 2024
f2966f7
new net
YKTian-x2b Aug 7, 2024
437cbbb
little mod
YKTian-x2b Aug 7, 2024
fb4d478
pre-commit
YKTian-x2b Aug 7, 2024
e638313
update largedit
YKTian-x2b Aug 8, 2024
c933f80
update largedit
YKTian-x2b Aug 8, 2024
9903122
INFERENCE_OPTIMIZE
YKTian-x2b Aug 8, 2024
f54958a
new modify_weight
YKTian-x2b Aug 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 125 additions & 59 deletions paddlemix/triton_ops/triton_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,6 @@ def weight_only_int8(x, qweight, scales, bias=None, bool_trans_w=True):
return out


########################### adaptive layer norm ###############################
fused_adaLN_scale_residual_template = (
"""

Expand Down Expand Up @@ -1317,40 +1316,41 @@ def rms_norm(x, weight=None, bias=None, epsilon=1e-05):
fused_rotary_emb_template = (
"""
std::vector<paddle::Tensor> ${op_name}_func(
const paddle::Tensor &x,
const paddle::Tensor &q,
const paddle::Tensor &k,
const paddle::Tensor &q_norm_weight,
const paddle::Tensor &q_norm_bias,
const paddle::Tensor &k_norm_weight,
const paddle::Tensor &k_norm_bias,
const paddle::Tensor &freqs_cis,
float epsilon) {
int BSZ = x.dims()[0];
int SEQ_LEN = x.dims()[1];
int BSZ = q.dims()[0];
int SEQ_LEN = q.dims()[1];
int DIM = q.dims()[2];
int HEAD_DIM = freqs_cis.dims()[2];
int DIM = q_norm_weight.dims()[0];

int NUM_HEAD = DIM / HEAD_DIM;
int M = BSZ * SEQ_LEN;
int DIM_concat = x.dims()[2];

auto q_out = paddle::empty({BSZ, SEQ_LEN, NUM_HEAD, HEAD_DIM}, x.dtype(), x.place());
auto k_out = paddle::empty({BSZ, SEQ_LEN, NUM_HEAD, HEAD_DIM}, x.dtype(), x.place());
auto v_out = paddle::empty({BSZ, SEQ_LEN, NUM_HEAD, HEAD_DIM}, x.dtype(), x.place());
auto q_out = paddle::empty({BSZ, SEQ_LEN, NUM_HEAD, HEAD_DIM}, q.dtype(), q.place());
auto k_out = paddle::empty({BSZ, SEQ_LEN, NUM_HEAD, HEAD_DIM}, k.dtype(), k.place());

auto x_ptr = get_tensor_ptr(x);
auto q_ptr = get_tensor_ptr(q);
auto k_ptr = get_tensor_ptr(k);
auto q_norm_weight_ptr = get_tensor_ptr(q_norm_weight);
auto q_norm_bias_ptr = get_tensor_ptr(q_norm_bias);
auto k_norm_weight_ptr = get_tensor_ptr(k_norm_weight);
auto k_norm_bias_ptr = get_tensor_ptr(k_norm_bias);
auto freqs_cis_ptr = get_tensor_ptr(freqs_cis);
auto q_out_ptr = get_tensor_ptr(q_out);
auto k_out_ptr = get_tensor_ptr(k_out);
auto v_out_ptr = get_tensor_ptr(v_out);

const paddle::Tensor &x = q;
auto run_stream = q_out.stream();
"""
+ tune_and_invoke_part
+ """
return {q_out, k_out, v_out};
return {q_out, k_out};
}

std::vector<std::vector<int64_t>> ${op_name}_InferShape(
Expand All @@ -1359,23 +1359,24 @@ def rms_norm(x, weight=None, bias=None, epsilon=1e-05):
const std::vector<int64_t>& C_shape,
const std::vector<int64_t>& D_shape,
const std::vector<int64_t>& E_shape,
const std::vector<int64_t>& F_shape) {
const std::vector<int64_t>& F_shape,
const std::vector<int64_t>& G_shape) {
int BSZ = A_shape[0];
int SEQ_LEN = A_shape[1];
int HEAD_DIM = F_shape[2];
int DIM = B_shape[0];
int DIM = A_shape[2];
int HEAD_DIM = G_shape[2];
int NUM_HEAD = DIM / HEAD_DIM;
std::vector<int64_t> res_shape = {BSZ, SEQ_LEN, NUM_HEAD, HEAD_DIM};
return {res_shape, res_shape, res_shape};
return {res_shape, res_shape};
}

std::vector<paddle::DataType> ${op_name}_InferDtype(const paddle::DataType& A_dtype) {
return {A_dtype, A_dtype, A_dtype};
return {A_dtype, A_dtype};
}

PD_BUILD_OP(${op_name})
.Inputs({"x", "q_norm_weight", "q_norm_bias", "k_norm_weight", "k_norm_bias", "freqs_cis"})
.Outputs({"q_out", "k_out", "v_out"})
.Inputs({"q", "k", "q_norm_weight", "q_norm_bias", "k_norm_weight", "k_norm_bias", "freqs_cis"})
.Outputs({"q_out", "k_out"})
.SetKernelFn(PD_KERNEL(${op_name}_func))
.Attrs({"epsilon: float"})
.SetInferDtypeFn(PD_INFER_DTYPE(${op_name}_InferDtype))
Expand All @@ -1389,29 +1390,28 @@ def rms_norm(x, weight=None, bias=None, epsilon=1e-05):
key=["M"],
)
def fused_rotary_emb_kernel(
x_ptr, # [BSZ, SEQ_LEN, DIM_concat]
q_out_ptr,
k_out_ptr, # [BSZ, SEQ_LEN, NUM_HEAD, HEAD_DIM, 2]
v_out_ptr, # [BSZ, SEQ_LEN, NUM_HEAD, HEAD_DIM]
q_norm_weight_ptr,
q_ptr, # [BSZ, SEQ_LEN, DIM]
k_ptr,
q_out_ptr, # [BSZ, SEQ_LEN, NUM_HEAD, HEAD_DIM]
k_out_ptr,
q_norm_weight_ptr, # [DIM]
q_norm_bias_ptr,
k_norm_weight_ptr,
k_norm_bias_ptr, # [DIM]
freqs_cis_ptr, # [1, seq_len, 1, head_dim, 2]
k_norm_bias_ptr,
freqs_cis_ptr, # [SEQ_LEN, 1, HEAD_DIM, 2]
epsilon,
SEQ_LEN,
M,
DIM,
DIM_concat,
DIM_npo2: tl.constexpr,
):
row = tl.program_id(axis=0)
x_ptr += row * DIM_concat
q_ptr += row * DIM
k_ptr += row * DIM
offs = tl.arange(0, DIM_npo2)
masks = offs < DIM
q_eles = tl.load(x_ptr + offs, mask=masks, other=0.0).to(tl.float32)
k_eles = tl.load(x_ptr + DIM + offs, mask=masks, other=0.0).to(tl.float32)
v_eles = tl.load(x_ptr + 2 * DIM + offs, mask=masks, other=0.0)
q_eles = tl.load(q_ptr + offs, mask=masks, other=0.0).to(tl.float32)
k_eles = tl.load(k_ptr + offs, mask=masks, other=0.0).to(tl.float32)

# qk layernorm
q_mean = tl.sum(q_eles, axis=0) / DIM
Expand Down Expand Up @@ -1451,49 +1451,116 @@ def fused_rotary_emb_kernel(
k_resi_hat = tl.reshape(k_resi_hat, (DIM_npo2, 2))
k_res = tl.sum(k_resi_hat * freqs_cis, axis=1)

out_offs = row * DIM + offs
tl.store(q_out_ptr + out_offs, q_res, mask=masks)
tl.store(k_out_ptr + out_offs, k_res, mask=masks)
tl.store(v_out_ptr + out_offs, v_eles, mask=masks)
tl.store(q_out_ptr + row * DIM + offs, q_res, mask=masks)
tl.store(k_out_ptr + row * DIM + offs, k_res, mask=masks)


def fused_rotary_emb(
x,
q_norm_weight,
q, # [BSZ, SEQ_LEN, DIM]
k,
q_norm_weight, # [DIM]
q_norm_bias,
k_norm_weight,
k_norm_bias,
freqs_cis,
epsilon=1e-5,
):
assert x.is_contiguous()
"""
Examples:

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
import paddle

def apply_rotary_emb(xq, xk, freqs_cis):
xq_ = paddle.as_complex(xq.cast("float32").reshape([*tuple(xq.shape)[:-1], -1, 2]))
xk_ = paddle.as_complex(xk.cast("float32").reshape([*tuple(xk.shape)[:-1], -1, 2]))
shape = [(d if i == 1 or i == xq_.ndim - 1 else 1) for i, d in enumerate(tuple(xq_.shape))]
freqs_cis = freqs_cis.reshape([*shape])
xq_out = paddle.as_real(xq_ * freqs_cis).flatten(start_axis=3)
xk_out = paddle.as_real(xk_ * freqs_cis).flatten(start_axis=3)
return xq_out.cast(xq.dtype), xk_out.cast(xk.dtype)

def get_freqs_cis(dim: int, end: int):
freqs = 1.0 / 10000.0 ** (paddle.arange(start=0, end=dim, step=2)[: dim // 2].cast("float32") / dim)
t = paddle.arange(end=end).cast("float32")
# [SEQ_LEN, HEAD_DIM//2]
freqs = paddle.outer(t, freqs).cast("float32")
# [SEQ_LEN, HEAD_DIM//2]
freqs_cis_ref = paddle.complex(
paddle.ones_like(freqs) * paddle.cos(freqs), paddle.ones_like(freqs) * paddle.sin(freqs)
)
freqs_cis = paddle.stack([
paddle.cos(freqs),
-paddle.sin(freqs),
paddle.sin(freqs),
paddle.cos(freqs)], axis=-1)
# [SEQ_LEN, HEAD_DIM, 2]
freqs_cis = freqs_cis.reshape([freqs_cis.shape[0], -1, 2]).unsqueeze(1)
return freqs_cis, freqs_cis_ref

BSZ = 2
SEQ_LEN = 64
NUM_HEAD = 16
HEAD_DIM = 72
DIM = NUM_HEAD * HEAD_DIM
epsilon = 1e-5
dtype_= "float16"
q = paddle.rand([BSZ, SEQ_LEN, DIM], dtype=dtype_)
k = paddle.rand([BSZ, SEQ_LEN, DIM], dtype=dtype_)
q_norm_weight = paddle.rand([DIM], dtype=dtype_)
k_norm_weight = paddle.rand([DIM], dtype=dtype_)
q_norm_bias = paddle.rand([DIM], dtype=dtype_)
k_norm_bias = paddle.rand([DIM], dtype=dtype_)

freqs_cis, freqs_cis_ref = get_freqs_cis(HEAD_DIM, SEQ_LEN)
freqs_cis = paddle.expand(freqs_cis, [-1, NUM_HEAD, -1, -1])

# paddle
q_ln = paddle.nn.functional.layer_norm(q, [DIM], weight=q_norm_weight, bias=q_norm_bias, epsilon=epsilon)
k_ln = paddle.nn.functional.layer_norm(k, [DIM], weight=k_norm_weight, bias=k_norm_bias, epsilon=epsilon)
q_ln = q_ln.reshape([BSZ, SEQ_LEN, NUM_HEAD, HEAD_DIM])
k_ln = k_ln.reshape([BSZ, SEQ_LEN, NUM_HEAD, HEAD_DIM])
q_out_baseline, k_out_baseline= apply_rotary_emb(q_ln, k_ln, freqs_cis_ref)
q_out_baseline = q_out_baseline.reshape([BSZ, SEQ_LEN, NUM_HEAD, HEAD_DIM])
k_out_baseline = k_out_baseline.reshape([BSZ, SEQ_LEN, NUM_HEAD, HEAD_DIM])

# triton
import paddlemix
q_out, k_out = paddlemix.triton_ops.fused_rotary_emb(q, k, q_norm_weight, q_norm_bias, k_norm_weight, k_norm_bias, freqs_cis, epsilon)

print("Q allclose: ", paddle.allclose(q_out_baseline, q_out, rtol=1e-03, atol=1e-02).numpy())
print("K allclose: ", paddle.allclose(k_out_baseline, k_out, rtol=1e-03, atol=1e-02).numpy())
"""
assert q.shape == k.shape, "q and k should have the same shape"
assert len(q.shape) == 3, "q should be [BSZ, SEQ_LEN, DIM]"
assert q_norm_weight is not None, "q_norm_weight should not be none"
assert q_norm_bias is not None, "q_norm_bias should not be none"
assert k_norm_weight is not None, "k_norm_weight should not be none"
assert k_norm_bias is not None, "k_norm_bias should not be none"
DIM = q_norm_weight.shape[0]
assert (
q_norm_weight.shape == q_norm_bias.shape == k_norm_weight.shape == k_norm_bias.shape
), "q_norm_weight, q_norm_bias, k_norm_weight, k_norm_bias should have the same shape"
assert q.shape[-1] == q_norm_weight.shape[0], "q_norm_weight should be [DIM]"

BSZ, SEQ_LEN, DIM = q.shape
HEAD_DIM = freqs_cis.shape[-2]
assert (DIM % HEAD_DIM) == 0, "dim should be divisible by head_dim"
DIM_concat = x.shape[-1]
assert (DIM * 3) == DIM_concat, "not support GQA, qkv num_head should be equal"

BSZ = x.shape[0]
SEQ_LEN = x.shape[1]
NUM_HEAD = DIM // HEAD_DIM
M = BSZ * SEQ_LEN
DIM_npo2 = triton.next_power_of_2(DIM)
dtype_ = x.dtype
dtype_ = q.dtype

# q_out_tensor = paddle.empty([BSZ, SEQ_LEN, NUM_HEAD, HEAD_DIM], dtype=dtype_)
# k_out_tensor = paddle.empty([BSZ, SEQ_LEN, NUM_HEAD, HEAD_DIM], dtype=dtype_)
# v_out_tensor = paddle.empty([BSZ, SEQ_LEN, NUM_HEAD, HEAD_DIM], dtype=dtype_)
# fused_rotary_emb_kernel[(M,)](
# input_tensor, q_out_tensor, k_out_tensor, v_out_tensor,
# q_norm_weight, q_norm_bias, k_norm_weight, k_norm_bias, freqs_cis, epsilon,
# SEQ_LEN, M, DIM, DIM_concat,
# q, k, q_out_tensor, k_out_tensor, q_norm_weight, q_norm_bias,
# k_norm_weight, k_norm_bias, freqs_cis, epsilon,
# SEQ_LEN, M, DIM,
# DIM_npo2, num_warps=4,
# )
# return q_out_tensor, k_out_tensor, v_out_tensor
# return q_out_tensor, k_out_tensor

op_name = "triton_fused_rotary_emb"
op_name += get_dtype_str(dtype_)
Expand All @@ -1511,13 +1578,12 @@ def fused_rotary_emb(
empty_dtype = dtype_ if dtype_ != paddle.bfloat16 else paddle.float16
q_out_tensor = paddle.empty([BSZ, SEQ_LEN, NUM_HEAD, HEAD_DIM], dtype=empty_dtype).astype(dtype_)
k_out_tensor = paddle.empty([BSZ, SEQ_LEN, NUM_HEAD, HEAD_DIM], dtype=empty_dtype).astype(dtype_)
v_out_tensor = paddle.empty([BSZ, SEQ_LEN, NUM_HEAD, HEAD_DIM], dtype=empty_dtype).astype(dtype_)
grid = ("M",)
fused_rotary_emb_kernel[(op_name, grid, fused_rotary_emb_kernel_config)](
x,
q,
k,
q_out_tensor,
k_out_tensor,
v_out_tensor,
q_norm_weight,
q_norm_bias,
k_norm_weight,
Expand All @@ -1527,28 +1593,29 @@ def fused_rotary_emb(
SEQ_LEN,
M,
DIM,
DIM_concat,
DIM_npo2,
)

if in_dynamic_or_pir_mode():
print(f"== we are in dynamic mode, op_name: {op_name}")
outs = _C_ops._run_custom_op(
op_name,
x,
q,
k,
q_norm_weight,
q_norm_bias,
k_norm_weight,
k_norm_bias,
freqs_cis,
epsilon,
)
return outs[0], outs[1], outs[2]
return outs[0], outs[1]
else:
print(f"== we are in dynamic to static mode, op_name: {op_name}")
helper = LayerHelper(op_name, **locals())
inputs = {
"x": x,
"q": q,
"k": k,
"q_norm_weight": q_norm_weight,
"q_norm_bias": q_norm_bias,
"k_norm_weight": k_norm_weight,
Expand All @@ -1557,13 +1624,12 @@ def fused_rotary_emb(
}
q_out = helper.create_variable_for_type_inference(dtype=dtype_)
k_out = helper.create_variable_for_type_inference(dtype=dtype_)
v_out = helper.create_variable_for_type_inference(dtype=dtype_)
helper.append_op(
type=op_name,
inputs=inputs,
attrs={
"epsilon": epsilon,
},
outputs={"q_out": q_out, "k_out": k_out, "v_out": v_out},
outputs={"q_out": q_out, "k_out": k_out},
)
return q_out, k_out, v_out
return q_out, k_out
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import paddle
from paddlenlp.trainer import set_seed

from ppdiffusers import DDIMScheduler, DiTPipeline

dtype = paddle.float32
dtype = paddle.bfloat16

# True for inference optimizate
os.environ["INFERENCE_OPTIMIZE"] = "False"

with paddle.LazyGuard():
pipe = DiTPipeline.from_pretrained("Alpha-VLLM/Large-DiT-3B-256", paddle_dtype=dtype)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import paddle
from paddlenlp.trainer import set_seed

from ppdiffusers import DDIMScheduler, DiTPipeline

dtype = paddle.float32
dtype = paddle.bfloat16

# True for inference optimizate
os.environ["INFERENCE_OPTIMIZE"] = "False"

with paddle.LazyGuard():
pipe = DiTPipeline.from_pretrained("Alpha-VLLM/Large-DiT-7B-256", paddle_dtype=dtype)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
Expand Down
Loading