Skip to content

Commit

Permalink
Some minor fixes to the GPTQ loader
Browse files Browse the repository at this point in the history
  • Loading branch information
oobabooga committed Mar 13, 2023
1 parent 8778b75 commit 518e5c4
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions modules/quant_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import modules.shared as shared

sys.path.insert(0, str(Path("repositories/GPTQ-for-LLaMa")))
import llama
import opt


def load_quantized(model_name):
Expand All @@ -21,9 +23,9 @@ def load_quantized(model_name):
model_type = shared.args.gptq_model_type.lower()

if model_type == 'llama':
from llama import load_quant
load_quant = llama.load_quant
elif model_type == 'opt':
from opt import load_quant
load_quant = opt.load_quant
else:
print("Unknown pre-quantized model type specified. Only 'llama' and 'opt' are supported")
exit()
Expand All @@ -50,7 +52,7 @@ def load_quantized(model_name):
print(f"Could not find {pt_model}, exiting...")
exit()

model = load_quant(path_to_model, str(pt_path), shared.args.gptq_bits)
model = load_quant(str(path_to_model), str(pt_path), shared.args.gptq_bits)

# Multiple GPUs or GPU+CPU
if shared.args.gpu_memory:
Expand Down

0 comments on commit 518e5c4

Please sign in to comment.