Skip to content

Commit 01a2d15

Browse files
committed
New ClusterConfiguration parameter for user labels
1 parent a9b314e commit 01a2d15

File tree

3 files changed

+20
-4
lines changed

3 files changed

+20
-4
lines changed

Diff for: src/codeflare_sdk/cluster/cluster.py

+2
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ def create_app_wrapper(self):
187187
write_to_file = self.config.write_to_file
188188
verify_tls = self.config.verify_tls
189189
local_queue = self.config.local_queue
190+
labels = self.config.labels
190191
return generate_appwrapper(
191192
name=name,
192193
namespace=namespace,
@@ -211,6 +212,7 @@ def create_app_wrapper(self):
211212
write_to_file=write_to_file,
212213
verify_tls=verify_tls,
213214
local_queue=local_queue,
215+
labels=labels,
214216
)
215217

216218
# creates a new cluster with the provided or default spec

Diff for: src/codeflare_sdk/cluster/config.py

+1
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ class ClusterConfiguration:
5454
dispatch_priority: str = None
5555
write_to_file: bool = False
5656
verify_tls: bool = True
57+
labels: dict = field(default_factory=dict)
5758

5859
def __post_init__(self):
5960
if not self.verify_tls:

Diff for: src/codeflare_sdk/utils/generate_yaml.py

+17-4
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,11 @@ def get_default_kueue_name(namespace: str):
309309

310310

311311
def write_components(
312-
user_yaml: dict, output_file_name: str, namespace: str, local_queue: Optional[str]
312+
user_yaml: dict,
313+
output_file_name: str,
314+
namespace: str,
315+
local_queue: Optional[str],
316+
user_labels: dict,
313317
):
314318
# Create the directory if it doesn't exist
315319
directory_path = os.path.dirname(output_file_name)
@@ -331,6 +335,7 @@ def write_components(
331335
]
332336
labels = component["generictemplate"]["metadata"]["labels"]
333337
labels.update({"kueue.x-k8s.io/queue-name": lq_name})
338+
labels.update(user_labels)
334339
outfile.write("---\n")
335340
yaml.dump(
336341
component["generictemplate"], outfile, default_flow_style=False
@@ -339,7 +344,11 @@ def write_components(
339344

340345

341346
def load_components(
342-
user_yaml: dict, name: str, namespace: str, local_queue: Optional[str]
347+
user_yaml: dict,
348+
name: str,
349+
namespace: str,
350+
local_queue: Optional[str],
351+
user_labels: dict,
343352
):
344353
component_list = []
345354
components = user_yaml.get("spec", "resources")["resources"].get("GenericItems")
@@ -355,6 +364,7 @@ def load_components(
355364
]
356365
labels = component["generictemplate"]["metadata"]["labels"]
357366
labels.update({"kueue.x-k8s.io/queue-name": lq_name})
367+
labels.update(user_labels)
358368
component_list.append(component["generictemplate"])
359369

360370
resources = "---\n" + "---\n".join(
@@ -395,6 +405,7 @@ def generate_appwrapper(
395405
write_to_file: bool,
396406
verify_tls: bool,
397407
local_queue: Optional[str],
408+
user_labels,
398409
):
399410
user_yaml = read_template(template)
400411
appwrapper_name, cluster_name = gen_names(name)
@@ -446,11 +457,13 @@ def generate_appwrapper(
446457
if mcad:
447458
write_user_appwrapper(user_yaml, outfile)
448459
else:
449-
write_components(user_yaml, outfile, namespace, local_queue)
460+
write_components(user_yaml, outfile, namespace, local_queue, user_labels)
450461
return outfile
451462
else:
452463
if mcad:
453464
user_yaml = load_appwrapper(user_yaml, name)
454465
else:
455-
user_yaml = load_components(user_yaml, name, namespace, local_queue)
466+
user_yaml = load_components(
467+
user_yaml, name, namespace, local_queue, user_labels
468+
)
456469
return user_yaml

0 commit comments

Comments
 (0)