diff --git a/llm/config/llama/AdvertiseGen/w8a8_ptq_argument.json b/llm/config/llama/AdvertiseGen/w8a8_ptq_argument.json
index b5f4a53f7f23..445182cf0704 100644
--- a/llm/config/llama/AdvertiseGen/w8a8_ptq_argument.json
+++ b/llm/config/llama/AdvertiseGen/w8a8_ptq_argument.json
@@ -21,6 +21,5 @@
"smooth_piecewise_search": true,
"smooth_k_piece": 3,
"smooth_search_piece": true,
- "act_quant_method": "avg",
- "cachekv_quant_method": "avg_headwise"
+ "act_quant_method": "avg"
}
\ No newline at end of file
diff --git a/llm/config/llama/AdvertiseGen/wfp8afp8_ptq_argument.json b/llm/config/llama/AdvertiseGen/wfp8afp8_ptq_argument.json
index 8604d81b766f..ff6fd648e104 100644
--- a/llm/config/llama/AdvertiseGen/wfp8afp8_ptq_argument.json
+++ b/llm/config/llama/AdvertiseGen/wfp8afp8_ptq_argument.json
@@ -1,8 +1,6 @@
{
"model_name_or_path": "meta-llama/Meta-Llama-3.1-8B-Instruct",
- "quant_type": "a8w8",
- "use_fp8": "WA",
- "fp8_type": ["e4m3", "e4m3"],
+ "quant_type": "a8w8_fp8",
"per_device_train_batch_size": 8,
"per_device_eval_batch_size": 8,
"eval_accumulation_steps":16,
@@ -11,7 +9,7 @@
"fp16": true,
"fp16_opt_level": "O2",
"dataset_name_or_path": "../dataset/AdvertiseGen",
- "output_dir": "../output/llama3.1/w8a8_ptq_ckpts_AdvertiseGen",
+ "output_dir": "../output/llama3.1/wfp8afp8_ptq_ckpts_AdvertiseGen",
"do_eval": true,
"eval_with_do_generation": false,
"do_ptq": true,
@@ -19,6 +17,5 @@
"unified_checkpoint": false,
"smooth": false,
"weight_quant_method": "abs_max",
- "act_quant_method": "abs_max",
- "cachekv_quant_method": "abs_max"
+ "act_quant_method": "abs_max"
}
\ No newline at end of file
diff --git a/llm/config/llama/ceval/w8a8_ptq_argument.json b/llm/config/llama/ceval/w8a8_ptq_argument.json
index 73aed3234933..06a41592308e 100644
--- a/llm/config/llama/ceval/w8a8_ptq_argument.json
+++ b/llm/config/llama/ceval/w8a8_ptq_argument.json
@@ -21,6 +21,5 @@
"smooth_piecewise_search": true,
"smooth_k_piece": 3,
"smooth_search_piece": true,
- "act_quant_method": "avg",
- "cachekv_quant_method": "avg_headwise"
+ "act_quant_method": "avg"
}
\ No newline at end of file
diff --git a/llm/config/llama/ceval/wfp8afp8_ptq_argument.json b/llm/config/llama/ceval/wfp8afp8_ptq_argument.json
index e7119a0971fc..9e3dfc4c8c01 100644
--- a/llm/config/llama/ceval/wfp8afp8_ptq_argument.json
+++ b/llm/config/llama/ceval/wfp8afp8_ptq_argument.json
@@ -1,7 +1,6 @@
{
"model_name_or_path": "meta-llama/Meta-Llama-3.1-8B-Instruct",
- "quant_type": "a8w8",
- "use_fp8": "WA",
+ "quant_type": "a8w8_fp8",
"per_device_train_batch_size": 8,
"per_device_eval_batch_size": 8,
"eval_accumulation_steps":16,
@@ -18,6 +17,5 @@
"unified_checkpoint": false,
"smooth": false,
"weight_quant_method": "abs_max",
- "act_quant_method": "abs_max",
- "cachekv_quant_method": "abs_max"
+ "act_quant_method": "abs_max"
}
\ No newline at end of file
diff --git a/llm/config/llama/ceval_ptq_argument.json b/llm/config/llama/ceval_ptq_argument.json
deleted file mode 100644
index 7ac0b507fba7..000000000000
--- a/llm/config/llama/ceval_ptq_argument.json
+++ /dev/null
@@ -1,26 +0,0 @@
-{
- "model_name_or_path": "meta-llama/Meta-Llama-3-8B",
- "per_device_train_batch_size": 8,
- "per_device_eval_batch_size": 8,
- "eval_accumulation_steps":16,
- "src_length": 1024,
- "max_length": 2048,
- "bf16": true,
- "fp16_opt_level": "O2",
- "dataset_name_or_path": "./data",
- "output_dir": "./checkpoints/llama_ptq_ckpts",
- "do_eval": true,
- "eval_with_do_generation": false,
- "do_ptq": false,
- "ptq_step": 1,
- "unified_checkpoint": false,
- "smooth": true,
- "smooth_step": 8,
- "smooth_all_linears": true,
- "smooth_piecewise_search": true,
- "smooth_k_piece": 1,
- "smooth_search_piece": true,
- "load_quant_model": true,
- "do_ceval": true,
- "ceval_data_path": "../dataset/ceval"
-}
\ No newline at end of file
diff --git a/llm/config/llama/fp8_ptq_argument.json b/llm/config/llama/fp8_ptq_argument.json
index b8506102a4a0..4bd45594f775 100644
--- a/llm/config/llama/fp8_ptq_argument.json
+++ b/llm/config/llama/fp8_ptq_argument.json
@@ -1,8 +1,6 @@
{
"model_name_or_path": "meta-llama/Meta-Llama-3-8B",
- "quant_type": "W8A8",
- "use_fp8": "WA",
- "fp8_type": ["e4m3", "e4m3"],
+ "quant_type": "a8w8_fp8",
"per_device_train_batch_size": 8,
"per_device_eval_batch_size": 8,
"eval_accumulation_steps":16,
diff --git a/llm/config/qwen/AdvertiseGen/w8a8_ptq_argument.json b/llm/config/qwen/AdvertiseGen/w8a8_ptq_argument.json
index eb5166fd70c1..7ea9895b05bd 100644
--- a/llm/config/qwen/AdvertiseGen/w8a8_ptq_argument.json
+++ b/llm/config/qwen/AdvertiseGen/w8a8_ptq_argument.json
@@ -22,5 +22,5 @@
"smooth_k_piece": 3,
"smooth_search_piece": true,
"act_quant_method": "abs_max",
- "cachekv_quant_method": "abs_max_headwise"
+ "skip_list_names": ["down_proj"]
}
\ No newline at end of file
diff --git a/llm/config/qwen/AdvertiseGen/w8a8c8_ptq_argument.json b/llm/config/qwen/AdvertiseGen/w8a8c8_ptq_argument.json
index 3ab3a0fd65b3..d37b5c7ceccb 100644
--- a/llm/config/qwen/AdvertiseGen/w8a8c8_ptq_argument.json
+++ b/llm/config/qwen/AdvertiseGen/w8a8c8_ptq_argument.json
@@ -22,5 +22,6 @@
"smooth_k_piece": 3,
"smooth_search_piece": true,
"act_quant_method": "abs_max",
- "cachekv_quant_method": "abs_max_headwise"
+ "cachekv_quant_method": "abs_max_headwise",
+ "skip_list_names": ["down_proj"]
}
\ No newline at end of file
diff --git a/llm/config/qwen/AdvertiseGen/wfp8afp8_ptq_argument.json b/llm/config/qwen/AdvertiseGen/wfp8afp8_ptq_argument.json
index 2e80c6d0a4f5..ff261704b8cc 100644
--- a/llm/config/qwen/AdvertiseGen/wfp8afp8_ptq_argument.json
+++ b/llm/config/qwen/AdvertiseGen/wfp8afp8_ptq_argument.json
@@ -1,8 +1,6 @@
{
"model_name_or_path": "Qwen/Qwen2-7B-Instruct",
- "quant_type": "a8w8",
- "use_fp8": "WA",
- "fp8_type": ["e4m3", "e4m3"],
+ "quant_type": "a8w8_fp8",
"per_device_train_batch_size": 8,
"per_device_eval_batch_size": 8,
"eval_accumulation_steps":16,
@@ -20,5 +18,5 @@
"smooth": false,
"weight_quant_method": "abs_max",
"act_quant_method": "abs_max",
- "cachekv_quant_method": "abs_max"
+ "skip_list_names": ["down_proj"]
}
\ No newline at end of file
diff --git a/llm/config/qwen/ceval/w8a8_ptq_argument.json b/llm/config/qwen/ceval/w8a8_ptq_argument.json
index 27cea97d6f1a..c1c7db7bae47 100644
--- a/llm/config/qwen/ceval/w8a8_ptq_argument.json
+++ b/llm/config/qwen/ceval/w8a8_ptq_argument.json
@@ -22,6 +22,5 @@
"smooth_k_piece": 3,
"smooth_search_piece": true,
"act_quant_method": "abs_max",
- "cachekv_quant_method": "abs_max_headwise",
"skip_list_names": ["down_proj"]
}
\ No newline at end of file
diff --git a/llm/config/qwen/ceval/wfp8afp8_ptq_argument.json b/llm/config/qwen/ceval/wfp8afp8_ptq_argument.json
index 57af3f55d5ed..f5512a0e105d 100644
--- a/llm/config/qwen/ceval/wfp8afp8_ptq_argument.json
+++ b/llm/config/qwen/ceval/wfp8afp8_ptq_argument.json
@@ -1,7 +1,6 @@
{
"model_name_or_path": "Qwen/Qwen2-7B-Instruct",
- "quant_type": "a8w8",
- "use_fp8": "WA",
+ "quant_type": "a8w8_fp8",
"per_device_train_batch_size": 8,
"per_device_eval_batch_size": 8,
"eval_accumulation_steps":16,
@@ -18,6 +17,5 @@
"unified_checkpoint": false,
"smooth": false,
"weight_quant_method": "abs_max",
- "act_quant_method": "abs_max",
- "cachekv_quant_method": "abs_max"
+ "act_quant_method": "abs_max"
}
\ No newline at end of file
diff --git a/llm/docs/quantization.md b/llm/docs/quantization.md
index 63fd9f7f74ff..9a507c6d1815 100644
--- a/llm/docs/quantization.md
+++ b/llm/docs/quantization.md
@@ -94,15 +94,19 @@ python run_finetune.py ./config/llama/ptq_c8_argument.json
python run_finetune.py ./config/llama/fp8_ptq_argument.json
```
-### 2.9 量化参数介绍
+### 2.8 量化参数介绍
量化参数(QuantArgument)
-- `quant_type`: PTQ,QAT 量化类型,默认为 a8w8(不区分大小写)。支持 a8w8,a8w8c8,wint4/weight_only_int4,wint8/weight_only_int8:a8w8指对激活(输入)进行 8位量化,对模型权重进行 8位量化,具体量化类型通过`use_fp8`字段给出;a8w8c8指对激活、权重、kvcache 进行8位量化,具体量化类型通过`use_fp8`字段给出;wint4/weight_only_int4指仅对模型权重进行 INT4量化,后续使用 WeightOnly 进行推理;wint8/weight_only_int8指仅对模型权重进行 INT8量化,后续使用 WeightOnly 进行推理。
-- `use_fp8`: 是否使用 FP8 量化,默认为空字符串。输入`"WA"`(不区分大小写)则将权重和激活的8位量化转换为 FP8量化。
-- `fp8_type`: FP8量化类型,长度应与`use_fp8`相同。默认为`["e4m3","e4m3"]`。
+- `quant_type`: PTQ,QAT 量化类型,默认为 a8w8(不区分大小写)。支持 a8w8,a8w8c8,a8w8_fp8,wint4/weight_only_int4,wint8/weight_only_int8:
+ - a8w8指对激活(输入)进行 8位量化,对模型权重进行 INT8量化
+ - a8w8c8指对激活、权重、kvcache 进行 INT8量化
+ - a8w8_fp8指对激活、权重进行 FP8量化
+ - wint4/weight_only_int4指仅对模型权重进行 INT4量化,后续使用 WeightOnly 进行推理
+ - wint8/weight_only_int8指仅对模型权重进行 INT8量化,后续使用 WeightOnly 进行推理
+- `fp8_type`: FP8量化类型,指定 activatin,weight 的 fp8类型,默认为`["e4m3","e4m3"]`。
- `do_ptq`: 是否进行 PTQ 量化,默认为 False。
- `weight_quant_method`: 权重量化方式,INT8量化可选 groupwise 或者 abs_max_channel_wise,FP8量化可选 abs_max 或 avg。
- `act_quant_method`: 激活量化方式,INT8可选 avg 或者 abs_max,FP8量化可选 abs_max 或 avg。
diff --git a/llm/experimental/observer/abs_max.py b/llm/experimental/observer/abs_max.py
new file mode 100644
index 000000000000..9d30db49cba3
--- /dev/null
+++ b/llm/experimental/observer/abs_max.py
@@ -0,0 +1,100 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import paddle
+from paddle.quantization.factory import ObserverFactory
+
+from .uniform import UniformObserver
+
+
+class AbsmaxObserver(ObserverFactory):
+ r"""
+ It collects maximum absolute values of target tensor.
+ Args:
+ bit_length(int, optional): Number of bits to represent an quantized integer in binary.
+ dtype(str, optional): The data type of input tensor.
+ name (str, optional): This parameter is used by developers to print debugging information. \
+ For details, please refer to :ref:`api_guide_Name`. Default is None.
+ Examples:
+ .. code-block:: python
+ from paddle.quantization import QuantConfig
+ from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver
+ quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.99)
+ q_config = QuantConfig(activation=quanter, weight=quanter)
+ """
+
+ def __init__(self, quant_bits=8):
+ super(AbsmaxObserver, self).__init__(quant_bits=quant_bits)
+
+ def _get_class(self):
+ return AbsmaxObserverLayer
+
+
+class AbsmaxObserverLayer(UniformObserver):
+ def __init__(
+ self,
+ layer,
+ quant_bits=8,
+ ):
+ super(AbsmaxObserverLayer, self).__init__(quant_bits=quant_bits)
+ self._quant_bits = quant_bits
+ self._layer = layer
+ self._scale = None
+ self._zero_point = None
+ self._min = None
+ self._max = paddle.to_tensor(1e-7, dtype="float32")
+ self.step = 0
+
+ def forward(self, inputs):
+ """Calculate forward pass."""
+ self._min, self._max = self.cal_min_max(inputs)
+ return inputs
+
+ def cal_min_max(self, inputs):
+ abs_max_val = paddle.max(paddle.abs(inputs.cast("float32")))
+ abs_max_val = paddle.maximum(abs_max_val, self._max)
+ return 0, abs_max_val
+
+ def cal_thresholds(self):
+ """Compute thresholds for MAX function."""
+ if self._scale is not None:
+ self._zero_point = 0
+ return
+ self._scale, self._zero_point = self.cal_scales_zero_points()
+
+ def min_value(self) -> float:
+ return self._min
+
+ def max_value(self) -> float:
+ return self._max
+
+ def bit_length(self):
+ """Return the bit length of quantized data."""
+ return self._quant_bits
+
+ def quant_axis(self):
+ """Return quantization axis."""
+ return -1
+
+ def scales(self):
+ """Return output scales."""
+ if self._scale is None:
+ self.cal_thresholds()
+ return self._scale
+
+ def zero_points(self):
+ """Return output zero points."""
+ if self._zero_point is None:
+ self.cal_thresholds()
+ return self._zero_point
diff --git a/llm/experimental/observer/avg.py b/llm/experimental/observer/avg.py
new file mode 100644
index 000000000000..c38b3ec45c78
--- /dev/null
+++ b/llm/experimental/observer/avg.py
@@ -0,0 +1,102 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import paddle
+from paddle.quantization.factory import ObserverFactory
+
+from .uniform import UniformObserver
+
+
+class AVGObserver(ObserverFactory):
+ r"""
+ It collects maximum absolute values of target tensor.
+ Args:
+ bit_length(int, optional): Number of bits to represent an quantized integer in binary.
+ dtype(str, optional): The data type of input tensor.
+ name (str, optional): This parameter is used by developers to print debugging information. \
+ For details, please refer to :ref:`api_guide_Name`. Default is None.
+ Examples:
+ .. code-block:: python
+ from paddle.quantization import QuantConfig
+ from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver
+ quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.99)
+ q_config = QuantConfig(activation=quanter, weight=quanter)
+ """
+
+ def __init__(self, quant_bits=8):
+ super(AVGObserver, self).__init__(quant_bits=quant_bits)
+
+ def _get_class(self):
+ return AVGObserverLayer
+
+
+class AVGObserverLayer(UniformObserver):
+ def __init__(
+ self,
+ layer,
+ quant_bits=8,
+ ):
+ super(AVGObserverLayer, self).__init__(quant_bits=quant_bits)
+ self._quant_bits = quant_bits
+ self._avg_list = []
+
+ def forward(self, inputs):
+ """Calculate forward pass."""
+ self._scale = None
+ self._zero_point = None
+ self._min = None
+ self._max = None
+ self._avg_min, self._avg_max = self.cal_min_max(inputs)
+ self._avg_list.append(self._avg_max)
+
+ return inputs
+
+ def cal_min_max(self, inputs):
+ abs_avg_value = paddle.abs(inputs.reshape((inputs.shape[0], -1)))
+ abs_avg_value = float(paddle.mean(paddle.max(abs_avg_value, axis=(1))))
+ return 0, abs_avg_value
+
+ def cal_thresholds(self):
+ """Compute thresholds for MAX function."""
+ if self._scale is not None:
+ self._zero_point = 0
+ return
+ self._min, self._max = self._avg_min, paddle.mean(paddle.to_tensor(self._avg_list))
+ self._scale, self._zero_point = self.cal_scales_zero_points()
+
+ def min_value(self) -> float:
+ return self._min
+
+ def max_value(self) -> float:
+ return self._max
+
+ def bit_length(self):
+ """Return the bit length of quantized data."""
+ return self._quant_bits
+
+ def quant_axis(self):
+ """Return quantization axis."""
+ return -1
+
+ def scales(self):
+ """Return output scales."""
+ if self._scale is None:
+ self.cal_thresholds()
+ return self._scale
+
+ def zero_points(self):
+ """Return output zero points."""
+ if self._zero_point is None:
+ self.cal_thresholds()
+ return self._zero_point
diff --git a/llm/experimental/observer/uniform.py b/llm/experimental/observer/uniform.py
new file mode 100644
index 000000000000..6c8882f5142f
--- /dev/null
+++ b/llm/experimental/observer/uniform.py
@@ -0,0 +1,114 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import abc
+from typing import Tuple
+
+import numpy as np
+from paddle.quantization.base_observer import BaseObserver
+
+
+class UniformObserver(BaseObserver):
+ """This is the base class for a uniform quantization observer, which provides
+ common functions for calculating the scale and zero-point used in uniform quantization.
+ Uniform quantization maps floating point values to integers, where the scale determines
+ the step size of the quantizer and the floating point zero is mapped to the zero-point,
+ an integer value ensuring that zero is quantized without error.
+
+ Args:
+ quant_bits (int): The number of bits for quantization.
+ sign (bool): Whether the quantized integer includes a sign.
+ symmetric (bool): Whether it is symmetric quantization. the quantization is symmetric.
+ In symmetric quantization, the range of floating point values is relaxed to be symmetric
+ around zero and the zero-point is always 0.
+
+ """
+
+ def __init__(
+ self,
+ quant_bits=8,
+ sign=True,
+ symmetric=True,
+ ):
+ super(UniformObserver, self).__init__()
+ self._quant_bits = quant_bits
+ self._sign = sign
+ self._symmetric = symmetric
+
+ self._min = None
+ self._max = None
+ self._qmin = None
+ self._qmax = None
+
+ self._scale = None
+ self._zero_point = None
+
+ @property
+ def qmin_qmax(self):
+ """Calculate the range of the quantized integer based on the specified
+ quant_bits, sign, and symmetric properties."""
+ if isinstance(self._quant_bits, tuple):
+ if self._quant_bits[0] == 4 and self._quant_bits[1] == 3 and len(self._quant_bits) == 2:
+ self._qmin = -448.0
+ self._qmax = 448.0
+ elif self._quant_bits[0] == 5 and self._quant_bits[1] == 2 and len(self._quant_bits) == 2:
+ self._qmin = -57344.0
+ self._qmax = 57344.0
+ else:
+ raise NotImplementedError(
+ "Currently, only float8_e4m3 and float8_e5m2 formats are supported. Please set quant_bits to (4,3) or (5,2) for the corresponding format."
+ )
+ else:
+ if self._sign:
+ self._qmin = -(2 ** (self.bit_length() - 1))
+ self._qmax = 2 ** (self.bit_length() - 1) - 1
+ else:
+ self._qmin = 0
+ self._qmax = 2 ** self.bit_length()
+ return self._qmin, self._qmax
+
+ @abc.abstractmethod
+ def min_value(self) -> float:
+ """The minimum value of floating-point numbers."""
+ raise NotImplementedError(
+ "Please implement the abstract method to get the The minimum value of floating-point numbers."
+ )
+
+ @abc.abstractmethod
+ def max_value(self) -> float:
+ """The maximum value of floating-point numbers."""
+ raise NotImplementedError(
+ "Please implement the abstract method to get the the maximum value value of floating-point numbers."
+ )
+
+ def cal_scales_zero_points(self) -> Tuple[float, float]:
+ """Calculate the scales and zero points based on the min_value and max_value."""
+ assert self.min_value() is not None and self.max_value() is not None
+ _qmin, _qmax = self.qmin_qmax
+ # For one-sided distributions, the range (_min , _max ) is relaxed to include zero.
+ # It is important to ensure that common operations like zero padding do not cause quantization errors.
+ _min = min(self.min_value(), 0.0)
+ _max = max(self.max_value(), 0.0)
+
+ if self._symmetric:
+ self._scale = max(-_min, _max)
+ if self._sign:
+ self._zero_point = 0
+ else:
+ self._zero_point = (_qmax + _qmin) / 2
+ else:
+ self._scale = (_max - _min) / float(_qmax - _qmin)
+ self._zero_point = _qmin - round(_min / self._scale)
+ self._zero_point = np.clip(self._zero_point, _qmin, _qmax)
+ return self._scale, self._zero_point
diff --git a/llm/utils/argument.py b/llm/utils/argument.py
index b5c0e7437c72..006d21aaec90 100644
--- a/llm/utils/argument.py
+++ b/llm/utils/argument.py
@@ -239,16 +239,9 @@ class QuantArgument:
metadata={"help": "Quantization type. Supported values: weight_only_int8, weight_only_int4, a8w8, a8w8c8"},
)
- use_fp8: str = field(
- default="",
- metadata={
- "help": "Whether to use FP8 on (activation, weight, cachekv), e.g. WAC means weight , activation, cachekv use fp8"
- },
- )
-
fp8_type: List[str] = field(
default_factory=lambda: ["e4m3", "e4m3"],
- metadata={"help": "Quantization type for (weight, activation, cachekv)", "nargs": "+"},
+ metadata={"help": "Quantization type for (activation, weight)", "nargs": "+"},
)
skip_list_names: List[str] = field(
diff --git a/llm/utils/quant.py b/llm/utils/quant.py
index 11a22fe6fa5b..8d5cee2f4d74 100644
--- a/llm/utils/quant.py
+++ b/llm/utils/quant.py
@@ -16,7 +16,9 @@
import paddle
from experimental.layers.custom_attention import QuantizedCustomAttentionLayer
+from experimental.observer.abs_max import AbsmaxObserver
from experimental.observer.abs_max_headwise import AbsMaxHeadwiseObserver
+from experimental.observer.avg import AVGObserver
from experimental.observer.avg_headwise import AvgHeadwiseObserver
from experimental.observer.channel_wise import ChannelWiseObserver
from paddle import nn
@@ -44,10 +46,8 @@
)
from paddleslim.quant.observers import (
AbsMaxChannelWiseWeightObserver,
- AVGObserver,
GroupWiseWeightObserver,
)
-from paddleslim.quant.observers.abs_max import AbsmaxObserver
from paddlenlp.peft import PrefixModelForCausalLM
from paddlenlp.peft.lora import (
@@ -214,32 +214,29 @@ def prepare_qconfig(args):
Prepare qconfig
"""
args.quant_type = args.quant_type.lower()
- args.use_fp8 = args.use_fp8.lower()
+
+ if args.quant_type in ["a8w8_fp8"]:
+ use_fp8 = "aw"
+ args.quant_type = args.quant_type.replace("_fp8", "")
+ else:
+ use_fp8 = ""
weight_observer = (
WEIGHT_OBSERVER.get(args.weight_quant_method, None)
- if "w" not in args.use_fp8
+ if "w" not in use_fp8
else FP8_OBSERVER.get(args.weight_quant_method, None)
)
act_observer = (
ACT_OBSERVER.get(args.act_quant_method, None)
- if "a" not in args.use_fp8
+ if "a" not in use_fp8
else FP8_OBSERVER.get(args.act_quant_method, None)
)
- cachekv_observer = (
- CACHEKV_OBSERVER.get(args.cachekv_quant_method, None)
- if "c" not in args.use_fp8
- else FP8_OBSERVER.get(args.cachekv_quant_method, None)
- )
+ cachekv_observer = CACHEKV_OBSERVER.get(args.cachekv_quant_method, None)
if "c8" in args.quant_type:
quant_type = args.quant_type.replace("c8", "")
cachekv_quant = True
-
- if "c" in args.use_fp8:
- cachekv_quant_bits = "fp8"
- else:
- cachekv_quant_bits = "int8"
+ cachekv_quant_bits = "int8"
else:
quant_type = args.quant_type.replace("c16", "")
cachekv_quant = False
@@ -247,13 +244,13 @@ def prepare_qconfig(args):
q_config = QuantConfig(activation=None, weight=None)
if quant_type in ["a8w8", "w8a8"]:
- if "w" in args.use_fp8:
- w_quant_bit = (4, 3) if args.fp8_type[args.use_fp8.index("w")] == "e4m3" else (5, 2)
+ if "w" in use_fp8:
+ w_quant_bit = (4, 3) if args.fp8_type[use_fp8.index("w")] == "e4m3" else (5, 2)
else:
w_quant_bit = 8
- if "a" in args.use_fp8:
- a_quant_bit = (4, 3) if args.fp8_type[args.use_fp8.index("a")] == "e4m3" else (5, 2)
+ if "a" in use_fp8:
+ a_quant_bit = (4, 3) if args.fp8_type[use_fp8.index("a")] == "e4m3" else (5, 2)
else:
a_quant_bit = 8
activation = act_observer(quant_bits=a_quant_bit)
@@ -265,13 +262,13 @@ def prepare_qconfig(args):
elif quant_type in ["wint8", "w8a16", "weight_only_int8"]:
activation = None
- if "w" in args.use_fp8:
+ if "w" in use_fp8:
weight = weight_observer(quant_bits=(4, 3))
else:
weight = weight_observer(quant_bits=8)
else:
raise ValueError(
- "quant_type should be in ['weight_only_int8/wint8', 'weight_only_int4/wint4', 'a8w8', 'a8w8c8']"
+ "quant_type should be in ['weight_only_int8/wint8', 'weight_only_int4/wint4', 'a8w8', 'a8w8c8', 'a8w8_fp8']"
)
q_config.add_qat_layer_mapping(ColumnParallelLinear, QuantizedColumnParallelLinear)
@@ -292,23 +289,8 @@ def prepare_qconfig(args):
cachekv_observer(quant_bits=cachekv_quant_bit),
]
q_config.add_qat_layer_mapping(FuncWrapper, QuantizedCustomAttentionLayer)
-
- elif cachekv_quant_bits == "fp8":
- cachekv_quant_bit = (4, 3) if args.fp8_type[args.use_fp8.index("c")] == "e4m3" else (5, 2)
-
- if "headwise" in args.cachekv_quant_method:
- cachekv = [
- cachekv_observer(quant_bits=cachekv_quant_bit, quant_axis=1),
- cachekv_observer(quant_bits=cachekv_quant_bit, quant_axis=1),
- ]
- else:
- cachekv = [
- cachekv_observer(quant_bits=cachekv_quant_bit),
- cachekv_observer(quant_bits=cachekv_quant_bit),
- ]
- q_config.add_qat_layer_mapping(FuncWrapper, QuantizedCustomAttentionLayer)
else:
- raise ValueError("cachekv_quant_bits should be 8")
+ raise ValueError("cachekv_quant_bits should be int8")
return activation, weight, cachekv, q_config