diff --git a/config/armada/config.yaml b/config/armada/config.yaml index 58b849ca288..f931bacb814 100644 --- a/config/armada/config.yaml +++ b/config/armada/config.yaml @@ -1,5 +1,6 @@ grpcPort: ":50051" redis: addr: "localhost:6379" + priorityHalfTime: 20m password: "" db: 0 diff --git a/internal/armada/configuration/types.go b/internal/armada/configuration/types.go index cc42820b3c9..0393d7068dd 100644 --- a/internal/armada/configuration/types.go +++ b/internal/armada/configuration/types.go @@ -1,8 +1,11 @@ package configuration +import "time" + type ArmadaConfig struct { - GrpcPort string - Redis RedisConfig + GrpcPort string + Redis RedisConfig + PriorityHalfTime time.Duration } type RedisConfig struct { diff --git a/internal/armada/repository/jobs.go b/internal/armada/repository/jobs.go index 0d2821588cb..4cc4d04c7ab 100644 --- a/internal/armada/repository/jobs.go +++ b/internal/armada/repository/jobs.go @@ -10,8 +10,9 @@ import ( //"github.com/golang/protobuf/ptypes/timestamp" ) -const jobObjectPrefix = "job:" -const queuePrefix = "Job:Queue:" +const jobObjectPrefix = "Job:" +const jobQueuePrefix = "Job:Queue:" + type JobRepository interface { AddJob(request *api.JobRequest) (string, error) @@ -32,7 +33,7 @@ func (repo RedisJobRepository) AddJob(request *api.JobRequest) (string, error) { return "", e } - pipe.ZAdd(queuePrefix+job.Queue, redis.Z{ + pipe.ZAdd(jobQueuePrefix+job.Queue, redis.Z{ Member: job.Id, Score: job.Priority}) @@ -45,7 +46,7 @@ func (repo RedisJobRepository) AddJob(request *api.JobRequest) (string, error) { } func (repo RedisJobRepository) PeekQueue(queue string, limit int64) ([]*api.Job, error) { - ids, e := repo.Db.ZRange(queuePrefix+queue, 0, limit-1).Result() + ids, e := repo.Db.ZRange(jobQueuePrefix+queue, 0, limit-1).Result() if e != nil { return nil, e } diff --git a/internal/armada/repository/queues.go b/internal/armada/repository/queues.go new file mode 100644 index 00000000000..138a6a7a5f0 --- /dev/null +++ b/internal/armada/repository/queues.go @@ -0,0 +1,20 @@ +package repository + +import ( + "github.com/go-redis/redis" +) + +type Queue struct { +} + +type QueueRepository interface { + GetQueues() ([]string, error) +} + +type RedisQueueRepository struct { + Db *redis.Client +} + +func (RedisQueueRepository) GetQueues() ([]string, error) { + panic("implement me") +} diff --git a/internal/armada/repository/usage.go b/internal/armada/repository/usage.go new file mode 100644 index 00000000000..45f03af782d --- /dev/null +++ b/internal/armada/repository/usage.go @@ -0,0 +1,87 @@ +package repository + +import ( + "github.com/G-Research/k8s-batch/internal/armada/api" + "github.com/go-redis/redis" + "github.com/gogo/protobuf/proto" + "strconv" +) + +type Usage struct { + PriorityPerQueue map[string]float64 + CurrentUsagePerQueue map[string]float64 +} + +const clusterReportKey = "Cluster:Report" +const clusterPrioritiesPrefix = "Cluster:Priority:" + +type UsageRepository interface { + + GetClusterUsageReports() (map[string]*api.ClusterUsageReport, error) + GetClusterPriority(clusterId string) (map[string]float64, error) + + UpdateCluster(report *api.ClusterUsageReport, priorities map[string]float64) error +} + +type RedisUsageRepository struct { + Db *redis.Client +} + + +func (r RedisUsageRepository) GetClusterUsageReports() (map[string]*api.ClusterUsageReport, error) { + result, err := r.Db.HGetAll(clusterReportKey).Result() + if err != nil { + return nil, err + } + reports := make(map[string]*api.ClusterUsageReport) + + for k, v := range result { + report := &api.ClusterUsageReport{} + e := proto.Unmarshal([]byte(v), report) + if e!= nil { + return nil, e + } + reports[k] = report + } + return reports, nil +} + +func (r RedisUsageRepository) GetClusterPriority(clusterId string) (map[string]float64, error) { + result, err := r.Db.HGetAll(clusterPrioritiesPrefix+clusterId).Result() + if err != nil { + return nil, err + } + return toFloat64Map(result) +} + +func (r RedisUsageRepository) UpdateCluster(report *api.ClusterUsageReport, priorities map[string]float64) error { + + pipe := r.Db.TxPipeline() + + data, e := proto.Marshal(report) + if e != nil { + return e + } + pipe.HSet(clusterReportKey, report.ClusterId, data) + + untyped := make(map[string]interface{}) + for k, v := range priorities { + untyped[k] = v + } + pipe.HMSet(clusterPrioritiesPrefix+report.ClusterId, untyped) + + _, err := pipe.Exec() + return err +} + +func toFloat64Map(result map[string]string) (map[string]float64, error) { + reports := make(map[string]float64) + for k, v := range result { + priority, e := strconv.ParseFloat(v, 64) + if e!= nil { + return nil, e + } + reports[k] = priority + } + return reports, nil +} diff --git a/internal/armada/server.go b/internal/armada/server.go index 72a81f5e02f..735717063b4 100644 --- a/internal/armada/server.go +++ b/internal/armada/server.go @@ -4,18 +4,19 @@ import ( "github.com/G-Research/k8s-batch/internal/armada/api" "github.com/G-Research/k8s-batch/internal/armada/configuration" "github.com/G-Research/k8s-batch/internal/armada/repository" - "github.com/G-Research/k8s-batch/internal/armada/service" + "github.com/G-Research/k8s-batch/internal/armada/server" "github.com/go-redis/redis" "google.golang.org/grpc" "log" "net" "sync" + "time" ) func Serve(config *configuration.ArmadaConfig) (*grpc.Server, *sync.WaitGroup) { wg := &sync.WaitGroup{} wg.Add(1) - server := grpc.NewServer() + grpcServer := grpc.NewServer() go func () { log.Printf("Grpc listening on %s", config.GrpcPort) defer log.Println("Stopping server.") @@ -27,22 +28,26 @@ func Serve(config *configuration.ArmadaConfig) (*grpc.Server, *sync.WaitGroup) { }) jobRepository := &repository.RedisJobRepository{ Db: db } - submitServer := &service.SubmitServer{ JobRepository: jobRepository } - aggregatedQueueServer := &service.AggregatedQueueServer{ JobRepository: jobRepository } + usageRepository := &repository.RedisUsageRepository{ Db: db } + + submitServer := &server.SubmitServer{ JobRepository: jobRepository } + usageServer := &server.UsageServer { UsageRepository: usageRepository, PriorityHalfTime: time.Minute } + aggregatedQueueServer := &server.AggregatedQueueServer{ JobRepository: jobRepository, UsageRepository: usageRepository } lis, err := net.Listen("tcp", config.GrpcPort) if err != nil { log.Fatalf("failed to listen: %v", err) } - api.RegisterSubmitServer(server, submitServer) - api.RegisterAggregatedQueueServer(server, aggregatedQueueServer) + api.RegisterSubmitServer(grpcServer, submitServer) + api.RegisterUsageServer(grpcServer, usageServer) + api.RegisterAggregatedQueueServer(grpcServer, aggregatedQueueServer) - if err := server.Serve(lis); err != nil { + if err := grpcServer.Serve(lis); err != nil { log.Fatalf("failed to serve: %v", err) } wg.Done() } () - return server, wg + return grpcServer, wg } diff --git a/internal/armada/service/queue.go b/internal/armada/server/queue.go similarity index 93% rename from internal/armada/service/queue.go rename to internal/armada/server/queue.go index 1d066107728..daadd21d089 100644 --- a/internal/armada/service/queue.go +++ b/internal/armada/server/queue.go @@ -1,4 +1,4 @@ -package service +package server import ( "context" @@ -10,9 +10,11 @@ import ( type AggregatedQueueServer struct { JobRepository repository.JobRepository + UsageRepository repository.UsageRepository } func (AggregatedQueueServer) LeaseJobs(context.Context, *api.LeaseRequest) (*api.JobLease, error) { + //TODO Implement me fmt.Println("Lease jobs called") jobLease := api.JobLease{ diff --git a/internal/armada/service/submit.go b/internal/armada/server/submit.go similarity index 97% rename from internal/armada/service/submit.go rename to internal/armada/server/submit.go index aa7112208f6..ea03c853a84 100644 --- a/internal/armada/service/submit.go +++ b/internal/armada/server/submit.go @@ -1,4 +1,4 @@ -package service +package server import ( "context" diff --git a/internal/armada/server/usage.go b/internal/armada/server/usage.go new file mode 100644 index 00000000000..10c002c0d23 --- /dev/null +++ b/internal/armada/server/usage.go @@ -0,0 +1,125 @@ + +package server + +import ( + "context" + "github.com/G-Research/k8s-batch/internal/armada/api" + "github.com/G-Research/k8s-batch/internal/armada/repository" + "github.com/G-Research/k8s-batch/internal/common" + "github.com/gogo/protobuf/types" + "k8s.io/apimachinery/pkg/api/resource" + "math" + "math/big" + "time" +) + +type UsageServer struct { + PriorityHalfTime time.Duration + UsageRepository repository.UsageRepository +} + +func (s UsageServer) ReportUsage(ctx context.Context, report *api.ClusterUsageReport) (*types.Empty, error) { + + reports, err := s.UsageRepository.GetClusterUsageReports() + if err != nil { + return nil, err + } + + previousPriority, err := s.UsageRepository.GetClusterPriority(report.ClusterId) + if err != nil { + return nil, err + } + + previousReport := reports[report.ClusterId] + timeChange := time.Minute + if previousReport != nil { + timeChange = report.ReportTime.Sub(previousReport.ReportTime) + } + + reports[report.ClusterId] = report + availableResources := sumResources(reports) + resourceScarcity := calculateResourceScarcity(availableResources) + usage := calculateUsage(resourceScarcity, report.Queues) + newPriority := calculatePriority(usage, previousPriority, timeChange, s.PriorityHalfTime) + + err = s.UsageRepository.UpdateCluster(report, newPriority) + if err != nil { + return nil, err + } + return nil, nil +} + +func calculatePriority(usage map[string]float64, previousPriority map[string]float64, timeChange time.Duration, halfTime time.Duration) map[string]float64 { + + newPriority := map[string]float64{} + timeChangeFactor := math.Pow(0.5, timeChange.Seconds() / halfTime.Seconds()) + + for queue, oldPriority := range previousPriority { + newPriority[queue] = timeChangeFactor * getOrDefault(usage, queue,0) + + (1 - timeChangeFactor) * oldPriority + } + for queue, usage := range usage { + _, exists := newPriority[queue] + if !exists { + newPriority[queue] = timeChangeFactor * usage + } + } + return newPriority +} + +func calculateUsage(resourceScarcity map[string]float64, queues []*api.QueueReport) map[string]float64 { + usages := map[string]float64{} + for _, queue := range queues { + usage := 0.0 + for resourceName, quantity := range queue.Resources { + scarcity := getOrDefault(resourceScarcity, resourceName, 1) + usage += asFloat64(quantity) * scarcity + } + usages[queue.Name] = usage + } + return usages +} + +// Calculates inverse of resources per cpu unit +// { cpu: 4, memory: 20GB, gpu: 2 } -> { cpu: 1.0, memory: 0.2, gpu: 2 } +func calculateResourceScarcity(res common.ComputeResources) map[string]float64 { + importance := map[string]float64{ + "cpu": 1, + } + cpu := asFloat64(res["cpu"]) + + for k, v := range res { + if k == "cpu"{ + continue + } + q := asFloat64(v) + if q >= 0.00001 { + importance[k] = cpu / q + } + } + return importance +} + +func sumResources(reports map[string]*api.ClusterUsageReport) common.ComputeResources { + result := common.ComputeResources{} + for _, report := range reports { + result.Add(report.ClusterCapacity) + } + return result +} + +func getOrDefault(m map[string]float64, key string, def float64) float64 { + v, ok := m[key] + if ok { + return v + } + return def +} + +func asFloat64(q resource.Quantity) float64 { + dec:= q.AsDec() + unscaled := dec.UnscaledBig() + scale := dec.Scale() + unscaledFloat, _ := new(big.Float).SetInt(unscaled).Float64() + return unscaledFloat * math.Pow10(-int(scale)) +} diff --git a/internal/armada/server/usage_test.go b/internal/armada/server/usage_test.go new file mode 100644 index 00000000000..ca8adc92b5b --- /dev/null +++ b/internal/armada/server/usage_test.go @@ -0,0 +1,64 @@ + +package server + +import ( + "context" + "github.com/G-Research/k8s-batch/internal/armada/api" + "github.com/G-Research/k8s-batch/internal/armada/repository" + "github.com/G-Research/k8s-batch/internal/common" + "github.com/alicebob/miniredis" + "github.com/go-redis/redis" + "github.com/stretchr/testify/assert" + "k8s.io/apimachinery/pkg/api/resource" + "testing" + "time" +) + +func TestUsageServer_ReportUsage(t *testing.T) { + withUsageServer(func (s *UsageServer){ + now := time.Now() + cpu, _ := resource.ParseQuantity("10") + memory, _ := resource.ParseQuantity("360Gi") + + _, err := s.ReportUsage(context.Background(), oneQueueReport(now, cpu, memory)) + assert.Nil(t, err) + + priority, err := s.UsageRepository.GetClusterPriority("clusterA") + assert.Nil(t, err) + assert.Equal(t, 10.0, priority["q1"], "Priority should be updated for the new cluster.") + + _, err = s.ReportUsage(context.Background(), oneQueueReport(now.Add(time.Minute), cpu, memory)) + assert.Nil(t, err) + + priority, err = s.UsageRepository.GetClusterPriority("clusterA") + assert.Nil(t, err) + assert.Equal(t, 15.0, priority["q1"], "Priority schould be updated considering previous report.") + }) +} + +func oneQueueReport(t time.Time, cpu resource.Quantity, memory resource.Quantity) *api.ClusterUsageReport { + return &api.ClusterUsageReport{ + ClusterId: "clusterA", + ReportTime: t, + ClusterCapacity: common.ComputeResources{"cpu": cpu, "memory": memory}, + Queues: []*api.QueueReport{ + { + Name: "q1", + Resources: common.ComputeResources{"cpu": cpu, "memory": memory}, + }, + }, + } +} + +func withUsageServer(action func (s *UsageServer)) { + db, err := miniredis.Run() + if err != nil { + panic(err) + } + defer db.Close() + + repo := &repository.RedisUsageRepository { Db: redis.NewClient(&redis.Options{Addr: db.Addr()})} + server := &UsageServer{UsageRepository: repo, PriorityHalfTime: time.Minute } + + action(server) +}