diff --git a/gateway/api/serializers.py b/gateway/api/serializers.py index c8b93aac3..1ac2b0a4c 100644 --- a/gateway/api/serializers.py +++ b/gateway/api/serializers.py @@ -14,7 +14,7 @@ from django.db.models import Q from rest_framework import serializers -from api.utils import build_env_variables, encrypt_env_vars +from api.utils import build_env_variables, encrypt_env_vars, sanitize_name from .models import ( Provider, Program, @@ -90,10 +90,10 @@ def retrieve_provider_function(self, title, provider_name): return Program.objects.filter(title=title, provider__name=provider_name).first() def create(self, validated_data): - title = validated_data.get("title") + title = sanitize_name(validated_data.get("title")) logger.info("Creating program [%s] with UploadProgramSerializer", title) - provider_name = validated_data.get("provider", None) + provider_name = sanitize_name(validated_data.get("provider", None)) if provider_name: validated_data["provider"] = Provider.objects.filter( name=provider_name diff --git a/gateway/api/utils.py b/gateway/api/utils.py index 7fdece114..8c71ddf8c 100644 --- a/gateway/api/utils.py +++ b/gateway/api/utils.py @@ -417,3 +417,14 @@ def create_dependency_allowlist(): raise ValueError("Unable to decode dependency allowlist") from e return allowlist + + +def sanitize_name(name: str): + """Sanitize name""" + if name: + sanitized_name = "" + for c in name: + if c.isalnum() or c in ["_", "-", "/"]: + sanitized_name += c + return sanitized_name + return name diff --git a/gateway/api/views.py b/gateway/api/views.py index a0114df57..f0bb2b800 100644 --- a/gateway/api/views.py +++ b/gateway/api/views.py @@ -31,6 +31,7 @@ from rest_framework.response import Response from qiskit_ibm_runtime import RuntimeInvalidStateError, QiskitRuntimeService +from api.utils import sanitize_name from utils import sanitize_file_path from .models import ( @@ -122,8 +123,8 @@ def get_object(self): def get_queryset(self): author = self.request.user - title = self.request.query_params.get("title") - provider_name = self.request.query_params.get("provider") + title = sanitize_name(self.request.query_params.get("title")) + provider_name = sanitize_name(self.request.query_params.get("provider")) type_filter = self.request.query_params.get("filter") author_programs = self._get_program_queryset_for_title_and_provider( @@ -263,7 +264,7 @@ def run(self, request): author_program = self.get_run_queryset() author = request.user - title = serializer.data.get("title") + title = sanitize_name(serializer.data.get("title")) program = author_program.filter(title=title).first() if program is None: logger.error("Qiskit Pattern [%s] was not found.", title) @@ -378,6 +379,8 @@ def _get_program_queryset_for_title_and_provider( author_groups_with_view_permissions_criteria = Q( instances__in=author_groups_with_view_permissions ) + title = sanitize_name(title) + provider_name = sanitize_name(provider_name) if title: serializer = self.get_serializer_upload_program(data=self.request.data) provider_name, title = serializer.get_provider_name_and_title(