Skip to content

Commit

Permalink
fix weight_only_linear_pass
Browse files Browse the repository at this point in the history
  • Loading branch information
freeliuzc committed Dec 25, 2023
1 parent 54d4bf3 commit 97e9192
Showing 1 changed file with 9 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,14 @@ class FusedWeightOnlyLinearPattern
return getSMVersion();
});

const auto &group_size_attr = res.Attr(
[](const pir::drr::MatchContext &match_ctx) -> int { return -1; });

const auto &weight_quantize =
res.Op(paddle::dialect::WeightQuantizeOp::name(),
{{"algo", weight_only_int8_attr}, {"arch", arch_attr}});
{{"algo", weight_only_int8_attr},
{"arch", arch_attr},
{"group_size", group_size_attr}});
weight_quantize({&res.Tensor("w")},
{&res.Tensor("quanted_weight_tensor"),
&res.Tensor("weight_scale_tensor")});
Expand All @@ -110,7 +115,9 @@ class FusedWeightOnlyLinearPattern

const auto &weight_only_linear =
res.Op(paddle::dialect::WeightOnlyLinearOp::name(),
{{"weight_dtype", weight_dtype_attr}, {"arch", arch_attr}});
{{"weight_dtype", weight_dtype_attr},
{"arch", arch_attr},
{"group_size", group_size_attr}});
weight_only_linear({&res.Tensor("x"),
&res.Tensor("quanted_weight_tensor"),
&res.Tensor("bias"),
Expand Down

0 comments on commit 97e9192

Please sign in to comment.