Skip to content

Commit

Permalink
determine model type from model name
Browse files Browse the repository at this point in the history
  • Loading branch information
Zerogoki00 committed Mar 13, 2023
1 parent b6c5c57 commit a6a6522
Showing 1 changed file with 21 additions and 3 deletions.
24 changes: 21 additions & 3 deletions modules/quant_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,17 @@
sys.path.insert(0, str(Path("repositories/GPTQ-for-LLaMa")))


# 4-bit LLaMA
def load_quantized(model_name, model_type):
def load_quantized(model_name):
if not shared.args.gptq_model_type:
# Try to determine model type from model name
model_type = model_name.split('-')[0].lower()
if model_type not in ('llama', 'opt'):
print("Can't determine model type from model name. Please specify it manually using --gptq-model-type "
"argument")
exit()
else:
model_type = shared.args.gptq_model_type.lower()

if model_type == 'llama':
from llama import load_quant
elif model_type == 'opt':
Expand All @@ -20,7 +29,16 @@ def load_quantized(model_name, model_type):
exit()

path_to_model = Path(f'models/{model_name}')
pt_model = f'{model_name}-{shared.args.gptq_bits}bit.pt'
if path_to_model.name.lower().startswith('llama-7b'):
pt_model = f'llama-7b-{shared.args.gptq_bits}bit.pt'
elif path_to_model.name.lower().startswith('llama-13b'):
pt_model = f'llama-13b-{shared.args.gptq_bits}bit.pt'
elif path_to_model.name.lower().startswith('llama-30b'):
pt_model = f'llama-30b-{shared.args.gptq_bits}bit.pt'
elif path_to_model.name.lower().startswith('llama-65b'):
pt_model = f'llama-65b-{shared.args.gptq_bits}bit.pt'
else:
pt_model = f'{model_name}-{shared.args.gptq_bits}bit.pt'

# Try to find the .pt both in models/ and in the subfolder
pt_path = None
Expand Down

0 comments on commit a6a6522

Please sign in to comment.