Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Aed/better policy create #221

Merged
merged 1 commit into from
Jul 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion enterprise_access/apps/api/filters/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
Module for filters across all enterprise-access apps.
"""
from .base import NoFilterOnRetrieveBackend
from .base import NoFilterOnDetailBackend
from .subsidy_access_policy import SubsidyAccessPolicyFilter
from .subsidy_request import SubsidyRequestCustomerConfigurationFilterBackend, SubsidyRequestFilterBackend
8 changes: 4 additions & 4 deletions enterprise_access/apps/api/filters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@ def filter_for_field(cls, field, field_name, lookup_expr=None):
return filter_obj


class NoFilterOnRetrieveBackend(drf_filters.DjangoFilterBackend):
class NoFilterOnDetailBackend(drf_filters.DjangoFilterBackend):
"""
Customized filter backend that simply doesn't use a filterset_class
on retrieve() actions.
on any of the detail actions of a viewset.
"""
def get_filterset_class(self, view, queryset=None):
"""
Returns None if this is a retrieve() operation.
Returns None if this is a detail operation.
"""
if view.action in ('retrieve', 'update', 'partial_update'):
if view.action in ('retrieve', 'create', 'update', 'partial_update', 'destroy'):
return None
return super().get_filterset_class(view, queryset)
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
from enterprise_access.apps.subsidy_access_policy.utils import create_idempotency_key_for_transaction
from test_utils import APITestWithMocks

SUBSIDY_ACCESS_POLICY_ADMIN_LIST_ENDPOINT = reverse('api:v1:admin-policy-list')
SUBSIDY_ACCESS_POLICY_DEPR_LIST_ENDPOINT = reverse('api:v1:admin-policy-list')
SUBSIDY_ACCESS_POLICY_LIST_ENDPOINT = reverse('api:v1:subsidy-access-policies-list')

TEST_ENTERPRISE_UUID = uuid4()

Expand Down Expand Up @@ -186,6 +187,13 @@ def test_policy_crud_write_views_unauthorized_forbidden(self, role_context_dict,

request_kwargs = {'uuid': str(self.redeemable_policy.uuid)}

# Test the create endpoint.
response = self.client.post(
SUBSIDY_ACCESS_POLICY_LIST_ENDPOINT,
data={'enterprise_customer_uuid': str(TEST_ENTERPRISE_UUID)},
)
self.assertEqual(response.status_code, expected_response_code)

# Test the delete endpoint.
response = self.client.delete(reverse('api:v1:subsidy-access-policies-detail', kwargs=request_kwargs))
self.assertEqual(response.status_code, expected_response_code)
Expand Down Expand Up @@ -509,6 +517,8 @@ def test_update_view_validates_fields_vs_policy_type(self, policy_class, request
class TestAdminPolicyCreateView(CRUDViewTestMixin, APITestWithMocks):
"""
Test the create view for subsidy access policy records.
This tests both the deprecated viewset and the preferred
``SubsidyAccessPolicyViewSet`` implementation.
"""

@ddt.data(
Expand Down Expand Up @@ -594,31 +604,31 @@ def test_create_view(self, policy_type, extra_fields, expected_response_code, ex
])

# Test the retrieve endpoint
create_url = SUBSIDY_ACCESS_POLICY_ADMIN_LIST_ENDPOINT
payload = {
'policy_type': policy_type,
'description': 'test description',
'active': True,
'enterprise_customer_uuid': str(TEST_ENTERPRISE_UUID),
'catalog_uuid': str(uuid4()),
'subsidy_uuid': str(uuid4()),
'access_method': AccessMethods.DIRECT,
'spend_limit': None,
}
payload.update(extra_fields)
response = self.client.post(create_url, payload)
assert response.status_code == expected_response_code

if expected_response_code == status.HTTP_201_CREATED:
response_json = response.json()
del response_json['uuid']
expected_response = payload.copy()
expected_response.setdefault("per_learner_enrollment_limit")
expected_response.setdefault("per_learner_spend_limit")
assert response_json == expected_response
elif expected_response_code == status.HTTP_400_BAD_REQUEST:
for expected_error_keyword in expected_error_keywords:
assert expected_error_keyword in response.content.decode("utf-8")
for create_url in (SUBSIDY_ACCESS_POLICY_DEPR_LIST_ENDPOINT, SUBSIDY_ACCESS_POLICY_LIST_ENDPOINT):
payload = {
'policy_type': policy_type,
'description': 'test description',
'active': True,
'enterprise_customer_uuid': str(TEST_ENTERPRISE_UUID),
'catalog_uuid': str(uuid4()),
'subsidy_uuid': str(uuid4()),
'access_method': AccessMethods.DIRECT,
'spend_limit': None,
}
payload.update(extra_fields)
response = self.client.post(create_url, payload)
assert response.status_code == expected_response_code

if expected_response_code == status.HTTP_201_CREATED:
response_json = response.json()
del response_json['uuid']
expected_response = payload.copy()
expected_response.setdefault("per_learner_enrollment_limit")
expected_response.setdefault("per_learner_spend_limit")
assert response_json == expected_response
elif expected_response_code == status.HTTP_400_BAD_REQUEST:
for expected_error_keyword in expected_error_keywords:
assert expected_error_keyword in response.content.decode("utf-8")

@ddt.data(
{
Expand All @@ -643,45 +653,45 @@ def test_idempotent_create_view(self, policy_type, extra_fields, expected_respon
])

# Test the retrieve endpoint
create_url = SUBSIDY_ACCESS_POLICY_ADMIN_LIST_ENDPOINT
enterprise_customer_uuid = str(TEST_ENTERPRISE_UUID)
catalog_uuid = str(uuid4())
subsidy_uuid = str(uuid4())
payload = {
'policy_type': policy_type,
'description': 'test description',
'active': True,
'enterprise_customer_uuid': enterprise_customer_uuid,
'catalog_uuid': catalog_uuid,
'subsidy_uuid': subsidy_uuid,
'access_method': AccessMethods.DIRECT,
'spend_limit': None,
}
payload.update(extra_fields)
response = self.client.post(create_url, payload)
assert response.status_code == expected_response_code

if expected_response_code == status.HTTP_201_CREATED:
response_json = response.json()
del response_json['uuid']
expected_response = payload.copy()
expected_response.setdefault("per_learner_enrollment_limit")
expected_response.setdefault("per_learner_spend_limit")
assert response_json == expected_response

# Test idempotency
response = self.client.post(create_url, payload)
duplicate_status_code = status.HTTP_200_OK

assert response.status_code == duplicate_status_code

if response.status_code == status.HTTP_200_OK:
response_json = response.json()
del response_json['uuid']
expected_response = payload.copy()
expected_response.setdefault("per_learner_enrollment_limit")
expected_response.setdefault("per_learner_spend_limit")
assert response_json == expected_response
for create_url in (SUBSIDY_ACCESS_POLICY_DEPR_LIST_ENDPOINT, SUBSIDY_ACCESS_POLICY_LIST_ENDPOINT):
enterprise_customer_uuid = str(TEST_ENTERPRISE_UUID)
catalog_uuid = str(uuid4())
subsidy_uuid = str(uuid4())
payload = {
'policy_type': policy_type,
'description': 'test description',
'active': True,
'enterprise_customer_uuid': enterprise_customer_uuid,
'catalog_uuid': catalog_uuid,
'subsidy_uuid': subsidy_uuid,
'access_method': AccessMethods.DIRECT,
'spend_limit': None,
}
payload.update(extra_fields)
response = self.client.post(create_url, payload)
assert response.status_code == expected_response_code

if expected_response_code == status.HTTP_201_CREATED:
response_json = response.json()
del response_json['uuid']
expected_response = payload.copy()
expected_response.setdefault("per_learner_enrollment_limit")
expected_response.setdefault("per_learner_spend_limit")
assert response_json == expected_response

# Test idempotency
response = self.client.post(create_url, payload)
duplicate_status_code = status.HTTP_200_OK

assert response.status_code == duplicate_status_code

if response.status_code == status.HTTP_200_OK:
response_json = response.json()
del response_json['uuid']
expected_response = payload.copy()
expected_response.setdefault("per_learner_enrollment_limit")
expected_response.setdefault("per_learner_spend_limit")
assert response_json == expected_response


@ddt.ddt
Expand Down
44 changes: 42 additions & 2 deletions enterprise_access/apps/api/v1/views/subsidy_access_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,23 +77,37 @@ def policy_permission_detail_fn(request, *args, uuid=None, **kwargs):


class SubsidyAccessPolicyViewSet(
mixins.CreateModelMixin,
mixins.ListModelMixin,
mixins.RetrieveModelMixin,
mixins.UpdateModelMixin,
mixins.DestroyModelMixin,
viewsets.GenericViewSet,
):
"""
Viewset supporting some CRUD operations on ``SubsidyAccessPolicy`` records.
Viewset supporting all CRUD operations on ``SubsidyAccessPolicy`` records.
"""
permission_classes = (permissions.IsAuthenticated,)
serializer_class = serializers.SubsidyAccessPolicyResponseSerializer
authentication_classes = (JwtAuthentication, authentication.SessionAuthentication)
filter_backends = (filters.NoFilterOnRetrieveBackend,)
filter_backends = (filters.NoFilterOnDetailBackend,)
filterset_class = filters.SubsidyAccessPolicyFilter
pagination_class = PaginationWithPageCount
lookup_field = 'uuid'

def __init__(self, *args, **kwargs):
self.extra_context = {}
super().__init__(*args, **kwargs)

def set_policy_created(self, created):
"""
Helper function, used from within a related serializer for creation,
to help understand in the context of this viewset whether
a policy was created, or if a policy with the requested parameters
already existed when creation was attempted.
"""
self.extra_context['created'] = created

def get_queryset(self):
"""
A base queryset to list or retrieve `SubsidyAccessPolicy` records.
Expand All @@ -105,6 +119,8 @@ def get_serializer_class(self):
Overrides the default behavior to return different
serializers depending on the request action.
"""
if self.action == 'create':
return serializers.SubsidyAccessPolicyCRUDSerializer
if self.action in ('update', 'partial_update'):
return serializers.SubsidyAccessPolicyUpdateRequestSerializer
return self.serializer_class
Expand Down Expand Up @@ -135,6 +151,30 @@ def list(self, request, *args, **kwargs):
"""
return super().list(request, *args, **kwargs)

@extend_schema(
tags=[SUBSIDY_ACCESS_POLICY_CRUD_API_TAG],
summary='Create a new subsidy access policy.',
request=serializers.SubsidyAccessPolicyCRUDSerializer,
responses={
status.HTTP_200_OK: serializers.SubsidyAccessPolicyResponseSerializer,
status.HTTP_201_CREATED: serializers.SubsidyAccessPolicyResponseSerializer,
},
)
@permission_required(
SUBSIDY_ACCESS_POLICY_WRITE_PERMISSION,
fn=lambda request: request.data.get('enterprise_customer_uuid')
)
def create(self, request, *args, **kwargs):
"""
Creates a single `SubsidyAccessPolicy` record, or returns
an existing one if an **active** record with the requested (enterprise_customer_uuid,
subsidy_uuid, catalog_uuid, access_method) values already exists.
"""
response = super().create(request, *args, **kwargs)
if not self.extra_context.get('created'):
response.status_code = status.HTTP_200_OK
return response

@extend_schema(
tags=[SUBSIDY_ACCESS_POLICY_CRUD_API_TAG],
summary='Partially update (with a PUT) a subsidy access policy by UUID.',
Expand Down