Fix chat auth

This commit is contained in:
2026-03-10 20:04:26 -06:00
parent 84c5f3c7b1
commit 2e401a544f
2 changed files with 42 additions and 5 deletions

View File

@@ -59,10 +59,22 @@ public class WebSocketAuthChannelInterceptor implements ChannelInterceptor {
Collections.singletonList(new SimpleGrantedAuthority("ROLE_" + user.getRole().name())) Collections.singletonList(new SimpleGrantedAuthority("ROLE_" + user.getRole().name()))
); );
accessor.setUser(authentication); accessor.setUser(authentication);
accessor.getSessionAttributes().put("user", authentication);
return message;
}
if (StompCommand.DISCONNECT.equals(command) || StompCommand.UNSUBSCRIBE.equals(command)) {
return message; return message;
} }
Principal principal = accessor.getUser(); Principal principal = accessor.getUser();
if (principal == null && accessor.getSessionAttributes() != null) {
Object sessionUser = accessor.getSessionAttributes().get("user");
if (sessionUser instanceof Principal storedPrincipal) {
accessor.setUser(storedPrincipal);
principal = storedPrincipal;
}
}
if (principal == null) { if (principal == null) {
throw new IllegalArgumentException("Unauthenticated websocket session"); throw new IllegalArgumentException("Unauthenticated websocket session");
} }

View File

@@ -4,36 +4,61 @@ import com.petshop.backend.dto.chat.MessageRequest;
import com.petshop.backend.dto.chat.MessageResponse; import com.petshop.backend.dto.chat.MessageResponse;
import com.petshop.backend.entity.User; import com.petshop.backend.entity.User;
import com.petshop.backend.repository.UserRepository; import com.petshop.backend.repository.UserRepository;
import com.petshop.backend.security.JwtUtil;
import com.petshop.backend.service.ChatRealtimeService; import com.petshop.backend.service.ChatRealtimeService;
import com.petshop.backend.service.ChatService; import com.petshop.backend.service.ChatService;
import jakarta.validation.Valid; import jakarta.validation.Valid;
import org.springframework.messaging.handler.annotation.DestinationVariable; import org.springframework.messaging.handler.annotation.DestinationVariable;
import org.springframework.messaging.handler.annotation.MessageMapping; import org.springframework.messaging.handler.annotation.MessageMapping;
import org.springframework.messaging.handler.annotation.Payload; import org.springframework.messaging.handler.annotation.Payload;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.annotation.SendToUser; import org.springframework.messaging.simp.annotation.SendToUser;
import org.springframework.security.core.Authentication;
import org.springframework.stereotype.Controller; import org.springframework.stereotype.Controller;
import java.security.Principal;
@Controller @Controller
public class ChatWebSocketController { public class ChatWebSocketController {
private final ChatService chatService; private final ChatService chatService;
private final ChatRealtimeService chatRealtimeService; private final ChatRealtimeService chatRealtimeService;
private final UserRepository userRepository; private final UserRepository userRepository;
private final JwtUtil jwtUtil;
public ChatWebSocketController(ChatService chatService, ChatRealtimeService chatRealtimeService, UserRepository userRepository) { public ChatWebSocketController(ChatService chatService, ChatRealtimeService chatRealtimeService, UserRepository userRepository, JwtUtil jwtUtil) {
this.chatService = chatService; this.chatService = chatService;
this.chatRealtimeService = chatRealtimeService; this.chatRealtimeService = chatRealtimeService;
this.userRepository = userRepository; this.userRepository = userRepository;
this.jwtUtil = jwtUtil;
} }
@MessageMapping("/chat/conversations/{id}/messages") @MessageMapping("/chat/conversations/{id}/messages")
@SendToUser("/queue/chat/errors") @SendToUser("/queue/chat/errors")
public void sendMessage(@DestinationVariable Long id, @Valid @Payload MessageRequest request, Authentication authentication) { public void sendMessage(@DestinationVariable Long id, @Valid @Payload MessageRequest request, SimpMessageHeaderAccessor headerAccessor) {
User user = userRepository.findByUsername(authentication.getName()) User user = resolveUser(headerAccessor);
.orElseThrow(() -> new IllegalArgumentException("User not found"));
MessageResponse message = chatService.sendMessage(id, user.getId(), user.getRole(), request); MessageResponse message = chatService.sendMessage(id, user.getId(), user.getRole(), request);
chatRealtimeService.publishMessage(id, message); chatRealtimeService.publishMessage(id, message);
chatRealtimeService.publishConversationUpdate(id); chatRealtimeService.publishConversationUpdate(id);
} }
private User resolveUser(SimpMessageHeaderAccessor headerAccessor) {
Principal principal = headerAccessor.getUser();
if (principal != null) {
return userRepository.findByUsername(principal.getName())
.orElseThrow(() -> new IllegalArgumentException("User not found"));
}
String tokenHeader = headerAccessor.getFirstNativeHeader("Authorization");
if (tokenHeader == null || tokenHeader.isBlank()) {
tokenHeader = headerAccessor.getFirstNativeHeader("token");
}
if (tokenHeader == null || tokenHeader.isBlank()) {
throw new IllegalArgumentException("User not authenticated");
}
String token = tokenHeader.startsWith("Bearer ") ? tokenHeader.substring(7) : tokenHeader;
String username = jwtUtil.extractUsername(token);
return userRepository.findByUsername(username)
.orElseThrow(() -> new IllegalArgumentException("User not found"));
}
} }