Skip to content

Commit

Permalink
Support gang scheduling with Yunikorn (kubeflow#2107)
Browse files Browse the repository at this point in the history
* Add Yunikorn scheduler and example

Signed-off-by: Jacob Salway <jacob.salway@gmail.com>

* Add test cases

Signed-off-by: Jacob Salway <jacob.salway@gmail.com>

* Add code comments

Signed-off-by: Jacob Salway <jacob.salway@gmail.com>

* Add license comment

Signed-off-by: Jacob Salway <jacob.salway@gmail.com>

* Inline mergeNodeSelector

Signed-off-by: Jacob Salway <jacob.salway@gmail.com>

* Fix initial number implementation

Signed-off-by: Jacob Salway <jacob.salway@gmail.com>

---------

Signed-off-by: Jacob Salway <jacob.salway@gmail.com>
  • Loading branch information
jacobsalway authored Aug 22, 2024
1 parent 5972482 commit 8fcda12
Show file tree
Hide file tree
Showing 13 changed files with 943 additions and 2 deletions.
4 changes: 2 additions & 2 deletions cmd/operator/controller/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ import (
"github.com/kubeflow/spark-operator/internal/metrics"
"github.com/kubeflow/spark-operator/internal/scheduler"
"github.com/kubeflow/spark-operator/internal/scheduler/volcano"
"github.com/kubeflow/spark-operator/internal/scheduler/yunikorn"
"github.com/kubeflow/spark-operator/pkg/common"
"github.com/kubeflow/spark-operator/pkg/util"
// +kubebuilder:scaffold:imports
Expand Down Expand Up @@ -206,9 +207,8 @@ func start() {
var registry *scheduler.Registry
if enableBatchScheduler {
registry = scheduler.GetRegistry()

// Register volcano scheduler.
registry.Register(common.VolcanoSchedulerName, volcano.Factory)
registry.Register(yunikorn.SchedulerName, yunikorn.Factory)
}

// Setup controller for SparkApplication.
Expand Down
45 changes: 45 additions & 0 deletions examples/spark-pi-yunikorn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#
# Copyright 2024 The Kubeflow authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

apiVersion: sparkoperator.k8s.io/v1beta2
kind: SparkApplication
metadata:
name: spark-pi-yunikorn
namespace: default
spec:
type: Scala
mode: cluster
image: spark:3.5.0
imagePullPolicy: IfNotPresent
mainClass: org.apache.spark.examples.SparkPi
mainApplicationFile: local:///opt/spark/examples/jars/spark-examples_2.12-3.5.0.jar
sparkVersion: 3.5.0
driver:
labels:
version: 3.5.0
cores: 1
coreLimit: 1200m
memory: 512m
serviceAccount: spark-operator-spark
executor:
labels:
version: 3.5.0
instances: 2
cores: 1
coreLimit: 1200m
memory: 512m
batchScheduler: yunikorn
batchSchedulerOptions:
queue: root.default
5 changes: 5 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,11 @@ replace (
k8s.io/cluster-bootstrap => k8s.io/cluster-bootstrap v0.29.3
k8s.io/code-generator => k8s.io/code-generator v0.29.3
k8s.io/component-base => k8s.io/component-base v0.29.3
k8s.io/controller-manager => k8s.io/controller-manager v0.29.3
k8s.io/cri-api => k8s.io/cri-api v0.29.3
k8s.io/csi-translation-lib => k8s.io/csi-translation-lib v0.29.3
k8s.io/dynamic-resource-allocation => k8s.io/dynamic-resource-allocation v0.29.3
k8s.io/endpointslice => k8s.io/endpointslice v0.29.3
k8s.io/kube-aggregator => k8s.io/kube-aggregator v0.29.3
k8s.io/kube-controller-manager => k8s.io/kube-controller-manager v0.29.3
k8s.io/kube-proxy => k8s.io/kube-proxy v0.29.3
Expand All @@ -237,7 +240,9 @@ replace (
k8s.io/kubelet => k8s.io/kubelet v0.29.3
k8s.io/legacy-cloud-providers => k8s.io/legacy-cloud-providers v0.29.3
k8s.io/metrics => k8s.io/metrics v0.29.3
k8s.io/mount-utils => k8s.io/mount-utils v0.29.3
k8s.io/node-api => k8s.io/node-api v0.29.3
k8s.io/pod-security-admission => k8s.io/pod-security-admission v0.29.3
k8s.io/sample-apiserver => k8s.io/sample-apiserver v0.29.3
k8s.io/sample-cli-plugin => k8s.io/sample-cli-plugin v0.29.3
k8s.io/sample-controller => k8s.io/sample-controller v0.29.3
Expand Down
3 changes: 3 additions & 0 deletions internal/controller/sparkapplication/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ import (
"github.com/kubeflow/spark-operator/internal/metrics"
"github.com/kubeflow/spark-operator/internal/scheduler"
"github.com/kubeflow/spark-operator/internal/scheduler/volcano"
"github.com/kubeflow/spark-operator/internal/scheduler/yunikorn"
"github.com/kubeflow/spark-operator/pkg/common"
"github.com/kubeflow/spark-operator/pkg/util"
)
Expand Down Expand Up @@ -1197,6 +1198,8 @@ func (r *Reconciler) shouldDoBatchScheduling(app *v1beta2.SparkApplication) (boo
RestConfig: r.manager.GetConfig(),
}
scheduler, err = r.registry.GetScheduler(schedulerName, config)
case yunikorn.SchedulerName:
scheduler, err = r.registry.GetScheduler(schedulerName, nil)
}

if err != nil || scheduler == nil {
Expand Down
56 changes: 56 additions & 0 deletions internal/scheduler/yunikorn/resourceusage/java.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
Copyright 2024 The Kubeflow authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package resourceusage

import (
"fmt"
"regexp"
"strconv"
"strings"
)

var (
javaStringSuffixes = map[string]int64{
"b": 1,
"kb": 1 << 10,
"k": 1 << 10,
"mb": 1 << 20,
"m": 1 << 20,
"gb": 1 << 30,
"g": 1 << 30,
"tb": 1 << 40,
"t": 1 << 40,
"pb": 1 << 50,
"p": 1 << 50,
}

javaStringPattern = regexp.MustCompile(`^([0-9]+)([a-z]+)?$`)
)

func byteStringAsBytes(byteString string) (int64, error) {
matches := javaStringPattern.FindStringSubmatch(strings.ToLower(byteString))
if matches != nil {
value, err := strconv.ParseInt(matches[1], 10, 64)
if err != nil {
return 0, err
}
if multiplier, present := javaStringSuffixes[matches[2]]; present {
return value * multiplier, nil
}
}
return 0, fmt.Errorf("unable to parse byte string: %s", byteString)
}
63 changes: 63 additions & 0 deletions internal/scheduler/yunikorn/resourceusage/java_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
Copyright 2024 The Kubeflow authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package resourceusage

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestByteStringAsMb(t *testing.T) {
testCases := []struct {
input string
expected int
}{
{"1k", 1024},
{"1m", 1024 * 1024},
{"1g", 1024 * 1024 * 1024},
{"1t", 1024 * 1024 * 1024 * 1024},
{"1p", 1024 * 1024 * 1024 * 1024 * 1024},
}

for _, tc := range testCases {
t.Run(tc.input, func(t *testing.T) {
actual, err := byteStringAsBytes(tc.input)
assert.Nil(t, err)
assert.Equal(t, int64(tc.expected), actual)
})
}
}

func TestByteStringAsMbInvalid(t *testing.T) {
invalidInputs := []string{
"0.064",
"0.064m",
"500ub",
"This breaks 600b",
"This breaks 600",
"600gb This breaks",
"This 123mb breaks",
}

for _, input := range invalidInputs {
t.Run(input, func(t *testing.T) {
_, err := byteStringAsBytes(input)
assert.NotNil(t, err)
})
}
}
108 changes: 108 additions & 0 deletions internal/scheduler/yunikorn/resourceusage/memory.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
Copyright 2024 The Kubeflow authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package resourceusage

import (
"fmt"
"math"
"strconv"

"github.com/kubeflow/spark-operator/api/v1beta2"
"github.com/kubeflow/spark-operator/pkg/common"
)

func isJavaApp(appType v1beta2.SparkApplicationType) bool {
return appType == v1beta2.SparkApplicationTypeJava || appType == v1beta2.SparkApplicationTypeScala
}

func getMemoryOverheadFactor(app *v1beta2.SparkApplication) (float64, error) {
if app.Spec.MemoryOverheadFactor != nil {
parsed, err := strconv.ParseFloat(*app.Spec.MemoryOverheadFactor, 64)
if err != nil {
return 0, fmt.Errorf("failed to parse memory overhead factor as float: %w", err)
}
return parsed, nil
} else if isJavaApp(app.Spec.Type) {
return common.DefaultJVMMemoryOverheadFactor, nil
}

return common.DefaultNonJVMMemoryOverheadFactor, nil
}

func memoryRequestBytes(podSpec *v1beta2.SparkPodSpec, memoryOverheadFactor float64) (int64, error) {
var memoryBytes, memoryOverheadBytes int64

if podSpec.Memory != nil {
parsed, err := byteStringAsBytes(*podSpec.Memory)
if err != nil {
return 0, err
}
memoryBytes = parsed
}

if podSpec.MemoryOverhead != nil {
parsed, err := byteStringAsBytes(*podSpec.MemoryOverhead)
if err != nil {
return 0, err
}
memoryOverheadBytes = parsed
} else {
memoryOverheadBytes = int64(math.Max(
float64(memoryBytes)*memoryOverheadFactor,
common.MinMemoryOverhead,
))
}

return memoryBytes + memoryOverheadBytes, nil
}

func bytesToMi(b int64) string {
// this floors the value to the nearest mebibyte
return fmt.Sprintf("%dMi", b/1024/1024)
}

func driverMemoryRequest(app *v1beta2.SparkApplication) (string, error) {
memoryOverheadFactor, err := getMemoryOverheadFactor(app)
if err != nil {
return "", err
}

requestBytes, err := memoryRequestBytes(&app.Spec.Driver.SparkPodSpec, memoryOverheadFactor)
if err != nil {
return "", err
}

// Convert memory quantity to mebibytes even if larger than a gibibyte to match Spark
// https://github.com/apache/spark/blob/11b682cf5b7c5360a02410be288b7905eecc1d28/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala#L88
// https://github.com/apache/spark/blob/11b682cf5b7c5360a02410be288b7905eecc1d28/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala#L121
return bytesToMi(requestBytes), nil
}

func executorMemoryRequest(app *v1beta2.SparkApplication) (string, error) {
memoryOverheadFactor, err := getMemoryOverheadFactor(app)
if err != nil {
return "", err
}

requestBytes, err := memoryRequestBytes(&app.Spec.Executor.SparkPodSpec, memoryOverheadFactor)
if err != nil {
return "", err
}

// See comment above in driver
return bytesToMi(requestBytes), nil
}
39 changes: 39 additions & 0 deletions internal/scheduler/yunikorn/resourceusage/memory_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
Copyright 2024 The Kubeflow authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package resourceusage

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestBytesToMi(t *testing.T) {
testCases := []struct {
input int64
expected string
}{
{(2 * 1024 * 1024) - 1, "1Mi"},
{2 * 1024 * 1024, "2Mi"},
{(1024 * 1024 * 1024) - 1, "1023Mi"},
{1024 * 1024 * 1024, "1024Mi"},
}

for _, tc := range testCases {
assert.Equal(t, tc.expected, bytesToMi(tc.input))
}
}
Loading

0 comments on commit 8fcda12

Please sign in to comment.