Align websocket auth

This commit is contained in:
2026-03-12 18:27:02 -06:00
parent 9dcb4865fa
commit 79e8eb08b8
2 changed files with 72 additions and 27 deletions

View File

@@ -2,6 +2,7 @@ package com.petshop.backend.config;
import com.petshop.backend.entity.User;
import com.petshop.backend.repository.UserRepository;
import com.petshop.backend.security.AppPrincipal;
import com.petshop.backend.security.JwtUtil;
import com.petshop.backend.service.ChatService;
import org.springframework.messaging.Message;
@@ -10,11 +11,9 @@ import org.springframework.messaging.simp.stomp.StompCommand;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.support.ChannelInterceptor;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.stereotype.Component;
import java.security.Principal;
import java.util.Collections;
import java.util.List;
@Component
@@ -46,41 +45,42 @@ public class WebSocketAuthChannelInterceptor implements ChannelInterceptor {
throw new IllegalArgumentException("Missing websocket token");
}
String username = jwtUtil.extractUsername(token);
User user = userRepository.findByUsername(username)
.orElseThrow(() -> new IllegalArgumentException("User not found"));
Long userId = jwtUtil.extractUserId(token);
User user = userId == null ? null : userRepository.findById(userId).orElse(null);
if (user == null) {
throw new IllegalArgumentException("User not found");
}
if (user.getActive() == null || !user.getActive()) {
throw new IllegalArgumentException("User account is inactive");
}
if (!jwtUtil.validateToken(token, user)) {
throw new IllegalArgumentException("Invalid websocket token");
}
UsernamePasswordAuthenticationToken authentication = new UsernamePasswordAuthenticationToken(
AppPrincipal principal = new AppPrincipal(
user.getId(),
user.getUsername(),
user.getRole(),
user.getTokenVersion()
);
UsernamePasswordAuthenticationToken authentication = new UsernamePasswordAuthenticationToken(
principal,
null,
Collections.singletonList(new SimpleGrantedAuthority("ROLE_" + user.getRole().name()))
principal.getAuthorities()
);
accessor.setUser(authentication);
accessor.getSessionAttributes().put("user", authentication);
return message;
}
if (StompCommand.DISCONNECT.equals(command) || StompCommand.UNSUBSCRIBE.equals(command)) {
return message;
}
Principal principal = accessor.getUser();
if (principal == null && accessor.getSessionAttributes() != null) {
Object sessionUser = accessor.getSessionAttributes().get("user");
if (sessionUser instanceof Principal storedPrincipal) {
accessor.setUser(storedPrincipal);
principal = storedPrincipal;
}
}
if (principal == null) {
User user = resolveUser(accessor.getUser(), accessor);
if (user == null) {
throw new IllegalArgumentException("Unauthenticated websocket session");
}
User user = userRepository.findByUsername(principal.getName())
.orElseThrow(() -> new IllegalArgumentException("User not found"));
if (StompCommand.DISCONNECT.equals(command) || StompCommand.UNSUBSCRIBE.equals(command)) {
return message;
}
if (StompCommand.SUBSCRIBE.equals(command)) {
authorizeSubscription(accessor.getDestination(), user);
@@ -91,6 +91,41 @@ public class WebSocketAuthChannelInterceptor implements ChannelInterceptor {
return message;
}
private User resolveUser(Principal principal, StompHeaderAccessor accessor) {
Principal currentPrincipal = principal;
if (currentPrincipal == null && accessor.getSessionAttributes() != null) {
Object sessionUser = accessor.getSessionAttributes().get("user");
if (sessionUser instanceof Principal storedPrincipal) {
accessor.setUser(storedPrincipal);
currentPrincipal = storedPrincipal;
}
}
if (currentPrincipal instanceof UsernamePasswordAuthenticationToken authenticationToken
&& authenticationToken.getPrincipal() instanceof AppPrincipal appPrincipal) {
return userRepository.findById(appPrincipal.getUserId())
.orElseThrow(() -> new IllegalArgumentException("User not found"));
}
if (currentPrincipal instanceof AppPrincipal appPrincipal) {
return userRepository.findById(appPrincipal.getUserId())
.orElseThrow(() -> new IllegalArgumentException("User not found"));
}
String tokenHeader = firstHeader(accessor, "Authorization");
String token = extractToken(tokenHeader != null ? tokenHeader : firstHeader(accessor, "token"));
if (token == null || token.isBlank()) {
return null;
}
Long userId = jwtUtil.extractUserId(token);
User user = userId == null ? null : userRepository.findById(userId).orElse(null);
if (user == null || user.getActive() == null || !user.getActive() || !jwtUtil.validateToken(token, user)) {
throw new IllegalArgumentException("User not found");
}
return user;
}
private void authorizeSubscription(String destination, User user) {
if (destination == null || destination.startsWith("/user/queue/")) {
return;

View File

@@ -4,6 +4,7 @@ import com.petshop.backend.dto.chat.MessageRequest;
import com.petshop.backend.dto.chat.MessageResponse;
import com.petshop.backend.entity.User;
import com.petshop.backend.repository.UserRepository;
import com.petshop.backend.security.AppPrincipal;
import com.petshop.backend.security.JwtUtil;
import com.petshop.backend.service.ChatRealtimeService;
import com.petshop.backend.service.ChatService;
@@ -43,8 +44,14 @@ public class ChatWebSocketController {
private User resolveUser(SimpMessageHeaderAccessor headerAccessor) {
Principal principal = headerAccessor.getUser();
if (principal != null) {
return userRepository.findByUsername(principal.getName())
if (principal instanceof org.springframework.security.authentication.UsernamePasswordAuthenticationToken authenticationToken
&& authenticationToken.getPrincipal() instanceof AppPrincipal appPrincipal) {
return userRepository.findById(appPrincipal.getUserId())
.orElseThrow(() -> new IllegalArgumentException("User not found"));
}
if (principal instanceof AppPrincipal appPrincipal) {
return userRepository.findById(appPrincipal.getUserId())
.orElseThrow(() -> new IllegalArgumentException("User not found"));
}
@@ -57,8 +64,11 @@ public class ChatWebSocketController {
}
String token = tokenHeader.startsWith("Bearer ") ? tokenHeader.substring(7) : tokenHeader;
String username = jwtUtil.extractUsername(token);
return userRepository.findByUsername(username)
.orElseThrow(() -> new IllegalArgumentException("User not found"));
Long userId = jwtUtil.extractUserId(token);
User user = userId == null ? null : userRepository.findById(userId).orElse(null);
if (user == null || user.getActive() == null || !user.getActive() || !jwtUtil.validateToken(token, user)) {
throw new IllegalArgumentException("User not found");
}
return user;
}
}