Use pre-compiled python module for ExLlama (#2770)

This commit is contained in:
jllllll 2023-06-24 18:24:17 -05:00 committed by GitHub
parent a70a2ac3be
commit bef67af23c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 41 additions and 16 deletions

View file

@ -9,10 +9,19 @@ from transformers.modeling_outputs import CausalLMOutputWithPast
from modules import shared
from modules.logging_colors import logger
from modules.relative_imports import RelativeImport
with RelativeImport("repositories/exllama"):
from model import ExLlama, ExLlamaCache, ExLlamaConfig
try:
from exllama.model import ExLlama, ExLlamaCache, ExLlamaConfig
except:
logger.warning('Exllama module failed to load. Will attempt to load from repositories.')
try:
from modules.relative_imports import RelativeImport
with RelativeImport("repositories/exllama"):
from model import ExLlama, ExLlamaCache, ExLlamaConfig
except:
logger.error("Could not find repositories/exllama/. Make sure that exllama is cloned inside repositories/ and is up to date.")
raise
class ExllamaHF(PreTrainedModel):
@ -68,7 +77,7 @@ class ExllamaHF(PreTrainedModel):
assert len(model_args) == 0 and len(kwargs) == 0, "extra args is currently not supported"
if isinstance(pretrained_model_name_or_path, str):
pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
pretrained_model_name_or_path = Path(f'{shared.args.model_dir}') / Path(pretrained_model_name_or_path)
config = ExLlamaConfig(pretrained_model_name_or_path / 'config.json')
@ -86,7 +95,7 @@ class ExllamaHF(PreTrainedModel):
if shared.args.gpu_split:
config.set_auto_map(shared.args.gpu_split)
config.gpu_peer_fix = True
# This slowes down a bit but align better with autogptq generation.
# TODO: Should give user choice to tune the exllama config
# config.fused_attn = False