Skip to content

Commit

Permalink
feat(gguf): add falcon q2 k
Browse files Browse the repository at this point in the history
  • Loading branch information
g-prz committed Sep 13, 2024
1 parent ecf7024 commit 9f48ada
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/source/en/gguf.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ For now the supported model architectures are the architectures that have been v
- Qwen2
- Qwen2Moe
- Phi3
- Falcon

## Example usage

Expand Down
41 changes: 40 additions & 1 deletion src/transformers/integrations/ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from tokenizers.models import BPE

from .. import AddedToken
from ..convert_slow_tokenizer import LlamaConverter, Qwen2Converter
from ..convert_slow_tokenizer import GPT2Converter, LlamaConverter, Qwen2Converter
from ..utils import logging
from ..utils.logging import tqdm

Expand Down Expand Up @@ -107,6 +107,19 @@
"output.weight": "lm_head.weight",
"output_norm": "model.norm",
},
"falcon": {
"token_embd": "model.embed_tokens",
"blk": "model.h",
"ffn_up": "mlp.dense_h_to_4h",
"ffn_down": "mlp.dense_4h_to_h",
"ffn_gate": "mlp.gate_proj",
"ffn_norm": "post_attention_layernorm",
"attn_norm": "ln_mlp",
"attn_qkv": "self_attention.query_key_value",
"attn_output": "self_attention.o_proj",
"output.weight": "lm_head.weight",
"output_norm": "model.ln_f",
},
}


Expand Down Expand Up @@ -163,6 +176,18 @@
"attention.layer_norm_rms_epsilon": "rms_norm_eps",
"vocab_size": "vocab_size",
},
"falcon": {
"context_length": "max_position_embeddings",
"block_count": "num_hidden_layers",
"feed_forward_length": "intermediate_size",
"embedding_length": "hidden_size",
"rope.dimension_count": None,
"rope.freq_base": "rope_theta",
"attention.head_count": "num_attention_heads",
"attention.head_count_kv": "num_key_value_heads",
"attention.layer_norm_rms_epsilon": "rms_norm_eps",
"vocab_size": "vocab_size",
},
"tokenizer": {
"ggml.bos_token_id": "bos_token_id",
"ggml.eos_token_id": "eos_token_id",
Expand Down Expand Up @@ -490,11 +515,25 @@ def converted(self) -> Tokenizer:
return tokenizer


class GGUFFalcon2Converter(GPT2Converter):
def __init__(self, tokenizer_dict):
self.original_tokenizer = GGUFTokenizerSkeleton(tokenizer_dict)
self.additional_kwargs = {}

def converted(self) -> Tokenizer:
vocab = {word: i for i, word in enumerate(self.original_tokenizer.tokens)}
merges = self.original_tokenizer.merges
tokenizer = super().converted(vocab, merges)

return tokenizer


GGUF_TO_FAST_CONVERTERS = {
"llama": GGUFLlamaConverter,
"qwen2": GGUFQwen2Converter,
"qwen2_moe": GGUFQwen2Converter,
"phi3": GGUFPhi3Converter,
"falcon": GGUFFalcon2Converter,
}


Expand Down
27 changes: 27 additions & 0 deletions tests/quantization/ggml/test_ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class GgufIntegrationTests(unittest.TestCase):
llama3_model_id = "NousResearch/Meta-Llama-3-8B-GGUF"
tinyllama_model_id = "PenutChen/TinyLlama-1.1B-Chat-v1.0-GGUF"
phi3_model_id = "microsoft/Phi-3-mini-4k-instruct-gguf"
falcon_model_id = "YokaiKoibito/falcon-40b-GGUF"

# standard quants
q4_0_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_0.gguf"
Expand Down Expand Up @@ -71,6 +72,17 @@ class GgufIntegrationTests(unittest.TestCase):
q4_llama3_model_id = "Meta-Llama-3-8B-Q4_K_M.gguf"
f16_tinyllama_model_id = "TinyLlama-1.1B-Chat-v1.0.FP16.gguf"

q2_k_falcon_model_id = "falcon-40b-Q2_K.gguf"
q3_k_l_falcon_model_id = "falcon-40b-Q3_K_L.gguf"
q3_k_m_falcon_model_id = "falcon-40b-Q3_K_M.gguf"
q3_k_s_falcon_model_id = "falcon-40b-Q3_K_S.gguf"
q4_k_m_falcon_model_id = "falcon-40b-Q4_K_M.gguf"
q4_k_s_falcon_model_id = "falcon-40b-Q4_K_S.gguf"
q5_k_m_falcon_model_id = "falcon-40b-Q5_K_M.gguf"
q5_k_s_falcon_model_id = "falcon-40b-Q5_K_S.gguf"
q6_k_falcon_model_id = "falcon-40b-Q6_K.gguf"
q8_0_falcon_model_id = "falcon-40b-Q8_0.gguf"

example_text = "Hello"

def test_q2_k(self):
Expand Down Expand Up @@ -385,6 +397,21 @@ def test_llama3_q4_0(self):
EXPECTED_TEXT = "Hello, I am interested in [The Park]\nThe"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)

def test_falcon_q2_k(self):
tokenizer = AutoTokenizer.from_pretrained(self.falcon_model_id, gguf_file=self.q2_k_falcon_model_id)
model = AutoModelForCausalLM.from_pretrained(
self.falcon_model_id,
gguf_file=self.q2_k_falcon_model_id,
device_map="auto",
torch_dtype=torch.float16,
)

text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)

EXPECTED_TEXT = "Hello, I am interested in [The Park]\nThe"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)

def test_tokenization_xnli(self):
import tqdm
from datasets import load_dataset
Expand Down

0 comments on commit 9f48ada

Please sign in to comment.