diff --git a/src/main/java/com/example/solidconnection/chat/config/CustomHandshakeHandler.java b/src/main/java/com/example/solidconnection/chat/config/CustomHandshakeHandler.java new file mode 100644 index 000000000..6c3054355 --- /dev/null +++ b/src/main/java/com/example/solidconnection/chat/config/CustomHandshakeHandler.java @@ -0,0 +1,27 @@ +package com.example.solidconnection.chat.config; + +import java.security.Principal; +import java.util.Map; +import org.springframework.http.server.ServerHttpRequest; +import org.springframework.stereotype.Component; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.server.support.DefaultHandshakeHandler; + +// WebSocket 세션의 Principal을 결정한다. +@Component +public class CustomHandshakeHandler extends DefaultHandshakeHandler { + + @Override + protected Principal determineUser(ServerHttpRequest request, WebSocketHandler wsHandler, + Map attributes) { + + Object userAttribute = attributes.get("user"); + + if (userAttribute instanceof Principal) { + Principal principal = (Principal) userAttribute; + return principal; + } + + return super.determineUser(request, wsHandler, attributes); + } +} diff --git a/src/main/java/com/example/solidconnection/chat/config/StompHandler.java b/src/main/java/com/example/solidconnection/chat/config/StompHandler.java index 660f01f28..2e99bf9c4 100644 --- a/src/main/java/com/example/solidconnection/chat/config/StompHandler.java +++ b/src/main/java/com/example/solidconnection/chat/config/StompHandler.java @@ -2,10 +2,14 @@ import static com.example.solidconnection.common.exception.ErrorCode.AUTHENTICATION_FAILED; -import com.example.solidconnection.auth.token.JwtTokenProvider; +import com.example.solidconnection.chat.service.ChatService; import com.example.solidconnection.common.exception.CustomException; import com.example.solidconnection.common.exception.ErrorCode; -import io.jsonwebtoken.Claims; +import com.example.solidconnection.security.authentication.TokenAuthentication; +import com.example.solidconnection.security.userdetails.SiteUserDetails; +import java.security.Principal; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import lombok.RequiredArgsConstructor; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; @@ -18,47 +22,48 @@ @RequiredArgsConstructor public class StompHandler implements ChannelInterceptor { - private final JwtTokenProvider jwtTokenProvider; + private static final Pattern ROOM_ID_PATTERN = Pattern.compile("^/topic/chat/(\\d+)$"); + private final ChatService chatService; @Override public Message preSend(Message message, MessageChannel channel) { final StompHeaderAccessor accessor = StompHeaderAccessor.wrap(message); if (StompCommand.CONNECT.equals(accessor.getCommand())) { - Claims claims = validateAndExtractClaims(accessor, AUTHENTICATION_FAILED); + Principal user = accessor.getUser(); + if (user == null) { + throw new CustomException(AUTHENTICATION_FAILED); + } } if (StompCommand.SUBSCRIBE.equals(accessor.getCommand())) { - Claims claims = validateAndExtractClaims(accessor, AUTHENTICATION_FAILED); + Principal user = accessor.getUser(); + if (user == null) { + throw new CustomException(AUTHENTICATION_FAILED); + } - String email = claims.getSubject(); - String destination = accessor.getDestination(); + TokenAuthentication tokenAuthentication = (TokenAuthentication) user; + SiteUserDetails siteUserDetails = (SiteUserDetails) tokenAuthentication.getPrincipal(); - String roomId = extractRoomId(destination); + String destination = accessor.getDestination(); + long roomId = Long.parseLong(extractRoomId(destination)); - // todo: roomId 기반 실제 구독 권한 검사 로직 추가 + chatService.validateChatRoomParticipant(siteUserDetails.getSiteUser().getId(), roomId); } return message; } - private Claims validateAndExtractClaims(StompHeaderAccessor accessor, ErrorCode errorCode) { - String bearerToken = accessor.getFirstNativeHeader("Authorization"); - if (bearerToken == null || !bearerToken.startsWith("Bearer ")) { - throw new CustomException(errorCode); - } - String token = bearerToken.substring(7); - return jwtTokenProvider.parseClaims(token); - } - private String extractRoomId(String destination) { if (destination == null) { throw new CustomException(ErrorCode.INVALID_ROOM_ID); } - String[] parts = destination.split("/"); - if (parts.length < 3 || !parts[1].equals("topic")) { + + Matcher matcher = ROOM_ID_PATTERN.matcher(destination); + if (!matcher.matches()) { throw new CustomException(ErrorCode.INVALID_ROOM_ID); } - return parts[2]; + + return matcher.group(1); } } diff --git a/src/main/java/com/example/solidconnection/chat/config/StompWebSocketConfig.java b/src/main/java/com/example/solidconnection/chat/config/StompWebSocketConfig.java index 86b6eef5d..51259a0e1 100644 --- a/src/main/java/com/example/solidconnection/chat/config/StompWebSocketConfig.java +++ b/src/main/java/com/example/solidconnection/chat/config/StompWebSocketConfig.java @@ -22,12 +22,18 @@ public class StompWebSocketConfig implements WebSocketMessageBrokerConfigurer { private final StompHandler stompHandler; private final StompProperties stompProperties; private final CorsProperties corsProperties; + private final WebSocketHandshakeInterceptor webSocketHandshakeInterceptor; + private final CustomHandshakeHandler customHandshakeHandler; @Override public void registerStompEndpoints(StompEndpointRegistry registry) { List strings = corsProperties.allowedOrigins(); String[] allowedOrigins = strings.toArray(String[]::new); - registry.addEndpoint("/connect").setAllowedOrigins(allowedOrigins).withSockJS(); + registry.addEndpoint("/connect") + .setAllowedOrigins(allowedOrigins) + .addInterceptors(webSocketHandshakeInterceptor) + .setHandshakeHandler(customHandshakeHandler) + .withSockJS(); } @Override diff --git a/src/main/java/com/example/solidconnection/chat/config/WebSocketHandshakeInterceptor.java b/src/main/java/com/example/solidconnection/chat/config/WebSocketHandshakeInterceptor.java new file mode 100644 index 000000000..9e8aafe2d --- /dev/null +++ b/src/main/java/com/example/solidconnection/chat/config/WebSocketHandshakeInterceptor.java @@ -0,0 +1,32 @@ +package com.example.solidconnection.chat.config; + +import java.security.Principal; +import java.util.Map; +import org.springframework.http.server.ServerHttpRequest; +import org.springframework.http.server.ServerHttpResponse; +import org.springframework.stereotype.Component; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.server.HandshakeInterceptor; + +// Principal을 WebSocket 세션에 저장하는 것에만 집중한다. +@Component +public class WebSocketHandshakeInterceptor implements HandshakeInterceptor { + + @Override + public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, + WebSocketHandler wsHandler, Map attributes) { + Principal principal = request.getPrincipal(); + + if (principal != null) { + attributes.put("user", principal); + return true; + } + + return false; + } + + @Override + public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, + WebSocketHandler wsHandler, Exception exception) { + } +} diff --git a/src/main/java/com/example/solidconnection/chat/controller/ChatMessageController.java b/src/main/java/com/example/solidconnection/chat/controller/ChatMessageController.java new file mode 100644 index 000000000..a7e158224 --- /dev/null +++ b/src/main/java/com/example/solidconnection/chat/controller/ChatMessageController.java @@ -0,0 +1,32 @@ +package com.example.solidconnection.chat.controller; + +import com.example.solidconnection.chat.dto.ChatMessageSendRequest; +import com.example.solidconnection.chat.service.ChatService; +import com.example.solidconnection.security.authentication.TokenAuthentication; +import com.example.solidconnection.security.userdetails.SiteUserDetails; +import jakarta.validation.Valid; +import java.security.Principal; +import lombok.RequiredArgsConstructor; +import org.springframework.messaging.handler.annotation.DestinationVariable; +import org.springframework.messaging.handler.annotation.MessageMapping; +import org.springframework.messaging.handler.annotation.Payload; +import org.springframework.stereotype.Controller; + +@Controller +@RequiredArgsConstructor +public class ChatMessageController { + + private final ChatService chatService; + + @MessageMapping("/chat/{roomId}") + public void sendChatMessage( + @DestinationVariable Long roomId, + @Valid @Payload ChatMessageSendRequest chatMessageSendRequest, + Principal principal + ) { + TokenAuthentication tokenAuthentication = (TokenAuthentication) principal; + SiteUserDetails siteUserDetails = (SiteUserDetails) tokenAuthentication.getPrincipal(); + + chatService.sendChatMessage(chatMessageSendRequest, siteUserDetails.getSiteUser().getId(), roomId); + } +} diff --git a/src/main/java/com/example/solidconnection/chat/dto/ChatMessageSendRequest.java b/src/main/java/com/example/solidconnection/chat/dto/ChatMessageSendRequest.java new file mode 100644 index 000000000..22d652a35 --- /dev/null +++ b/src/main/java/com/example/solidconnection/chat/dto/ChatMessageSendRequest.java @@ -0,0 +1,12 @@ +package com.example.solidconnection.chat.dto; + +import jakarta.validation.constraints.NotNull; +import jakarta.validation.constraints.Size; + +public record ChatMessageSendRequest( + @NotNull(message = "메시지를 입력해주세요.") + @Size(max = 500, message = "메시지는 500자를 초과할 수 없습니다") + String content +) { + +} diff --git a/src/main/java/com/example/solidconnection/chat/dto/ChatMessageSendResponse.java b/src/main/java/com/example/solidconnection/chat/dto/ChatMessageSendResponse.java new file mode 100644 index 000000000..065c7ba1c --- /dev/null +++ b/src/main/java/com/example/solidconnection/chat/dto/ChatMessageSendResponse.java @@ -0,0 +1,19 @@ +package com.example.solidconnection.chat.dto; + +import com.example.solidconnection.chat.domain.ChatMessage; + +public record ChatMessageSendResponse( + long messageId, + String content, + long senderId +) { + + public static ChatMessageSendResponse from(ChatMessage chatMessage) { + return new ChatMessageSendResponse( + chatMessage.getId(), + chatMessage.getContent(), + chatMessage.getSenderId() + ); + } + +} diff --git a/src/main/java/com/example/solidconnection/chat/service/ChatService.java b/src/main/java/com/example/solidconnection/chat/service/ChatService.java index c378f6b50..fadd284fe 100644 --- a/src/main/java/com/example/solidconnection/chat/service/ChatService.java +++ b/src/main/java/com/example/solidconnection/chat/service/ChatService.java @@ -10,6 +10,8 @@ import com.example.solidconnection.chat.domain.ChatRoom; import com.example.solidconnection.chat.dto.ChatAttachmentResponse; import com.example.solidconnection.chat.dto.ChatMessageResponse; +import com.example.solidconnection.chat.dto.ChatMessageSendRequest; +import com.example.solidconnection.chat.dto.ChatMessageSendResponse; import com.example.solidconnection.chat.dto.ChatParticipantResponse; import com.example.solidconnection.chat.dto.ChatRoomListResponse; import com.example.solidconnection.chat.dto.ChatRoomResponse; @@ -24,13 +26,13 @@ import java.time.ZonedDateTime; import java.util.List; import java.util.Optional; -import lombok.RequiredArgsConstructor; +import org.springframework.context.annotation.Lazy; import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Slice; +import org.springframework.messaging.simp.SimpMessageSendingOperations; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; -@RequiredArgsConstructor @Service public class ChatService { @@ -40,6 +42,22 @@ public class ChatService { private final ChatReadStatusRepository chatReadStatusRepository; private final SiteUserRepository siteUserRepository; + private final SimpMessageSendingOperations simpMessageSendingOperations; + + public ChatService(ChatRoomRepository chatRoomRepository, + ChatMessageRepository chatMessageRepository, + ChatParticipantRepository chatParticipantRepository, + ChatReadStatusRepository chatReadStatusRepository, + SiteUserRepository siteUserRepository, + @Lazy SimpMessageSendingOperations simpMessageSendingOperations) { + this.chatRoomRepository = chatRoomRepository; + this.chatMessageRepository = chatMessageRepository; + this.chatParticipantRepository = chatParticipantRepository; + this.chatReadStatusRepository = chatReadStatusRepository; + this.siteUserRepository = siteUserRepository; + this.simpMessageSendingOperations = simpMessageSendingOperations; + } + @Transactional(readOnly = true) public ChatRoomListResponse getChatRooms(long siteUserId) { // todo : n + 1 문제 해결 필요! @@ -89,6 +107,13 @@ public SliceResponse getChatMessages(long siteUserId, long return SliceResponse.of(content, chatMessages); } + public void validateChatRoomParticipant(long siteUserId, long roomId) { + boolean isParticipant = chatParticipantRepository.existsByChatRoomIdAndSiteUserId(roomId, siteUserId); + if (!isParticipant) { + throw new CustomException(CHAT_PARTICIPANT_NOT_FOUND); + } + } + private ChatMessageResponse toChatMessageResponse(ChatMessage message) { List attachments = message.getChatAttachments().stream() .map(attachment -> ChatAttachmentResponse.of( @@ -109,13 +134,6 @@ private ChatMessageResponse toChatMessageResponse(ChatMessage message) { ); } - private void validateChatRoomParticipant(long siteUserId, long roomId) { - boolean isParticipant = chatParticipantRepository.existsByChatRoomIdAndSiteUserId(roomId, siteUserId); - if (!isParticipant) { - throw new CustomException(CHAT_PARTICIPANT_NOT_FOUND); - } - } - @Transactional public void markChatMessagesAsRead(long siteUserId, long roomId) { ChatParticipant participant = chatParticipantRepository @@ -124,4 +142,24 @@ public void markChatMessagesAsRead(long siteUserId, long roomId) { chatReadStatusRepository.upsertReadStatus(roomId, participant.getId()); } + + @Transactional + public void sendChatMessage(ChatMessageSendRequest chatMessageSendRequest, long siteUserId, long roomId) { + long senderId = chatParticipantRepository.findByChatRoomIdAndSiteUserId(roomId, siteUserId) + .orElseThrow(() -> new CustomException(CHAT_PARTICIPANT_NOT_FOUND)) + .getId(); + + ChatMessage chatMessage = new ChatMessage( + chatMessageSendRequest.content(), + senderId, + chatRoomRepository.findById(roomId) + .orElseThrow(() -> new CustomException(INVALID_CHAT_ROOM_STATE)) + ); + + chatMessageRepository.save(chatMessage); + + ChatMessageSendResponse chatMessageResponse = ChatMessageSendResponse.from(chatMessage); + + simpMessageSendingOperations.convertAndSend("/topic/chat/" + roomId, chatMessageResponse); + } } diff --git a/src/main/java/com/example/solidconnection/security/config/SecurityConfiguration.java b/src/main/java/com/example/solidconnection/security/config/SecurityConfiguration.java index 3667e9d84..706fedd52 100644 --- a/src/main/java/com/example/solidconnection/security/config/SecurityConfiguration.java +++ b/src/main/java/com/example/solidconnection/security/config/SecurityConfiguration.java @@ -5,8 +5,8 @@ import com.example.solidconnection.common.exception.CustomAccessDeniedHandler; import com.example.solidconnection.common.exception.CustomAuthenticationEntryPoint; import com.example.solidconnection.security.filter.ExceptionHandlerFilter; -import com.example.solidconnection.security.filter.TokenAuthenticationFilter; import com.example.solidconnection.security.filter.SignOutCheckFilter; +import com.example.solidconnection.security.filter.TokenAuthenticationFilter; import lombok.RequiredArgsConstructor; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; @@ -62,6 +62,7 @@ public SecurityFilterChain filterChain(HttpSecurity http) throws Exception { .cors(corsConfigurer -> corsConfigurer.configurationSource(corsConfigurationSource())) .sessionManagement((session) -> session.sessionCreationPolicy(SessionCreationPolicy.STATELESS)) .authorizeHttpRequests(auth -> auth + .requestMatchers("/connect/**").authenticated() .requestMatchers("/admin/**").hasRole(ADMIN.name()) .anyRequest().permitAll() )