Skip to content

Commit 6b15fc6

Browse files
authored
Added GPU selection feature to python inference (#321)
* Added GPU selection feature to python inference * pylint pep8 fixes * pep8 fixes
1 parent bc77ca5 commit 6b15fc6

File tree

2 files changed

+20
-3
lines changed

2 files changed

+20
-3
lines changed

inference_realesrgan.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ def main():
3939
type=str,
4040
default='auto',
4141
help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs')
42+
parser.add_argument(
43+
'-g', '--gpu-id', type=int, default=None, help='gpu device to use (default=None) can be 0,1,2 for multi-gpu')
44+
4245
args = parser.parse_args()
4346

4447
# determine models according to model names
@@ -71,7 +74,8 @@ def main():
7174
tile=args.tile,
7275
tile_pad=args.tile_pad,
7376
pre_pad=args.pre_pad,
74-
half=not args.fp32)
77+
half=not args.fp32,
78+
gpu_id=args.gpu_id)
7579

7680
if args.face_enhance: # Use GFPGAN for face enhancement
7781
from gfpgan import GFPGANer

realesrgan/utils.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,16 @@ class RealESRGANer():
2626
half (float): Whether to use half precision during inference. Default: False.
2727
"""
2828

29-
def __init__(self, scale, model_path, model=None, tile=0, tile_pad=10, pre_pad=10, half=False, device=None):
29+
def __init__(self,
30+
scale,
31+
model_path,
32+
model=None,
33+
tile=0,
34+
tile_pad=10,
35+
pre_pad=10,
36+
half=False,
37+
device=None,
38+
gpu_id=None):
3039
self.scale = scale
3140
self.tile_size = tile
3241
self.tile_pad = tile_pad
@@ -35,7 +44,11 @@ def __init__(self, scale, model_path, model=None, tile=0, tile_pad=10, pre_pad=1
3544
self.half = half
3645

3746
# initialize model
38-
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
47+
if gpu_id:
48+
self.device = torch.device(
49+
f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device
50+
else:
51+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
3952
# if the model_path starts with https, it will first download models to the folder: realesrgan/weights
4053
if model_path.startswith('https://'):
4154
model_path = load_file_from_url(

0 commit comments

Comments
 (0)