diff --git a/charts/quantum-serverless/charts/gateway/templates/deployment.yaml b/charts/quantum-serverless/charts/gateway/templates/deployment.yaml index bcd1ea858..3d1b97bc8 100644 --- a/charts/quantum-serverless/charts/gateway/templates/deployment.yaml +++ b/charts/quantum-serverless/charts/gateway/templates/deployment.yaml @@ -111,6 +111,12 @@ spec: value: {{ .Values.application.auth.keycloak.realm | quote }} - name: SETTINGS_KEYCLOAK_CLIENT_SECRET value: {{ .Values.application.auth.keycloak.clientSecret | quote }} + - name: RAY_CLUSTER_WORKER_REPLICAS + value: {{ .Values.application.ray.replicas | quote }} + - name: RAY_CLUSTER_WORKER_MIN_REPLICAS + value: {{ .Values.application.ray.minReplicas | quote }} + - name: RAY_CLUSTER_WORKER_MAX_REPLICAS + value: {{ .Values.application.ray.maxReplicas | quote }} {{- if .Values.application.superuser.enable }} - name: DJANGO_SUPERUSER_USERNAME valueFrom: diff --git a/charts/quantum-serverless/charts/gateway/templates/rayclustertemplate.yaml b/charts/quantum-serverless/charts/gateway/templates/rayclustertemplate.yaml index 27f21d9cd..bc7bc10c7 100644 --- a/charts/quantum-serverless/charts/gateway/templates/rayclustertemplate.yaml +++ b/charts/quantum-serverless/charts/gateway/templates/rayclustertemplate.yaml @@ -17,7 +17,8 @@ data: {{- if .Values.application.ray.scrapeWithPrometheus }} headServiceAnnotations: prometheus.io/scrape: "true" -{{- end }} +{{- end }} + enableInTreeAutoscaling: {{`{{ auto_scaling }}`}} headGroupSpec: rayStartParams: dashboard-host: 0.0.0.0 @@ -193,11 +194,11 @@ data: claimName: {{ .Values.cos.claimName }} workerGroupSpecs: - groupName: g - maxReplicas: {{ .Values.application.ray.maxReplicas }} - minReplicas: {{ .Values.application.ray.minReplicas }} + maxReplicas: {{`{{ max_workers }}`}} + minReplicas: {{`{{ min_workers }}`}} rayStartParams: block: 'true' - replicas: {{ .Values.application.ray.replicas }} + replicas: {{`{{ workers }}`}} template: {{- if .Values.application.ray.scrapeWithPrometheus }} metadata: diff --git a/client/quantum_serverless/__init__.py b/client/quantum_serverless/__init__.py index d8cb2bc88..36a491c14 100644 --- a/client/quantum_serverless/__init__.py +++ b/client/quantum_serverless/__init__.py @@ -20,7 +20,7 @@ QuantumServerlessException get_auto_discovered_provider """ - +# pylint: disable=W0404 from importlib_metadata import version as metadata_version, PackageNotFoundError from .core import ( @@ -36,6 +36,7 @@ RayProvider, LocalProvider, save_result, + Configuration, ) from .quantum_serverless import ( QuantumServerless, diff --git a/client/quantum_serverless/core/__init__.py b/client/quantum_serverless/core/__init__.py index 015c5b3c5..2f0511922 100644 --- a/client/quantum_serverless/core/__init__.py +++ b/client/quantum_serverless/core/__init__.py @@ -71,10 +71,10 @@ LocalJobClient, Job, save_result, + Configuration, ) from .pattern import ( QiskitPattern, - Program, ProgramStorage, ProgramRepository, download_and_unpack_artifact, diff --git a/client/quantum_serverless/core/job.py b/client/quantum_serverless/core/job.py index 16d4c93cb..05350d6a3 100644 --- a/client/quantum_serverless/core/job.py +++ b/client/quantum_serverless/core/job.py @@ -27,6 +27,7 @@ RuntimeEnv Job """ +# pylint: disable=duplicate-code import json import logging import os @@ -36,6 +37,7 @@ from pathlib import Path from typing import Dict, Any, Optional, List, Union from uuid import uuid4 +from dataclasses import asdict, dataclass import subprocess from subprocess import Popen @@ -58,6 +60,7 @@ MAX_ARTIFACT_FILE_SIZE_MB, ENV_JOB_ARGUMENTS, ) + from quantum_serverless.core.pattern import QiskitPattern from quantum_serverless.exception import QuantumServerlessException from quantum_serverless.serializers.program_serializers import ( @@ -69,11 +72,31 @@ RuntimeEnv = ray.runtime_env.RuntimeEnv +@dataclass +class Configuration: # pylint: disable=too-many-instance-attributes + """Program Configuration. + + Args: + workers: number of worker pod when auto scaling is NOT enabled + auto_scaling: set True to enable auto scating of the workers + min_workers: minimum number of workers when auto scaling is enabled + max_workers: maxmum number of workers when auto scaling is enabled + """ + + workers: Optional[int] = None + min_workers: Optional[int] = None + max_workers: Optional[int] = None + auto_scaling: Optional[bool] = False + + class BaseJobClient: """Base class for Job clients.""" def run( - self, program: QiskitPattern, arguments: Optional[Dict[str, Any]] = None + self, + program: QiskitPattern, + arguments: Optional[Dict[str, Any]] = None, + config: Optional[Configuration] = None, ) -> "Job": """Runs program.""" raise NotImplementedError @@ -86,6 +109,7 @@ def run_existing( self, program: Union[str, QiskitPattern], arguments: Optional[Dict[str, Any]] = None, + config: Optional[Configuration] = None, ): """Executes existing program.""" raise NotImplementedError @@ -151,7 +175,12 @@ def list(self, **kwargs) -> List["Job"]: Job(job.job_id, job_client=self) for job in self._job_client.list_jobs() ] - def run(self, program: QiskitPattern, arguments: Optional[Dict[str, Any]] = None): + def run( + self, + program: QiskitPattern, + arguments: Optional[Dict[str, Any]] = None, + config: Optional[Configuration] = None, + ): arguments = arguments or {} entrypoint = f"python {program.entrypoint}" @@ -180,6 +209,7 @@ def run_existing( self, program: Union[str, QiskitPattern], arguments: Optional[Dict[str, Any]] = None, + config: Optional[Configuration] = None, ): raise NotImplementedError("Run existing is not available for RayJobClient.") @@ -214,7 +244,12 @@ def get(self, job_id) -> Optional["Job"]: def list(self, **kwargs) -> List["Job"]: return [job["job"] for job in list(self._jobs.values())] - def run(self, program: QiskitPattern, arguments: Optional[Dict[str, Any]] = None): + def run( + self, + program: QiskitPattern, + arguments: Optional[Dict[str, Any]] = None, + config: Optional[Configuration] = None, + ): if program.dependencies: for dependency in program.dependencies: subprocess.check_call( @@ -266,10 +301,11 @@ def upload(self, program: QiskitPattern): } return program.title - def run_existing( + def run_existing( # pylint: disable=too-many-locals self, program: Union[str, QiskitPattern], arguments: Optional[Dict[str, Any]] = None, + config: Optional[Configuration] = None, ): if isinstance(program, QiskitPattern): title = program.title @@ -333,7 +369,10 @@ def __init__(self, host: str, token: str, version: str): self._token = token def run( # pylint: disable=too-many-locals - self, program: QiskitPattern, arguments: Optional[Dict[str, Any]] = None + self, + program: QiskitPattern, + arguments: Optional[Dict[str, Any]] = None, + config: Optional[Configuration] = None, ) -> "Job": tracer = trace.get_tracer("client.tracer") with tracer.start_as_current_span("job.run") as span: @@ -367,21 +406,26 @@ def run( # pylint: disable=too-many-locals ) with open(artifact_file_path, "rb") as file: + data = { + "title": program.title, + "entrypoint": program.entrypoint, + "arguments": json.dumps(arguments or {}, cls=QiskitObjectsEncoder), + "dependencies": json.dumps(program.dependencies or []), + } + if config: + data["config"] = json.dumps(asdict(config)) + else: + data["config"] = "{}" + response_data = safe_json_request( request=lambda: requests.post( url=url, - data={ - "title": program.title, - "entrypoint": program.entrypoint, - "arguments": json.dumps( - arguments or {}, cls=QiskitObjectsEncoder - ), - "dependencies": json.dumps(program.dependencies or []), - }, + data=data, files={"artifact": file}, headers={"Authorization": f"Bearer {self._token}"}, timeout=REQUESTS_TIMEOUT, - ) + ), + verbose=True, ) job_id = response_data.get("id") span.set_attribute("job.id", job_id) @@ -449,6 +493,7 @@ def run_existing( self, program: Union[str, QiskitPattern], arguments: Optional[Dict[str, Any]] = None, + config: Optional[Configuration] = None, ): if isinstance(program, QiskitPattern): title = program.title @@ -462,15 +507,19 @@ def run_existing( url = f"{self.host}/api/{self.version}/programs/run_existing/" + data = { + "title": title, + "arguments": json.dumps(arguments or {}, cls=QiskitObjectsEncoder), + } + if config: + data["config"] = json.dumps(asdict(config)) + else: + data["config"] = "{}" + response_data = safe_json_request( request=lambda: requests.post( url=url, - data={ - "title": title, - "arguments": json.dumps( - arguments or {}, cls=QiskitObjectsEncoder - ), - }, + data=data, headers={"Authorization": f"Bearer {self._token}"}, timeout=REQUESTS_TIMEOUT, ) diff --git a/client/quantum_serverless/core/provider.py b/client/quantum_serverless/core/provider.py index 146877f0a..7a86c8307 100644 --- a/client/quantum_serverless/core/provider.py +++ b/client/quantum_serverless/core/provider.py @@ -26,6 +26,7 @@ ComputeResource ServerlessProvider """ +# pylint: disable=duplicate-code import logging import warnings import os.path @@ -53,6 +54,7 @@ GatewayJobClient, LocalJobClient, BaseJobClient, + Configuration, ) from quantum_serverless.core.pattern import QiskitPattern from quantum_serverless.core.tracing import _trace_env_vars @@ -265,6 +267,7 @@ def run( self, program: Union[QiskitPattern, str], arguments: Optional[Dict[str, Any]] = None, + config: Optional[Configuration] = None, ) -> Job: """Execute a program as a async job. @@ -295,7 +298,7 @@ def run( ) return None - return job_client.run(program, arguments) + return job_client.run(program, arguments, config) def upload(self, program: QiskitPattern): """Uploads program.""" @@ -415,13 +418,14 @@ def run( self, program: Union[QiskitPattern, str], arguments: Optional[Dict[str, Any]] = None, + config: Optional[Configuration] = None, ) -> Job: tracer = trace.get_tracer("client.tracer") with tracer.start_as_current_span("Provider.run"): if isinstance(program, QiskitPattern) and program.entrypoint is not None: - job = self._job_client.run(program, arguments) + job = self._job_client.run(program, arguments, config) else: - job = self._job_client.run_existing(program, arguments) + job = self._job_client.run_existing(program, arguments, config) return job def upload(self, program: QiskitPattern): @@ -584,11 +588,12 @@ def run( self, program: Union[QiskitPattern, str], arguments: Optional[Dict[str, Any]] = None, + config: Optional[Configuration] = None, ) -> Job: if isinstance(program, str): raise NotImplementedError("Ray provider only supports full Programs.") - return self.client.run(program, arguments) + return self.client.run(program, arguments, config) def get_job_by_id(self, job_id: str) -> Optional[Job]: return self.client.get(job_id) @@ -615,6 +620,7 @@ def run( self, program: Union[QiskitPattern, str], arguments: Optional[Dict[str, Any]] = None, + config: Optional[Configuration] = None, ) -> Job: if isinstance(program, QiskitPattern) and program.entrypoint is not None: job = self.client.run(program, arguments) diff --git a/client/quantum_serverless/quantum_serverless.py b/client/quantum_serverless/quantum_serverless.py index 6dfb5e0d2..aaa578d5c 100644 --- a/client/quantum_serverless/quantum_serverless.py +++ b/client/quantum_serverless/quantum_serverless.py @@ -41,7 +41,7 @@ from opentelemetry.sdk.trace.export import BatchSpanProcessor from opentelemetry.instrumentation.requests import RequestsInstrumentor -from quantum_serverless.core.job import Job +from quantum_serverless.core.job import Job, Configuration from quantum_serverless.core.pattern import QiskitPattern from quantum_serverless.core.provider import BaseProvider, ComputeResource from quantum_serverless.exception import QuantumServerlessException @@ -103,7 +103,10 @@ def job_client(self): return self._selected_provider.job_client() def run( - self, program: QiskitPattern, arguments: Optional[Dict[str, Any]] = None + self, + program: QiskitPattern, + arguments: Optional[Dict[str, Any]] = None, + config: Optional[Configuration] = None, ) -> Optional[Job]: """Execute a program as a async job @@ -119,13 +122,14 @@ def run( Args: arguments: arguments to run program with program: Program object + config: Configuration object Returns: Job """ tracer = trace.get_tracer("client.tracer") with tracer.start_as_current_span("QuantumServerless.run"): - job = self._selected_provider.run(program, arguments) + job = self._selected_provider.run(program, arguments, config) return job def upload(self, program: QiskitPattern): diff --git a/gateway/README.md b/gateway/README.md index 9f5794f63..a3a61f76b 100644 --- a/gateway/README.md +++ b/gateway/README.md @@ -38,8 +38,11 @@ docker build -t qiskit/quantum-serverless-gateway: . | RAY_CLUSTER_TEMPLATE_CPU | default compute kuberay template cpu setting | | RAY_CLUSTER_TEMPLATE_MEM | default compute kuberay template memory setting | | RAY_CLUSTER_WORKER_REPLICAS | worker replicas per cluster | -| RAY_CLUSTER_WORKER_MIN_REPLICAS | min worker replicas per cluster | -| RAY_CLUSTER_WORKER_MAX_REPLICAS | max replicas per cluster | +| RAY_CLUSTER_WORKER_REPLICAS_MAX | maximum number of worker replicas per cluster | +| RAY_CLUSTER_WORKER_MIN_REPLICAS | min worker replicas per cluster for auto scaling | +| RAY_CLUSTER_WORKER_MIN_REPLICAS_MAX | maximum number of min worker replicas per cluster for auto scaling | +| RAY_CLUSTER_WORKER_MAX_REPLICAS | max replicas per cluster for auto scaling | +| RAY_CLUSTER_WORKER_MAX_REPLICAS_MAX | maximum number of max worker replicas per cluster for auto scaling | | RAY_CLUSTER_MAX_READINESS_TIME | max time in seconds to wait for cluster readiness. Will fail job if cluster is not ready in time. | | QISKIT_IBM_CHANNEL | Channel that will be set in env variables in jobs for QiskitRuntimeService client | | QISKIT_IBM_URL | Authentication url for QiskitRuntimeService that will be set for each job | diff --git a/gateway/api/migrations/0011_jobconfig_job_config.py b/gateway/api/migrations/0011_jobconfig_job_config.py new file mode 100644 index 000000000..458144316 --- /dev/null +++ b/gateway/api/migrations/0011_jobconfig_job_config.py @@ -0,0 +1,61 @@ +# Generated by Django 4.2.2 on 2023-11-03 17:54 + +import django.core.validators +from django.db import migrations, models +import django.db.models.deletion +import uuid + + +class Migration(migrations.Migration): + + dependencies = [ + ("api", "0010_job_version"), + ] + + operations = [ + migrations.CreateModel( + name="JobConfig", + fields=[ + ( + "id", + models.UUIDField( + default=uuid.uuid4, + editable=False, + primary_key=True, + serialize=False, + ), + ), + ("created", models.DateTimeField(auto_now_add=True)), + ("auto_scaling", models.BooleanField(default=False, null=True)), + ( + "workers", + models.IntegerField( + null=True, + ), + ), + ( + "min_workers", + models.IntegerField( + null=True, + ), + ), + ( + "max_workers", + models.IntegerField( + null=True, + ), + ), + ], + ), + migrations.AddField( + model_name="job", + name="config", + field=models.ForeignKey( + blank=True, + default=None, + null=True, + on_delete=django.db.models.deletion.CASCADE, + to="api.jobconfig", + ), + ), + ] diff --git a/gateway/api/models.py b/gateway/api/models.py index b78aa7a52..636171526 100644 --- a/gateway/api/models.py +++ b/gateway/api/models.py @@ -2,7 +2,9 @@ import uuid from concurrency.fields import IntegerVersionField -from django.core.validators import FileExtensionValidator +from django.core.validators import ( + FileExtensionValidator, +) from django.db import models from django.conf import settings from django_prometheus.models import ExportModelOperationsMixin @@ -13,6 +15,27 @@ def get_upload_path(instance, filename): return f"{instance.author.username}/{instance.id}/{filename}" +class JobConfig(models.Model): + """Job Configuration model.""" + + id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) + created = models.DateTimeField(auto_now_add=True) + + auto_scaling = models.BooleanField(default=False, null=True) + workers = models.IntegerField( + null=True, + ) + min_workers = models.IntegerField( + null=True, + ) + max_workers = models.IntegerField( + null=True, + ) + + def __str__(self): + return self.id + + class Program(ExportModelOperationsMixin("program"), models.Model): """Program model.""" @@ -109,6 +132,14 @@ class Job(models.Model): version = IntegerVersionField() + config = models.ForeignKey( + to=JobConfig, + on_delete=models.CASCADE, + default=None, + null=True, + blank=True, + ) + def __str__(self): return f"" diff --git a/gateway/api/ray.py b/gateway/api/ray.py index c0e6be5a3..fe217efcb 100644 --- a/gateway/api/ray.py +++ b/gateway/api/ray.py @@ -20,7 +20,7 @@ from opentelemetry import trace from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator -from api.models import ComputeResource, Job +from api.models import ComputeResource, Job, JobConfig from api.utils import try_json_loads, retry_function, decrypt_env_vars from main import settings @@ -157,7 +157,10 @@ def submit_job(job: Job) -> Job: def create_ray_cluster( - user: Any, cluster_name: Optional[str] = None, cluster_data: Optional[str] = None + user: Any, + cluster_name: Optional[str] = None, + cluster_data: Optional[str] = None, + job_config: Optional[JobConfig] = None, ) -> Optional[ComputeResource]: """Creates ray cluster. @@ -174,9 +177,27 @@ def create_ray_cluster( namespace = settings.RAY_KUBERAY_NAMESPACE cluster_name = cluster_name or f"{user.username}-{str(uuid.uuid4())[:8]}" if not cluster_data: + if not job_config: + job_config = JobConfig() + if not job_config.workers: + job_config.workers = settings.RAY_CLUSTER_WORKER_REPLICAS + if not job_config.min_workers: + job_config.min_workers = settings.RAY_CLUSTER_WORKER_MIN_REPLICAS + if not job_config.max_workers: + job_config.max_workers = settings.RAY_CLUSTER_WORKER_MAX_REPLICAS + if not job_config.auto_scaling: + job_config.auto_scaling = settings.RAY_CLUSTER_WORKER_AUTO_SCALING + cluster = get_template("rayclustertemplate.yaml") manifest = cluster.render( - {"cluster_name": cluster_name, "user_id": user.username} + { + "cluster_name": cluster_name, + "user_id": user.username, + "workers": job_config.workers, + "min_workers": job_config.min_workers, + "max_workers": job_config.max_workers, + "auto_scaling": job_config.auto_scaling, + } ) cluster_data = yaml.safe_load(manifest) diff --git a/gateway/api/schedule.py b/gateway/api/schedule.py index 98b83fa7d..1ac00dde7 100644 --- a/gateway/api/schedule.py +++ b/gateway/api/schedule.py @@ -17,6 +17,7 @@ from api.ray import submit_job, create_ray_cluster, kill_ray_cluster from main import settings as config + User: Model = get_user_model() logger = logging.getLogger("commands") @@ -25,11 +26,12 @@ def save_program(serializer, request) -> Program: """Save program. Args: - serializer: program serializer with data attached. + request: request data. Returns: saved program """ + existing_program = ( Program.objects.filter(title=serializer.data.get("title"), author=request.user) .order_by("-created") @@ -97,7 +99,9 @@ def execute_job(job: Job) -> Job: job.status = Job.FAILED job.logs = "Compute resource was not found." else: - compute_resource = create_ray_cluster(job.author, cluster_name=cluster_name) + compute_resource = create_ray_cluster( + job.author, cluster_name=cluster_name, job_config=job.config + ) if compute_resource: # if compute resource was created in time with no problems job.compute_resource = compute_resource diff --git a/gateway/api/serializers.py b/gateway/api/serializers.py index cc8b7d3a6..8cd57935d 100644 --- a/gateway/api/serializers.py +++ b/gateway/api/serializers.py @@ -6,9 +6,43 @@ Version serializers inherit from the different serializers. """ +from django.conf import settings from rest_framework import serializers +from .models import Program, Job, JobConfig -from .models import Program, Job + +class JobConfigSerializer(serializers.ModelSerializer): + """ + Serializer for the Job Config model. + """ + + class Meta: + model = JobConfig + fields = [ + "workers", + "min_workers", + "max_workers", + "auto_scaling", + ] + + workers = serializers.IntegerField( + max_value=settings.RAY_CLUSTER_WORKER_REPLICAS_MAX, + required=False, + allow_null=True, + ) + min_workers = serializers.IntegerField( + max_value=settings.RAY_CLUSTER_WORKER_MIN_REPLICAS_MAX, + required=False, + allow_null=True, + ) + max_workers = serializers.IntegerField( + max_value=settings.RAY_CLUSTER_WORKER_MAX_REPLICAS_MAX, + required=False, + allow_null=True, + ) + auto_scaling = serializers.BooleanField( + default=False, required=False, allow_null=True + ) class ProgramSerializer(serializers.ModelSerializer): diff --git a/gateway/api/v1/serializers.py b/gateway/api/v1/serializers.py index 19ed41698..d5c557b1a 100644 --- a/gateway/api/v1/serializers.py +++ b/gateway/api/v1/serializers.py @@ -11,7 +11,13 @@ class ProgramSerializer(serializers.ProgramSerializer): """ class Meta(serializers.ProgramSerializer.Meta): - fields = ["title", "entrypoint", "artifact", "dependencies", "arguments"] + fields = [ + "title", + "entrypoint", + "artifact", + "dependencies", + "arguments", + ] class JobSerializer(serializers.JobSerializer): diff --git a/gateway/api/views.py b/gateway/api/views.py index 54f2f0c3d..6fc312769 100644 --- a/gateway/api/views.py +++ b/gateway/api/views.py @@ -37,7 +37,7 @@ from .models import Program, Job from .ray import get_job_handler from .schedule import save_program -from .serializers import JobSerializer, ExistingProgramSerializer +from .serializers import JobSerializer, ExistingProgramSerializer, JobConfigSerializer from .utils import build_env_variables, encrypt_env_vars logger = logging.getLogger("gateway") @@ -115,11 +115,23 @@ def run_existing(self, request): status=status.HTTP_404_NOT_FOUND, ) + jobconfig = None + config_data = request.data.get("config") + if config_data: + config_serializer = JobConfigSerializer(data=json.loads(config_data)) + if not config_serializer.is_valid(): + return Response( + config_serializer.errors, status=status.HTTP_400_BAD_REQUEST + ) + + jobconfig = config_serializer.save() + job = Job( program=program, arguments=serializer.data.get("arguments"), author=request.user, status=Job.QUEUED, + config=jobconfig, ) job.save() @@ -150,12 +162,25 @@ def run(self, request): if not serializer.is_valid(): return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + jobconfig = None + config_data = request.data.get("config") + if config_data: + config_serializer = JobConfigSerializer(data=json.loads(config_data)) + if not config_serializer.is_valid(): + return Response( + config_serializer.errors, status=status.HTTP_400_BAD_REQUEST + ) + + jobconfig = config_serializer.save() + program = save_program(serializer=serializer, request=request) + job = Job( program=program, arguments=program.arguments, author=request.user, status=Job.QUEUED, + config=jobconfig, ) job.save() diff --git a/gateway/main/settings.py b/gateway/main/settings.py index fd6ce8c53..608facef7 100644 --- a/gateway/main/settings.py +++ b/gateway/main/settings.py @@ -300,13 +300,25 @@ "RAY_CLUSTER_MODE_LOCAL_HOST", "http://localhost:8265" ), } -RAY_CLUSTER_WORKER_REPLICAS = int(os.environ.get("RAY_CLUSTER_WORKER_REPLICAS", "0")) +RAY_CLUSTER_WORKER_REPLICAS = int(os.environ.get("RAY_CLUSTER_WORKER_REPLICAS", "1")) +RAY_CLUSTER_WORKER_REPLICAS_MAX = int( + os.environ.get("RAY_CLUSTER_WORKER_REPLICAS_MAX", "5") +) RAY_CLUSTER_WORKER_MIN_REPLICAS = int( - os.environ.get("RAY_CLUSTER_WORKER_MIN_REPLICAS", "0") + os.environ.get("RAY_CLUSTER_WORKER_MIN_REPLICAS", "1") +) +RAY_CLUSTER_WORKER_MIN_REPLICAS_MAX = int( + os.environ.get("RAY_CLUSTER_WORKER_MIN_REPLICAS_MAX", "2") ) RAY_CLUSTER_WORKER_MAX_REPLICAS = int( os.environ.get("RAY_CLUSTER_WORKER_MAX_REPLICAS", "4") ) +RAY_CLUSTER_WORKER_MAX_REPLICAS_MAX = int( + os.environ.get("RAY_CLUSTER_WORKER_MAX_REPLICAS_MAX", "10") +) +RAY_CLUSTER_WORKER_AUTO_SCALING = bool( + os.environ.get("RAY_CLUSTER_WORKER_AUTO_SCALING", False) +) RAY_CLUSTER_MAX_READINESS_TIME = int( os.environ.get("RAY_CLUSTER_MAX_READINESS_TIME", "120") ) diff --git a/gateway/tests/api/test_serializer.py b/gateway/tests/api/test_serializer.py new file mode 100644 index 000000000..db6a6cdda --- /dev/null +++ b/gateway/tests/api/test_serializer.py @@ -0,0 +1,33 @@ +"""Tests for serializer functions.""" + +import json +from rest_framework.test import APITestCase +from api.serializers import JobConfigSerializer +from api.models import JobConfig + + +class SerializerTest(APITestCase): + """Tests for serializer.""" + + def test_JobConfigSerializer(self): + data = '{"workers": null, "min_workers": 1, "max_workers": 5, "auto_scaling": true}' + config_serializer = JobConfigSerializer(data=json.loads(data)) + assert config_serializer.is_valid() + jobconfig = config_serializer.save() + + entry = JobConfig.objects.get(id=jobconfig.id) + assert not entry.workers + assert entry.min_workers == 1 + assert entry.max_workers == 5 + assert entry.auto_scaling + + data = '{"workers": 3, "min_workers": null, "max_workers": null, "auto_scaling": null}' + config_serializer = JobConfigSerializer(data=json.loads(data)) + assert config_serializer.is_valid() + jobconfig = config_serializer.save() + + entry = JobConfig.objects.get(id=jobconfig.id) + assert entry.workers == 3 + assert not entry.min_workers + assert not entry.max_workers + assert not entry.auto_scaling diff --git a/gateway/tests/api/test_v1_program.py b/gateway/tests/api/test_v1_program.py index ac23a5220..757ce7e30 100644 --- a/gateway/tests/api/test_v1_program.py +++ b/gateway/tests/api/test_v1_program.py @@ -1,8 +1,8 @@ """Tests program APIs.""" - from django.urls import reverse from rest_framework import status from rest_framework.test import APITestCase +from api.models import Job, JobConfig class TestProgramApi(APITestCase): @@ -54,3 +54,30 @@ def test_program_detail(self): self.assertEqual(programs_response.status_code, status.HTTP_200_OK) self.assertEqual(programs_response.data.get("title"), "Program") self.assertEqual(programs_response.data.get("entrypoint"), "program.py") + + def test_run_existing(self): + """Tests run existing authorized.""" + auth = reverse("rest_login") + response = self.client.post( + auth, {"username": "test_user", "password": "123"}, format="json" + ) + token = response.data.get("access") + self.client.credentials(HTTP_AUTHORIZATION="Bearer " + token) + + programs_response = self.client.post( + "/api/v1/programs/run_existing/", + data={ + "title": "Program", + "entrypoint": "program.py", + "arguments": {}, + "dependencies": [], + "config": '{"workers": null, "min_workers": 1, "max_workers": 5, "auto_scaling": true}', + }, + format="json", + ) + job_id = programs_response.data.get("id") + job = Job.objects.get(id=job_id) + self.assertEqual(job.config.min_workers, 1) + self.assertEqual(job.config.max_workers, 5) + self.assertEqual(job.config.workers, None) + self.assertEqual(job.config.auto_scaling, True)