diff --git a/CHANGELOG.md b/CHANGELOG.md index 13adfea..7bbcb8a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,8 @@ ## Unreleased ## 2.3.0 +### Adds +- [#25](https://github.com/ibleducation/ibl-edx-lti-1p3-provider-app/issues/25): Adds `allowed_courses` to `LaunchGate` - [#21](https://github.com/ibleducation/ibl-edx-lti-1p3-provider-app/issues/21): Adds API to manage `LtiToolKey`'s at a tenant level - [#22](https://github.com/ibleducation/ibl-edx-lti-1p3-provider-app/issues/22): Adds API to manage `LtiTool`'s at a tenant level diff --git a/src/lti_1p3_provider/admin.py b/src/lti_1p3_provider/admin.py index ddcc028..7760329 100644 --- a/src/lti_1p3_provider/admin.py +++ b/src/lti_1p3_provider/admin.py @@ -46,7 +46,7 @@ class LtiGradedResourceAdmin(admin.ModelAdmin): @admin.register(models.LtiToolOrg) -class LtiTooLOrgAdmin(admin.ModelAdmin): +class LtiToolOrgAdmin(admin.ModelAdmin): list_display = ("tool_name", "tool_issuer", "tool_client_id", "edx_org_name") def tool_name(self, obj) -> str: @@ -67,3 +67,8 @@ def formfield_for_foreignkey(self, db_field, request, **kwargs): if db_field.name == "tool": formfield.label_from_instance = lambda obj: f"{obj.title} ({obj.client_id})" return formfield + + +@admin.register(models.LtiKeyOrg) +class LtiKeyOrgAdmin(admin.ModelAdmin): + list_display = ("key", "org") diff --git a/src/lti_1p3_provider/api/serializers.py b/src/lti_1p3_provider/api/serializers.py index 3be5652..711b73c 100644 --- a/src/lti_1p3_provider/api/serializers.py +++ b/src/lti_1p3_provider/api/serializers.py @@ -1,11 +1,16 @@ from __future__ import annotations +import json + from django.db import IntegrityError +from opaque_keys import InvalidKeyError +from opaque_keys.edx.keys import CourseKey, UsageKey +from openedx.core.lib.api.serializers import CourseKeyField, UsageKeyField from organizations.models import Organization from pylti1p3.contrib.django.lti1p3_tool_config.models import LtiTool, LtiToolKey from rest_framework import serializers -from ..models import LtiKeyOrg, LtiToolOrg +from ..models import LaunchGate, LtiKeyOrg, LtiToolOrg from . import ssl_services @@ -13,6 +18,27 @@ class StringListField(serializers.ListField): child = serializers.CharField() +class TextBackedListField(StringListField): + """A ListField backed by a Char-Type field in the db + + - Writes as a JSON String + - Reads object from a JSON string + """ + + def to_representation(self, data): + if isinstance(data, str): + try: + data = json.loads(data) + except ValueError: + data = [] + return super().to_representation(data) + + def to_internal_value(self, data): + if data: + return json.dumps(data) + return "[]" + + class LtiToolKeySerializer(serializers.ModelSerializer): class Meta: model = LtiToolKey @@ -41,7 +67,10 @@ def to_representation(self, instance): def update(self, instance, validated_data): name = validated_data["name"] validated_data["name"] = f"{self.context['org_short_name']}-{name}" - return super().update(instance, validated_data) + try: + return super().update(instance, validated_data) + except IntegrityError: + raise serializers.ValidationError(f"Key name: '{name}' already exists") def create(self, validated_data): """Autogenerate private/public key pairs""" @@ -59,16 +88,100 @@ def create(self, validated_data): tool_key = LtiToolKey.objects.create(**validated_data) LtiKeyOrg.objects.create(key=tool_key, org=lti_org) except IntegrityError: - raise serializers.ValidationError(f"Tool name: '{name}' already exists") + raise serializers.ValidationError(f"Key name: '{name}' already exists") return tool_key +class LaunchGateSerializer(serializers.ModelSerializer): + class Meta: + model = LaunchGate + fields = ["allowed_keys", "allowed_courses", "allow_all_within_org"] + + allowed_keys = StringListField( + allow_empty=True, + default=lambda: [], + ) + allowed_courses = StringListField( + allow_empty=True, + default=lambda: [], + ) + allow_all_within_org = serializers.BooleanField( + default=False, + help_text="If True, a target_link_uri will work with any content within this org", + ) + + def validate(self, attrs): + """Ensure at least one of allow* is set""" + if not ( + attrs["allowed_keys"] + or attrs["allowed_courses"] + or attrs["allow_all_within_org"] + ): + raise serializers.ValidationError( + "Set either allow_all_within_org or one or more of allowed_courses/allowed_keys" + ) + return attrs + + def to_representation(self, instance): + rep = super().to_representation(instance) + rep["allow_all_within_org"] = instance.allowed_orgs == [ + self.context["org_short_name"] + ] + return rep + + def validate_allowed_courses(self, value): + for key in value: + try: + key = CourseKey.from_string(key) + org_short_name = self.context["org_short_name"] + if key.org != org_short_name: + raise serializers.ValidationError( + f"Course Key must be within org: {org_short_name}" + ) + except InvalidKeyError: + raise serializers.ValidationError( + "Invalid Course Key. Format is: course-v1:++" + ) + return value + + def validate_allowed_keys(self, value): + for key in value: + try: + key = UsageKey.from_string(key) + org_short_name = self.context["org_short_name"] + if key.course_key.org != org_short_name: + raise serializers.ValidationError( + f"Usage Key must be within org: {org_short_name}" + ) + except InvalidKeyError: + raise serializers.ValidationError( + "Invalid Usage Key. Format is: " + "block-v1:+++type@+block@" + ) + return value + + class LtiToolSerializer(serializers.ModelSerializer): class Meta: model = LtiTool - fields = "__all__" - - deployment_ids = StringListField() + fields = [ + "id", + "title", + "issuer", + "is_active", + "client_id", + "auth_login_url", + "auth_token_url", + "auth_audience", + "key_set_url", + "key_set", + "tool_key", + "deployment_ids", + "launch_gate", + ] + + deployment_ids = TextBackedListField() + launch_gate = LaunchGateSerializer() def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -79,6 +192,12 @@ def __init__(self, *args, **kwargs): def validate(self, attrs): short_name = self.context["org_short_name"] + # Since this endpoint is for an org, allowed_orgs must be [] or their org only + allow_all_within_org = attrs["launch_gate"].pop("allow_all_within_org") + attrs["launch_gate"]["allowed_orgs"] = ( + [short_name] if allow_all_within_org else [] + ) + try: # Since we're validating it we may as well store it attrs["org"] = Organization.objects.get(short_name=short_name) @@ -94,8 +213,26 @@ def validate(self, attrs): return attrs + def update(self, instance, validated_data): + """Update object and launch gate, creating launch gate if necessary""" + # Update LtiTool object + launch_gate_data = validated_data.pop("launch_gate") + for attr, value in validated_data.items(): + setattr(instance, attr, value) + instance.save() + + # Update LaunchGate + launch_gate = instance.launch_gate + for attr, value in launch_gate_data.items(): + setattr(launch_gate, attr, value) + launch_gate.save() + + return instance + def create(self, validated_data): lti_org = validated_data.pop("org") - tool = super().create(validated_data) + launch_gate_data = validated_data.pop("launch_gate") + tool = LtiTool.objects.create(**validated_data) LtiToolOrg.objects.create(tool=tool, org=lti_org) + LaunchGate.objects.create(tool=tool, **launch_gate_data) return tool diff --git a/src/lti_1p3_provider/api/tests/test_views.py b/src/lti_1p3_provider/api/tests/test_views.py index 08efa89..8a79678 100644 --- a/src/lti_1p3_provider/api/tests/test_views.py +++ b/src/lti_1p3_provider/api/tests/test_views.py @@ -1,5 +1,6 @@ from __future__ import annotations +import copy from typing import Any import pytest @@ -114,7 +115,7 @@ def test_create_name_already_exists_in_org_returns_400(self, client, admin_token resp = self.request(client, "post", endpoint, data=payload, token=admin_token) - assert resp.json() == ["Tool name: 'test' already exists"] + assert resp.json() == ["Key name: 'test' already exists"] assert resp.status_code == 400 def test_create_same_name_in_multiple_orgs_succeeds_200(self, client, admin_token): @@ -235,6 +236,24 @@ def test_detail_returns_200(self, client, admin_token): } assert resp.status_code == 200 + def test_update_key_name_already_exists_returns_400(self, client, admin_token): + """Test updating a tool name that already exists in org returns 400""" + org1 = OrganizationFactory() + org_key1 = factories.LtiKeyOrgFactory( + org=org1, key__name=f"{org1.short_name}-test1" + ) + org_key = factories.LtiKeyOrgFactory( + org=org1, key__name=f"{org1.short_name}-test" + ) + # try to rename it to test which is already taken in this org + payload = {"name": "test1"} + endpoint = self._get_detail_endpoint(org1.short_name, org_key.key.id) + + resp = self.request(client, "put", endpoint, data=payload, token=admin_token) + + assert resp.json() == ["Key name: 'test1' already exists"] + assert resp.status_code == 400 + def test_update_returns_200(self, client, admin_token): """Update updates name and returns 200""" org = OrganizationFactory() @@ -282,7 +301,6 @@ def setup_method(self): "is_active": True, "issuer": "https://issuer.local", "client_id": "12345", - "use_by_default": False, "auth_login_url": "https://issuer.local/auth", "auth_token_url": "https://issuer.local/token", "auth_audience": "", @@ -290,6 +308,13 @@ def setup_method(self): "key_set": "", "tool_key": self.key.id, "deployment_ids": [1, "test", 1234, "5"], + "launch_gate": { + "allowed_keys": [ + f"block-v1:{self.org.short_name}+course+run+type@obj+block@uuid" + ], + "allowed_courses": [f"course-v1:{self.org.short_name}+course+run"], + "allow_all_within_org": False, + }, } @pytest.mark.parametrize("pop_key_field", ("key_set_url", "key_set")) @@ -306,6 +331,7 @@ def test_create_returns_201(self, pop_key_field, client, admin_token): client, "post", endpoint, data=self.payload, token=admin_token ) + assert resp.status_code == 201, resp.json() tool = LtiTool.objects.get(client_id="12345") expected = self.payload.copy() expected["id"] = tool.id @@ -313,6 +339,51 @@ def test_create_returns_201(self, pop_key_field, client, admin_token): expected["deployment_ids"] = [str(x) for x in self.payload["deployment_ids"]] assert resp.json() == expected assert tool.tool_org.org == self.org + # Validate launchg gate created + launch_gate = tool.launch_gate + assert launch_gate.allowed_keys == self.payload["launch_gate"]["allowed_keys"] + assert ( + launch_gate.allowed_courses + == self.payload["launch_gate"]["allowed_courses"] + ) + assert not launch_gate.allowed_orgs + + def test_create_invalid_course_key_returns_400(self, client, admin_token): + """If course_key is invalid, returns 400""" + endpoint = self._get_list_endpoint(self.org.short_name) + self.payload["launch_gate"]["allowed_courses"] = ["bad-key"] + + resp = self.request( + client, "post", endpoint, data=self.payload, token=admin_token + ) + + assert resp.status_code == 400, resp.json() + assert resp.json() == { + "launch_gate": { + "allowed_courses": [ + "Invalid Course Key. Format is: course-v1:++" + ] + } + } + + def test_create_invalid_usage_key_returns_400(self, client, admin_token): + """If usage_key is invalid, returns 400""" + endpoint = self._get_list_endpoint(self.org.short_name) + self.payload["launch_gate"]["allowed_keys"] = ["bad-key"] + + resp = self.request( + client, "post", endpoint, data=self.payload, token=admin_token + ) + + assert resp.status_code == 400, resp.json() + assert resp.json() == { + "launch_gate": { + "allowed_keys": [ + "Invalid Usage Key. Format is: " + "block-v1:+++type@+block@" + ] + } + } def test_create_using_non_supplied_defaults_returns_201_with_defaults_set( self, client, admin_token @@ -321,7 +392,6 @@ def test_create_using_non_supplied_defaults_returns_201_with_defaults_set( endpoint = self._get_list_endpoint(self.org.short_name) expected = self.payload.copy() self.payload.pop("is_active") - self.payload.pop("use_by_default") self.payload.pop("auth_audience") self.payload.pop("key_set") @@ -355,7 +425,7 @@ def test_create_missing_both_key_set_and_key_set_url_returns_400( } assert resp.status_code == 400 - def test_create_org_dne_returns_400(self, client, admin_token): + def test_create_if_org_dne_returns_400(self, client, admin_token): """Test creating key for org that DNE returns 400""" endpoint = self._get_list_endpoint("dne") @@ -364,7 +434,41 @@ def test_create_org_dne_returns_400(self, client, admin_token): ) assert resp.json() == { - "tool_key": [f'Invalid pk "{self.key.id}" - object does not exist.'] + "tool_key": [f'Invalid pk "{self.key.id}" - object does not exist.'], + "launch_gate": { + "allowed_courses": ["Course Key must be within org: dne"], + "allowed_keys": ["Usage Key must be within org: dne"], + }, + } + assert resp.status_code == 400 + + def test_create_if_launch_gate_courses_or_keys_not_in_org_returns_400( + self, client, admin_token + ): + """If launch gate courses/keys not within target org, 400 is returned""" + org1 = OrganizationFactory() + org2 = OrganizationFactory() + self.payload["launch_gate"]["allowed_courses"] = [ + f"course-v1:{org1.short_name}+course+run" + ] + self.payload["launch_gate"]["allowed_keys"] = [ + f"block-v1:{org2.short_name}+course+run+type@obj+block@uuid" + ] + endpoint = self._get_list_endpoint(self.org.short_name) + + resp = self.request( + client, "post", endpoint, data=self.payload, token=admin_token + ) + + assert resp.json() == { + "launch_gate": { + "allowed_courses": [ + f"Course Key must be within org: {self.org.short_name}" + ], + "allowed_keys": [ + f"Usage Key must be within org: {self.org.short_name}" + ], + }, } assert resp.status_code == 400 @@ -488,33 +592,40 @@ def test_detail_returns_200(self, client, admin_token): def test_update_returns_200(self, client, admin_token): """Update updates entity and returns 200""" - org = OrganizationFactory() - key_org = factories.LtiToolOrgFactory(org=org, tool__tool_key=self.key) - new_key = factories.LtiKeyOrgFactory(org=org) - org = key_org.org + key_org = factories.LtiToolOrgFactory(org=self.org, tool__tool_key=self.key) + new_key = factories.LtiKeyOrgFactory(org=self.org) tool = key_org.tool - endpoint = self._get_detail_endpoint(org.short_name, tool.pk) + existing_gate = factories.LaunchGateFactory(tool=tool) + endpoint = self._get_detail_endpoint(self.org.short_name, tool.pk) self.payload["tool_key"] = new_key.key.id + self.payload["launch_gate"]["allow_all_within_org"] = True resp = self.request( client, "put", endpoint, data=self.payload, token=admin_token ) - expected = self.payload.copy() + expected = copy.deepcopy(self.payload) expected["id"] = tool.id expected["deployment_ids"] = [str(x) for x in self.payload["deployment_ids"]] + + assert resp.status_code == 200, resp.json() assert resp.json() == expected - assert resp.status_code == 200 + existing_gate.refresh_from_db() + assert existing_gate.allowed_keys == self.payload["launch_gate"]["allowed_keys"] + assert ( + existing_gate.allowed_courses + == self.payload["launch_gate"]["allowed_courses"] + ) + assert existing_gate.allowed_orgs == [self.org.short_name] def test_update_with_tool_key_from_other_org_returns_400(self, client, admin_token): """Update updates entity and returns 200""" - org = OrganizationFactory() - org2 = OrganizationFactory() - key_org = factories.LtiToolOrgFactory(org=org, tool__tool_key=self.key) - new_key = factories.LtiKeyOrgFactory(org=org2) - org = key_org.org - tool = key_org.tool - endpoint = self._get_detail_endpoint(org.short_name, tool.pk) + new_org = OrganizationFactory() + tool_org = factories.LtiToolOrgFactory(org=self.org, tool__tool_key=self.key) + new_key = factories.LtiKeyOrgFactory(org=new_org) + org = tool_org.org + tool = tool_org.tool + endpoint = self._get_detail_endpoint(self.org.short_name, tool.pk) self.payload["tool_key"] = new_key.key.id resp = self.request( @@ -523,5 +634,5 @@ def test_update_with_tool_key_from_other_org_returns_400(self, client, admin_tok assert resp.status_code == 400, resp.json() assert resp.json() == { - "tool_key": [f'Invalid pk "{new_key.key.id}" - object does not exist.'] + "tool_key": [f'Invalid pk "{new_key.key.id}" - object does not exist.'], } diff --git a/src/lti_1p3_provider/migrations/0004_ltikeyorg.py b/src/lti_1p3_provider/migrations/0004_auto_20240822_1832.py similarity index 74% rename from src/lti_1p3_provider/migrations/0004_ltikeyorg.py rename to src/lti_1p3_provider/migrations/0004_auto_20240822_1832.py index 1a1c88f..1924d84 100644 --- a/src/lti_1p3_provider/migrations/0004_ltikeyorg.py +++ b/src/lti_1p3_provider/migrations/0004_auto_20240822_1832.py @@ -1,4 +1,4 @@ -# Generated by Django 3.2.20 on 2024-08-21 21:43 +# Generated by Django 3.2.20 on 2024-08-22 18:32 from django.db import migrations, models import django.db.models.deletion @@ -13,6 +13,11 @@ class Migration(migrations.Migration): ] operations = [ + migrations.AddField( + model_name='launchgate', + name='allowed_courses', + field=models.JSONField(blank=True, default=list, help_text="Allows tool to access these specific CourseKey's"), + ), migrations.CreateModel( name='LtiKeyOrg', fields=[ diff --git a/src/lti_1p3_provider/models.py b/src/lti_1p3_provider/models.py index fafd167..b54b976 100644 --- a/src/lti_1p3_provider/models.py +++ b/src/lti_1p3_provider/models.py @@ -362,6 +362,11 @@ class LaunchGate(models.Model): help_text="Allows tool to access these specific UsageKeys", blank=True, ) + allowed_courses = models.JSONField( + default=list, + help_text="Allows tool to access these specific CourseKey's", + blank=True, + ) allowed_orgs = models.JSONField( default=list, help_text="Allows tools to access any content in these orgs", @@ -371,16 +376,19 @@ class LaunchGate(models.Model): def can_access_key(self, usage_key: UsageKey) -> bool: """Return True if tool can access usage_key - This is evaluated as an OR of allowed_keys and allowed_orgs + This is evaluated as an OR of allowed_keys, allowed_courses, allowed_orgs """ - allowed_keys, allowed_orgs = False, False + allowed_keys, allowed_courses, allowed_orgs = False, False, False if self.allowed_keys: allowed_keys = str(usage_key) in self.allowed_keys + if self.allowed_courses: + allowed_courses = str(usage_key.course_key) in self.allowed_courses + if self.allowed_orgs: allowed_orgs = usage_key.course_key.org in self.allowed_orgs - return allowed_keys or allowed_orgs + return allowed_keys or allowed_courses or allowed_orgs class LtiToolOrg(models.Model): diff --git a/src/lti_1p3_provider/tests/test_models.py b/src/lti_1p3_provider/tests/test_models.py index 4370312..04421a4 100644 --- a/src/lti_1p3_provider/tests/test_models.py +++ b/src/lti_1p3_provider/tests/test_models.py @@ -232,25 +232,29 @@ def test_can_access_key_empty_sets_returns_false(self): assert not gate.can_access_key(key) - def test_can_access_key_not_in_allowed_keys_or_allowed_orgs_returns_false(self): + def test_can_access_key_not_in_keys_or_courses_or_orgs_returns_false(self): """If target key not in allowed_keys or allowed_orgs, returns False""" allowed_keys = [ UsageKey.from_string("block-v1:no+course+run+type@some+block@html_id1"), UsageKey.from_string("block-v1:no+course+run+type@some+block@html_id2"), ] + allowed_courses = [str(allowed_keys[0].course_key)] allowed_keys = [str(key) for key in allowed_keys] allowed_orgs = ["org1", "org2"] target_key = UsageKey.from_string( "block-v1:bad_org+course+run+type@some+block@html_id" ) - # target org not org1/org2, target key not in allowed_keys + # target org not org1/org2, target key not in allowed_keys, not in + # allowed_courses gate = factories.LaunchGateFactory( - allowed_keys=allowed_keys, allowed_orgs=allowed_orgs + allowed_keys=allowed_keys, + allowed_courses=allowed_courses, + allowed_orgs=allowed_orgs, ) assert not gate.can_access_key(target_key) - def test_can_access_key_allowed_keys_match_returns_true(self): + def test_can_access_key_in_allowed_keys_returns_true(self): """If target key in allowed_keys, returns True""" allowed_keys = [ UsageKey.from_string("block-v1:org+course+run+type@some+block@html_id1"), @@ -270,8 +274,24 @@ def test_can_access_key_allowed_keys_match_returns_true(self): UsageKey.from_string("block-v1:org2+other+run+type@some+block@html_id"), ), ) - def test_can_access_key_allowed_orgs_match_returns_true(self, key): + def test_can_access_key_allowed_in_allowed_orgs_returns_true(self, key): """If target key in allowed_orgs, returns True""" gate = factories.LaunchGateFactory(allowed_orgs=["org1", "org2"]) assert gate.can_access_key(key) + + @pytest.mark.parametrize( + "key", + ( + UsageKey.from_string("block-v1:org1+course+run+type@some+block@html_id"), + UsageKey.from_string("block-v1:org1+course+run+type@verical+block@html_id"), + UsageKey.from_string("block-v1:org2+other+run+type@some+block@html_id"), + ), + ) + def test_can_access_key_allowed_in_allowed_courses_returns_true(self, key): + """If target key in allowed_courses, returns True""" + course1 = CourseKey.from_string("course-v1:org1+course+run") + course2 = CourseKey.from_string("course-v1:org2+other+run") + gate = factories.LaunchGateFactory(allowed_courses=[str(course1), str(course2)]) + + assert gate.can_access_key(key)