Intel Gpu support initialization (#4340)
This commit is contained in:
parent
317e2c857e
commit
778a010df8
14 changed files with 106 additions and 42 deletions
|
@ -9,7 +9,7 @@ import traceback
|
|||
import numpy as np
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import LogitsProcessorList
|
||||
from transformers import LogitsProcessorList, is_torch_xpu_available
|
||||
|
||||
import modules.shared as shared
|
||||
from modules.callbacks import (
|
||||
|
@ -132,8 +132,8 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
|
|||
elif torch.backends.mps.is_available():
|
||||
device = torch.device('mps')
|
||||
return input_ids.to(device)
|
||||
elif hasattr(torch, 'xpu') and torch.xpu.is_available():
|
||||
return input_ids.to('xpu')
|
||||
elif is_torch_xpu_available():
|
||||
return input_ids.to("xpu:0")
|
||||
else:
|
||||
return input_ids.cuda()
|
||||
|
||||
|
@ -238,7 +238,8 @@ def set_manual_seed(seed):
|
|||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
elif is_torch_xpu_available():
|
||||
torch.xpu.manual_seed_all(seed)
|
||||
return seed
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue