Add --max-gpu-memory parameter for #7

This commit is contained in:
oobabooga 2023-01-15 22:33:35 -03:00
parent bb1a172da0
commit ebf4d5f506
2 changed files with 10 additions and 4 deletions

View file

@ -131,6 +131,7 @@ Optionally, you can use the following command-line flags:
| `--cpu` | Use the CPU to generate text.| | `--cpu` | Use the CPU to generate text.|
| `--auto-devices` | Automatically split the model across the available GPU(s) and CPU.| | `--auto-devices` | Automatically split the model across the available GPU(s) and CPU.|
| `--load-in-8bit` | Load the model with 8-bit precision.| | `--load-in-8bit` | Load the model with 8-bit precision.|
| `----max-gpu-memory MAX_GPU_MEMORY` | Maximum memory in GiB to allocate to the GPU while loading the model. This is useful if get out of memory errors while trying to generate text. Must be an integer number. |
| `--no-listen` | Make the webui unreachable from your local network.| | `--no-listen` | Make the webui unreachable from your local network.|
| `--settings-file SETTINGS_FILE` | Load default interface settings from this json file. See settings-template.json for an example.| | `--settings-file SETTINGS_FILE` | Load default interface settings from this json file. See settings-template.json for an example.|

View file

@ -9,7 +9,7 @@ from pathlib import Path
import gradio as gr import gradio as gr
import transformers import transformers
from html_generator import * from html_generator import *
from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
import warnings import warnings
@ -23,6 +23,7 @@ parser.add_argument('--cai-chat', action='store_true', help='Launch the webui in
parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text.') parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text.')
parser.add_argument('--auto-devices', action='store_true', help='Automatically split the model across the available GPU(s) and CPU.') parser.add_argument('--auto-devices', action='store_true', help='Automatically split the model across the available GPU(s) and CPU.')
parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.') parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.')
parser.add_argument('--max-gpu-memory', type=int, help='Maximum memory in GiB to allocate to the GPU while loading the model. This is useful if get out of memory errors while trying to generate text. Must be an integer number.')
parser.add_argument('--no-listen', action='store_true', help='Make the webui unreachable from your local network.') parser.add_argument('--no-listen', action='store_true', help='Make the webui unreachable from your local network.')
parser.add_argument('--settings-file', type=str, help='Load default interface settings from this json file. See settings-template.json for an example.') parser.add_argument('--settings-file', type=str, help='Load default interface settings from this json file. See settings-template.json for an example.')
args = parser.parse_args() args = parser.parse_args()
@ -61,7 +62,7 @@ def load_model(model_name):
t0 = time.time() t0 = time.time()
# Default settings # Default settings
if not (args.cpu or args.auto_devices or args.load_in_8bit): if not (args.cpu or args.auto_devices or args.load_in_8bit or args.max_gpu_memory is not None):
if Path(f"torch-dumps/{model_name}.pt").exists(): if Path(f"torch-dumps/{model_name}.pt").exists():
print("Loading in .pt format...") print("Loading in .pt format...")
model = torch.load(Path(f"torch-dumps/{model_name}.pt")) model = torch.load(Path(f"torch-dumps/{model_name}.pt"))
@ -79,7 +80,11 @@ def load_model(model_name):
if args.cpu: if args.cpu:
settings.append("torch_dtype=torch.float32") settings.append("torch_dtype=torch.float32")
else: else:
if args.load_in_8bit: if args.max_gpu_memory is not None:
settings.append(f"max_memory={{0: '{args.max_gpu_memory}GiB', 'cpu': '99GiB'}}")
settings.append("device_map='auto'")
settings.append("torch_dtype=torch.float16")
elif args.load_in_8bit:
settings.append("device_map='auto'") settings.append("device_map='auto'")
settings.append("load_in_8bit=True") settings.append("load_in_8bit=True")
else: else:
@ -89,7 +94,7 @@ def load_model(model_name):
else: else:
cuda = ".cuda()" cuda = ".cuda()"
settings = ', '.join(settings) settings = ', '.join(list(set(settings)))
command = f"{command}(Path(f'models/{model_name}'), {settings}){cuda}" command = f"{command}(Path(f'models/{model_name}'), {settings}){cuda}"
model = eval(command) model = eval(command)