From 9f48adaad21f93176fbcc2d990279cc80e600eac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABl?= Date: Wed, 11 Sep 2024 15:32:26 +0300 Subject: [PATCH] feat(gguf): add falcon q2 k --- docs/source/en/gguf.md | 1 + src/transformers/integrations/ggml.py | 41 ++++++++++++++++++++++++++- tests/quantization/ggml/test_ggml.py | 27 ++++++++++++++++++ 3 files changed, 68 insertions(+), 1 deletion(-) diff --git a/docs/source/en/gguf.md b/docs/source/en/gguf.md index 8e6741a306d898..914fb00794e6c1 100644 --- a/docs/source/en/gguf.md +++ b/docs/source/en/gguf.md @@ -80,6 +80,7 @@ For now the supported model architectures are the architectures that have been v - Qwen2 - Qwen2Moe - Phi3 +- Falcon ## Example usage diff --git a/src/transformers/integrations/ggml.py b/src/transformers/integrations/ggml.py index b5471574a13d8b..fd98d6a62a18a3 100644 --- a/src/transformers/integrations/ggml.py +++ b/src/transformers/integrations/ggml.py @@ -25,7 +25,7 @@ from tokenizers.models import BPE from .. import AddedToken -from ..convert_slow_tokenizer import LlamaConverter, Qwen2Converter +from ..convert_slow_tokenizer import GPT2Converter, LlamaConverter, Qwen2Converter from ..utils import logging from ..utils.logging import tqdm @@ -107,6 +107,19 @@ "output.weight": "lm_head.weight", "output_norm": "model.norm", }, + "falcon": { + "token_embd": "model.embed_tokens", + "blk": "model.h", + "ffn_up": "mlp.dense_h_to_4h", + "ffn_down": "mlp.dense_4h_to_h", + "ffn_gate": "mlp.gate_proj", + "ffn_norm": "post_attention_layernorm", + "attn_norm": "ln_mlp", + "attn_qkv": "self_attention.query_key_value", + "attn_output": "self_attention.o_proj", + "output.weight": "lm_head.weight", + "output_norm": "model.ln_f", + }, } @@ -163,6 +176,18 @@ "attention.layer_norm_rms_epsilon": "rms_norm_eps", "vocab_size": "vocab_size", }, + "falcon": { + "context_length": "max_position_embeddings", + "block_count": "num_hidden_layers", + "feed_forward_length": "intermediate_size", + "embedding_length": "hidden_size", + "rope.dimension_count": None, + "rope.freq_base": "rope_theta", + "attention.head_count": "num_attention_heads", + "attention.head_count_kv": "num_key_value_heads", + "attention.layer_norm_rms_epsilon": "rms_norm_eps", + "vocab_size": "vocab_size", + }, "tokenizer": { "ggml.bos_token_id": "bos_token_id", "ggml.eos_token_id": "eos_token_id", @@ -490,11 +515,25 @@ def converted(self) -> Tokenizer: return tokenizer +class GGUFFalcon2Converter(GPT2Converter): + def __init__(self, tokenizer_dict): + self.original_tokenizer = GGUFTokenizerSkeleton(tokenizer_dict) + self.additional_kwargs = {} + + def converted(self) -> Tokenizer: + vocab = {word: i for i, word in enumerate(self.original_tokenizer.tokens)} + merges = self.original_tokenizer.merges + tokenizer = super().converted(vocab, merges) + + return tokenizer + + GGUF_TO_FAST_CONVERTERS = { "llama": GGUFLlamaConverter, "qwen2": GGUFQwen2Converter, "qwen2_moe": GGUFQwen2Converter, "phi3": GGUFPhi3Converter, + "falcon": GGUFFalcon2Converter, } diff --git a/tests/quantization/ggml/test_ggml.py b/tests/quantization/ggml/test_ggml.py index 6d3bb3f5337185..04c1b89b10c912 100644 --- a/tests/quantization/ggml/test_ggml.py +++ b/tests/quantization/ggml/test_ggml.py @@ -42,6 +42,7 @@ class GgufIntegrationTests(unittest.TestCase): llama3_model_id = "NousResearch/Meta-Llama-3-8B-GGUF" tinyllama_model_id = "PenutChen/TinyLlama-1.1B-Chat-v1.0-GGUF" phi3_model_id = "microsoft/Phi-3-mini-4k-instruct-gguf" + falcon_model_id = "YokaiKoibito/falcon-40b-GGUF" # standard quants q4_0_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_0.gguf" @@ -71,6 +72,17 @@ class GgufIntegrationTests(unittest.TestCase): q4_llama3_model_id = "Meta-Llama-3-8B-Q4_K_M.gguf" f16_tinyllama_model_id = "TinyLlama-1.1B-Chat-v1.0.FP16.gguf" + q2_k_falcon_model_id = "falcon-40b-Q2_K.gguf" + q3_k_l_falcon_model_id = "falcon-40b-Q3_K_L.gguf" + q3_k_m_falcon_model_id = "falcon-40b-Q3_K_M.gguf" + q3_k_s_falcon_model_id = "falcon-40b-Q3_K_S.gguf" + q4_k_m_falcon_model_id = "falcon-40b-Q4_K_M.gguf" + q4_k_s_falcon_model_id = "falcon-40b-Q4_K_S.gguf" + q5_k_m_falcon_model_id = "falcon-40b-Q5_K_M.gguf" + q5_k_s_falcon_model_id = "falcon-40b-Q5_K_S.gguf" + q6_k_falcon_model_id = "falcon-40b-Q6_K.gguf" + q8_0_falcon_model_id = "falcon-40b-Q8_0.gguf" + example_text = "Hello" def test_q2_k(self): @@ -385,6 +397,21 @@ def test_llama3_q4_0(self): EXPECTED_TEXT = "Hello, I am interested in [The Park]\nThe" self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) + def test_falcon_q2_k(self): + tokenizer = AutoTokenizer.from_pretrained(self.falcon_model_id, gguf_file=self.q2_k_falcon_model_id) + model = AutoModelForCausalLM.from_pretrained( + self.falcon_model_id, + gguf_file=self.q2_k_falcon_model_id, + device_map="auto", + torch_dtype=torch.float16, + ) + + text = tokenizer(self.example_text, return_tensors="pt").to(torch_device) + out = model.generate(**text, max_new_tokens=10) + + EXPECTED_TEXT = "Hello, I am interested in [The Park]\nThe" + self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) + def test_tokenization_xnli(self): import tqdm from datasets import load_dataset