|
6 | 6 | """
|
7 | 7 | import os
|
8 | 8 | import warnings
|
| 9 | +from shutil import which |
9 | 10 |
|
10 | 11 | from ...utils.filemanip import fname_presuffix, split_filename, copyfile
|
11 | 12 | from ..base import (
|
@@ -383,6 +384,7 @@ class BEDPOSTX5InputSpec(FSLXCommandInputSpec):
|
383 | 384 | )
|
384 | 385 | grad_dev = File(exists=True, desc="grad_dev file, if gradnonlin, -g is True")
|
385 | 386 | use_gpu = traits.Bool(False, desc="Use the GPU version of bedpostx")
|
| 387 | + num_threads = traits.Int(nohash=True, desc="Number of threads to use") |
386 | 388 |
|
387 | 389 |
|
388 | 390 | class BEDPOSTX5OutputSpec(TraitedSpec):
|
@@ -451,13 +453,25 @@ class BEDPOSTX5(FSLXCommand):
|
451 | 453 | def __init__(self, **inputs):
|
452 | 454 | super().__init__(**inputs)
|
453 | 455 | self.inputs.on_trait_change(self._cuda_update, "use_gpu")
|
| 456 | + self.inputs.on_trait_change(self._num_threads_update, "num_threads") |
454 | 457 |
|
455 | 458 | def _cuda_update(self):
|
456 |
| - if isdefined(self.inputs.use_gpu) and self.inputs.use_gpu: |
| 459 | + if isdefined(self.inputs.use_gpu) and self.inputs.use_gpu and which("bedpostx_gpu") is not None: |
457 | 460 | self._cmd = "bedpostx_gpu"
|
| 461 | + self.inputs.num_threads = 1 |
458 | 462 | else:
|
459 | 463 | self._cmd = self._default_cmd
|
460 | 464 |
|
| 465 | + def _num_threads_update(self): |
| 466 | + if isdefined(self.inputs.use_gpu) and self.inputs.use_gpu and which("bedpostx_gpu") is not None: |
| 467 | + self.inputs.num_threads = 1 |
| 468 | + self._num_threads = self.inputs.num_threads |
| 469 | + if not isdefined(self.inputs.num_threads): |
| 470 | + if "FSLSUB_PARALLEL" in self.inputs.environ: |
| 471 | + del self.inputs.environ["FSLSUB_PARALLEL"] |
| 472 | + else: |
| 473 | + self.inputs.environ["FSLSUB_PARALLEL"] = str(self.inputs.num_threads) |
| 474 | + |
461 | 475 | def _run_interface(self, runtime):
|
462 | 476 | subjectdir = os.path.abspath(self.inputs.out_dir)
|
463 | 477 | if not os.path.exists(subjectdir):
|
@@ -1024,6 +1038,7 @@ class ProbTrackX2InputSpec(ProbTrackXBaseInputSpec):
|
1024 | 1038 | '"vox"'
|
1025 | 1039 | ),
|
1026 | 1040 | )
|
| 1041 | + use_gpu = traits.Bool(False, desc="Use the GPU version of probtrackx2") |
1027 | 1042 |
|
1028 | 1043 |
|
1029 | 1044 | class ProbTrackX2OutputSpec(ProbTrackXOutputSpec):
|
@@ -1059,9 +1074,20 @@ class ProbTrackX2(ProbTrackX):
|
1059 | 1074 | """
|
1060 | 1075 |
|
1061 | 1076 | _cmd = "probtrackx2"
|
| 1077 | + _default_cmd = _cmd |
1062 | 1078 | input_spec = ProbTrackX2InputSpec
|
1063 | 1079 | output_spec = ProbTrackX2OutputSpec
|
1064 | 1080 |
|
| 1081 | + def __init__(self, **inputs): |
| 1082 | + super().__init__(**inputs) |
| 1083 | + self.inputs.on_trait_change(self._cuda_update, "use_gpu") |
| 1084 | + |
| 1085 | + def _cuda_update(self): |
| 1086 | + if isdefined(self.inputs.use_gpu) and self.inputs.use_gpu and which("probtrackx2_gpu") is not None: |
| 1087 | + self._cmd = "probtrackx2_gpu" |
| 1088 | + else: |
| 1089 | + self._cmd = self._default_cmd |
| 1090 | + |
1065 | 1091 | def _list_outputs(self):
|
1066 | 1092 | outputs = super()._list_outputs()
|
1067 | 1093 |
|
|
0 commit comments