diff --git a/llm/config/llama/AdvertiseGen/w8a8_ptq_argument.json b/llm/config/llama/AdvertiseGen/w8a8_ptq_argument.json index b5f4a53f7f23..445182cf0704 100644 --- a/llm/config/llama/AdvertiseGen/w8a8_ptq_argument.json +++ b/llm/config/llama/AdvertiseGen/w8a8_ptq_argument.json @@ -21,6 +21,5 @@ "smooth_piecewise_search": true, "smooth_k_piece": 3, "smooth_search_piece": true, - "act_quant_method": "avg", - "cachekv_quant_method": "avg_headwise" + "act_quant_method": "avg" } \ No newline at end of file diff --git a/llm/config/llama/AdvertiseGen/wfp8afp8_ptq_argument.json b/llm/config/llama/AdvertiseGen/wfp8afp8_ptq_argument.json index 8604d81b766f..ff6fd648e104 100644 --- a/llm/config/llama/AdvertiseGen/wfp8afp8_ptq_argument.json +++ b/llm/config/llama/AdvertiseGen/wfp8afp8_ptq_argument.json @@ -1,8 +1,6 @@ { "model_name_or_path": "meta-llama/Meta-Llama-3.1-8B-Instruct", - "quant_type": "a8w8", - "use_fp8": "WA", - "fp8_type": ["e4m3", "e4m3"], + "quant_type": "a8w8_fp8", "per_device_train_batch_size": 8, "per_device_eval_batch_size": 8, "eval_accumulation_steps":16, @@ -11,7 +9,7 @@ "fp16": true, "fp16_opt_level": "O2", "dataset_name_or_path": "../dataset/AdvertiseGen", - "output_dir": "../output/llama3.1/w8a8_ptq_ckpts_AdvertiseGen", + "output_dir": "../output/llama3.1/wfp8afp8_ptq_ckpts_AdvertiseGen", "do_eval": true, "eval_with_do_generation": false, "do_ptq": true, @@ -19,6 +17,5 @@ "unified_checkpoint": false, "smooth": false, "weight_quant_method": "abs_max", - "act_quant_method": "abs_max", - "cachekv_quant_method": "abs_max" + "act_quant_method": "abs_max" } \ No newline at end of file diff --git a/llm/config/llama/ceval/w8a8_ptq_argument.json b/llm/config/llama/ceval/w8a8_ptq_argument.json index 73aed3234933..06a41592308e 100644 --- a/llm/config/llama/ceval/w8a8_ptq_argument.json +++ b/llm/config/llama/ceval/w8a8_ptq_argument.json @@ -21,6 +21,5 @@ "smooth_piecewise_search": true, "smooth_k_piece": 3, "smooth_search_piece": true, - "act_quant_method": "avg", - "cachekv_quant_method": "avg_headwise" + "act_quant_method": "avg" } \ No newline at end of file diff --git a/llm/config/llama/ceval/wfp8afp8_ptq_argument.json b/llm/config/llama/ceval/wfp8afp8_ptq_argument.json index e7119a0971fc..9e3dfc4c8c01 100644 --- a/llm/config/llama/ceval/wfp8afp8_ptq_argument.json +++ b/llm/config/llama/ceval/wfp8afp8_ptq_argument.json @@ -1,7 +1,6 @@ { "model_name_or_path": "meta-llama/Meta-Llama-3.1-8B-Instruct", - "quant_type": "a8w8", - "use_fp8": "WA", + "quant_type": "a8w8_fp8", "per_device_train_batch_size": 8, "per_device_eval_batch_size": 8, "eval_accumulation_steps":16, @@ -18,6 +17,5 @@ "unified_checkpoint": false, "smooth": false, "weight_quant_method": "abs_max", - "act_quant_method": "abs_max", - "cachekv_quant_method": "abs_max" + "act_quant_method": "abs_max" } \ No newline at end of file diff --git a/llm/config/llama/ceval_ptq_argument.json b/llm/config/llama/ceval_ptq_argument.json deleted file mode 100644 index 7ac0b507fba7..000000000000 --- a/llm/config/llama/ceval_ptq_argument.json +++ /dev/null @@ -1,26 +0,0 @@ -{ - "model_name_or_path": "meta-llama/Meta-Llama-3-8B", - "per_device_train_batch_size": 8, - "per_device_eval_batch_size": 8, - "eval_accumulation_steps":16, - "src_length": 1024, - "max_length": 2048, - "bf16": true, - "fp16_opt_level": "O2", - "dataset_name_or_path": "./data", - "output_dir": "./checkpoints/llama_ptq_ckpts", - "do_eval": true, - "eval_with_do_generation": false, - "do_ptq": false, - "ptq_step": 1, - "unified_checkpoint": false, - "smooth": true, - "smooth_step": 8, - "smooth_all_linears": true, - "smooth_piecewise_search": true, - "smooth_k_piece": 1, - "smooth_search_piece": true, - "load_quant_model": true, - "do_ceval": true, - "ceval_data_path": "../dataset/ceval" -} \ No newline at end of file diff --git a/llm/config/llama/fp8_ptq_argument.json b/llm/config/llama/fp8_ptq_argument.json index b8506102a4a0..4bd45594f775 100644 --- a/llm/config/llama/fp8_ptq_argument.json +++ b/llm/config/llama/fp8_ptq_argument.json @@ -1,8 +1,6 @@ { "model_name_or_path": "meta-llama/Meta-Llama-3-8B", - "quant_type": "W8A8", - "use_fp8": "WA", - "fp8_type": ["e4m3", "e4m3"], + "quant_type": "a8w8_fp8", "per_device_train_batch_size": 8, "per_device_eval_batch_size": 8, "eval_accumulation_steps":16, diff --git a/llm/config/qwen/AdvertiseGen/w8a8_ptq_argument.json b/llm/config/qwen/AdvertiseGen/w8a8_ptq_argument.json index eb5166fd70c1..7ea9895b05bd 100644 --- a/llm/config/qwen/AdvertiseGen/w8a8_ptq_argument.json +++ b/llm/config/qwen/AdvertiseGen/w8a8_ptq_argument.json @@ -22,5 +22,5 @@ "smooth_k_piece": 3, "smooth_search_piece": true, "act_quant_method": "abs_max", - "cachekv_quant_method": "abs_max_headwise" + "skip_list_names": ["down_proj"] } \ No newline at end of file diff --git a/llm/config/qwen/AdvertiseGen/w8a8c8_ptq_argument.json b/llm/config/qwen/AdvertiseGen/w8a8c8_ptq_argument.json index 3ab3a0fd65b3..d37b5c7ceccb 100644 --- a/llm/config/qwen/AdvertiseGen/w8a8c8_ptq_argument.json +++ b/llm/config/qwen/AdvertiseGen/w8a8c8_ptq_argument.json @@ -22,5 +22,6 @@ "smooth_k_piece": 3, "smooth_search_piece": true, "act_quant_method": "abs_max", - "cachekv_quant_method": "abs_max_headwise" + "cachekv_quant_method": "abs_max_headwise", + "skip_list_names": ["down_proj"] } \ No newline at end of file diff --git a/llm/config/qwen/AdvertiseGen/wfp8afp8_ptq_argument.json b/llm/config/qwen/AdvertiseGen/wfp8afp8_ptq_argument.json index 2e80c6d0a4f5..ff261704b8cc 100644 --- a/llm/config/qwen/AdvertiseGen/wfp8afp8_ptq_argument.json +++ b/llm/config/qwen/AdvertiseGen/wfp8afp8_ptq_argument.json @@ -1,8 +1,6 @@ { "model_name_or_path": "Qwen/Qwen2-7B-Instruct", - "quant_type": "a8w8", - "use_fp8": "WA", - "fp8_type": ["e4m3", "e4m3"], + "quant_type": "a8w8_fp8", "per_device_train_batch_size": 8, "per_device_eval_batch_size": 8, "eval_accumulation_steps":16, @@ -20,5 +18,5 @@ "smooth": false, "weight_quant_method": "abs_max", "act_quant_method": "abs_max", - "cachekv_quant_method": "abs_max" + "skip_list_names": ["down_proj"] } \ No newline at end of file diff --git a/llm/config/qwen/ceval/w8a8_ptq_argument.json b/llm/config/qwen/ceval/w8a8_ptq_argument.json index 27cea97d6f1a..c1c7db7bae47 100644 --- a/llm/config/qwen/ceval/w8a8_ptq_argument.json +++ b/llm/config/qwen/ceval/w8a8_ptq_argument.json @@ -22,6 +22,5 @@ "smooth_k_piece": 3, "smooth_search_piece": true, "act_quant_method": "abs_max", - "cachekv_quant_method": "abs_max_headwise", "skip_list_names": ["down_proj"] } \ No newline at end of file diff --git a/llm/config/qwen/ceval/wfp8afp8_ptq_argument.json b/llm/config/qwen/ceval/wfp8afp8_ptq_argument.json index 57af3f55d5ed..f5512a0e105d 100644 --- a/llm/config/qwen/ceval/wfp8afp8_ptq_argument.json +++ b/llm/config/qwen/ceval/wfp8afp8_ptq_argument.json @@ -1,7 +1,6 @@ { "model_name_or_path": "Qwen/Qwen2-7B-Instruct", - "quant_type": "a8w8", - "use_fp8": "WA", + "quant_type": "a8w8_fp8", "per_device_train_batch_size": 8, "per_device_eval_batch_size": 8, "eval_accumulation_steps":16, @@ -18,6 +17,5 @@ "unified_checkpoint": false, "smooth": false, "weight_quant_method": "abs_max", - "act_quant_method": "abs_max", - "cachekv_quant_method": "abs_max" + "act_quant_method": "abs_max" } \ No newline at end of file diff --git a/llm/docs/quantization.md b/llm/docs/quantization.md index 63fd9f7f74ff..9a507c6d1815 100644 --- a/llm/docs/quantization.md +++ b/llm/docs/quantization.md @@ -94,15 +94,19 @@ python run_finetune.py ./config/llama/ptq_c8_argument.json python run_finetune.py ./config/llama/fp8_ptq_argument.json ``` -### 2.9 量化参数介绍 +### 2.8 量化参数介绍   量化参数(QuantArgument)
-- `quant_type`: PTQ,QAT 量化类型,默认为 a8w8(不区分大小写)。支持 a8w8,a8w8c8,wint4/weight_only_int4,wint8/weight_only_int8:a8w8指对激活(输入)进行 8位量化,对模型权重进行 8位量化,具体量化类型通过`use_fp8`字段给出;a8w8c8指对激活、权重、kvcache 进行8位量化,具体量化类型通过`use_fp8`字段给出;wint4/weight_only_int4指仅对模型权重进行 INT4量化,后续使用 WeightOnly 进行推理;wint8/weight_only_int8指仅对模型权重进行 INT8量化,后续使用 WeightOnly 进行推理。 -- `use_fp8`: 是否使用 FP8 量化,默认为空字符串。输入`"WA"`(不区分大小写)则将权重和激活的8位量化转换为 FP8量化。 -- `fp8_type`: FP8量化类型,长度应与`use_fp8`相同。默认为`["e4m3","e4m3"]`。 +- `quant_type`: PTQ,QAT 量化类型,默认为 a8w8(不区分大小写)。支持 a8w8,a8w8c8,a8w8_fp8,wint4/weight_only_int4,wint8/weight_only_int8: + - a8w8指对激活(输入)进行 8位量化,对模型权重进行 INT8量化 + - a8w8c8指对激活、权重、kvcache 进行 INT8量化 + - a8w8_fp8指对激活、权重进行 FP8量化 + - wint4/weight_only_int4指仅对模型权重进行 INT4量化,后续使用 WeightOnly 进行推理 + - wint8/weight_only_int8指仅对模型权重进行 INT8量化,后续使用 WeightOnly 进行推理 +- `fp8_type`: FP8量化类型,指定 activatin,weight 的 fp8类型,默认为`["e4m3","e4m3"]`。 - `do_ptq`: 是否进行 PTQ 量化,默认为 False。 - `weight_quant_method`: 权重量化方式,INT8量化可选 groupwise 或者 abs_max_channel_wise,FP8量化可选 abs_max 或 avg。 - `act_quant_method`: 激活量化方式,INT8可选 avg 或者 abs_max,FP8量化可选 abs_max 或 avg。 diff --git a/llm/experimental/observer/abs_max.py b/llm/experimental/observer/abs_max.py new file mode 100644 index 000000000000..9d30db49cba3 --- /dev/null +++ b/llm/experimental/observer/abs_max.py @@ -0,0 +1,100 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle.quantization.factory import ObserverFactory + +from .uniform import UniformObserver + + +class AbsmaxObserver(ObserverFactory): + r""" + It collects maximum absolute values of target tensor. + Args: + bit_length(int, optional): Number of bits to represent an quantized integer in binary. + dtype(str, optional): The data type of input tensor. + name (str, optional): This parameter is used by developers to print debugging information. \ + For details, please refer to :ref:`api_guide_Name`. Default is None. + Examples: + .. code-block:: python + from paddle.quantization import QuantConfig + from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver + quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.99) + q_config = QuantConfig(activation=quanter, weight=quanter) + """ + + def __init__(self, quant_bits=8): + super(AbsmaxObserver, self).__init__(quant_bits=quant_bits) + + def _get_class(self): + return AbsmaxObserverLayer + + +class AbsmaxObserverLayer(UniformObserver): + def __init__( + self, + layer, + quant_bits=8, + ): + super(AbsmaxObserverLayer, self).__init__(quant_bits=quant_bits) + self._quant_bits = quant_bits + self._layer = layer + self._scale = None + self._zero_point = None + self._min = None + self._max = paddle.to_tensor(1e-7, dtype="float32") + self.step = 0 + + def forward(self, inputs): + """Calculate forward pass.""" + self._min, self._max = self.cal_min_max(inputs) + return inputs + + def cal_min_max(self, inputs): + abs_max_val = paddle.max(paddle.abs(inputs.cast("float32"))) + abs_max_val = paddle.maximum(abs_max_val, self._max) + return 0, abs_max_val + + def cal_thresholds(self): + """Compute thresholds for MAX function.""" + if self._scale is not None: + self._zero_point = 0 + return + self._scale, self._zero_point = self.cal_scales_zero_points() + + def min_value(self) -> float: + return self._min + + def max_value(self) -> float: + return self._max + + def bit_length(self): + """Return the bit length of quantized data.""" + return self._quant_bits + + def quant_axis(self): + """Return quantization axis.""" + return -1 + + def scales(self): + """Return output scales.""" + if self._scale is None: + self.cal_thresholds() + return self._scale + + def zero_points(self): + """Return output zero points.""" + if self._zero_point is None: + self.cal_thresholds() + return self._zero_point diff --git a/llm/experimental/observer/avg.py b/llm/experimental/observer/avg.py new file mode 100644 index 000000000000..c38b3ec45c78 --- /dev/null +++ b/llm/experimental/observer/avg.py @@ -0,0 +1,102 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle.quantization.factory import ObserverFactory + +from .uniform import UniformObserver + + +class AVGObserver(ObserverFactory): + r""" + It collects maximum absolute values of target tensor. + Args: + bit_length(int, optional): Number of bits to represent an quantized integer in binary. + dtype(str, optional): The data type of input tensor. + name (str, optional): This parameter is used by developers to print debugging information. \ + For details, please refer to :ref:`api_guide_Name`. Default is None. + Examples: + .. code-block:: python + from paddle.quantization import QuantConfig + from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver + quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.99) + q_config = QuantConfig(activation=quanter, weight=quanter) + """ + + def __init__(self, quant_bits=8): + super(AVGObserver, self).__init__(quant_bits=quant_bits) + + def _get_class(self): + return AVGObserverLayer + + +class AVGObserverLayer(UniformObserver): + def __init__( + self, + layer, + quant_bits=8, + ): + super(AVGObserverLayer, self).__init__(quant_bits=quant_bits) + self._quant_bits = quant_bits + self._avg_list = [] + + def forward(self, inputs): + """Calculate forward pass.""" + self._scale = None + self._zero_point = None + self._min = None + self._max = None + self._avg_min, self._avg_max = self.cal_min_max(inputs) + self._avg_list.append(self._avg_max) + + return inputs + + def cal_min_max(self, inputs): + abs_avg_value = paddle.abs(inputs.reshape((inputs.shape[0], -1))) + abs_avg_value = float(paddle.mean(paddle.max(abs_avg_value, axis=(1)))) + return 0, abs_avg_value + + def cal_thresholds(self): + """Compute thresholds for MAX function.""" + if self._scale is not None: + self._zero_point = 0 + return + self._min, self._max = self._avg_min, paddle.mean(paddle.to_tensor(self._avg_list)) + self._scale, self._zero_point = self.cal_scales_zero_points() + + def min_value(self) -> float: + return self._min + + def max_value(self) -> float: + return self._max + + def bit_length(self): + """Return the bit length of quantized data.""" + return self._quant_bits + + def quant_axis(self): + """Return quantization axis.""" + return -1 + + def scales(self): + """Return output scales.""" + if self._scale is None: + self.cal_thresholds() + return self._scale + + def zero_points(self): + """Return output zero points.""" + if self._zero_point is None: + self.cal_thresholds() + return self._zero_point diff --git a/llm/experimental/observer/uniform.py b/llm/experimental/observer/uniform.py new file mode 100644 index 000000000000..6c8882f5142f --- /dev/null +++ b/llm/experimental/observer/uniform.py @@ -0,0 +1,114 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +from typing import Tuple + +import numpy as np +from paddle.quantization.base_observer import BaseObserver + + +class UniformObserver(BaseObserver): + """This is the base class for a uniform quantization observer, which provides + common functions for calculating the scale and zero-point used in uniform quantization. + Uniform quantization maps floating point values to integers, where the scale determines + the step size of the quantizer and the floating point zero is mapped to the zero-point, + an integer value ensuring that zero is quantized without error. + + Args: + quant_bits (int): The number of bits for quantization. + sign (bool): Whether the quantized integer includes a sign. + symmetric (bool): Whether it is symmetric quantization. the quantization is symmetric. + In symmetric quantization, the range of floating point values is relaxed to be symmetric + around zero and the zero-point is always 0. + + """ + + def __init__( + self, + quant_bits=8, + sign=True, + symmetric=True, + ): + super(UniformObserver, self).__init__() + self._quant_bits = quant_bits + self._sign = sign + self._symmetric = symmetric + + self._min = None + self._max = None + self._qmin = None + self._qmax = None + + self._scale = None + self._zero_point = None + + @property + def qmin_qmax(self): + """Calculate the range of the quantized integer based on the specified + quant_bits, sign, and symmetric properties.""" + if isinstance(self._quant_bits, tuple): + if self._quant_bits[0] == 4 and self._quant_bits[1] == 3 and len(self._quant_bits) == 2: + self._qmin = -448.0 + self._qmax = 448.0 + elif self._quant_bits[0] == 5 and self._quant_bits[1] == 2 and len(self._quant_bits) == 2: + self._qmin = -57344.0 + self._qmax = 57344.0 + else: + raise NotImplementedError( + "Currently, only float8_e4m3 and float8_e5m2 formats are supported. Please set quant_bits to (4,3) or (5,2) for the corresponding format." + ) + else: + if self._sign: + self._qmin = -(2 ** (self.bit_length() - 1)) + self._qmax = 2 ** (self.bit_length() - 1) - 1 + else: + self._qmin = 0 + self._qmax = 2 ** self.bit_length() + return self._qmin, self._qmax + + @abc.abstractmethod + def min_value(self) -> float: + """The minimum value of floating-point numbers.""" + raise NotImplementedError( + "Please implement the abstract method to get the The minimum value of floating-point numbers." + ) + + @abc.abstractmethod + def max_value(self) -> float: + """The maximum value of floating-point numbers.""" + raise NotImplementedError( + "Please implement the abstract method to get the the maximum value value of floating-point numbers." + ) + + def cal_scales_zero_points(self) -> Tuple[float, float]: + """Calculate the scales and zero points based on the min_value and max_value.""" + assert self.min_value() is not None and self.max_value() is not None + _qmin, _qmax = self.qmin_qmax + # For one-sided distributions, the range (_min , _max ) is relaxed to include zero. + # It is important to ensure that common operations like zero padding do not cause quantization errors. + _min = min(self.min_value(), 0.0) + _max = max(self.max_value(), 0.0) + + if self._symmetric: + self._scale = max(-_min, _max) + if self._sign: + self._zero_point = 0 + else: + self._zero_point = (_qmax + _qmin) / 2 + else: + self._scale = (_max - _min) / float(_qmax - _qmin) + self._zero_point = _qmin - round(_min / self._scale) + self._zero_point = np.clip(self._zero_point, _qmin, _qmax) + return self._scale, self._zero_point diff --git a/llm/utils/argument.py b/llm/utils/argument.py index b5c0e7437c72..006d21aaec90 100644 --- a/llm/utils/argument.py +++ b/llm/utils/argument.py @@ -239,16 +239,9 @@ class QuantArgument: metadata={"help": "Quantization type. Supported values: weight_only_int8, weight_only_int4, a8w8, a8w8c8"}, ) - use_fp8: str = field( - default="", - metadata={ - "help": "Whether to use FP8 on (activation, weight, cachekv), e.g. WAC means weight , activation, cachekv use fp8" - }, - ) - fp8_type: List[str] = field( default_factory=lambda: ["e4m3", "e4m3"], - metadata={"help": "Quantization type for (weight, activation, cachekv)", "nargs": "+"}, + metadata={"help": "Quantization type for (activation, weight)", "nargs": "+"}, ) skip_list_names: List[str] = field( diff --git a/llm/utils/quant.py b/llm/utils/quant.py index 11a22fe6fa5b..8d5cee2f4d74 100644 --- a/llm/utils/quant.py +++ b/llm/utils/quant.py @@ -16,7 +16,9 @@ import paddle from experimental.layers.custom_attention import QuantizedCustomAttentionLayer +from experimental.observer.abs_max import AbsmaxObserver from experimental.observer.abs_max_headwise import AbsMaxHeadwiseObserver +from experimental.observer.avg import AVGObserver from experimental.observer.avg_headwise import AvgHeadwiseObserver from experimental.observer.channel_wise import ChannelWiseObserver from paddle import nn @@ -44,10 +46,8 @@ ) from paddleslim.quant.observers import ( AbsMaxChannelWiseWeightObserver, - AVGObserver, GroupWiseWeightObserver, ) -from paddleslim.quant.observers.abs_max import AbsmaxObserver from paddlenlp.peft import PrefixModelForCausalLM from paddlenlp.peft.lora import ( @@ -214,32 +214,29 @@ def prepare_qconfig(args): Prepare qconfig """ args.quant_type = args.quant_type.lower() - args.use_fp8 = args.use_fp8.lower() + + if args.quant_type in ["a8w8_fp8"]: + use_fp8 = "aw" + args.quant_type = args.quant_type.replace("_fp8", "") + else: + use_fp8 = "" weight_observer = ( WEIGHT_OBSERVER.get(args.weight_quant_method, None) - if "w" not in args.use_fp8 + if "w" not in use_fp8 else FP8_OBSERVER.get(args.weight_quant_method, None) ) act_observer = ( ACT_OBSERVER.get(args.act_quant_method, None) - if "a" not in args.use_fp8 + if "a" not in use_fp8 else FP8_OBSERVER.get(args.act_quant_method, None) ) - cachekv_observer = ( - CACHEKV_OBSERVER.get(args.cachekv_quant_method, None) - if "c" not in args.use_fp8 - else FP8_OBSERVER.get(args.cachekv_quant_method, None) - ) + cachekv_observer = CACHEKV_OBSERVER.get(args.cachekv_quant_method, None) if "c8" in args.quant_type: quant_type = args.quant_type.replace("c8", "") cachekv_quant = True - - if "c" in args.use_fp8: - cachekv_quant_bits = "fp8" - else: - cachekv_quant_bits = "int8" + cachekv_quant_bits = "int8" else: quant_type = args.quant_type.replace("c16", "") cachekv_quant = False @@ -247,13 +244,13 @@ def prepare_qconfig(args): q_config = QuantConfig(activation=None, weight=None) if quant_type in ["a8w8", "w8a8"]: - if "w" in args.use_fp8: - w_quant_bit = (4, 3) if args.fp8_type[args.use_fp8.index("w")] == "e4m3" else (5, 2) + if "w" in use_fp8: + w_quant_bit = (4, 3) if args.fp8_type[use_fp8.index("w")] == "e4m3" else (5, 2) else: w_quant_bit = 8 - if "a" in args.use_fp8: - a_quant_bit = (4, 3) if args.fp8_type[args.use_fp8.index("a")] == "e4m3" else (5, 2) + if "a" in use_fp8: + a_quant_bit = (4, 3) if args.fp8_type[use_fp8.index("a")] == "e4m3" else (5, 2) else: a_quant_bit = 8 activation = act_observer(quant_bits=a_quant_bit) @@ -265,13 +262,13 @@ def prepare_qconfig(args): elif quant_type in ["wint8", "w8a16", "weight_only_int8"]: activation = None - if "w" in args.use_fp8: + if "w" in use_fp8: weight = weight_observer(quant_bits=(4, 3)) else: weight = weight_observer(quant_bits=8) else: raise ValueError( - "quant_type should be in ['weight_only_int8/wint8', 'weight_only_int4/wint4', 'a8w8', 'a8w8c8']" + "quant_type should be in ['weight_only_int8/wint8', 'weight_only_int4/wint4', 'a8w8', 'a8w8c8', 'a8w8_fp8']" ) q_config.add_qat_layer_mapping(ColumnParallelLinear, QuantizedColumnParallelLinear) @@ -292,23 +289,8 @@ def prepare_qconfig(args): cachekv_observer(quant_bits=cachekv_quant_bit), ] q_config.add_qat_layer_mapping(FuncWrapper, QuantizedCustomAttentionLayer) - - elif cachekv_quant_bits == "fp8": - cachekv_quant_bit = (4, 3) if args.fp8_type[args.use_fp8.index("c")] == "e4m3" else (5, 2) - - if "headwise" in args.cachekv_quant_method: - cachekv = [ - cachekv_observer(quant_bits=cachekv_quant_bit, quant_axis=1), - cachekv_observer(quant_bits=cachekv_quant_bit, quant_axis=1), - ] - else: - cachekv = [ - cachekv_observer(quant_bits=cachekv_quant_bit), - cachekv_observer(quant_bits=cachekv_quant_bit), - ] - q_config.add_qat_layer_mapping(FuncWrapper, QuantizedCustomAttentionLayer) else: - raise ValueError("cachekv_quant_bits should be 8") + raise ValueError("cachekv_quant_bits should be int8") return activation, weight, cachekv, q_config