lint
This commit is contained in:
parent
9b55d3a9f9
commit
e202190c4f
24 changed files with 146 additions and 125 deletions
|
|
@ -66,7 +66,7 @@ def add_lora_autogptq(lora_names):
|
|||
logger.error("This version of AutoGPTQ does not support LoRA. You need to install from source or wait for a new release.")
|
||||
return
|
||||
|
||||
if len(lora_names) == 0:
|
||||
if len(lora_names) == 0:
|
||||
reload_model()
|
||||
|
||||
shared.lora_names = []
|
||||
|
|
@ -108,14 +108,14 @@ def add_lora_transformers(lora_names):
|
|||
# If any LoRA needs to be removed, start over
|
||||
if len(removed_set) > 0:
|
||||
# shared.model may no longer be PeftModel
|
||||
if hasattr(shared.model, 'disable_adapter'):
|
||||
shared.model.disable_adapter()
|
||||
if hasattr(shared.model, 'disable_adapter'):
|
||||
shared.model.disable_adapter()
|
||||
shared.model = shared.model.base_model.model
|
||||
|
||||
if len(lora_names) > 0:
|
||||
params = {}
|
||||
if not shared.args.cpu:
|
||||
if shared.args.load_in_4bit or shared.args.load_in_8bit:
|
||||
if shared.args.load_in_4bit or shared.args.load_in_8bit:
|
||||
params['peft_type'] = shared.model.dtype
|
||||
else:
|
||||
params['dtype'] = shared.model.dtype
|
||||
|
|
|
|||
|
|
@ -54,14 +54,14 @@ loaders_and_params = {
|
|||
'trust_remote_code',
|
||||
'transformers_info'
|
||||
],
|
||||
'ExLlama' : [
|
||||
'ExLlama': [
|
||||
'gpu_split',
|
||||
'max_seq_len',
|
||||
'compress_pos_emb',
|
||||
'alpha_value',
|
||||
'exllama_info',
|
||||
],
|
||||
'ExLlama_HF' : [
|
||||
'ExLlama_HF': [
|
||||
'gpu_split',
|
||||
'max_seq_len',
|
||||
'compress_pos_emb',
|
||||
|
|
|
|||
|
|
@ -106,11 +106,11 @@ def load_tokenizer(model_name, model):
|
|||
use_fast=False
|
||||
)
|
||||
except ValueError:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
path_to_model,
|
||||
trust_remote_code=shared.args.trust_remote_code,
|
||||
use_fast=True
|
||||
)
|
||||
)
|
||||
|
||||
if tokenizer.__class__.__name__ == 'LlamaTokenizer':
|
||||
pairs = [
|
||||
|
|
|
|||
|
|
@ -126,6 +126,7 @@ class RepetitionPenaltyLogitsProcessorWithRange(LogitsProcessor):
|
|||
'''
|
||||
Copied from the transformers library
|
||||
'''
|
||||
|
||||
def __init__(self, penalty: float, _range: int):
|
||||
if not isinstance(penalty, float) or not (penalty > 0):
|
||||
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
|
||||
|
|
|
|||
|
|
@ -181,7 +181,7 @@ parser.add_argument("--gradio-auth-path", type=str, help='Set the gradio authent
|
|||
# API
|
||||
parser.add_argument('--api', action='store_true', help='Enable the API extension.')
|
||||
parser.add_argument('--api-blocking-port', type=int, default=5000, help='The listening port for the blocking API.')
|
||||
parser.add_argument('--api-streaming-port', type=int, default=5005, help='The listening port for the streaming API.')
|
||||
parser.add_argument('--api-streaming-port', type=int, default=5005, help='The listening port for the streaming API.')
|
||||
parser.add_argument('--public-api', action='store_true', help='Create a public URL for the API using Cloudfare.')
|
||||
|
||||
# Multimodal
|
||||
|
|
|
|||
|
|
@ -116,7 +116,7 @@ def get_available_loras():
|
|||
def get_datasets(path: str, ext: str):
|
||||
# include subdirectories for raw txt files to allow training from a subdirectory of txt files
|
||||
if ext == "txt":
|
||||
return ['None'] + sorted(set([k.stem for k in list(Path(path).glob('txt'))+list(Path(path).glob('*/')) if k.stem != 'put-trainer-datasets-here']), key=natural_keys)
|
||||
return ['None'] + sorted(set([k.stem for k in list(Path(path).glob('txt')) + list(Path(path).glob('*/')) if k.stem != 'put-trainer-datasets-here']), key=natural_keys)
|
||||
|
||||
return ['None'] + sorted(set([k.stem for k in Path(path).glob(f'*.{ext}') if k.stem != 'put-trainer-datasets-here']), key=natural_keys)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue