Add tail-free and top-a sampling (#2357)

This commit is contained in:
Luis Lopez 2023-05-30 08:40:01 +08:00 committed by GitHub
parent b4662bf4af
commit 9e7204bef4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 113 additions and 5 deletions

View file

@ -15,7 +15,7 @@ from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM,
BitsAndBytesConfig, LlamaTokenizer)
import modules.shared as shared
from modules import llama_attn_hijack
from modules import llama_attn_hijack, sampler_hijack
from modules.logging_colors import logger
transformers.logging.set_verbosity_error()
@ -36,6 +36,8 @@ if shared.args.deepspeed:
ds_config = generate_ds_config(shared.args.bf16, 1 * world_size, shared.args.nvme_offload_dir)
dschf = HfDeepSpeedConfig(ds_config) # Keep this object alive for the Transformers integration
sampler_hijack.hijack_samplers()
# Some models require special treatment in various parts of the code.
# This function detects those models