Align websocket auth

This commit is contained in:
2026-03-12 18:27:02 -06:00
parent 9dcb4865fa
commit 79e8eb08b8
2 changed files with 72 additions and 27 deletions

View File

@@ -2,6 +2,7 @@ package com.petshop.backend.config;
import com.petshop.backend.entity.User; import com.petshop.backend.entity.User;
import com.petshop.backend.repository.UserRepository; import com.petshop.backend.repository.UserRepository;
import com.petshop.backend.security.AppPrincipal;
import com.petshop.backend.security.JwtUtil; import com.petshop.backend.security.JwtUtil;
import com.petshop.backend.service.ChatService; import com.petshop.backend.service.ChatService;
import org.springframework.messaging.Message; import org.springframework.messaging.Message;
@@ -10,11 +11,9 @@ import org.springframework.messaging.simp.stomp.StompCommand;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor; import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.support.ChannelInterceptor; import org.springframework.messaging.support.ChannelInterceptor;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import java.security.Principal; import java.security.Principal;
import java.util.Collections;
import java.util.List; import java.util.List;
@Component @Component
@@ -46,41 +45,42 @@ public class WebSocketAuthChannelInterceptor implements ChannelInterceptor {
throw new IllegalArgumentException("Missing websocket token"); throw new IllegalArgumentException("Missing websocket token");
} }
String username = jwtUtil.extractUsername(token); Long userId = jwtUtil.extractUserId(token);
User user = userRepository.findByUsername(username) User user = userId == null ? null : userRepository.findById(userId).orElse(null);
.orElseThrow(() -> new IllegalArgumentException("User not found")); if (user == null) {
throw new IllegalArgumentException("User not found");
}
if (user.getActive() == null || !user.getActive()) { if (user.getActive() == null || !user.getActive()) {
throw new IllegalArgumentException("User account is inactive"); throw new IllegalArgumentException("User account is inactive");
} }
if (!jwtUtil.validateToken(token, user)) {
throw new IllegalArgumentException("Invalid websocket token");
}
UsernamePasswordAuthenticationToken authentication = new UsernamePasswordAuthenticationToken( AppPrincipal principal = new AppPrincipal(
user.getId(),
user.getUsername(), user.getUsername(),
user.getRole(),
user.getTokenVersion()
);
UsernamePasswordAuthenticationToken authentication = new UsernamePasswordAuthenticationToken(
principal,
null, null,
Collections.singletonList(new SimpleGrantedAuthority("ROLE_" + user.getRole().name())) principal.getAuthorities()
); );
accessor.setUser(authentication); accessor.setUser(authentication);
accessor.getSessionAttributes().put("user", authentication); accessor.getSessionAttributes().put("user", authentication);
return message; return message;
} }
if (StompCommand.DISCONNECT.equals(command) || StompCommand.UNSUBSCRIBE.equals(command)) { User user = resolveUser(accessor.getUser(), accessor);
return message; if (user == null) {
}
Principal principal = accessor.getUser();
if (principal == null && accessor.getSessionAttributes() != null) {
Object sessionUser = accessor.getSessionAttributes().get("user");
if (sessionUser instanceof Principal storedPrincipal) {
accessor.setUser(storedPrincipal);
principal = storedPrincipal;
}
}
if (principal == null) {
throw new IllegalArgumentException("Unauthenticated websocket session"); throw new IllegalArgumentException("Unauthenticated websocket session");
} }
User user = userRepository.findByUsername(principal.getName()) if (StompCommand.DISCONNECT.equals(command) || StompCommand.UNSUBSCRIBE.equals(command)) {
.orElseThrow(() -> new IllegalArgumentException("User not found")); return message;
}
if (StompCommand.SUBSCRIBE.equals(command)) { if (StompCommand.SUBSCRIBE.equals(command)) {
authorizeSubscription(accessor.getDestination(), user); authorizeSubscription(accessor.getDestination(), user);
@@ -91,6 +91,41 @@ public class WebSocketAuthChannelInterceptor implements ChannelInterceptor {
return message; return message;
} }
private User resolveUser(Principal principal, StompHeaderAccessor accessor) {
Principal currentPrincipal = principal;
if (currentPrincipal == null && accessor.getSessionAttributes() != null) {
Object sessionUser = accessor.getSessionAttributes().get("user");
if (sessionUser instanceof Principal storedPrincipal) {
accessor.setUser(storedPrincipal);
currentPrincipal = storedPrincipal;
}
}
if (currentPrincipal instanceof UsernamePasswordAuthenticationToken authenticationToken
&& authenticationToken.getPrincipal() instanceof AppPrincipal appPrincipal) {
return userRepository.findById(appPrincipal.getUserId())
.orElseThrow(() -> new IllegalArgumentException("User not found"));
}
if (currentPrincipal instanceof AppPrincipal appPrincipal) {
return userRepository.findById(appPrincipal.getUserId())
.orElseThrow(() -> new IllegalArgumentException("User not found"));
}
String tokenHeader = firstHeader(accessor, "Authorization");
String token = extractToken(tokenHeader != null ? tokenHeader : firstHeader(accessor, "token"));
if (token == null || token.isBlank()) {
return null;
}
Long userId = jwtUtil.extractUserId(token);
User user = userId == null ? null : userRepository.findById(userId).orElse(null);
if (user == null || user.getActive() == null || !user.getActive() || !jwtUtil.validateToken(token, user)) {
throw new IllegalArgumentException("User not found");
}
return user;
}
private void authorizeSubscription(String destination, User user) { private void authorizeSubscription(String destination, User user) {
if (destination == null || destination.startsWith("/user/queue/")) { if (destination == null || destination.startsWith("/user/queue/")) {
return; return;

View File

@@ -4,6 +4,7 @@ import com.petshop.backend.dto.chat.MessageRequest;
import com.petshop.backend.dto.chat.MessageResponse; import com.petshop.backend.dto.chat.MessageResponse;
import com.petshop.backend.entity.User; import com.petshop.backend.entity.User;
import com.petshop.backend.repository.UserRepository; import com.petshop.backend.repository.UserRepository;
import com.petshop.backend.security.AppPrincipal;
import com.petshop.backend.security.JwtUtil; import com.petshop.backend.security.JwtUtil;
import com.petshop.backend.service.ChatRealtimeService; import com.petshop.backend.service.ChatRealtimeService;
import com.petshop.backend.service.ChatService; import com.petshop.backend.service.ChatService;
@@ -43,8 +44,14 @@ public class ChatWebSocketController {
private User resolveUser(SimpMessageHeaderAccessor headerAccessor) { private User resolveUser(SimpMessageHeaderAccessor headerAccessor) {
Principal principal = headerAccessor.getUser(); Principal principal = headerAccessor.getUser();
if (principal != null) { if (principal instanceof org.springframework.security.authentication.UsernamePasswordAuthenticationToken authenticationToken
return userRepository.findByUsername(principal.getName()) && authenticationToken.getPrincipal() instanceof AppPrincipal appPrincipal) {
return userRepository.findById(appPrincipal.getUserId())
.orElseThrow(() -> new IllegalArgumentException("User not found"));
}
if (principal instanceof AppPrincipal appPrincipal) {
return userRepository.findById(appPrincipal.getUserId())
.orElseThrow(() -> new IllegalArgumentException("User not found")); .orElseThrow(() -> new IllegalArgumentException("User not found"));
} }
@@ -57,8 +64,11 @@ public class ChatWebSocketController {
} }
String token = tokenHeader.startsWith("Bearer ") ? tokenHeader.substring(7) : tokenHeader; String token = tokenHeader.startsWith("Bearer ") ? tokenHeader.substring(7) : tokenHeader;
String username = jwtUtil.extractUsername(token); Long userId = jwtUtil.extractUserId(token);
return userRepository.findByUsername(username) User user = userId == null ? null : userRepository.findById(userId).orElse(null);
.orElseThrow(() -> new IllegalArgumentException("User not found")); if (user == null || user.getActive() == null || !user.getActive() || !jwtUtil.validateToken(token, user)) {
throw new IllegalArgumentException("User not found");
}
return user;
} }
} }