Skip to content

Commit

Permalink
[GPU] Apply dynamic padding for onednn gemm (#24605)
Browse files Browse the repository at this point in the history
### Details:
 - Apply dynamic padding for onednn gemm
 - Update memory descriptor to handle the padded dims and strides

### Tickets:
 - 140516

---------

Signed-off-by: Andrew Park <andrew.park@intel.com>
  • Loading branch information
andrew-k-park committed May 22, 2024
1 parent 50b7316 commit 4b0868c
Show file tree
Hide file tree
Showing 5 changed files with 334 additions and 137 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -681,14 +681,6 @@ void prepare_buffer_fusing::run(program& p) {
if (gather_prim) {
update_dep(gather_prim);
}

// Fallback to ocl impl since oneDNN doesn't support dynamic paddings
for (auto user : node.get_users()) {
if (user->get_preferred_impl_type() == impl_types::onednn) {
GPU_DEBUG_TRACE_DETAIL << user->id() << ": change impl to ocl because of dynamic input paddings\n";
user->set_preferred_impl_type(impl_types::ocl);
}
}
}
});
program_helpers::do_for_types<read_value>(*node, [](read_value_node& node) {
Expand Down
62 changes: 54 additions & 8 deletions src/plugins/intel_gpu/src/graph/impls/onednn/gemm_onednn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ struct gemm_onednn : typed_primitive_onednn_impl<gemm> {
dnnl::memory::data_type& out_dt,
dnnl::memory::dims& in0_dims,
dnnl::memory::dims& in1_dims,
dnnl::memory::dims& in0_strides,
dnnl::memory::dims& in1_strides,
dnnl::memory::dims& out_dims,
dnnl::memory::format_tag& in0_fmt,
dnnl::memory::format_tag& in1_fmt,
Expand Down Expand Up @@ -111,6 +113,22 @@ struct gemm_onednn : typed_primitive_onednn_impl<gemm> {
in1_fmt = onednn::convert_gemm_data_format(in1_dims, in1_l.format);
out_fmt = onednn::convert_gemm_data_format(out_dims, out_l.format);

if (in0_l.data_padding) {
dnnl::memory::dims in0_padded_dims = onednn::convert_gemm_tensor(in0_l.get_buffer_size(), rank, batched_dims_can_be_removed);
if (prim->transpose_input0) {
std::swap(in0_padded_dims[in0_padded_dims.size() - 1], in0_padded_dims[in0_padded_dims.size() - 2]);
}
in0_strides = onednn::get_strides(in0_padded_dims);
}

if (in1_l.data_padding) {
dnnl::memory::dims in1_padded_dims = onednn::convert_gemm_tensor(in1_l.get_buffer_size(), rank, batched_dims_can_be_removed);
if (prim->transpose_input1) {
std::swap(in1_padded_dims[in1_padded_dims.size() - 1], in1_padded_dims[in1_padded_dims.size() - 2]);
}
in1_strides = onednn::get_strides(in1_padded_dims);
}

if (prim->transpose_input0) {
in0_fmt = transpose_format(in0_fmt);
std::swap(in0_dims[in0_dims.size() - 1], in0_dims[in0_dims.size() - 2]);
Expand All @@ -130,6 +148,19 @@ struct gemm_onednn : typed_primitive_onednn_impl<gemm> {
}
}

static dnnl::memory::desc get_input_memory_desc(const dnnl::memory::dims& dims,
dnnl::memory::data_type dt,
dnnl::memory::format_tag fmt,
const dnnl::memory::dims& strides) {
dnnl::memory::desc res;
if (strides.empty()) {
res = dnnl::memory::desc(dims, dt, fmt);
} else {
res = dnnl::memory::desc(dims, dt, strides);
}
return res;
}

static std::shared_ptr<dnnl::matmul::primitive_desc> get_gemm_primitive_descriptor(const kernel_impl_params& impl_params,
const dnnl::primitive_attr& attr = dnnl::primitive_attr()) {
auto& engine = impl_params.prog->get_engine();
Expand All @@ -146,16 +177,19 @@ struct gemm_onednn : typed_primitive_onednn_impl<gemm> {
dnnl::memory::dims out_dims;
dnnl::memory::dims bias_dims;

dnnl::memory::dims in0_strides;
dnnl::memory::dims in1_strides;

dnnl::memory::format_tag in0_fmt;
dnnl::memory::format_tag in1_fmt;
dnnl::memory::format_tag out_fmt;
dnnl::memory::format_tag bias_fmt;

get_gemm_primitive_md(impl_params, in0_dt, in1_dt, out_dt, in0_dims, in1_dims, out_dims, in0_fmt, in1_fmt, out_fmt,
gemm_with_bias, bias_dt, bias_dims, bias_fmt);
get_gemm_primitive_md(impl_params, in0_dt, in1_dt, out_dt, in0_dims, in1_dims, in0_strides, in1_strides,
out_dims, in0_fmt, in1_fmt, out_fmt, gemm_with_bias, bias_dt, bias_dims, bias_fmt);

dnnl::memory::desc in0_md(in0_dims, in0_dt, in0_fmt);
dnnl::memory::desc in1_md(in1_dims, in1_dt, in1_fmt);
dnnl::memory::desc in0_md = get_input_memory_desc(in0_dims, in0_dt, in0_fmt, in0_strides);
dnnl::memory::desc in1_md = get_input_memory_desc(in1_dims, in1_dt, in1_fmt, in1_strides);
dnnl::memory::desc out_md(out_dims, out_dt, out_fmt);

if (gemm_with_bias) {
Expand Down Expand Up @@ -199,13 +233,16 @@ struct gemm_onednn : typed_primitive_onednn_impl<gemm> {
dnnl::memory::dims out_dims;
dnnl::memory::dims bias_dims;

dnnl::memory::dims in0_strides;
dnnl::memory::dims in1_strides;

dnnl::memory::format_tag in0_fmt;
dnnl::memory::format_tag in1_fmt;
dnnl::memory::format_tag out_fmt;
dnnl::memory::format_tag bias_fmt;

get_gemm_primitive_md(*impl_params, in0_dt, in1_dt, out_dt, in0_dims, in1_dims, out_dims, in0_fmt, in1_fmt, out_fmt,
gemm_with_bias, bias_dt, bias_dims, bias_fmt);
get_gemm_primitive_md(*impl_params, in0_dt, in1_dt, out_dt, in0_dims, in1_dims, in0_strides, in1_strides,
out_dims, in0_fmt, in1_fmt, out_fmt, gemm_with_bias, bias_dt, bias_dims, bias_fmt);

ob << make_data(&in0_dt, sizeof(dnnl::memory::data_type));
ob << make_data(&in1_dt, sizeof(dnnl::memory::data_type));
Expand All @@ -215,6 +252,9 @@ struct gemm_onednn : typed_primitive_onednn_impl<gemm> {
ob << in1_dims;
ob << out_dims;

ob << in0_strides;
ob << in1_strides;

ob << make_data(&in0_fmt, sizeof(dnnl::memory::format_tag));
ob << make_data(&in1_fmt, sizeof(dnnl::memory::format_tag));
ob << make_data(&out_fmt, sizeof(dnnl::memory::format_tag));
Expand Down Expand Up @@ -248,6 +288,9 @@ struct gemm_onednn : typed_primitive_onednn_impl<gemm> {
dnnl::memory::dims out_dims;
dnnl::memory::dims bias_dims;

dnnl::memory::dims in0_strides;
dnnl::memory::dims in1_strides;

dnnl::memory::format_tag in0_fmt = dnnl::memory::format_tag::undef;
dnnl::memory::format_tag in1_fmt = dnnl::memory::format_tag::undef;
dnnl::memory::format_tag out_fmt = dnnl::memory::format_tag::undef;
Expand All @@ -261,6 +304,9 @@ struct gemm_onednn : typed_primitive_onednn_impl<gemm> {
ib >> in1_dims;
ib >> out_dims;

ib >> in0_strides;
ib >> in1_strides;

ib >> make_data(&in0_fmt, sizeof(dnnl::memory::format_tag));
ib >> make_data(&in1_fmt, sizeof(dnnl::memory::format_tag));
ib >> make_data(&out_fmt, sizeof(dnnl::memory::format_tag));
Expand All @@ -271,8 +317,8 @@ struct gemm_onednn : typed_primitive_onednn_impl<gemm> {
ib >> make_data(&bias_fmt, sizeof(dnnl::memory::format_tag));
}

dnnl::memory::desc in0_md(in0_dims, in0_dt, in0_fmt);
dnnl::memory::desc in1_md(in1_dims, in1_dt, in1_fmt);
dnnl::memory::desc in0_md = get_input_memory_desc(in0_dims, in0_dt, in0_fmt, in0_strides);
dnnl::memory::desc in1_md = get_input_memory_desc(in1_dims, in1_dt, in1_fmt, in1_strides);
dnnl::memory::desc out_md(out_dims, out_dt, out_fmt);

if (gemm_with_bias) {
Expand Down
6 changes: 6 additions & 0 deletions src/plugins/intel_gpu/src/graph/impls/onednn/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,12 @@ dnnl::memory::dims flatten_tensor(cldnn::tensor t) {
return {static_cast<int64_t>(t.count())};
}

dnnl::memory::dims get_strides(dnnl::memory::dims dims) {
dnnl::memory::dims strides(dims.size(), dnnl::memory::dim(1));
std::partial_sum(dims.rbegin(), dims.rend() - 1, strides.rbegin() + 1, std::multiplies<dnnl::memory::dim>());
return strides;
}

dnnl::memory::data_type convert_data_type(cldnn::data_types dt) {
switch (dt) {
case cldnn::data_types::f32: return dnnl::memory::data_type::f32;
Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_gpu/src/graph/impls/onednn/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ dnnl::memory::dims convert_tensor(cldnn::tensor t, size_t dims = 2, bool is_grou
dnnl::memory::dims convert_gemm_tensor(cldnn::tensor t, size_t dims, bool batched_dims_can_be_removed);
dnnl::memory::dims convert_spatials(cldnn::tensor t, size_t dims = 2);
dnnl::memory::dims flatten_tensor(cldnn::tensor t);
dnnl::memory::dims get_strides(dnnl::memory::dims dims);
dnnl::memory::data_type convert_data_type(cldnn::data_types dt);
dnnl::memory::format_tag convert_data_format(cldnn::format fmt);
cldnn::format convert_data_format(dnnl::memory::format_tag fmt);
Expand Down
Loading

0 comments on commit 4b0868c

Please sign in to comment.