Intel Gpu support initialization (#4340)
This commit is contained in:
parent
317e2c857e
commit
778a010df8
14 changed files with 106 additions and 42 deletions
|
@ -3,6 +3,7 @@ from typing import List, Optional
|
|||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import is_torch_xpu_available
|
||||
|
||||
|
||||
class AbstractMultimodalPipeline(ABC):
|
||||
|
@ -55,7 +56,7 @@ class AbstractMultimodalPipeline(ABC):
|
|||
|
||||
def _get_device(self, setting_name: str, params: dict):
|
||||
if params[setting_name] is None:
|
||||
return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
return torch.device("cuda:0" if torch.cuda.is_available() else "xpu:0" if is_torch_xpu_available() else "cpu")
|
||||
return torch.device(params[setting_name])
|
||||
|
||||
def _get_dtype(self, setting_name: str, params: dict):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue