diff --git a/src/main/java/com/petshop/backend/config/TomcatPathToleranceConfig.java b/src/main/java/com/petshop/backend/config/TomcatPathToleranceConfig.java new file mode 100644 index 00000000..9a89c5ab --- /dev/null +++ b/src/main/java/com/petshop/backend/config/TomcatPathToleranceConfig.java @@ -0,0 +1,19 @@ +package com.petshop.backend.config; + +import org.springframework.boot.tomcat.servlet.TomcatServletWebServerFactory; +import org.springframework.boot.web.server.WebServerFactoryCustomizer; +import org.springframework.stereotype.Component; + +@Component +public class TomcatPathToleranceConfig implements WebServerFactoryCustomizer { + + @Override + public void customize(TomcatServletWebServerFactory factory) { + factory.addConnectorCustomizers(connector -> { + connector.setAllowBackslash(true); + connector.setEncodedReverseSolidusHandling("decode"); + connector.setProperty("relaxedPathChars", "\\"); + connector.setProperty("relaxedQueryChars", "\\"); + }); + } +} diff --git a/src/main/java/com/petshop/backend/config/TrailingSlashNormalizationFilter.java b/src/main/java/com/petshop/backend/config/TrailingSlashNormalizationFilter.java new file mode 100644 index 00000000..38ececb9 --- /dev/null +++ b/src/main/java/com/petshop/backend/config/TrailingSlashNormalizationFilter.java @@ -0,0 +1,91 @@ +package com.petshop.backend.config; + +import jakarta.servlet.FilterChain; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletRequestWrapper; +import jakarta.servlet.http.HttpServletResponse; +import org.springframework.core.Ordered; +import org.springframework.core.annotation.Order; +import org.springframework.stereotype.Component; +import org.springframework.web.filter.OncePerRequestFilter; + +import java.io.IOException; + +@Component +@Order(Ordered.HIGHEST_PRECEDENCE) +public class TrailingSlashNormalizationFilter extends OncePerRequestFilter { + + @Override + protected boolean shouldNotFilter(HttpServletRequest request) { + String requestUri = request.getRequestURI(); + if (requestUri == null || requestUri.isBlank()) { + return true; + } + return requestUri.equals(normalizePath(requestUri)); + } + + @Override + protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { + String normalizedUri = normalizePath(request.getRequestURI()); + String normalizedServletPath = normalizePath(request.getServletPath()); + String normalizedPathInfo = normalizePath(request.getPathInfo()); + + HttpServletRequestWrapper wrapper = new HttpServletRequestWrapper(request) { + @Override + public String getRequestURI() { + return normalizedUri; + } + + @Override + public StringBuffer getRequestURL() { + String original = super.getRequestURL().toString(); + int schemeSeparator = original.indexOf("://"); + int pathStart = schemeSeparator >= 0 ? original.indexOf('/', schemeSeparator + 3) : original.indexOf('/'); + if (pathStart < 0) { + return new StringBuffer(original); + } + String prefix = original.substring(0, pathStart); + return new StringBuffer(prefix + normalizedUri); + } + + @Override + public String getServletPath() { + return normalizedServletPath; + } + + @Override + public String getPathInfo() { + return normalizedPathInfo; + } + }; + + filterChain.doFilter(wrapper, response); + } + + private String normalizePath(String value) { + if (value == null) { + return null; + } + String normalized = value.replace('\\', '/'); + while (normalized.contains("//")) { + normalized = normalized.replace("//", "/"); + } + if (shouldLowercase(normalized)) { + normalized = normalized.toLowerCase(java.util.Locale.ROOT); + } + int end = normalized.length(); + while (end > 1 && normalized.charAt(end - 1) == '/') { + end--; + } + return normalized.substring(0, end); + } + + private boolean shouldLowercase(String path) { + String lower = path.toLowerCase(java.util.Locale.ROOT); + return lower.startsWith("/api/") + || lower.equals("/api") + || lower.startsWith("/ws/") + || lower.equals("/ws"); + } +} diff --git a/src/main/java/com/petshop/backend/config/WebSocketAuthChannelInterceptor.java b/src/main/java/com/petshop/backend/config/WebSocketAuthChannelInterceptor.java index b62dfe34..c7f23fc4 100644 --- a/src/main/java/com/petshop/backend/config/WebSocketAuthChannelInterceptor.java +++ b/src/main/java/com/petshop/backend/config/WebSocketAuthChannelInterceptor.java @@ -5,6 +5,7 @@ import com.petshop.backend.repository.UserRepository; import com.petshop.backend.security.AppPrincipal; import com.petshop.backend.security.JwtUtil; import com.petshop.backend.service.ChatService; +import io.jsonwebtoken.JwtException; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; import org.springframework.messaging.simp.stomp.StompCommand; @@ -14,7 +15,10 @@ import org.springframework.security.authentication.UsernamePasswordAuthenticatio import org.springframework.stereotype.Component; import java.security.Principal; +import java.util.LinkedHashMap; import java.util.List; +import java.util.Locale; +import java.util.Map; @Component public class WebSocketAuthChannelInterceptor implements ChannelInterceptor { @@ -45,7 +49,7 @@ public class WebSocketAuthChannelInterceptor implements ChannelInterceptor { throw new IllegalArgumentException("Missing websocket token"); } - Long userId = jwtUtil.extractUserId(token); + Long userId = extractUserId(token); User user = userId == null ? null : userRepository.findById(userId).orElse(null); if (user == null) { throw new IllegalArgumentException("User not found"); @@ -73,15 +77,15 @@ public class WebSocketAuthChannelInterceptor implements ChannelInterceptor { return message; } + if (StompCommand.DISCONNECT.equals(command) || StompCommand.UNSUBSCRIBE.equals(command)) { + return message; + } + User user = resolveUser(accessor.getUser(), accessor); if (user == null) { throw new IllegalArgumentException("Unauthenticated websocket session"); } - if (StompCommand.DISCONNECT.equals(command) || StompCommand.UNSUBSCRIBE.equals(command)) { - return message; - } - if (StompCommand.SUBSCRIBE.equals(command)) { authorizeSubscription(accessor.getDestination(), user); } else if (StompCommand.SEND.equals(command)) { @@ -118,15 +122,22 @@ public class WebSocketAuthChannelInterceptor implements ChannelInterceptor { return null; } - Long userId = jwtUtil.extractUserId(token); + Long userId = extractUserId(token); User user = userId == null ? null : userRepository.findById(userId).orElse(null); - if (user == null || user.getActive() == null || !user.getActive() || !jwtUtil.validateToken(token, user)) { + if (user == null) { throw new IllegalArgumentException("User not found"); } + if (user.getActive() == null || !user.getActive()) { + throw new IllegalArgumentException("User account is inactive"); + } + if (!jwtUtil.validateToken(token, user)) { + throw new IllegalArgumentException("Invalid websocket token"); + } return user; } private void authorizeSubscription(String destination, User user) { + destination = normalizeDestination(destination); if (destination == null || destination.startsWith("/user/queue/")) { return; } @@ -147,6 +158,7 @@ public class WebSocketAuthChannelInterceptor implements ChannelInterceptor { } private void authorizeSend(String destination, User user) { + destination = normalizeDestination(destination); Long conversationId = extractConversationId(destination, "/app/chat/conversations/"); if (conversationId != null && destination.endsWith("/messages") && chatService.hasConversationAccess(conversationId, user.getId(), user.getRole())) { return; @@ -175,13 +187,51 @@ public class WebSocketAuthChannelInterceptor implements ChannelInterceptor { private String firstHeader(StompHeaderAccessor accessor, String name) { List values = accessor.getNativeHeader(name); - return values == null || values.isEmpty() ? null : values.get(0); + if (values != null && !values.isEmpty()) { + return values.get(0); + } + for (String headerName : accessor.toNativeHeaderMap().keySet()) { + if (headerName.equalsIgnoreCase(name)) { + List alternateValues = accessor.getNativeHeader(headerName); + return alternateValues == null || alternateValues.isEmpty() ? null : alternateValues.get(0); + } + } + return null; } private String extractToken(String rawValue) { if (rawValue == null || rawValue.isBlank()) { return null; } - return rawValue.startsWith("Bearer ") ? rawValue.substring(7) : rawValue; + String normalized = rawValue.trim(); + return normalized.regionMatches(true, 0, "Bearer ", 0, 7) ? normalized.substring(7) : normalized; + } + + private String normalizeDestination(String destination) { + if (destination == null || destination.isBlank()) { + return destination; + } + String normalized = destination.replace('\\', '/'); + while (normalized.contains("//")) { + normalized = normalized.replace("//", "/"); + } + return normalized.toLowerCase(Locale.ROOT); + } + + private Long extractUserId(String token) { + try { + return jwtUtil.extractUserId(token); + } catch (JwtException | IllegalArgumentException ex) { + throw new IllegalArgumentException("Invalid websocket token: " + ex.getMessage(), ex); + } + } + + public Map buildErrorPayload(Exception ex, String destination, Principal principal) { + Map response = new LinkedHashMap<>(); + response.put("message", ex.getMessage() == null || ex.getMessage().isBlank() ? "WebSocket request failed" : ex.getMessage()); + response.put("details", ex.getClass().getSimpleName()); + response.put("destination", normalizeDestination(destination)); + response.put("authenticated", principal != null); + return response; } } diff --git a/src/main/java/com/petshop/backend/config/WebSocketConfig.java b/src/main/java/com/petshop/backend/config/WebSocketConfig.java index 27526bec..67dc1048 100644 --- a/src/main/java/com/petshop/backend/config/WebSocketConfig.java +++ b/src/main/java/com/petshop/backend/config/WebSocketConfig.java @@ -33,8 +33,13 @@ public class WebSocketConfig implements WebSocketMessageBrokerConfigurer { public void registerStompEndpoints(StompEndpointRegistry registry) { registry.addEndpoint("/ws/chat") .setAllowedOriginPatterns("*"); + registry.addEndpoint("/ws/chat/") + .setAllowedOriginPatterns("*"); registry.addEndpoint("/ws/chat-sockjs") .setAllowedOriginPatterns("*") .withSockJS(); + registry.addEndpoint("/ws/chat-sockjs/") + .setAllowedOriginPatterns("*") + .withSockJS(); } } diff --git a/src/main/java/com/petshop/backend/controller/ChatWebSocketController.java b/src/main/java/com/petshop/backend/controller/ChatWebSocketController.java index d7f1e3b6..ed0a3718 100644 --- a/src/main/java/com/petshop/backend/controller/ChatWebSocketController.java +++ b/src/main/java/com/petshop/backend/controller/ChatWebSocketController.java @@ -1,5 +1,6 @@ package com.petshop.backend.controller; +import com.petshop.backend.config.WebSocketAuthChannelInterceptor; import com.petshop.backend.dto.chat.MessageRequest; import com.petshop.backend.dto.chat.MessageResponse; import com.petshop.backend.entity.User; @@ -10,6 +11,7 @@ import com.petshop.backend.service.ChatRealtimeService; import com.petshop.backend.service.ChatService; import jakarta.validation.Valid; import org.springframework.messaging.handler.annotation.DestinationVariable; +import org.springframework.messaging.handler.annotation.MessageExceptionHandler; import org.springframework.messaging.handler.annotation.MessageMapping; import org.springframework.messaging.handler.annotation.Payload; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; @@ -17,6 +19,9 @@ import org.springframework.messaging.simp.annotation.SendToUser; import org.springframework.stereotype.Controller; import java.security.Principal; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; @Controller public class ChatWebSocketController { @@ -25,12 +30,20 @@ public class ChatWebSocketController { private final ChatRealtimeService chatRealtimeService; private final UserRepository userRepository; private final JwtUtil jwtUtil; + private final WebSocketAuthChannelInterceptor webSocketAuthChannelInterceptor; - public ChatWebSocketController(ChatService chatService, ChatRealtimeService chatRealtimeService, UserRepository userRepository, JwtUtil jwtUtil) { + public ChatWebSocketController( + ChatService chatService, + ChatRealtimeService chatRealtimeService, + UserRepository userRepository, + JwtUtil jwtUtil, + WebSocketAuthChannelInterceptor webSocketAuthChannelInterceptor + ) { this.chatService = chatService; this.chatRealtimeService = chatRealtimeService; this.userRepository = userRepository; this.jwtUtil = jwtUtil; + this.webSocketAuthChannelInterceptor = webSocketAuthChannelInterceptor; } @MessageMapping("/chat/conversations/{id}/messages") @@ -42,6 +55,12 @@ public class ChatWebSocketController { chatRealtimeService.publishConversationUpdate(id); } + @MessageExceptionHandler({IllegalArgumentException.class, RuntimeException.class}) + @SendToUser("/queue/chat/errors") + public Map handleMessageException(Exception ex, SimpMessageHeaderAccessor headerAccessor) { + return webSocketAuthChannelInterceptor.buildErrorPayload(ex, headerAccessor.getDestination(), headerAccessor.getUser()); + } + private User resolveUser(SimpMessageHeaderAccessor headerAccessor) { Principal principal = headerAccessor.getUser(); if (principal instanceof org.springframework.security.authentication.UsernamePasswordAuthenticationToken authenticationToken @@ -55,20 +74,50 @@ public class ChatWebSocketController { .orElseThrow(() -> new IllegalArgumentException("User not found")); } - String tokenHeader = headerAccessor.getFirstNativeHeader("Authorization"); + String tokenHeader = firstHeader(headerAccessor, "Authorization"); if (tokenHeader == null || tokenHeader.isBlank()) { - tokenHeader = headerAccessor.getFirstNativeHeader("token"); + tokenHeader = firstHeader(headerAccessor, "token"); } if (tokenHeader == null || tokenHeader.isBlank()) { throw new IllegalArgumentException("User not authenticated"); } - String token = tokenHeader.startsWith("Bearer ") ? tokenHeader.substring(7) : tokenHeader; - Long userId = jwtUtil.extractUserId(token); + String token = extractToken(tokenHeader); + Long userId; + try { + userId = jwtUtil.extractUserId(token); + } catch (RuntimeException ex) { + throw new IllegalArgumentException("Invalid websocket token", ex); + } User user = userId == null ? null : userRepository.findById(userId).orElse(null); - if (user == null || user.getActive() == null || !user.getActive() || !jwtUtil.validateToken(token, user)) { + if (user == null) { throw new IllegalArgumentException("User not found"); } + if (user.getActive() == null || !user.getActive()) { + throw new IllegalArgumentException("User account is inactive"); + } + if (!jwtUtil.validateToken(token, user)) { + throw new IllegalArgumentException("Invalid websocket token"); + } return user; } + + private String firstHeader(SimpMessageHeaderAccessor headerAccessor, String name) { + List values = headerAccessor.getNativeHeader(name); + if (values != null && !values.isEmpty()) { + return values.get(0); + } + Map> headers = headerAccessor.toNativeHeaderMap(); + for (Map.Entry> entry : headers.entrySet()) { + if (entry.getKey().equalsIgnoreCase(name)) { + return entry.getValue() == null || entry.getValue().isEmpty() ? null : entry.getValue().get(0); + } + } + return null; + } + + private String extractToken(String rawValue) { + String normalized = rawValue.trim(); + return normalized.regionMatches(true, 0, "Bearer ", 0, 7) ? normalized.substring(7) : normalized; + } } diff --git a/src/main/java/com/petshop/backend/exception/ApiErrorResponder.java b/src/main/java/com/petshop/backend/exception/ApiErrorResponder.java new file mode 100644 index 00000000..39f4d66c --- /dev/null +++ b/src/main/java/com/petshop/backend/exception/ApiErrorResponder.java @@ -0,0 +1,32 @@ +package com.petshop.backend.exception; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.json.JsonMapper; +import jakarta.servlet.http.HttpServletResponse; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.stereotype.Component; + +import java.io.IOException; +import java.time.LocalDateTime; + +@Component +public class ApiErrorResponder { + + private final ObjectMapper objectMapper = JsonMapper.builder().findAndAddModules().build(); + + public void write(HttpServletResponse response, HttpStatus status, String message, String details, String path) throws IOException { + response.setStatus(status.value()); + response.setContentType(MediaType.APPLICATION_JSON_VALUE); + objectMapper.writeValue( + response.getWriter(), + new ApiErrorResponse( + status.value(), + message, + details, + path, + LocalDateTime.now() + ) + ); + } +} diff --git a/src/main/java/com/petshop/backend/exception/ApiErrorResponse.java b/src/main/java/com/petshop/backend/exception/ApiErrorResponse.java new file mode 100644 index 00000000..b3aea542 --- /dev/null +++ b/src/main/java/com/petshop/backend/exception/ApiErrorResponse.java @@ -0,0 +1,12 @@ +package com.petshop.backend.exception; + +import java.time.LocalDateTime; + +public record ApiErrorResponse( + int status, + String message, + String details, + String path, + LocalDateTime timestamp +) { +} diff --git a/src/main/java/com/petshop/backend/exception/GlobalExceptionHandler.java b/src/main/java/com/petshop/backend/exception/GlobalExceptionHandler.java index e055ea44..b41f8789 100644 --- a/src/main/java/com/petshop/backend/exception/GlobalExceptionHandler.java +++ b/src/main/java/com/petshop/backend/exception/GlobalExceptionHandler.java @@ -1,14 +1,15 @@ package com.petshop.backend.exception; +import jakarta.servlet.http.HttpServletRequest; +import org.springframework.dao.DataIntegrityViolationException; import org.springframework.http.HttpStatus; import org.springframework.http.ResponseEntity; -import org.springframework.dao.DataIntegrityViolationException; -import org.springframework.web.method.annotation.MethodArgumentTypeMismatchException; -import org.springframework.web.server.ResponseStatusException; import org.springframework.validation.FieldError; import org.springframework.web.bind.MethodArgumentNotValidException; import org.springframework.web.bind.annotation.ExceptionHandler; import org.springframework.web.bind.annotation.RestControllerAdvice; +import org.springframework.web.method.annotation.MethodArgumentTypeMismatchException; +import org.springframework.web.server.ResponseStatusException; import java.time.LocalDateTime; import java.util.HashMap; @@ -18,27 +19,17 @@ import java.util.Map; public class GlobalExceptionHandler { @ExceptionHandler(ResourceNotFoundException.class) - public ResponseEntity handleResourceNotFound(ResourceNotFoundException ex) { - ErrorResponse error = new ErrorResponse( - HttpStatus.NOT_FOUND.value(), - ex.getMessage(), - LocalDateTime.now() - ); - return ResponseEntity.status(HttpStatus.NOT_FOUND).body(error); + public ResponseEntity handleResourceNotFound(ResourceNotFoundException ex, HttpServletRequest request) { + return buildErrorResponse(HttpStatus.NOT_FOUND, ex.getMessage(), ex, request); } @ExceptionHandler(BusinessException.class) - public ResponseEntity handleBusinessException(BusinessException ex) { - ErrorResponse error = new ErrorResponse( - HttpStatus.BAD_REQUEST.value(), - ex.getMessage(), - LocalDateTime.now() - ); - return ResponseEntity.status(HttpStatus.BAD_REQUEST).body(error); + public ResponseEntity handleBusinessException(BusinessException ex, HttpServletRequest request) { + return buildErrorResponse(HttpStatus.BAD_REQUEST, ex.getMessage(), ex, request); } @ExceptionHandler(MethodArgumentNotValidException.class) - public ResponseEntity> handleValidationExceptions(MethodArgumentNotValidException ex) { + public ResponseEntity> handleValidationExceptions(MethodArgumentNotValidException ex, HttpServletRequest request) { Map errors = new HashMap<>(); ex.getBindingResult().getAllErrors().forEach((error) -> { String fieldName = ((FieldError) error).getField(); @@ -48,72 +39,74 @@ public class GlobalExceptionHandler { Map response = new HashMap<>(); response.put("status", HttpStatus.BAD_REQUEST.value()); + response.put("message", "Validation failed"); response.put("errors", errors); + response.put("details", buildDetails(ex)); + response.put("path", request.getRequestURI()); response.put("timestamp", LocalDateTime.now()); return ResponseEntity.status(HttpStatus.BAD_REQUEST).body(response); } @ExceptionHandler(org.springframework.security.access.AccessDeniedException.class) - public ResponseEntity handleAccessDeniedException(org.springframework.security.access.AccessDeniedException ex) { - ErrorResponse error = new ErrorResponse( - HttpStatus.FORBIDDEN.value(), - ex.getMessage(), - LocalDateTime.now() - ); - return ResponseEntity.status(HttpStatus.FORBIDDEN).body(error); + public ResponseEntity handleAccessDeniedException(org.springframework.security.access.AccessDeniedException ex, HttpServletRequest request) { + return buildErrorResponse(HttpStatus.FORBIDDEN, ex.getMessage(), ex, request); } @ExceptionHandler(IllegalArgumentException.class) - public ResponseEntity handleIllegalArgumentException(IllegalArgumentException ex) { - ErrorResponse error = new ErrorResponse( - HttpStatus.BAD_REQUEST.value(), - ex.getMessage(), - LocalDateTime.now() - ); - return ResponseEntity.status(HttpStatus.BAD_REQUEST).body(error); + public ResponseEntity handleIllegalArgumentException(IllegalArgumentException ex, HttpServletRequest request) { + return buildErrorResponse(HttpStatus.BAD_REQUEST, ex.getMessage(), ex, request); } @ExceptionHandler(DataIntegrityViolationException.class) - public ResponseEntity handleDataIntegrityViolationException(DataIntegrityViolationException ex) { - ErrorResponse error = new ErrorResponse( - HttpStatus.BAD_REQUEST.value(), - "Operation violates existing data relationships", - LocalDateTime.now() - ); - return ResponseEntity.status(HttpStatus.BAD_REQUEST).body(error); + public ResponseEntity handleDataIntegrityViolationException(DataIntegrityViolationException ex, HttpServletRequest request) { + return buildErrorResponse(HttpStatus.BAD_REQUEST, "Operation violates existing data relationships", ex, request); } @ExceptionHandler(MethodArgumentTypeMismatchException.class) - public ResponseEntity handleMethodArgumentTypeMismatchException(MethodArgumentTypeMismatchException ex) { - ErrorResponse error = new ErrorResponse( - HttpStatus.BAD_REQUEST.value(), - "Invalid value for parameter: " + ex.getName(), - LocalDateTime.now() - ); - return ResponseEntity.status(HttpStatus.BAD_REQUEST).body(error); + public ResponseEntity handleMethodArgumentTypeMismatchException(MethodArgumentTypeMismatchException ex, HttpServletRequest request) { + String message = "Invalid value for parameter: " + ex.getName(); + if (ex.getValue() != null) { + message += " (" + ex.getValue() + ")"; + } + return buildErrorResponse(HttpStatus.BAD_REQUEST, message, ex, request); } @ExceptionHandler(ResponseStatusException.class) - public ResponseEntity handleResponseStatusException(ResponseStatusException ex) { + public ResponseEntity handleResponseStatusException(ResponseStatusException ex, HttpServletRequest request) { String message = ex.getReason() != null ? ex.getReason() : ex.getMessage(); - ErrorResponse error = new ErrorResponse( - ex.getStatusCode().value(), - message, - LocalDateTime.now() - ); - return ResponseEntity.status(ex.getStatusCode()).body(error); + return buildErrorResponse(HttpStatus.valueOf(ex.getStatusCode().value()), message, ex, request); } @ExceptionHandler(Exception.class) - public ResponseEntity handleGenericException(Exception ex) { - ErrorResponse error = new ErrorResponse( - HttpStatus.INTERNAL_SERVER_ERROR.value(), - "An unexpected error occurred: " + ex.getMessage(), + public ResponseEntity handleGenericException(Exception ex, HttpServletRequest request) { + String message = ex.getMessage() == null || ex.getMessage().isBlank() + ? "Unexpected server error" + : ex.getMessage(); + return buildErrorResponse(HttpStatus.INTERNAL_SERVER_ERROR, message, ex, request); + } + + private ResponseEntity buildErrorResponse(HttpStatus status, String message, Exception ex, HttpServletRequest request) { + ApiErrorResponse error = new ApiErrorResponse( + status.value(), + message, + buildDetails(ex), + request.getRequestURI(), LocalDateTime.now() ); - return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR).body(error); + return ResponseEntity.status(status).body(error); + } + + private String buildDetails(Exception ex) { + Throwable rootCause = ex; + while (rootCause.getCause() != null && rootCause.getCause() != rootCause) { + rootCause = rootCause.getCause(); + } + + String rootMessage = rootCause.getMessage(); + if (rootMessage == null || rootMessage.isBlank()) { + return rootCause.getClass().getSimpleName(); + } + return rootCause.getClass().getSimpleName() + ": " + rootMessage; } } - -record ErrorResponse(int status, String message, LocalDateTime timestamp) {} diff --git a/src/main/java/com/petshop/backend/security/JwtAuthenticationFilter.java b/src/main/java/com/petshop/backend/security/JwtAuthenticationFilter.java index d804a3b7..a4caaaef 100644 --- a/src/main/java/com/petshop/backend/security/JwtAuthenticationFilter.java +++ b/src/main/java/com/petshop/backend/security/JwtAuthenticationFilter.java @@ -1,6 +1,7 @@ package com.petshop.backend.security; import com.petshop.backend.entity.User; +import com.petshop.backend.exception.ApiErrorResponder; import com.petshop.backend.repository.UserRepository; import io.jsonwebtoken.JwtException; import jakarta.servlet.FilterChain; @@ -15,17 +16,17 @@ import org.springframework.stereotype.Component; import org.springframework.web.filter.OncePerRequestFilter; import java.io.IOException; -import java.time.LocalDateTime; - @Component public class JwtAuthenticationFilter extends OncePerRequestFilter { private final JwtUtil jwtUtil; private final UserRepository userRepository; + private final ApiErrorResponder apiErrorResponder; - public JwtAuthenticationFilter(JwtUtil jwtUtil, UserRepository userRepository) { + public JwtAuthenticationFilter(JwtUtil jwtUtil, UserRepository userRepository, ApiErrorResponder apiErrorResponder) { this.jwtUtil = jwtUtil; this.userRepository = userRepository; + this.apiErrorResponder = apiErrorResponder; } @Override @@ -46,18 +47,18 @@ public class JwtAuthenticationFilter extends OncePerRequestFilter { try { userId = jwtUtil.extractUserId(jwt); } catch (JwtException | IllegalArgumentException ex) { - writeUnauthorized(response, "Invalid or expired token"); + writeUnauthorized(request, response, "Invalid or expired token", ex); return; } if (userId != null && SecurityContextHolder.getContext().getAuthentication() == null) { User user = userRepository.findById(userId).orElse(null); if (user == null || user.getActive() == null || !user.getActive()) { - writeUnauthorized(response, "User account is inactive"); + writeUnauthorized(request, response, "User account is inactive", null); return; } if (!jwtUtil.validateToken(jwt, user)) { - writeUnauthorized(response, "Invalid or expired token"); + writeUnauthorized(request, response, "Invalid or expired token", null); return; } @@ -78,11 +79,8 @@ public class JwtAuthenticationFilter extends OncePerRequestFilter { filterChain.doFilter(request, response); } - private void writeUnauthorized(HttpServletResponse response, String message) throws IOException { - response.setStatus(HttpServletResponse.SC_UNAUTHORIZED); - response.setContentType("application/json"); - response.getWriter().write( - "{\"status\":401,\"message\":\"" + message + "\",\"timestamp\":\"" + LocalDateTime.now() + "\"}" - ); + private void writeUnauthorized(HttpServletRequest request, HttpServletResponse response, String message, Exception ex) throws IOException { + String details = ex == null ? message : ex.getClass().getSimpleName() + ": " + ex.getMessage(); + apiErrorResponder.write(response, org.springframework.http.HttpStatus.UNAUTHORIZED, message, details, request.getRequestURI()); } } diff --git a/src/main/java/com/petshop/backend/security/RestAccessDeniedHandler.java b/src/main/java/com/petshop/backend/security/RestAccessDeniedHandler.java new file mode 100644 index 00000000..2ef240e9 --- /dev/null +++ b/src/main/java/com/petshop/backend/security/RestAccessDeniedHandler.java @@ -0,0 +1,33 @@ +package com.petshop.backend.security; + +import com.petshop.backend.exception.ApiErrorResponder; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import org.springframework.http.HttpStatus; +import org.springframework.security.access.AccessDeniedException; +import org.springframework.security.web.access.AccessDeniedHandler; +import org.springframework.stereotype.Component; + +import java.io.IOException; + +@Component +public class RestAccessDeniedHandler implements AccessDeniedHandler { + + private final ApiErrorResponder apiErrorResponder; + + public RestAccessDeniedHandler(ApiErrorResponder apiErrorResponder) { + this.apiErrorResponder = apiErrorResponder; + } + + @Override + public void handle(HttpServletRequest request, HttpServletResponse response, AccessDeniedException accessDeniedException) throws IOException, ServletException { + apiErrorResponder.write( + response, + HttpStatus.FORBIDDEN, + "Access Denied", + accessDeniedException.getClass().getSimpleName() + ": " + accessDeniedException.getMessage(), + request.getRequestURI() + ); + } +} diff --git a/src/main/java/com/petshop/backend/security/RestAuthenticationEntryPoint.java b/src/main/java/com/petshop/backend/security/RestAuthenticationEntryPoint.java new file mode 100644 index 00000000..2ae541b4 --- /dev/null +++ b/src/main/java/com/petshop/backend/security/RestAuthenticationEntryPoint.java @@ -0,0 +1,33 @@ +package com.petshop.backend.security; + +import com.petshop.backend.exception.ApiErrorResponder; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import org.springframework.http.HttpStatus; +import org.springframework.security.core.AuthenticationException; +import org.springframework.security.web.AuthenticationEntryPoint; +import org.springframework.stereotype.Component; + +import java.io.IOException; + +@Component +public class RestAuthenticationEntryPoint implements AuthenticationEntryPoint { + + private final ApiErrorResponder apiErrorResponder; + + public RestAuthenticationEntryPoint(ApiErrorResponder apiErrorResponder) { + this.apiErrorResponder = apiErrorResponder; + } + + @Override + public void commence(HttpServletRequest request, HttpServletResponse response, AuthenticationException authException) throws IOException, ServletException { + apiErrorResponder.write( + response, + HttpStatus.UNAUTHORIZED, + "Authentication required", + authException.getClass().getSimpleName() + ": " + authException.getMessage(), + request.getRequestURI() + ); + } +} diff --git a/src/main/java/com/petshop/backend/security/SecurityConfig.java b/src/main/java/com/petshop/backend/security/SecurityConfig.java index 0a893c18..00ce63f8 100644 --- a/src/main/java/com/petshop/backend/security/SecurityConfig.java +++ b/src/main/java/com/petshop/backend/security/SecurityConfig.java @@ -25,10 +25,19 @@ public class SecurityConfig { private final JwtAuthenticationFilter jwtAuthFilter; private final UserDetailsService userDetailsService; + private final RestAuthenticationEntryPoint restAuthenticationEntryPoint; + private final RestAccessDeniedHandler restAccessDeniedHandler; - public SecurityConfig(JwtAuthenticationFilter jwtAuthFilter, UserDetailsService userDetailsService) { + public SecurityConfig( + JwtAuthenticationFilter jwtAuthFilter, + UserDetailsService userDetailsService, + RestAuthenticationEntryPoint restAuthenticationEntryPoint, + RestAccessDeniedHandler restAccessDeniedHandler + ) { this.jwtAuthFilter = jwtAuthFilter; this.userDetailsService = userDetailsService; + this.restAuthenticationEntryPoint = restAuthenticationEntryPoint; + this.restAccessDeniedHandler = restAccessDeniedHandler; } @Bean @@ -47,6 +56,10 @@ public class SecurityConfig { .requestMatchers(HttpMethod.GET, "/api/v1/appointments/availability").permitAll() .anyRequest().authenticated() ) + .exceptionHandling(ex -> ex + .authenticationEntryPoint(restAuthenticationEntryPoint) + .accessDeniedHandler(restAccessDeniedHandler) + ) .sessionManagement(session -> session.sessionCreationPolicy(SessionCreationPolicy.STATELESS)) .authenticationProvider(daoAuthenticationProvider()) .addFilterBefore(jwtAuthFilter, UsernamePasswordAuthenticationFilter.class); diff --git a/src/main/java/com/petshop/backend/service/EmployeeService.java b/src/main/java/com/petshop/backend/service/EmployeeService.java index 2199e441..baf83bb8 100644 --- a/src/main/java/com/petshop/backend/service/EmployeeService.java +++ b/src/main/java/com/petshop/backend/service/EmployeeService.java @@ -126,22 +126,22 @@ public class EmployeeService { } private EmployeeResponse mapToResponse(Employee employee) { - User user = requireLinkedUser(employee); + User user = employee.getUserId() == null ? null : userRepository.findById(employee.getUserId()).orElse(null); return mapToResponse(employee, user); } private EmployeeResponse mapToResponse(Employee employee, User user) { EmployeeResponse response = new EmployeeResponse(); response.setEmployeeId(employee.getEmployeeId()); - response.setUserId(user.getId()); - response.setUsername(user.getUsername()); + response.setUserId(user != null ? user.getId() : employee.getUserId()); + response.setUsername(user != null ? user.getUsername() : null); response.setFirstName(employee.getFirstName()); response.setLastName(employee.getLastName()); - response.setFullName(user.getFullName()); - response.setEmail(user.getEmail()); - response.setPhone(user.getPhone()); - response.setRole(user.getRole().name()); - response.setActive(user.getActive()); + response.setFullName(user != null ? user.getFullName() : fullName(employee)); + response.setEmail(user != null ? user.getEmail() : employee.getEmail()); + response.setPhone(user != null ? user.getPhone() : null); + response.setRole(user != null ? user.getRole().name() : normalizeRole(employee.getRole())); + response.setActive(user != null ? user.getActive() : employee.getIsActive()); response.setCreatedAt(employee.getCreatedAt()); response.setUpdatedAt(employee.getUpdatedAt()); return response; @@ -165,6 +165,14 @@ public class EmployeeService { return (request.getFirstName().trim() + " " + request.getLastName().trim()).trim(); } + private String fullName(Employee employee) { + return (employee.getFirstName().trim() + " " + employee.getLastName().trim()).trim(); + } + + private String normalizeRole(String role) { + return role == null ? null : role.trim().toUpperCase(java.util.Locale.ROOT); + } + private String trimToNull(String value) { if (value == null) { return null; diff --git a/src/test/java/com/petshop/backend/security/JwtAuthenticationFilterTest.java b/src/test/java/com/petshop/backend/security/JwtAuthenticationFilterTest.java index fa8b429c..4d7ce01a 100644 --- a/src/test/java/com/petshop/backend/security/JwtAuthenticationFilterTest.java +++ b/src/test/java/com/petshop/backend/security/JwtAuthenticationFilterTest.java @@ -1,6 +1,7 @@ package com.petshop.backend.security; import com.petshop.backend.entity.User; +import com.petshop.backend.exception.ApiErrorResponder; import com.petshop.backend.repository.UserRepository; import jakarta.servlet.FilterChain; import org.junit.jupiter.api.AfterEach; @@ -42,7 +43,7 @@ class JwtAuthenticationFilterTest { User user = buildUser(); String token = jwtUtil.generateToken(user); AtomicBoolean chainCalled = new AtomicBoolean(false); - JwtAuthenticationFilter filter = new JwtAuthenticationFilter(jwtUtil, userRepositoryFor(user)); + JwtAuthenticationFilter filter = new JwtAuthenticationFilter(jwtUtil, userRepositoryFor(user), new ApiErrorResponder()); MockHttpServletRequest request = new MockHttpServletRequest(); request.addHeader("Authorization", "Bearer " + token); @@ -63,7 +64,7 @@ class JwtAuthenticationFilterTest { User user = buildUser(); user.setActive(false); String token = jwtUtil.generateToken(user); - JwtAuthenticationFilter filter = new JwtAuthenticationFilter(jwtUtil, userRepositoryFor(user)); + JwtAuthenticationFilter filter = new JwtAuthenticationFilter(jwtUtil, userRepositoryFor(user), new ApiErrorResponder()); MockHttpServletRequest request = new MockHttpServletRequest(); request.addHeader("Authorization", "Bearer " + token); @@ -73,6 +74,8 @@ class JwtAuthenticationFilterTest { }); assertEquals(401, response.getStatus()); + assertTrue(response.getContentAsString().contains("\"message\":\"User account is inactive\"")); + assertTrue(response.getContentAsString().contains("\"path\":\"\"")); assertNull(SecurityContextHolder.getContext().getAuthentication()); } @@ -81,7 +84,7 @@ class JwtAuthenticationFilterTest { User user = buildUser(); String token = jwtUtil.generateToken(user); user.setTokenVersion(4); - JwtAuthenticationFilter filter = new JwtAuthenticationFilter(jwtUtil, userRepositoryFor(user)); + JwtAuthenticationFilter filter = new JwtAuthenticationFilter(jwtUtil, userRepositoryFor(user), new ApiErrorResponder()); MockHttpServletRequest request = new MockHttpServletRequest(); request.addHeader("Authorization", "Bearer " + token); @@ -91,6 +94,7 @@ class JwtAuthenticationFilterTest { }); assertEquals(401, response.getStatus()); + assertTrue(response.getContentAsString().contains("\"message\":\"Invalid or expired token\"")); assertNull(SecurityContextHolder.getContext().getAuthentication()); }