Skip to content

Commit 29fcbc1

Browse files
committed
[external resources support]
- support for kubernetes / argo
1 parent 10bfd16 commit 29fcbc1

File tree

6 files changed

+31
-0
lines changed

6 files changed

+31
-0
lines changed

Diff for: metaflow/plugins/argo/argo_workflows.py

+6
Original file line numberDiff line numberDiff line change
@@ -1906,6 +1906,11 @@ def _container_templates(self):
19061906
resources["disk"],
19071907
)
19081908

1909+
extended_resources = resources.get("extended_resources", {})
1910+
1911+
qos_requests = {**qos_requests, **extended_resources}
1912+
qos_limits = {**qos_limits, **extended_resources}
1913+
19091914
# Create a ContainerTemplate for this node. Ideally, we would have
19101915
# liked to inline this ContainerTemplate and avoid scanning the workflow
19111916
# twice, but due to issues with variable substitution, we will have to
@@ -1962,6 +1967,7 @@ def _container_templates(self):
19621967
shared_memory=shared_memory,
19631968
port=port,
19641969
qos=resources["qos"],
1970+
extended_resources=extended_resources,
19651971
)
19661972

19671973
for k, v in env.items():

Diff for: metaflow/plugins/kubernetes/kubernetes.py

+4
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ def create_jobset(
194194
port=None,
195195
num_parallel=None,
196196
qos=None,
197+
extended_resources=None,
197198
):
198199
name = "js-%s" % str(uuid4())[:6]
199200
jobset = (
@@ -227,6 +228,7 @@ def create_jobset(
227228
port=port,
228229
num_parallel=num_parallel,
229230
qos=qos,
231+
extended_resources=extended_resources,
230232
)
231233
.environment_variable("METAFLOW_CODE_SHA", code_package_sha)
232234
.environment_variable("METAFLOW_CODE_URL", code_package_url)
@@ -488,6 +490,7 @@ def create_job_object(
488490
name_pattern=None,
489491
qos=None,
490492
annotations=None,
493+
extended_resources=None,
491494
):
492495
if env is None:
493496
env = {}
@@ -530,6 +533,7 @@ def create_job_object(
530533
shared_memory=shared_memory,
531534
port=port,
532535
qos=qos,
536+
extended_resources=extended_resources,
533537
)
534538
.environment_variable("METAFLOW_CODE_SHA", code_package_sha)
535539
.environment_variable("METAFLOW_CODE_URL", code_package_url)

Diff for: metaflow/plugins/kubernetes/kubernetes_cli.py

+8
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,12 @@ def kubernetes():
145145
type=JSONTypeClass(),
146146
multiple=False,
147147
)
148+
@click.option(
149+
"--extended-resources",
150+
default=None,
151+
type=JSONTypeClass(),
152+
multiple=False,
153+
)
148154
@click.pass_context
149155
def step(
150156
ctx,
@@ -176,6 +182,7 @@ def step(
176182
qos=None,
177183
labels=None,
178184
annotations=None,
185+
extended_resources=None,
179186
**kwargs
180187
):
181188
def echo(msg, stream="stderr", job_id=None, **kwargs):
@@ -319,6 +326,7 @@ def _sync_metadata():
319326
qos=qos,
320327
labels=labels,
321328
annotations=annotations,
329+
extended_resources=extended_resources,
322330
)
323331
except Exception:
324332
traceback.print_exc(chain=False)

Diff for: metaflow/plugins/kubernetes/kubernetes_decorator.py

+5
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,9 @@ class KubernetesDecorator(StepDecorator):
121121
Only applicable when @parallel is used.
122122
qos: str, default: Burstable
123123
Quality of Service class to assign to the pod. Supported values are: Guaranteed, Burstable, BestEffort
124+
extended_resources: Dict[str, str], optional, default None
125+
Extended resources to be requested for the pod.
126+
https://kubernetes.io/docs/tasks/administer-cluster/extended-resource-node/
124127
"""
125128

126129
name = "kubernetes"
@@ -151,6 +154,7 @@ class KubernetesDecorator(StepDecorator):
151154
"executable": None,
152155
"hostname_resolution_timeout": 10 * 60,
153156
"qos": KUBERNETES_QOS,
157+
"extended_resources": {},
154158
}
155159
package_url = None
156160
package_sha = None
@@ -473,6 +477,7 @@ def runtime_step_cli(
473477
"persistent_volume_claims",
474478
"labels",
475479
"annotations",
480+
"extended_resources",
476481
]:
477482
cli_args.command_options[k] = json.dumps(v)
478483
else:

Diff for: metaflow/plugins/kubernetes/kubernetes_job.py

+3
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,9 @@ def create_job_spec(self):
7979
self._kwargs["memory"],
8080
self._kwargs["disk"],
8181
)
82+
extended_resources = self._kwargs.get("extended_resources", {})
83+
qos_requests = {**qos_requests, **extended_resources}
84+
qos_limits = {**qos_limits, **extended_resources}
8285

8386
return client.V1JobSpec(
8487
# Retries are handled by Metaflow when it is responsible for

Diff for: metaflow/plugins/kubernetes/kubernetes_jobsets.py

+5
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,11 @@ def dump(self):
560560
self._kwargs["memory"],
561561
self._kwargs["disk"],
562562
)
563+
564+
extended_resources = self._kwargs.get("extended_resources", {})
565+
qos_requests = {**qos_requests, **extended_resources}
566+
qos_limits = {**qos_limits, **extended_resources}
567+
563568
return dict(
564569
name=self.name,
565570
template=client.api_client.ApiClient().sanitize_for_serialization(

0 commit comments

Comments
 (0)