Add cache_8bit option

This commit is contained in:
oobabooga 2023-11-02 11:23:04 -07:00
parent 42f816312d
commit c0655475ae
7 changed files with 32 additions and 5 deletions

View file

@ -6,6 +6,7 @@ import torch
from exllamav2 import (
ExLlamaV2,
ExLlamaV2Cache,
ExLlamaV2Cache_8bit,
ExLlamaV2Config,
ExLlamaV2Tokenizer
)
@ -57,7 +58,11 @@ class Exllamav2Model:
model.load(split)
tokenizer = ExLlamaV2Tokenizer(config)
cache = ExLlamaV2Cache(model)
if shared.args.cache_8bit:
cache = ExLlamaV2Cache_8bit(model)
else:
cache = ExLlamaV2Cache(model)
generator = ExLlamaV2BaseGenerator(model, cache, tokenizer)
result = self()