diff --git a/src/main/java/com/petshop/backend/config/WebSocketAuthChannelInterceptor.java b/src/main/java/com/petshop/backend/config/WebSocketAuthChannelInterceptor.java index 6f01c384..b62dfe34 100644 --- a/src/main/java/com/petshop/backend/config/WebSocketAuthChannelInterceptor.java +++ b/src/main/java/com/petshop/backend/config/WebSocketAuthChannelInterceptor.java @@ -2,6 +2,7 @@ package com.petshop.backend.config; import com.petshop.backend.entity.User; import com.petshop.backend.repository.UserRepository; +import com.petshop.backend.security.AppPrincipal; import com.petshop.backend.security.JwtUtil; import com.petshop.backend.service.ChatService; import org.springframework.messaging.Message; @@ -10,11 +11,9 @@ import org.springframework.messaging.simp.stomp.StompCommand; import org.springframework.messaging.simp.stomp.StompHeaderAccessor; import org.springframework.messaging.support.ChannelInterceptor; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; -import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.stereotype.Component; import java.security.Principal; -import java.util.Collections; import java.util.List; @Component @@ -46,41 +45,42 @@ public class WebSocketAuthChannelInterceptor implements ChannelInterceptor { throw new IllegalArgumentException("Missing websocket token"); } - String username = jwtUtil.extractUsername(token); - User user = userRepository.findByUsername(username) - .orElseThrow(() -> new IllegalArgumentException("User not found")); + Long userId = jwtUtil.extractUserId(token); + User user = userId == null ? null : userRepository.findById(userId).orElse(null); + if (user == null) { + throw new IllegalArgumentException("User not found"); + } if (user.getActive() == null || !user.getActive()) { throw new IllegalArgumentException("User account is inactive"); } + if (!jwtUtil.validateToken(token, user)) { + throw new IllegalArgumentException("Invalid websocket token"); + } - UsernamePasswordAuthenticationToken authentication = new UsernamePasswordAuthenticationToken( + AppPrincipal principal = new AppPrincipal( + user.getId(), user.getUsername(), + user.getRole(), + user.getTokenVersion() + ); + UsernamePasswordAuthenticationToken authentication = new UsernamePasswordAuthenticationToken( + principal, null, - Collections.singletonList(new SimpleGrantedAuthority("ROLE_" + user.getRole().name())) + principal.getAuthorities() ); accessor.setUser(authentication); accessor.getSessionAttributes().put("user", authentication); return message; } - if (StompCommand.DISCONNECT.equals(command) || StompCommand.UNSUBSCRIBE.equals(command)) { - return message; - } - - Principal principal = accessor.getUser(); - if (principal == null && accessor.getSessionAttributes() != null) { - Object sessionUser = accessor.getSessionAttributes().get("user"); - if (sessionUser instanceof Principal storedPrincipal) { - accessor.setUser(storedPrincipal); - principal = storedPrincipal; - } - } - if (principal == null) { + User user = resolveUser(accessor.getUser(), accessor); + if (user == null) { throw new IllegalArgumentException("Unauthenticated websocket session"); } - User user = userRepository.findByUsername(principal.getName()) - .orElseThrow(() -> new IllegalArgumentException("User not found")); + if (StompCommand.DISCONNECT.equals(command) || StompCommand.UNSUBSCRIBE.equals(command)) { + return message; + } if (StompCommand.SUBSCRIBE.equals(command)) { authorizeSubscription(accessor.getDestination(), user); @@ -91,6 +91,41 @@ public class WebSocketAuthChannelInterceptor implements ChannelInterceptor { return message; } + private User resolveUser(Principal principal, StompHeaderAccessor accessor) { + Principal currentPrincipal = principal; + if (currentPrincipal == null && accessor.getSessionAttributes() != null) { + Object sessionUser = accessor.getSessionAttributes().get("user"); + if (sessionUser instanceof Principal storedPrincipal) { + accessor.setUser(storedPrincipal); + currentPrincipal = storedPrincipal; + } + } + + if (currentPrincipal instanceof UsernamePasswordAuthenticationToken authenticationToken + && authenticationToken.getPrincipal() instanceof AppPrincipal appPrincipal) { + return userRepository.findById(appPrincipal.getUserId()) + .orElseThrow(() -> new IllegalArgumentException("User not found")); + } + + if (currentPrincipal instanceof AppPrincipal appPrincipal) { + return userRepository.findById(appPrincipal.getUserId()) + .orElseThrow(() -> new IllegalArgumentException("User not found")); + } + + String tokenHeader = firstHeader(accessor, "Authorization"); + String token = extractToken(tokenHeader != null ? tokenHeader : firstHeader(accessor, "token")); + if (token == null || token.isBlank()) { + return null; + } + + Long userId = jwtUtil.extractUserId(token); + User user = userId == null ? null : userRepository.findById(userId).orElse(null); + if (user == null || user.getActive() == null || !user.getActive() || !jwtUtil.validateToken(token, user)) { + throw new IllegalArgumentException("User not found"); + } + return user; + } + private void authorizeSubscription(String destination, User user) { if (destination == null || destination.startsWith("/user/queue/")) { return; diff --git a/src/main/java/com/petshop/backend/controller/ChatWebSocketController.java b/src/main/java/com/petshop/backend/controller/ChatWebSocketController.java index bb30bd44..d7f1e3b6 100644 --- a/src/main/java/com/petshop/backend/controller/ChatWebSocketController.java +++ b/src/main/java/com/petshop/backend/controller/ChatWebSocketController.java @@ -4,6 +4,7 @@ import com.petshop.backend.dto.chat.MessageRequest; import com.petshop.backend.dto.chat.MessageResponse; import com.petshop.backend.entity.User; import com.petshop.backend.repository.UserRepository; +import com.petshop.backend.security.AppPrincipal; import com.petshop.backend.security.JwtUtil; import com.petshop.backend.service.ChatRealtimeService; import com.petshop.backend.service.ChatService; @@ -43,8 +44,14 @@ public class ChatWebSocketController { private User resolveUser(SimpMessageHeaderAccessor headerAccessor) { Principal principal = headerAccessor.getUser(); - if (principal != null) { - return userRepository.findByUsername(principal.getName()) + if (principal instanceof org.springframework.security.authentication.UsernamePasswordAuthenticationToken authenticationToken + && authenticationToken.getPrincipal() instanceof AppPrincipal appPrincipal) { + return userRepository.findById(appPrincipal.getUserId()) + .orElseThrow(() -> new IllegalArgumentException("User not found")); + } + + if (principal instanceof AppPrincipal appPrincipal) { + return userRepository.findById(appPrincipal.getUserId()) .orElseThrow(() -> new IllegalArgumentException("User not found")); } @@ -57,8 +64,11 @@ public class ChatWebSocketController { } String token = tokenHeader.startsWith("Bearer ") ? tokenHeader.substring(7) : tokenHeader; - String username = jwtUtil.extractUsername(token); - return userRepository.findByUsername(username) - .orElseThrow(() -> new IllegalArgumentException("User not found")); + Long userId = jwtUtil.extractUserId(token); + User user = userId == null ? null : userRepository.findById(userId).orElse(null); + if (user == null || user.getActive() == null || !user.getActive() || !jwtUtil.validateToken(token, user)) { + throw new IllegalArgumentException("User not found"); + } + return user; } }