Add cache_8bit option
This commit is contained in:
parent
42f816312d
commit
c0655475ae
7 changed files with 32 additions and 5 deletions
|
|
@ -4,7 +4,12 @@ from pathlib import Path
|
|||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
from exllamav2 import ExLlamaV2, ExLlamaV2Cache, ExLlamaV2Config
|
||||
from exllamav2 import (
|
||||
ExLlamaV2,
|
||||
ExLlamaV2Cache,
|
||||
ExLlamaV2Cache_8bit,
|
||||
ExLlamaV2Config
|
||||
)
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers import GenerationConfig, PretrainedConfig, PreTrainedModel
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
|
|
@ -40,11 +45,18 @@ class Exllamav2HF(PreTrainedModel):
|
|||
self.generation_config = GenerationConfig()
|
||||
self.loras = None
|
||||
|
||||
self.ex_cache = ExLlamaV2Cache(self.ex_model)
|
||||
self.past_seq = None
|
||||
if shared.args.cache_8bit:
|
||||
self.ex_cache = ExLlamaV2Cache_8bit(self.ex_model)
|
||||
else:
|
||||
self.ex_cache = ExLlamaV2Cache(self.ex_model)
|
||||
|
||||
self.past_seq = None
|
||||
if shared.args.cfg_cache:
|
||||
self.ex_cache_negative = ExLlamaV2Cache(self.ex_model)
|
||||
if shared.args.cache_8bit:
|
||||
self.ex_cache_negative = ExLlamaV2Cache_8bit(self.ex_model)
|
||||
else:
|
||||
self.ex_cache_negative = ExLlamaV2Cache(self.ex_model)
|
||||
|
||||
self.past_seq_negative = None
|
||||
|
||||
def _validate_model_class(self):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue