diff --git a/drf_spectacular/contrib/rest_auth.py b/drf_spectacular/contrib/rest_auth.py index bd8be038..aab21a62 100644 --- a/drf_spectacular/contrib/rest_auth.py +++ b/drf_spectacular/contrib/rest_auth.py @@ -1,4 +1,5 @@ from django.conf import settings +from django.utils.version import get_version_tuple from rest_framework import serializers from drf_spectacular.contrib.rest_framework_simplejwt import ( @@ -9,13 +10,33 @@ from drf_spectacular.utils import extend_schema +def get_dj_rest_auth_setting(class_name, setting_name): + from dj_rest_auth.__version__ import __version__ + + if get_version_tuple(__version__) < (3, 0, 0): + from dj_rest_auth import app_settings + + return getattr(app_settings, class_name) + else: + from dj_rest_auth.app_settings import api_settings + + return getattr(api_settings, setting_name) + + def get_token_serializer_class(): - from dj_rest_auth.app_settings import JWTSerializer, TokenSerializer + from dj_rest_auth.__version__ import __version__ - if getattr(settings, 'REST_USE_JWT', False): - return JWTSerializer + if get_version_tuple(__version__) < (3, 0, 0): + use_jwt = getattr(settings, 'REST_USE_JWT', False) else: - return TokenSerializer + from dj_rest_auth.app_settings import api_settings + + use_jwt = api_settings.USE_JWT + + if use_jwt: + return get_dj_rest_auth_setting('JWTSerializer', 'JWT_SERIALIZER') + else: + return get_dj_rest_auth_setting('TokenSerializer', 'TOKEN_SERIALIZER') class RestAuthDetailSerializer(serializers.Serializer): @@ -89,10 +110,8 @@ class RestAuthJWTSerializer(OpenApiSerializerExtension): target_class = 'dj_rest_auth.serializers.JWTSerializer' def map_serializer(self, auto_schema, direction): - from dj_rest_auth.app_settings import UserDetailsSerializer - class Fixed(self.target_class): - user = UserDetailsSerializer() + user = get_dj_rest_auth_setting('UserDetailsSerializer', 'USER_DETAILS_SERIALIZER')() return auto_schema._map_serializer(Fixed, direction) diff --git a/tests/contrib/test_rest_auth.py b/tests/contrib/test_rest_auth.py index f0bb6898..e7bf7d6a 100644 --- a/tests/contrib/test_rest_auth.py +++ b/tests/contrib/test_rest_auth.py @@ -30,8 +30,8 @@ def test_rest_auth(no_warnings): @pytest.mark.contrib('dj_rest_auth', 'allauth', 'rest_framework_simplejwt') @mock.patch('drf_spectacular.settings.spectacular_settings.SCHEMA_PATH_PREFIX', '') +@mock.patch('dj_rest_auth.app_settings.api_settings.USE_JWT', True) def test_rest_auth_token(no_warnings, settings): - settings.REST_USE_JWT = True # flush module import cache to re-evaluate conditional import import dj_rest_auth.urls reload(dj_rest_auth.urls)