Skip to content

refactor workload controller to prepare to upstream to Kueue #298

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 6 additions & 57 deletions internal/controller/workload/workload_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,7 @@ limitations under the License.
package workload

import (
"fmt"

"k8s.io/apimachinery/pkg/api/meta"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
"k8s.io/apimachinery/pkg/runtime/schema"

"sigs.k8s.io/controller-runtime/pkg/builder"
Expand Down Expand Up @@ -75,72 +72,24 @@ func (aw *AppWrapper) GVK() schema.GroupVersionKind {
}

func (aw *AppWrapper) PodSets() []kueue.PodSet {
podSets := []kueue.PodSet{}
if err := utils.EnsureComponentStatusInitialized((*workloadv1beta2.AppWrapper)(aw)); err != nil {
// Kueue will raise an error on zero length PodSet. Unfortunately, the Kueue API prevents propagating the actual error
return podSets
}
for idx := range aw.Status.ComponentStatus {
if len(aw.Status.ComponentStatus[idx].PodSets) > 0 {
obj := &unstructured.Unstructured{}
if _, _, err := unstructured.UnstructuredJSONScheme.Decode(aw.Spec.Components[idx].Template.Raw, nil, obj); err != nil {
// Should be unreachable; Template.Raw validated by AppWrapper AdmissionController
return []kueue.PodSet{} // Kueue will raise an error on zero length PodSet.
}
for psIdx, podSet := range aw.Status.ComponentStatus[idx].PodSets {
replicas := utils.Replicas(podSet)
if template, err := utils.GetPodTemplateSpec(obj, podSet.Path); err == nil {
podSets = append(podSets, kueue.PodSet{
Name: fmt.Sprintf("%s-%v-%v", aw.Name, idx, psIdx),
Template: *template,
Count: replicas,
})
}
}
}
podSets, err := utils.GetPodSets((*workloadv1beta2.AppWrapper)(aw))
if err != nil {
// Kueue will raise an error on zero length PodSet; the Kueue GenericJob API prevents propagating the actual error.
return []kueue.PodSet{}
}
return podSets
}

// RunWithPodSetsInfo records the assigned PodSetInfos for each component and sets aw.spec.Suspend to false
func (aw *AppWrapper) RunWithPodSetsInfo(podSetsInfo []podset.PodSetInfo) error {
if err := utils.EnsureComponentStatusInitialized((*workloadv1beta2.AppWrapper)(aw)); err != nil {
if err := utils.SetPodSetInfos((*workloadv1beta2.AppWrapper)(aw), podSetsInfo); err != nil {
return err
}
podSetsInfoIndex := 0
for idx := range aw.Spec.Components {
if len(aw.Spec.Components[idx].PodSetInfos) != len(aw.Status.ComponentStatus[idx].PodSets) {
aw.Spec.Components[idx].PodSetInfos = make([]workloadv1beta2.AppWrapperPodSetInfo, len(aw.Status.ComponentStatus[idx].PodSets))
}
for podSetIdx := range aw.Status.ComponentStatus[idx].PodSets {
podSetsInfoIndex += 1
if podSetsInfoIndex > len(podSetsInfo) {
continue // we will return an error below...continuing to get an accurate count for the error message
}
aw.Spec.Components[idx].PodSetInfos[podSetIdx] = workloadv1beta2.AppWrapperPodSetInfo{
Annotations: podSetsInfo[podSetsInfoIndex-1].Annotations,
Labels: podSetsInfo[podSetsInfoIndex-1].Labels,
NodeSelector: podSetsInfo[podSetsInfoIndex-1].NodeSelector,
Tolerations: podSetsInfo[podSetsInfoIndex-1].Tolerations,
}
}
}

if podSetsInfoIndex != len(podSetsInfo) {
return podset.BadPodSetsInfoLenError(podSetsInfoIndex, len(podSetsInfo))
}

aw.Spec.Suspend = false

return nil
}

// RestorePodSetsInfo clears the PodSetInfos saved by RunWithPodSetsInfo
func (aw *AppWrapper) RestorePodSetsInfo(podSetsInfo []podset.PodSetInfo) bool {
for idx := range aw.Spec.Components {
aw.Spec.Components[idx].PodSetInfos = nil
}
return true
return utils.ClearPodSetInfos((*workloadv1beta2.AppWrapper)(aw))
}

func (aw *AppWrapper) Finished() (message string, success, finished bool) {
Expand Down
69 changes: 69 additions & 0 deletions pkg/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ import (
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/utils/ptr"

kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1"
"sigs.k8s.io/kueue/pkg/podset"

workloadv1beta2 "github.com/project-codeflare/appwrapper/api/v1beta2"
)

Expand Down Expand Up @@ -327,6 +330,72 @@ func EnsureComponentStatusInitialized(aw *workloadv1beta2.AppWrapper) error {
return nil
}

// GetPodSets constructs the kueue.PodSets for an AppWrapper
func GetPodSets(aw *workloadv1beta2.AppWrapper) ([]kueue.PodSet, error) {
podSets := []kueue.PodSet{}
if err := EnsureComponentStatusInitialized(aw); err != nil {
return nil, err
}
for idx := range aw.Status.ComponentStatus {
if len(aw.Status.ComponentStatus[idx].PodSets) > 0 {
obj := &unstructured.Unstructured{}
if _, _, err := unstructured.UnstructuredJSONScheme.Decode(aw.Spec.Components[idx].Template.Raw, nil, obj); err != nil {
// Should be unreachable; Template.Raw validated by AppWrapper AdmissionController
return nil, err
}
for psIdx, podSet := range aw.Status.ComponentStatus[idx].PodSets {
replicas := Replicas(podSet)
if template, err := GetPodTemplateSpec(obj, podSet.Path); err == nil {
podSets = append(podSets, kueue.PodSet{
Name: fmt.Sprintf("%s-%v-%v", aw.Name, idx, psIdx),
Template: *template,
Count: replicas,
})
}
}
}
}
return podSets, nil
}

// SetPodSetInfos propagates podSetsInfo into the PodSetInfos of aw.Spec.Components
func SetPodSetInfos(aw *workloadv1beta2.AppWrapper, podSetsInfo []podset.PodSetInfo) error {
if err := EnsureComponentStatusInitialized(aw); err != nil {
return err
}
podSetsInfoIndex := 0
for idx := range aw.Spec.Components {
if len(aw.Spec.Components[idx].PodSetInfos) != len(aw.Status.ComponentStatus[idx].PodSets) {
aw.Spec.Components[idx].PodSetInfos = make([]workloadv1beta2.AppWrapperPodSetInfo, len(aw.Status.ComponentStatus[idx].PodSets))
}
for podSetIdx := range aw.Status.ComponentStatus[idx].PodSets {
podSetsInfoIndex += 1
if podSetsInfoIndex > len(podSetsInfo) {
continue // we will return an error below...continuing to get an accurate count for the error message
}
aw.Spec.Components[idx].PodSetInfos[podSetIdx] = workloadv1beta2.AppWrapperPodSetInfo{
Annotations: podSetsInfo[podSetsInfoIndex-1].Annotations,
Labels: podSetsInfo[podSetsInfoIndex-1].Labels,
NodeSelector: podSetsInfo[podSetsInfoIndex-1].NodeSelector,
Tolerations: podSetsInfo[podSetsInfoIndex-1].Tolerations,
}
}
}

if podSetsInfoIndex != len(podSetsInfo) {
return podset.BadPodSetsInfoLenError(podSetsInfoIndex, len(podSetsInfo))
}
return nil
}

// ClearPodSetInfos clears the PodSetInfos saved by SetPodSetInfos
func ClearPodSetInfos(aw *workloadv1beta2.AppWrapper) bool {
for idx := range aw.Spec.Components {
aw.Spec.Components[idx].PodSetInfos = nil
}
return true
}

// inferReplicas parses the value at the given path within obj as an int or return 1 or error
func inferReplicas(obj map[string]interface{}, path string) (int32, error) {
if path == "" {
Expand Down
Loading