Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package com.example.solidconnection.chat.config;

import java.security.Principal;
import java.util.Map;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.support.DefaultHandshakeHandler;

// WebSocket 세션의 Principal을 결정한다.
@Component
public class CustomHandshakeHandler extends DefaultHandshakeHandler {

@Override
protected Principal determineUser(ServerHttpRequest request, WebSocketHandler wsHandler,
Map<String, Object> attributes) {

Object userAttribute = attributes.get("user");

if (userAttribute instanceof Principal) {
Principal principal = (Principal) userAttribute;
return principal;
}
Comment on lines +18 to +23
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

attributes.get("user")를 그대로 리턴하는 게 어떤 의미인지 알려주실 수 있나요!?
지금 보기에는 super.determineUser()와 똑같아보여서요..!

Copy link
Member Author

@whqtker whqtker Aug 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

<3줄 요약>

  1. 두 로직은 핸드셰이크가 수행된 쓰레드와 WebSocket 관련 쓰레드가 같은 경우 같은 동작을 수행합니다.
  2. 그럼에도 attributes.get("user") 를 사용하여 Principal 을 리턴한 이유는, Principal 객체를 안전하게 전달하기 위해 attributes 에 담았고, 그걸 꺼내 리턴한 것입니다.
  3. super.determineUser() 는 혹시나 꺼낸 값이 Principal 이 아닌 경우 예외 터뜨리지 않고 기본 동작을 수행하도록 하기 위해 작성하였습니다.

저 코드 구현 당시 제 생각은 ...

Principal 객체를 어떻게 WebSocket 세션에 전달할지 찾아보았고, 커스텀 핸들러에서 determineUser 메서드를 구현하여 전달한다는 것을 알게 되었습니다.

image

determineUser 메서드에 대해 잘 몰라서 정의를 살펴보았고, attribute 파라미터가 WebSocket 세션에 전달하고 싶은 데이터를 저장하는 임시 저장소 역할을 한다는 것을 알게 되었습니다.

@Override
public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response,
                               WebSocketHandler wsHandler, Map<String, Object> attributes) {
    Principal principal = request.getPrincipal();

    if (principal != null) {
        attributes.put("user", principal);
        return true;
    }

    return false;
}

그래서 인터셉터에서 핸드셰이크 전 Principaluser 라는 키로 attributes 에 저장하였고, 핸들러에서 user 에 해당하는 값을 꺼내 간단히 Principal 인지 검증 후 리턴하도록 구현하였습니다.
혹시 만약에 꺼낸 값이 Principal 이 아닌 경우 관련 처리도 필요했고, 그 경우 부모 클래스의 기본 동작을 따르도록 하였습니다.


제가 첨부한 블로그는 어쨌거나 attributes 인자를 사용하지 않았고, 잘 동작한 것 같습니다. 그럼 attributes 는 사용 안 해도 되는 거 아닐까 ? 해서 찾아보았는데, WebSocket 핸드셰이크를 수행하는 쓰레드와 실제 WebSocket을 사용하여 메시지를 주고받는 쓰레드가 다를 수 있으며, 이 경우 SecurityContextHolder 로부터 Principal 을 가져올 수 없습니다. 생성된 Principal 은 같은 쓰레드 내에서만 유효합니다. [참고]


return super.determineUser(request, wsHandler, attributes);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@

import static com.example.solidconnection.common.exception.ErrorCode.AUTHENTICATION_FAILED;

import com.example.solidconnection.auth.token.JwtTokenProvider;
import com.example.solidconnection.chat.service.ChatService;
import com.example.solidconnection.common.exception.CustomException;
import com.example.solidconnection.common.exception.ErrorCode;
import io.jsonwebtoken.Claims;
import com.example.solidconnection.security.authentication.TokenAuthentication;
import com.example.solidconnection.security.userdetails.SiteUserDetails;
import java.security.Principal;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import lombok.RequiredArgsConstructor;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
Expand All @@ -18,47 +22,48 @@
@RequiredArgsConstructor
public class StompHandler implements ChannelInterceptor {

private final JwtTokenProvider jwtTokenProvider;
private static final Pattern ROOM_ID_PATTERN = Pattern.compile("^/topic/chat/(\\d+)$");
private final ChatService chatService;

@Override
public Message<?> preSend(Message<?> message, MessageChannel channel) {
final StompHeaderAccessor accessor = StompHeaderAccessor.wrap(message);

if (StompCommand.CONNECT.equals(accessor.getCommand())) {
Claims claims = validateAndExtractClaims(accessor, AUTHENTICATION_FAILED);
Principal user = accessor.getUser();
if (user == null) {
throw new CustomException(AUTHENTICATION_FAILED);
}
}

if (StompCommand.SUBSCRIBE.equals(accessor.getCommand())) {
Claims claims = validateAndExtractClaims(accessor, AUTHENTICATION_FAILED);
Principal user = accessor.getUser();
if (user == null) {
throw new CustomException(AUTHENTICATION_FAILED);
}

String email = claims.getSubject();
String destination = accessor.getDestination();
TokenAuthentication tokenAuthentication = (TokenAuthentication) user;
SiteUserDetails siteUserDetails = (SiteUserDetails) tokenAuthentication.getPrincipal();
Comment on lines +45 to +46
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

타입 캐스팅 안전성을 개선해주세요.

ChatMessageController와 동일하게, PrincipalTokenAuthentication으로, 그리고 그 안의 principal을 SiteUserDetails로 캐스팅하는 부분에서 타입 안전성이 보장되지 않습니다.

안전한 타입 검증을 추가하세요:

-            TokenAuthentication tokenAuthentication = (TokenAuthentication) user;
-            SiteUserDetails siteUserDetails = (SiteUserDetails) tokenAuthentication.getPrincipal();
+            if (!(user instanceof TokenAuthentication tokenAuthentication)) {
+                throw new CustomException(AUTHENTICATION_FAILED);
+            }
+            
+            Object principalObject = tokenAuthentication.getPrincipal();
+            if (!(principalObject instanceof SiteUserDetails siteUserDetails)) {
+                throw new CustomException(AUTHENTICATION_FAILED);
+            }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
TokenAuthentication tokenAuthentication = (TokenAuthentication) user;
SiteUserDetails siteUserDetails = (SiteUserDetails) tokenAuthentication.getPrincipal();
if (!(user instanceof TokenAuthentication tokenAuthentication)) {
throw new CustomException(AUTHENTICATION_FAILED);
}
Object principalObject = tokenAuthentication.getPrincipal();
if (!(principalObject instanceof SiteUserDetails siteUserDetails)) {
throw new CustomException(AUTHENTICATION_FAILED);
}
🤖 Prompt for AI Agents
In src/main/java/com/example/solidconnection/chat/config/StompHandler.java
around lines 42 to 43, the current casting of user to TokenAuthentication and
then its principal to SiteUserDetails lacks type safety. Add explicit type
checks using instanceof before casting to ensure the objects are of the expected
types, and handle the cases where they are not to prevent ClassCastException.


String roomId = extractRoomId(destination);
String destination = accessor.getDestination();
long roomId = Long.parseLong(extractRoomId(destination));

// todo: roomId 기반 실제 구독 권한 검사 로직 추가
chatService.validateChatRoomParticipant(siteUserDetails.getSiteUser().getId(), roomId);
}

return message;
}

private Claims validateAndExtractClaims(StompHeaderAccessor accessor, ErrorCode errorCode) {
String bearerToken = accessor.getFirstNativeHeader("Authorization");
if (bearerToken == null || !bearerToken.startsWith("Bearer ")) {
throw new CustomException(errorCode);
}
String token = bearerToken.substring(7);
return jwtTokenProvider.parseClaims(token);
}

private String extractRoomId(String destination) {
if (destination == null) {
throw new CustomException(ErrorCode.INVALID_ROOM_ID);
}
String[] parts = destination.split("/");
if (parts.length < 3 || !parts[1].equals("topic")) {

Matcher matcher = ROOM_ID_PATTERN.matcher(destination);
if (!matcher.matches()) {
throw new CustomException(ErrorCode.INVALID_ROOM_ID);
}
return parts[2];

return matcher.group(1);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,18 @@ public class StompWebSocketConfig implements WebSocketMessageBrokerConfigurer {
private final StompHandler stompHandler;
private final StompProperties stompProperties;
private final CorsProperties corsProperties;
private final WebSocketHandshakeInterceptor webSocketHandshakeInterceptor;
private final CustomHandshakeHandler customHandshakeHandler;

@Override
public void registerStompEndpoints(StompEndpointRegistry registry) {
List<String> strings = corsProperties.allowedOrigins();
String[] allowedOrigins = strings.toArray(String[]::new);
registry.addEndpoint("/connect").setAllowedOrigins(allowedOrigins).withSockJS();
registry.addEndpoint("/connect")
.setAllowedOrigins(allowedOrigins)
.addInterceptors(webSocketHandshakeInterceptor)
.setHandshakeHandler(customHandshakeHandler)
.withSockJS();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package com.example.solidconnection.chat.config;

import java.security.Principal;
import java.util.Map;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;

// Principal을 WebSocket 세션에 저장하는 것에만 집중한다.
@Component
public class WebSocketHandshakeInterceptor implements HandshakeInterceptor {

@Override
public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response,
WebSocketHandler wsHandler, Map<String, Object> attributes) {
Principal principal = request.getPrincipal();

if (principal != null) {
attributes.put("user", principal);
return true;
}

return false;
}

@Override
public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response,
WebSocketHandler wsHandler, Exception exception) {
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package com.example.solidconnection.chat.controller;

import com.example.solidconnection.chat.dto.ChatMessageSendRequest;
import com.example.solidconnection.chat.service.ChatService;
import com.example.solidconnection.security.authentication.TokenAuthentication;
import com.example.solidconnection.security.userdetails.SiteUserDetails;
import jakarta.validation.Valid;
import java.security.Principal;
import lombok.RequiredArgsConstructor;
import org.springframework.messaging.handler.annotation.DestinationVariable;
import org.springframework.messaging.handler.annotation.MessageMapping;
import org.springframework.messaging.handler.annotation.Payload;
import org.springframework.stereotype.Controller;

@Controller
@RequiredArgsConstructor
public class ChatMessageController {

private final ChatService chatService;

@MessageMapping("/chat/{roomId}")
public void sendChatMessage(
@DestinationVariable Long roomId,
@Valid @Payload ChatMessageSendRequest chatMessageSendRequest,
Principal principal
) {
TokenAuthentication tokenAuthentication = (TokenAuthentication) principal;
SiteUserDetails siteUserDetails = (SiteUserDetails) tokenAuthentication.getPrincipal();
Comment on lines +27 to +28
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

타입 캐스팅의 안전성을 보장해주세요.

PrincipalTokenAuthentication으로, 그리고 그 안의 principal을 SiteUserDetails로 캐스팅하는 부분에서 ClassCastException이 발생할 가능성이 있습니다. WebSocket 세션에서 예상과 다른 타입의 Principal이 전달될 경우를 대비한 검증 로직이 필요합니다.

다음과 같이 안전한 타입 검증을 추가하는 것을 권장합니다:

-        TokenAuthentication tokenAuthentication = (TokenAuthentication) principal;
-        SiteUserDetails siteUserDetails = (SiteUserDetails) tokenAuthentication.getPrincipal();
+        if (!(principal instanceof TokenAuthentication tokenAuthentication)) {
+            throw new CustomException(ErrorCode.AUTHENTICATION_FAILED);
+        }
+        
+        Object principalObject = tokenAuthentication.getPrincipal();
+        if (!(principalObject instanceof SiteUserDetails siteUserDetails)) {
+            throw new CustomException(ErrorCode.AUTHENTICATION_FAILED);
+        }
🤖 Prompt for AI Agents
In
src/main/java/com/example/solidconnection/chat/controller/ChatMessageController.java
around lines 26 to 27, the code casts Principal to TokenAuthentication and then
to SiteUserDetails without type checks, risking ClassCastException. Add
instanceof checks before each cast to verify the object's type, and handle the
case where the type is unexpected, such as by logging an error or throwing a
controlled exception, to ensure safe type casting.


chatService.sendChatMessage(chatMessageSendRequest, siteUserDetails.getSiteUser().getId(), roomId);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package com.example.solidconnection.chat.dto;

import jakarta.validation.constraints.NotNull;
import jakarta.validation.constraints.Size;

public record ChatMessageSendRequest(
@NotNull(message = "메시지를 입력해주세요.")
@Size(max = 500, message = "메시지는 500자를 초과할 수 없습니다")
String content
) {

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package com.example.solidconnection.chat.dto;

import com.example.solidconnection.chat.domain.ChatMessage;

public record ChatMessageSendResponse(
long messageId,
String content,
long senderId
) {

public static ChatMessageSendResponse from(ChatMessage chatMessage) {
return new ChatMessageSendResponse(
chatMessage.getId(),
chatMessage.getContent(),
chatMessage.getSenderId()
);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import com.example.solidconnection.chat.domain.ChatRoom;
import com.example.solidconnection.chat.dto.ChatAttachmentResponse;
import com.example.solidconnection.chat.dto.ChatMessageResponse;
import com.example.solidconnection.chat.dto.ChatMessageSendRequest;
import com.example.solidconnection.chat.dto.ChatMessageSendResponse;
import com.example.solidconnection.chat.dto.ChatParticipantResponse;
import com.example.solidconnection.chat.dto.ChatRoomListResponse;
import com.example.solidconnection.chat.dto.ChatRoomResponse;
Expand All @@ -24,13 +26,13 @@
import java.time.ZonedDateTime;
import java.util.List;
import java.util.Optional;
import lombok.RequiredArgsConstructor;
import org.springframework.context.annotation.Lazy;
import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Slice;
import org.springframework.messaging.simp.SimpMessageSendingOperations;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;

@RequiredArgsConstructor
@Service
public class ChatService {

Expand All @@ -40,6 +42,22 @@ public class ChatService {
private final ChatReadStatusRepository chatReadStatusRepository;
private final SiteUserRepository siteUserRepository;

private final SimpMessageSendingOperations simpMessageSendingOperations;

public ChatService(ChatRoomRepository chatRoomRepository,
ChatMessageRepository chatMessageRepository,
ChatParticipantRepository chatParticipantRepository,
ChatReadStatusRepository chatReadStatusRepository,
SiteUserRepository siteUserRepository,
@Lazy SimpMessageSendingOperations simpMessageSendingOperations) {
this.chatRoomRepository = chatRoomRepository;
this.chatMessageRepository = chatMessageRepository;
this.chatParticipantRepository = chatParticipantRepository;
this.chatReadStatusRepository = chatReadStatusRepository;
this.siteUserRepository = siteUserRepository;
this.simpMessageSendingOperations = simpMessageSendingOperations;
}

@Transactional(readOnly = true)
public ChatRoomListResponse getChatRooms(long siteUserId) {
// todo : n + 1 문제 해결 필요!
Expand Down Expand Up @@ -89,6 +107,13 @@ public SliceResponse<ChatMessageResponse> getChatMessages(long siteUserId, long
return SliceResponse.of(content, chatMessages);
}

public void validateChatRoomParticipant(long siteUserId, long roomId) {
boolean isParticipant = chatParticipantRepository.existsByChatRoomIdAndSiteUserId(roomId, siteUserId);
if (!isParticipant) {
throw new CustomException(CHAT_PARTICIPANT_NOT_FOUND);
}
}

private ChatMessageResponse toChatMessageResponse(ChatMessage message) {
List<ChatAttachmentResponse> attachments = message.getChatAttachments().stream()
.map(attachment -> ChatAttachmentResponse.of(
Expand All @@ -109,13 +134,6 @@ private ChatMessageResponse toChatMessageResponse(ChatMessage message) {
);
}

private void validateChatRoomParticipant(long siteUserId, long roomId) {
boolean isParticipant = chatParticipantRepository.existsByChatRoomIdAndSiteUserId(roomId, siteUserId);
if (!isParticipant) {
throw new CustomException(CHAT_PARTICIPANT_NOT_FOUND);
}
}

@Transactional
public void markChatMessagesAsRead(long siteUserId, long roomId) {
ChatParticipant participant = chatParticipantRepository
Expand All @@ -124,4 +142,24 @@ public void markChatMessagesAsRead(long siteUserId, long roomId) {

chatReadStatusRepository.upsertReadStatus(roomId, participant.getId());
}

@Transactional
public void sendChatMessage(ChatMessageSendRequest chatMessageSendRequest, long siteUserId, long roomId) {
long senderId = chatParticipantRepository.findByChatRoomIdAndSiteUserId(roomId, siteUserId)
.orElseThrow(() -> new CustomException(CHAT_PARTICIPANT_NOT_FOUND))
.getId();

ChatMessage chatMessage = new ChatMessage(
chatMessageSendRequest.content(),
senderId,
chatRoomRepository.findById(roomId)
.orElseThrow(() -> new CustomException(INVALID_CHAT_ROOM_STATE))
);

chatMessageRepository.save(chatMessage);

ChatMessageSendResponse chatMessageResponse = ChatMessageSendResponse.from(chatMessage);

simpMessageSendingOperations.convertAndSend("/topic/chat/" + roomId, chatMessageResponse);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import com.example.solidconnection.common.exception.CustomAccessDeniedHandler;
import com.example.solidconnection.common.exception.CustomAuthenticationEntryPoint;
import com.example.solidconnection.security.filter.ExceptionHandlerFilter;
import com.example.solidconnection.security.filter.TokenAuthenticationFilter;
import com.example.solidconnection.security.filter.SignOutCheckFilter;
import com.example.solidconnection.security.filter.TokenAuthenticationFilter;
import lombok.RequiredArgsConstructor;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
Expand Down Expand Up @@ -62,6 +62,7 @@ public SecurityFilterChain filterChain(HttpSecurity http) throws Exception {
.cors(corsConfigurer -> corsConfigurer.configurationSource(corsConfigurationSource()))
.sessionManagement((session) -> session.sessionCreationPolicy(SessionCreationPolicy.STATELESS))
.authorizeHttpRequests(auth -> auth
.requestMatchers("/connect/**").authenticated()
.requestMatchers("/admin/**").hasRole(ADMIN.name())
.anyRequest().permitAll()
)
Expand Down
Loading