This commit is contained in:
oobabooga 2023-07-12 11:33:25 -07:00
parent 9b55d3a9f9
commit e202190c4f
24 changed files with 146 additions and 125 deletions

View file

@ -66,7 +66,7 @@ def add_lora_autogptq(lora_names):
logger.error("This version of AutoGPTQ does not support LoRA. You need to install from source or wait for a new release.")
return
if len(lora_names) == 0:
if len(lora_names) == 0:
reload_model()
shared.lora_names = []
@ -108,14 +108,14 @@ def add_lora_transformers(lora_names):
# If any LoRA needs to be removed, start over
if len(removed_set) > 0:
# shared.model may no longer be PeftModel
if hasattr(shared.model, 'disable_adapter'):
shared.model.disable_adapter()
if hasattr(shared.model, 'disable_adapter'):
shared.model.disable_adapter()
shared.model = shared.model.base_model.model
if len(lora_names) > 0:
params = {}
if not shared.args.cpu:
if shared.args.load_in_4bit or shared.args.load_in_8bit:
if shared.args.load_in_4bit or shared.args.load_in_8bit:
params['peft_type'] = shared.model.dtype
else:
params['dtype'] = shared.model.dtype

View file

@ -54,14 +54,14 @@ loaders_and_params = {
'trust_remote_code',
'transformers_info'
],
'ExLlama' : [
'ExLlama': [
'gpu_split',
'max_seq_len',
'compress_pos_emb',
'alpha_value',
'exllama_info',
],
'ExLlama_HF' : [
'ExLlama_HF': [
'gpu_split',
'max_seq_len',
'compress_pos_emb',

View file

@ -106,11 +106,11 @@ def load_tokenizer(model_name, model):
use_fast=False
)
except ValueError:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer = AutoTokenizer.from_pretrained(
path_to_model,
trust_remote_code=shared.args.trust_remote_code,
use_fast=True
)
)
if tokenizer.__class__.__name__ == 'LlamaTokenizer':
pairs = [

View file

@ -126,6 +126,7 @@ class RepetitionPenaltyLogitsProcessorWithRange(LogitsProcessor):
'''
Copied from the transformers library
'''
def __init__(self, penalty: float, _range: int):
if not isinstance(penalty, float) or not (penalty > 0):
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")

View file

@ -181,7 +181,7 @@ parser.add_argument("--gradio-auth-path", type=str, help='Set the gradio authent
# API
parser.add_argument('--api', action='store_true', help='Enable the API extension.')
parser.add_argument('--api-blocking-port', type=int, default=5000, help='The listening port for the blocking API.')
parser.add_argument('--api-streaming-port', type=int, default=5005, help='The listening port for the streaming API.')
parser.add_argument('--api-streaming-port', type=int, default=5005, help='The listening port for the streaming API.')
parser.add_argument('--public-api', action='store_true', help='Create a public URL for the API using Cloudfare.')
# Multimodal

View file

@ -116,7 +116,7 @@ def get_available_loras():
def get_datasets(path: str, ext: str):
# include subdirectories for raw txt files to allow training from a subdirectory of txt files
if ext == "txt":
return ['None'] + sorted(set([k.stem for k in list(Path(path).glob('txt'))+list(Path(path).glob('*/')) if k.stem != 'put-trainer-datasets-here']), key=natural_keys)
return ['None'] + sorted(set([k.stem for k in list(Path(path).glob('txt')) + list(Path(path).glob('*/')) if k.stem != 'put-trainer-datasets-here']), key=natural_keys)
return ['None'] + sorted(set([k.stem for k in Path(path).glob(f'*.{ext}') if k.stem != 'put-trainer-datasets-here']), key=natural_keys)