Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature :: 소켓 애러 핸들링 #333 #345

Merged
merged 6 commits into from
Nov 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package com.seugi.api.domain.chat.presentation.websocket.config

import com.seugi.api.domain.chat.application.service.chat.room.ChatRoomService
import com.seugi.api.domain.chat.presentation.websocket.handler.StompErrorHandler
import com.seugi.api.domain.chat.presentation.websocket.util.SecurityUtils
import com.seugi.api.domain.member.application.exception.MemberErrorCode
import com.seugi.api.global.auth.jwt.JwtUserDetails
import com.seugi.api.global.auth.jwt.JwtUtils
import com.seugi.api.global.exception.CustomException
import org.springframework.beans.factory.annotation.Value
Expand All @@ -15,7 +16,6 @@ import org.springframework.messaging.simp.config.ChannelRegistration
import org.springframework.messaging.simp.config.MessageBrokerRegistry
import org.springframework.messaging.simp.stomp.StompHeaderAccessor
import org.springframework.messaging.support.ChannelInterceptor
import org.springframework.messaging.support.MessageBuilder
import org.springframework.messaging.support.MessageHeaderAccessor
import org.springframework.util.AntPathMatcher
import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker
Expand All @@ -29,11 +29,13 @@ class StompWebSocketConfig(
private val jwtUtils: JwtUtils,
private val chatRoomService: ChatRoomService,
@Value("\${spring.rabbitmq.host}") private val rabbitmqHost: String,
private val stompErrorHandler: StompErrorHandler
) : WebSocketMessageBrokerConfigurer {

override fun registerStompEndpoints(registry: StompEndpointRegistry) {
registry.addEndpoint("/stomp/chat")
.setAllowedOrigins("*")
registry.setErrorHandler(stompErrorHandler)
}

override fun configureMessageBroker(registry: MessageBrokerRegistry) {
Expand All @@ -42,6 +44,7 @@ class StompWebSocketConfig(
registry.enableStompBrokerRelay("/queue", "/topic", "/exchange", "/amq/queue")
.setRelayHost(rabbitmqHost)
.setVirtualHost("/")
registry.setUserDestinationPrefix("/user")
}

override fun configureClientInboundChannel(registration: ChannelRegistration) {
Expand All @@ -50,56 +53,46 @@ class StompWebSocketConfig(
val accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor::class.java)!!

when (accessor.messageType) {
SimpMessageType.CONNECT -> handleConnect(message, accessor)
SimpMessageType.CONNECT -> handleConnect(accessor)
SimpMessageType.SUBSCRIBE -> handleSubscribe(accessor)
SimpMessageType.UNSUBSCRIBE, SimpMessageType.DISCONNECT -> handleUnsubscribeOrDisconnect()
SimpMessageType.UNSUBSCRIBE, SimpMessageType.DISCONNECT -> handleUnsubscribeOrDisconnect(accessor)
else -> {}
}
return message
}
})
}

private fun handleConnect(message: Message<*>, accessor: StompHeaderAccessor) {
private fun handleConnect(accessor: StompHeaderAccessor) {
val authToken = accessor.getNativeHeader("Authorization")?.firstOrNull()
if (authToken != null && authToken.startsWith("Bearer ")) {
val auth = jwtUtils.getAuthentication(authToken)
val userDetails = auth.principal as? JwtUserDetails
val userId: String? = userDetails?.id?.value?.toString()

if (userId != null) {
val simpAttributes = SimpAttributesContextHolder.currentAttributes()
simpAttributes.setAttribute("user-id", userId)
MessageBuilder.createMessage(message.payload, accessor.messageHeaders)
} else {
throw CustomException(MemberErrorCode.MEMBER_NOT_FOUND)
}
accessor.user = auth
} else {
throw CustomException(MemberErrorCode.MEMBER_NOT_FOUND)
}
}

private fun handleSubscribe(accessor: StompHeaderAccessor) {
accessor.destination?.let {
val simpAttributes = SimpAttributesContextHolder.currentAttributes()
simpAttributes.setAttribute("sub", it.substringAfterLast("."))
val userId = simpAttributes.getAttribute("user-id") as String
chatRoomService.sub(
userId = userId.toLong(),
roomId = it.substringAfterLast(".")
)
if (it.contains(".")) {
chatRoomService.sub(
userId = SecurityUtils.getUserId(accessor.user),
roomId = it.substringAfterLast(".")
)
}
}
}

private fun handleUnsubscribeOrDisconnect() {
val simpAttributes = SimpAttributesContextHolder.currentAttributes()
val userId = simpAttributes.getAttribute("user-id") as String?
val roomId = simpAttributes.getAttribute("sub") as String?
userId?.let {
roomId?.let {
chatRoomService.unSub(
userId = userId.toLong(),
roomId = it
)
}
private fun handleUnsubscribeOrDisconnect(accessor: StompHeaderAccessor) {
accessor.destination?.let {
val simpAttributes = SimpAttributesContextHolder.currentAttributes()
chatRoomService.unSub(
userId = SecurityUtils.getUserId(accessor.user),
roomId = simpAttributes.getAttribute("sub").toString()
)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,20 @@ package com.seugi.api.domain.chat.presentation.websocket.controller

import com.seugi.api.domain.chat.application.service.message.MessageService
import com.seugi.api.domain.chat.presentation.websocket.dto.ChatMessageDto
import com.seugi.api.domain.chat.presentation.websocket.util.SecurityUtils
import org.springframework.messaging.handler.annotation.MessageMapping
import org.springframework.messaging.simp.SimpAttributesContextHolder
import org.springframework.stereotype.Controller
import java.security.Principal


@Controller
class StompRabbitMQController(
private val messageService: MessageService
private val messageService: MessageService,
) {

@MessageMapping("chat.message")
fun send(chat: ChatMessageDto) {
val simpAttributes = SimpAttributesContextHolder.currentAttributes()
val userId = simpAttributes.getAttribute("user-id") as String?
messageService.sendAndSaveMessage(chat, userId!!.toLong())
fun send(chat: ChatMessageDto, principal: Principal) {
messageService.sendAndSaveMessage(chat, SecurityUtils.getUserId(principal))
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package com.seugi.api.domain.chat.presentation.websocket.handler

import com.fasterxml.jackson.core.JsonProcessingException
import com.fasterxml.jackson.databind.ObjectMapper
import com.seugi.api.global.response.ErrorResponse
import io.jsonwebtoken.ExpiredJwtException
import io.jsonwebtoken.MalformedJwtException
import io.jsonwebtoken.UnsupportedJwtException
import org.springframework.context.annotation.Configuration
import org.springframework.messaging.Message
import org.springframework.messaging.MessageDeliveryException
import org.springframework.messaging.simp.stomp.StompCommand
import org.springframework.messaging.simp.stomp.StompHeaderAccessor
import org.springframework.messaging.support.MessageBuilder
import org.springframework.security.access.AccessDeniedException
import org.springframework.web.socket.messaging.StompSubProtocolErrorHandler
import java.nio.charset.StandardCharsets
import java.security.SignatureException

@Configuration
class StompErrorHandler(private val objectMapper: ObjectMapper) : StompSubProtocolErrorHandler() {

override fun handleClientMessageProcessingError(clientMessage: Message<ByteArray>?, ex: Throwable): Message<ByteArray>? {

return when (ex) {
is MessageDeliveryException -> {
when (val cause = ex.cause) {
is AccessDeniedException -> {
sendErrorMessage(ErrorResponse(status = 4403, message = "Access denied"))
}
else -> {
if (isJwtException(cause)) {
sendErrorMessage(ErrorResponse(status = 4403, message = cause?.message ?: "JWT Exception"))
} else {
sendErrorMessage(ErrorResponse(status = 4403, message = cause?.stackTraceToString() ?: "Unhandled exception"))
}
}
}
}
else -> {
sendErrorMessage(ErrorResponse(status = 4400, message = ex.message ?: "Unhandled root exception"))
}
}
}

private fun isJwtException(ex: Throwable?): Boolean {
return ex is SignatureException || ex is ExpiredJwtException || ex is MalformedJwtException || ex is UnsupportedJwtException || ex is IllegalArgumentException
}

private fun sendErrorMessage(errorResponse: ErrorResponse): Message<ByteArray> {
val headers = StompHeaderAccessor.create(StompCommand.ERROR).apply {
message = errorResponse.message
}
return try {
val json = objectMapper.writeValueAsString(errorResponse)
MessageBuilder.createMessage(json.toByteArray(StandardCharsets.UTF_8), headers.messageHeaders)
} catch (e: JsonProcessingException) {
MessageBuilder.createMessage(errorResponse.message.toByteArray(StandardCharsets.UTF_8), headers.messageHeaders)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package com.seugi.api.domain.chat.presentation.websocket.util

import com.seugi.api.global.auth.jwt.JwtUserDetails
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken
import java.security.Principal

object SecurityUtils {

fun getUserId(principal: Principal?): Long {
return (principal as? UsernamePasswordAuthenticationToken)?.principal.let { it as? JwtUserDetails }?.member?.id?.value
?: -1
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.seugi.api.domain.chat.presentation.websocket.config
package com.seugi.api.global.config

import org.springframework.amqp.core.BindingBuilder
import org.springframework.amqp.core.Queue
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package com.seugi.api.global.exception

import com.seugi.api.global.response.ErrorResponse
import org.springframework.messaging.handler.annotation.MessageExceptionHandler
import org.springframework.messaging.Message
import org.springframework.messaging.simp.SimpMessagingTemplate
import org.springframework.messaging.simp.stomp.StompHeaderAccessor
import org.springframework.web.bind.annotation.ControllerAdvice
import java.io.IOException
import java.net.SocketException
import java.security.Principal

@ControllerAdvice
class CustomSocketExceptionHandler(
private val template: SimpMessagingTemplate
) {

private val bindingUrl = "/queue/errors"

@MessageExceptionHandler(SocketException::class)
fun handleSocketException(message: Message<*>, principal: Principal, ex: Exception) {
removeSession(message)
template.convertAndSendToUser(
principal.name,
bindingUrl,
ErrorResponse(status = 4500, message = ex.cause?.message ?: "Socket Error")
)
}

@MessageExceptionHandler(RuntimeException::class)
fun handleRuntimeException(principal: Principal, ex: Exception) {
template.convertAndSendToUser(
principal.name,
bindingUrl,
ErrorResponse(status = 4500, message = ex.cause?.stackTraceToString() ?: "Socket Error")
)
}

@Throws(IOException::class)
private fun removeSession(message: Message<*>) {
val stompHeaderAccessor = StompHeaderAccessor.wrap(message)
val sessionId = stompHeaderAccessor.sessionId
stompHeaderAccessor.sessionAttributes?.remove(sessionId)
}


}
12 changes: 6 additions & 6 deletions src/main/kotlin/com/seugi/api/global/response/BaseResponse.kt
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ import org.springframework.http.HttpStatus
@JsonInclude(JsonInclude.Include.NON_NULL)
data class BaseResponse<T>(

val status: Int = HttpStatus.OK.value(),
val success: Boolean = true,
val state: String? = "OK",
val message: String,
val data: T? = null
override val status: Int = HttpStatus.OK.value(),
override val success: Boolean = true,
override val state: String = "OK",
override val message: String,
val data: T? = null,

) {
) : ResponseInterface {

// errorResponse constructor
constructor(code: CustomErrorCode) : this(
Expand Down
11 changes: 11 additions & 0 deletions src/main/kotlin/com/seugi/api/global/response/ErrorResponse.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package com.seugi.api.global.response

import com.fasterxml.jackson.annotation.JsonInclude

@JsonInclude(JsonInclude.Include.NON_NULL)
data class ErrorResponse(
override val status: Int = 4500,
override val success: Boolean = false,
override val state: String = "Error",
override val message: String,
) : ResponseInterface
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package com.seugi.api.global.response

interface ResponseInterface {
val status: Int
val success: Boolean
val state: String
val message: String
}
Loading