Skip to content

Commit

Permalink
sanitize title and provider input string (#1504)
Browse files Browse the repository at this point in the history
* sanitize title and provider input string
  • Loading branch information
akihikokuroda authored Sep 30, 2024
1 parent f337fe6 commit eed51de
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 6 deletions.
6 changes: 3 additions & 3 deletions gateway/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from django.db.models import Q
from rest_framework import serializers

from api.utils import build_env_variables, encrypt_env_vars
from api.utils import build_env_variables, encrypt_env_vars, sanitize_name
from .models import (
Provider,
Program,
Expand Down Expand Up @@ -90,10 +90,10 @@ def retrieve_provider_function(self, title, provider_name):
return Program.objects.filter(title=title, provider__name=provider_name).first()

def create(self, validated_data):
title = validated_data.get("title")
title = sanitize_name(validated_data.get("title"))
logger.info("Creating program [%s] with UploadProgramSerializer", title)

provider_name = validated_data.get("provider", None)
provider_name = sanitize_name(validated_data.get("provider", None))
if provider_name:
validated_data["provider"] = Provider.objects.filter(
name=provider_name
Expand Down
11 changes: 11 additions & 0 deletions gateway/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,3 +417,14 @@ def create_dependency_allowlist():
raise ValueError("Unable to decode dependency allowlist") from e

return allowlist


def sanitize_name(name: str):
"""Sanitize name"""
if name:
sanitized_name = ""
for c in name:
if c.isalnum() or c in ["_", "-", "/"]:
sanitized_name += c
return sanitized_name
return name
9 changes: 6 additions & 3 deletions gateway/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from rest_framework.response import Response

from qiskit_ibm_runtime import RuntimeInvalidStateError, QiskitRuntimeService
from api.utils import sanitize_name
from utils import sanitize_file_path

from .models import (
Expand Down Expand Up @@ -122,8 +123,8 @@ def get_object(self):

def get_queryset(self):
author = self.request.user
title = self.request.query_params.get("title")
provider_name = self.request.query_params.get("provider")
title = sanitize_name(self.request.query_params.get("title"))
provider_name = sanitize_name(self.request.query_params.get("provider"))
type_filter = self.request.query_params.get("filter")

author_programs = self._get_program_queryset_for_title_and_provider(
Expand Down Expand Up @@ -263,7 +264,7 @@ def run(self, request):

author_program = self.get_run_queryset()
author = request.user
title = serializer.data.get("title")
title = sanitize_name(serializer.data.get("title"))
program = author_program.filter(title=title).first()
if program is None:
logger.error("Qiskit Pattern [%s] was not found.", title)
Expand Down Expand Up @@ -378,6 +379,8 @@ def _get_program_queryset_for_title_and_provider(
author_groups_with_view_permissions_criteria = Q(
instances__in=author_groups_with_view_permissions
)
title = sanitize_name(title)
provider_name = sanitize_name(provider_name)
if title:
serializer = self.get_serializer_upload_program(data=self.request.data)
provider_name, title = serializer.get_provider_name_and_title(
Expand Down

0 comments on commit eed51de

Please sign in to comment.