From efe35de8eeed590381ffbb8ac7bb6913f2850a94 Mon Sep 17 00:00:00 2001 From: Chris Pappas Date: Wed, 2 Oct 2019 12:51:36 -0400 Subject: [PATCH] ENT-2310 | Adding a post_save receiver to listen to CourseEnrollments table (#589) Adding in a test file I forgot. Updating a few test doctrings Actually adding in the tasks file too Updating task to prevent duplicate recreation of EnterpriseCourseEnrollment Appeasing pylint bumping version --- CHANGELOG.rst | 5 ++ enterprise/__init__.py | 2 +- enterprise/signals.py | 36 ++++++++++++ enterprise/tasks.py | 42 ++++++++++++++ tests/test_enterprise/test_signals.py | 69 +++++++++++++++++++++- tests/test_enterprise/test_tasks.py | 82 +++++++++++++++++++++++++++ 6 files changed, 234 insertions(+), 2 deletions(-) create mode 100644 enterprise/tasks.py create mode 100644 tests/test_enterprise/test_tasks.py diff --git a/CHANGELOG.rst b/CHANGELOG.rst index d896a359f5..ccc84d0fd5 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -14,6 +14,11 @@ Change Log Unreleased ---------- +[1.11.0] - 2019-10-02 +--------------------- + +* Adding post-save receiver to spin off EnterpriseCourseEnrollment creation tasks on CourseEnrollment creation signals + [1.10.8] - 2019-10-01 --------------------- diff --git a/enterprise/__init__.py b/enterprise/__init__.py index 7f41d1d100..b2e8888238 100644 --- a/enterprise/__init__.py +++ b/enterprise/__init__.py @@ -4,6 +4,6 @@ from __future__ import absolute_import, unicode_literals -__version__ = "1.10.8" +__version__ = "1.11.0" default_app_config = "enterprise.apps.EnterpriseConfig" # pylint: disable=invalid-name diff --git a/enterprise/signals.py b/enterprise/signals.py index 94f2212989..0361d2b998 100644 --- a/enterprise/signals.py +++ b/enterprise/signals.py @@ -6,6 +6,7 @@ from logging import getLogger +from django.core.exceptions import ObjectDoesNotExist from django.db import transaction from django.db.models.signals import post_delete, post_save from django.dispatch import receiver @@ -19,8 +20,14 @@ SystemWideEnterpriseRole, SystemWideEnterpriseUserRoleAssignment, ) +from enterprise.tasks import create_enterprise_enrollment from enterprise.utils import get_default_catalog_content_filter, track_enrollment +try: + from student.models import CourseEnrollment +except ImportError: + CourseEnrollment = None + logger = getLogger(__name__) # pylint: disable=invalid-name @@ -116,3 +123,32 @@ def delete_enterprise_learner_role_assignment(sender, instance, **kwargs): # except SystemWideEnterpriseUserRoleAssignment.DoesNotExist: # Do nothing if no role assignment is present for the enterprise customer user. pass + + +def create_enterprise_enrollment_receiver(sender, instance, **kwargs): # pylint: disable=unused-argument + """ + Watches for post_save signal for creates on the CourseEnrollment table. + + Spin off an async task to generate an EnterpriseCourseEnrollment if appropriate. + """ + if kwargs.get('created') and instance.user: + user_id = instance.user.id + try: + ecu = EnterpriseCustomerUser.objects.get(user_id=user_id) + except ObjectDoesNotExist: + return + logger.info(( + "User %s is an EnterpriseCustomerUser. " + "Spinning off task to check if course is within User's " + "Enterprise's EnterpriseCustomerCatalog." + ), user_id) + + create_enterprise_enrollment.delay( + instance.course_id, + ecu, + ) + + +# Don't connect this receiver if we dont have access to CourseEnrollment model +if CourseEnrollment is not None: + post_save.connect(create_enterprise_enrollment_receiver, sender=CourseEnrollment) diff --git a/enterprise/tasks.py b/enterprise/tasks.py new file mode 100644 index 0000000000..a44ced86c9 --- /dev/null +++ b/enterprise/tasks.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- +""" +Django tasks. +""" +from __future__ import absolute_import, unicode_literals + +from logging import getLogger + +from celery import shared_task + +from enterprise.models import EnterpriseCourseEnrollment + +LOGGER = getLogger(__name__) + + +@shared_task +def create_enterprise_enrollment(course_id, enterprise_customer_user): + """ + Create enterprise enrollment for user if course_id part of catalog for the ENT customer. + """ + # Prevent duplicate records from being created if possible + # before we need to make a call to discovery + if EnterpriseCourseEnrollment.objects.filter( + enterprise_customer_user=enterprise_customer_user, + course_id=course_id, + ).exists(): + LOGGER.info(( + "EnterpriseCourseEnrollment record exists for user %s " + "on course %s. Exiting task." + ), enterprise_customer_user.user_id, course_id) + return + + enterprise_customer = enterprise_customer_user.enterprise_customer + if enterprise_customer.catalog_contains_course(course_id): + LOGGER.info(( + "Creating EnterpriseCourseEnrollment for user %s " + "on course %s for enterprise_customer %s" + ), enterprise_customer_user.user_id, course_id, enterprise_customer) + EnterpriseCourseEnrollment.objects.create( + course_id=course_id, + enterprise_customer_user=enterprise_customer_user, + ) diff --git a/tests/test_enterprise/test_signals.py b/tests/test_enterprise/test_signals.py index 1fb7ed101c..8e85b164f5 100644 --- a/tests/test_enterprise/test_signals.py +++ b/tests/test_enterprise/test_signals.py @@ -23,7 +23,7 @@ SystemWideEnterpriseRole, SystemWideEnterpriseUserRoleAssignment, ) -from enterprise.signals import handle_user_post_save +from enterprise.signals import create_enterprise_enrollment_receiver, handle_user_post_save from test_utils.factories import ( EnterpriseCustomerCatalogFactory, EnterpriseCustomerFactory, @@ -465,3 +465,70 @@ def test_delete_enterprise_learner_role_assignment_no_user_associated(self): role=self.enterprise_learner_role ) self.assertFalse(learner_role_assignment.exists()) + + +@mark.django_db +class TestCourseEnrollmentSignals(unittest.TestCase): + """ + Tests signals associated with CourseEnrollments (that are found in edx-platform). + """ + def setUp(self): + """ + Setup for `TestCourseEnrollmentSignals` test. + """ + self.user = UserFactory(id=2, email='user@example.com') + self.enterprise_customer = EnterpriseCustomerFactory( + name='Team Titans', + ) + self.enterprise_customer_user = EnterpriseCustomerUserFactory( + user_id=self.user.id, + enterprise_customer=self.enterprise_customer, + ) + self.non_enterprise_user = UserFactory(id=999, email='user999@example.com') + super(TestCourseEnrollmentSignals, self).setUp() + + @mock.patch('enterprise.tasks.create_enterprise_enrollment.delay') + def test_receiver_calls_task_if_ecu_exists(self, mock_task): + """ + Receiver should call a task + if user tied to the CourseEnrollment that is handed into the function + is an EnterpriseCustomerUser + """ + sender = mock.Mock() # This would be a CourseEnrollment class + instance = mock.Mock() # This would be a CourseEnrollment instance + instance.user = self.user + instance.course_id = "fake:course_id" + # Signal metadata (note: 'signal' would be an actual object, but we dont need it here) + kwargs = { + 'update_fields': None, + 'raw': False, + 'signal': '', + 'using': 'default', + 'created': True, + } + + create_enterprise_enrollment_receiver(sender, instance, **kwargs) + mock_task.assert_called_once_with(instance.course_id, self.enterprise_customer_user) + + @mock.patch('enterprise.tasks.create_enterprise_enrollment.delay') + def test_receiver_does_not_call_task_if_ecu_not_exists(self, mock_task): + """ + Receiver should NOT call a task + if user tied to the CourseEnrollment that is handed into the function + is NOT an EnterpriseCustomerUser + """ + sender = mock.Mock() # This would be a CourseEnrollment class + instance = mock.Mock() # This would be a CourseEnrollment instance + instance.user = self.non_enterprise_user + instance.course_id = "fake:course_id" + # Signal metadata (note: 'signal' would be an actual object, but we dont need it here) + kwargs = { + 'update_fields': None, + 'raw': False, + 'signal': '', + 'using': 'default', + 'created': True, + } + + create_enterprise_enrollment_receiver(sender, instance, **kwargs) + mock_task.assert_not_called() diff --git a/tests/test_enterprise/test_tasks.py b/tests/test_enterprise/test_tasks.py new file mode 100644 index 0000000000..a32e9a3226 --- /dev/null +++ b/tests/test_enterprise/test_tasks.py @@ -0,0 +1,82 @@ +# -*- coding: utf-8 -*- +""" +Tests for the `edx-enterprise` tasks module. +""" +from __future__ import absolute_import, unicode_literals, with_statement + +import unittest + +import mock +from pytest import mark + +from enterprise.models import EnterpriseCourseEnrollment +from enterprise.tasks import create_enterprise_enrollment +from test_utils.factories import EnterpriseCustomerFactory, EnterpriseCustomerUserFactory, UserFactory + + +@mark.django_db +class TestEnterpriseTasks(unittest.TestCase): + """ + Tests tasks associated with Enterprise. + """ + def setUp(self): + """ + Setup for `TestEnterpriseTasks` test. + """ + self.user = UserFactory(id=2, email='user@example.com') + self.enterprise_customer = EnterpriseCustomerFactory( + name='Team Titans', + ) + self.enterprise_customer_user = EnterpriseCustomerUserFactory( + user_id=self.user.id, + enterprise_customer=self.enterprise_customer, + ) + super(TestEnterpriseTasks, self).setUp() + + @mock.patch('enterprise.models.EnterpriseCustomer.catalog_contains_course') + def test_create_enrollment_task_course_in_catalog(self, mock_contains_course): + """ + Task should create an enterprise enrollment if the course_id handed to + the function is part of the EnterpriseCustomer's catalogs + """ + mock_contains_course.return_value = True + + assert EnterpriseCourseEnrollment.objects.count() == 0 + create_enterprise_enrollment( + 'fake:course', + self.enterprise_customer_user + ) + assert EnterpriseCourseEnrollment.objects.count() == 1 + + @mock.patch('enterprise.models.EnterpriseCustomer.catalog_contains_course') + def test_create_enrollment_task_course_not_in_catalog(self, mock_contains_course): + """ + Task should NOT create an enterprise enrollment if the course_id handed + to the function is NOT part of the EnterpriseCustomer's catalogs + """ + mock_contains_course.return_value = False + + assert EnterpriseCourseEnrollment.objects.count() == 0 + create_enterprise_enrollment( + 'fake:course', + self.enterprise_customer_user + ) + assert EnterpriseCourseEnrollment.objects.count() == 0 + + def test_create_enrollment_task_no_create_duplicates(self): + """ + Task should return without creating a new EnterpriseCourseEnrollment + if one with the course_id and enterprise_customer_user specified + already exists. + """ + EnterpriseCourseEnrollment.objects.create( + course_id='fake:course', + enterprise_customer_user=self.enterprise_customer_user, + ) + + assert EnterpriseCourseEnrollment.objects.count() == 1 + create_enterprise_enrollment( + 'fake:course', + self.enterprise_customer_user + ) + assert EnterpriseCourseEnrollment.objects.count() == 1