Use pre-compiled python module for ExLlama (#2770)

This commit is contained in:
jllllll 2023-06-24 18:24:17 -05:00 committed by GitHub
parent a70a2ac3be
commit bef67af23c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 41 additions and 16 deletions

View file

@ -20,10 +20,13 @@ def add_lora_to_model(lora_names):
def add_lora_exllama(lora_names):
try:
from repositories.exllama.lora import ExLlamaLora
from exllama.lora import ExLlamaLora
except:
logger.error("Could not find the file repositories/exllama/lora.py. Make sure that exllama is cloned inside repositories/ and is up to date.")
return
try:
from repositories.exllama.lora import ExLlamaLora
except:
logger.error("Could not find the file repositories/exllama/lora.py. Make sure that exllama is cloned inside repositories/ and is up to date.")
return
if len(lora_names) == 0:
shared.model.generator.lora = None

View file

@ -3,12 +3,23 @@ from pathlib import Path
from modules import shared
from modules.logging_colors import logger
from modules.relative_imports import RelativeImport
with RelativeImport("repositories/exllama"):
from generator import ExLlamaGenerator
from model import ExLlama, ExLlamaCache, ExLlamaConfig
from tokenizer import ExLlamaTokenizer
try:
from exllama.generator import ExLlamaGenerator
from exllama.model import ExLlama, ExLlamaCache, ExLlamaConfig
from exllama.tokenizer import ExLlamaTokenizer
except:
logger.warning('Exllama module failed to load. Will attempt to load from repositories.')
try:
from modules.relative_imports import RelativeImport
with RelativeImport("repositories/exllama"):
from generator import ExLlamaGenerator
from model import ExLlama, ExLlamaCache, ExLlamaConfig
from tokenizer import ExLlamaTokenizer
except:
logger.error("Could not find repositories/exllama/. Make sure that exllama is cloned inside repositories/ and is up to date.")
raise
class ExllamaModel:

View file

@ -9,10 +9,19 @@ from transformers.modeling_outputs import CausalLMOutputWithPast
from modules import shared
from modules.logging_colors import logger
from modules.relative_imports import RelativeImport
with RelativeImport("repositories/exllama"):
from model import ExLlama, ExLlamaCache, ExLlamaConfig
try:
from exllama.model import ExLlama, ExLlamaCache, ExLlamaConfig
except:
logger.warning('Exllama module failed to load. Will attempt to load from repositories.')
try:
from modules.relative_imports import RelativeImport
with RelativeImport("repositories/exllama"):
from model import ExLlama, ExLlamaCache, ExLlamaConfig
except:
logger.error("Could not find repositories/exllama/. Make sure that exllama is cloned inside repositories/ and is up to date.")
raise
class ExllamaHF(PreTrainedModel):
@ -68,7 +77,7 @@ class ExllamaHF(PreTrainedModel):
assert len(model_args) == 0 and len(kwargs) == 0, "extra args is currently not supported"
if isinstance(pretrained_model_name_or_path, str):
pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
pretrained_model_name_or_path = Path(f'{shared.args.model_dir}') / Path(pretrained_model_name_or_path)
config = ExLlamaConfig(pretrained_model_name_or_path / 'config.json')
@ -86,7 +95,7 @@ class ExllamaHF(PreTrainedModel):
if shared.args.gpu_split:
config.set_auto_map(shared.args.gpu_split)
config.gpu_peer_fix = True
# This slowes down a bit but align better with autogptq generation.
# TODO: Should give user choice to tune the exllama config
# config.fused_attn = False