From 8fdd28cbbd92114ffdb6c459e70d7f050ab5d9e8 Mon Sep 17 00:00:00 2001 From: Harkamal Randhawa Date: Thu, 12 Mar 2026 17:44:39 -0600 Subject: [PATCH] Use claims auth --- .../backend/controller/AuthController.java | 17 +-- .../backend/security/AppPrincipal.java | 51 +++++++ .../security/JwtAuthenticationFilter.java | 62 +++++---- .../com/petshop/backend/security/JwtUtil.java | 35 ++++- .../security/JwtAuthenticationFilterTest.java | 129 ++++++++++++++++++ .../petshop/backend/security/JwtUtilTest.java | 59 ++++++++ 6 files changed, 304 insertions(+), 49 deletions(-) create mode 100644 src/main/java/com/petshop/backend/security/AppPrincipal.java create mode 100644 src/test/java/com/petshop/backend/security/JwtAuthenticationFilterTest.java create mode 100644 src/test/java/com/petshop/backend/security/JwtUtilTest.java diff --git a/src/main/java/com/petshop/backend/controller/AuthController.java b/src/main/java/com/petshop/backend/controller/AuthController.java index bf907257..0aca7d24 100644 --- a/src/main/java/com/petshop/backend/controller/AuthController.java +++ b/src/main/java/com/petshop/backend/controller/AuthController.java @@ -24,7 +24,6 @@ import org.springframework.security.authentication.InternalAuthenticationService import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContextHolder; -import org.springframework.security.core.userdetails.UserDetails; import org.springframework.security.core.userdetails.UsernameNotFoundException; import org.springframework.security.crypto.password.PasswordEncoder; import org.springframework.web.bind.annotation.*; @@ -89,13 +88,7 @@ public class AuthController { // Create or link customer record userBusinessLinkageService.ensureLinkedCustomer(savedUser); - UserDetails userDetails = new org.springframework.security.core.userdetails.User( - savedUser.getUsername(), - savedUser.getPassword(), - java.util.Collections.emptyList() - ); - - String token = jwtUtil.generateToken(userDetails); + String token = jwtUtil.generateToken(savedUser); return ResponseEntity.status(HttpStatus.CREATED).body(new RegisterResponse( savedUser.getId(), @@ -116,13 +109,7 @@ public class AuthController { User user = userRepository.findByUsername(request.getUsername()) .orElseThrow(() -> new UsernameNotFoundException("User not found")); - UserDetails userDetails = new org.springframework.security.core.userdetails.User( - user.getUsername(), - user.getPassword(), - java.util.Collections.emptyList() - ); - - String token = jwtUtil.generateToken(userDetails); + String token = jwtUtil.generateToken(user); return ResponseEntity.ok(new LoginResponse( token, diff --git a/src/main/java/com/petshop/backend/security/AppPrincipal.java b/src/main/java/com/petshop/backend/security/AppPrincipal.java new file mode 100644 index 00000000..30ceca66 --- /dev/null +++ b/src/main/java/com/petshop/backend/security/AppPrincipal.java @@ -0,0 +1,51 @@ +package com.petshop.backend.security; + +import com.petshop.backend.entity.User; +import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.authority.SimpleGrantedAuthority; + +import java.security.Principal; +import java.util.Collection; +import java.util.List; + +public class AppPrincipal implements Principal { + + private final Long userId; + private final String username; + private final User.Role role; + private final Integer tokenVersion; + private final List authorities; + + public AppPrincipal(Long userId, String username, User.Role role, Integer tokenVersion) { + this.userId = userId; + this.username = username; + this.role = role; + this.tokenVersion = tokenVersion; + this.authorities = List.of(new SimpleGrantedAuthority("ROLE_" + role.name())); + } + + public Long getUserId() { + return userId; + } + + @Override + public String getName() { + return username; + } + + public String getUsername() { + return username; + } + + public User.Role getRole() { + return role; + } + + public Integer getTokenVersion() { + return tokenVersion; + } + + public Collection getAuthorities() { + return authorities; + } +} diff --git a/src/main/java/com/petshop/backend/security/JwtAuthenticationFilter.java b/src/main/java/com/petshop/backend/security/JwtAuthenticationFilter.java index d79c88f6..8d311f74 100644 --- a/src/main/java/com/petshop/backend/security/JwtAuthenticationFilter.java +++ b/src/main/java/com/petshop/backend/security/JwtAuthenticationFilter.java @@ -1,15 +1,14 @@ package com.petshop.backend.security; +import com.petshop.backend.entity.User; +import com.petshop.backend.repository.UserRepository; import jakarta.servlet.FilterChain; import jakarta.servlet.ServletException; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; import org.springframework.lang.NonNull; -import org.springframework.security.authentication.DisabledException; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.core.context.SecurityContextHolder; -import org.springframework.security.core.userdetails.UserDetails; -import org.springframework.security.core.userdetails.UserDetailsService; import org.springframework.security.web.authentication.WebAuthenticationDetailsSource; import org.springframework.stereotype.Component; import org.springframework.web.filter.OncePerRequestFilter; @@ -21,11 +20,11 @@ import java.time.LocalDateTime; public class JwtAuthenticationFilter extends OncePerRequestFilter { private final JwtUtil jwtUtil; - private final UserDetailsService userDetailsService; + private final UserRepository userRepository; - public JwtAuthenticationFilter(JwtUtil jwtUtil, UserDetailsService userDetailsService) { + public JwtAuthenticationFilter(JwtUtil jwtUtil, UserRepository userRepository) { this.jwtUtil = jwtUtil; - this.userDetailsService = userDetailsService; + this.userRepository = userRepository; } @Override @@ -36,38 +35,47 @@ public class JwtAuthenticationFilter extends OncePerRequestFilter { ) throws ServletException, IOException { final String authHeader = request.getHeader("Authorization"); final String jwt; - final String username; - if (authHeader == null || !authHeader.startsWith("Bearer ")) { filterChain.doFilter(request, response); return; } jwt = authHeader.substring(7); - username = jwtUtil.extractUsername(jwt); + Long userId = jwtUtil.extractUserId(jwt); - if (username != null && SecurityContextHolder.getContext().getAuthentication() == null) { - UserDetails userDetails; - try { - userDetails = userDetailsService.loadUserByUsername(username); - } catch (DisabledException ex) { - response.setStatus(HttpServletResponse.SC_UNAUTHORIZED); - response.setContentType("application/json"); - response.getWriter().write( - "{\"status\":401,\"message\":\"" + ex.getMessage() + "\",\"timestamp\":\"" + LocalDateTime.now() + "\"}" - ); + if (userId != null && SecurityContextHolder.getContext().getAuthentication() == null) { + User user = userRepository.findById(userId).orElse(null); + if (user == null || user.getActive() == null || !user.getActive()) { + writeUnauthorized(response, "User account is inactive"); return; } - if (jwtUtil.validateToken(jwt, userDetails)) { - UsernamePasswordAuthenticationToken authToken = new UsernamePasswordAuthenticationToken( - userDetails, - null, - userDetails.getAuthorities() - ); - authToken.setDetails(new WebAuthenticationDetailsSource().buildDetails(request)); - SecurityContextHolder.getContext().setAuthentication(authToken); + if (!jwtUtil.validateToken(jwt, user)) { + writeUnauthorized(response, "Invalid or expired token"); + return; } + + AppPrincipal principal = new AppPrincipal( + user.getId(), + user.getUsername(), + user.getRole(), + user.getTokenVersion() + ); + UsernamePasswordAuthenticationToken authToken = new UsernamePasswordAuthenticationToken( + principal, + null, + principal.getAuthorities() + ); + authToken.setDetails(new WebAuthenticationDetailsSource().buildDetails(request)); + SecurityContextHolder.getContext().setAuthentication(authToken); } filterChain.doFilter(request, response); } + + private void writeUnauthorized(HttpServletResponse response, String message) throws IOException { + response.setStatus(HttpServletResponse.SC_UNAUTHORIZED); + response.setContentType("application/json"); + response.getWriter().write( + "{\"status\":401,\"message\":\"" + message + "\",\"timestamp\":\"" + LocalDateTime.now() + "\"}" + ); + } } diff --git a/src/main/java/com/petshop/backend/security/JwtUtil.java b/src/main/java/com/petshop/backend/security/JwtUtil.java index 9381369b..3d4541bc 100644 --- a/src/main/java/com/petshop/backend/security/JwtUtil.java +++ b/src/main/java/com/petshop/backend/security/JwtUtil.java @@ -1,10 +1,10 @@ package com.petshop.backend.security; +import com.petshop.backend.entity.User; import io.jsonwebtoken.Claims; import io.jsonwebtoken.Jwts; import io.jsonwebtoken.security.Keys; import org.springframework.beans.factory.annotation.Value; -import org.springframework.security.core.userdetails.UserDetails; import org.springframework.stereotype.Component; import javax.crypto.SecretKey; @@ -28,7 +28,20 @@ public class JwtUtil { } public String extractUsername(String token) { - return extractClaim(token, Claims::getSubject); + return extractAllClaims(token).get("username", String.class); + } + + public Long extractUserId(String token) { + return Long.parseLong(extractClaim(token, Claims::getSubject)); + } + + public String extractRole(String token) { + return extractAllClaims(token).get("role", String.class); + } + + public Integer extractTokenVersion(String token) { + Number tokenVersion = extractAllClaims(token).get("tokenVersion", Number.class); + return tokenVersion == null ? null : tokenVersion.intValue(); } public Date extractExpiration(String token) { @@ -52,9 +65,12 @@ public class JwtUtil { return extractExpiration(token).before(new Date()); } - public String generateToken(UserDetails userDetails) { + public String generateToken(User user) { Map claims = new HashMap<>(); - return createToken(claims, userDetails.getUsername()); + claims.put("username", user.getUsername()); + claims.put("role", user.getRole().name()); + claims.put("tokenVersion", user.getTokenVersion()); + return createToken(claims, user.getId().toString()); } private String createToken(Map claims, String subject) { @@ -67,8 +83,13 @@ public class JwtUtil { .compact(); } - public Boolean validateToken(String token, UserDetails userDetails) { - final String username = extractUsername(token); - return (username.equals(userDetails.getUsername()) && !isTokenExpired(token)); + public Boolean validateToken(String token, User user) { + Long userId = extractUserId(token); + String role = extractRole(token); + Integer tokenVersion = extractTokenVersion(token); + return user.getId().equals(userId) + && user.getRole().name().equals(role) + && user.getTokenVersion().equals(tokenVersion) + && !isTokenExpired(token); } } diff --git a/src/test/java/com/petshop/backend/security/JwtAuthenticationFilterTest.java b/src/test/java/com/petshop/backend/security/JwtAuthenticationFilterTest.java new file mode 100644 index 00000000..fa8b429c --- /dev/null +++ b/src/test/java/com/petshop/backend/security/JwtAuthenticationFilterTest.java @@ -0,0 +1,129 @@ +package com.petshop.backend.security; + +import com.petshop.backend.entity.User; +import com.petshop.backend.repository.UserRepository; +import jakarta.servlet.FilterChain; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.test.util.ReflectionTestUtils; + +import java.lang.reflect.Proxy; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicBoolean; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class JwtAuthenticationFilterTest { + + private JwtUtil jwtUtil; + + @BeforeEach + void setUp() { + jwtUtil = new JwtUtil(); + ReflectionTestUtils.setField(jwtUtil, "secret", "change_me_please_make_this_at_least_32_characters_long_for_security"); + ReflectionTestUtils.setField(jwtUtil, "expiration", 86400000L); + SecurityContextHolder.clearContext(); + } + + @AfterEach + void tearDown() { + SecurityContextHolder.clearContext(); + } + + @Test + void validTokenBuildsAppPrincipalAuthentication() throws Exception { + User user = buildUser(); + String token = jwtUtil.generateToken(user); + AtomicBoolean chainCalled = new AtomicBoolean(false); + JwtAuthenticationFilter filter = new JwtAuthenticationFilter(jwtUtil, userRepositoryFor(user)); + + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addHeader("Authorization", "Bearer " + token); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain chain = (req, res) -> chainCalled.set(true); + + filter.doFilter(request, response, chain); + + Object principal = SecurityContextHolder.getContext().getAuthentication().getPrincipal(); + assertInstanceOf(AppPrincipal.class, principal); + assertEquals("staff-user", ((AppPrincipal) principal).getUsername()); + assertEquals(User.Role.STAFF, ((AppPrincipal) principal).getRole()); + assertTrue(chainCalled.get()); + } + + @Test + void inactiveUserReturnsUnauthorized() throws Exception { + User user = buildUser(); + user.setActive(false); + String token = jwtUtil.generateToken(user); + JwtAuthenticationFilter filter = new JwtAuthenticationFilter(jwtUtil, userRepositoryFor(user)); + + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addHeader("Authorization", "Bearer " + token); + MockHttpServletResponse response = new MockHttpServletResponse(); + + filter.doFilter(request, response, (req, res) -> { + }); + + assertEquals(401, response.getStatus()); + assertNull(SecurityContextHolder.getContext().getAuthentication()); + } + + @Test + void tokenVersionMismatchReturnsUnauthorized() throws Exception { + User user = buildUser(); + String token = jwtUtil.generateToken(user); + user.setTokenVersion(4); + JwtAuthenticationFilter filter = new JwtAuthenticationFilter(jwtUtil, userRepositoryFor(user)); + + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addHeader("Authorization", "Bearer " + token); + MockHttpServletResponse response = new MockHttpServletResponse(); + + filter.doFilter(request, response, (req, res) -> { + }); + + assertEquals(401, response.getStatus()); + assertNull(SecurityContextHolder.getContext().getAuthentication()); + } + + private UserRepository userRepositoryFor(User user) { + return (UserRepository) Proxy.newProxyInstance( + UserRepository.class.getClassLoader(), + new Class[]{UserRepository.class}, + (proxy, method, args) -> { + if ("findById".equals(method.getName())) { + return user.getId().equals(args[0]) ? Optional.of(user) : Optional.empty(); + } + if ("equals".equals(method.getName())) { + return proxy == args[0]; + } + if ("hashCode".equals(method.getName())) { + return System.identityHashCode(proxy); + } + if ("toString".equals(method.getName())) { + return "UserRepositoryProxy"; + } + throw new UnsupportedOperationException(method.getName()); + } + ); + } + + private User buildUser() { + User user = new User(); + user.setId(42L); + user.setUsername("staff-user"); + user.setPassword("encoded"); + user.setRole(User.Role.STAFF); + user.setActive(true); + user.setTokenVersion(3); + return user; + } +} diff --git a/src/test/java/com/petshop/backend/security/JwtUtilTest.java b/src/test/java/com/petshop/backend/security/JwtUtilTest.java new file mode 100644 index 00000000..14f9ba3e --- /dev/null +++ b/src/test/java/com/petshop/backend/security/JwtUtilTest.java @@ -0,0 +1,59 @@ +package com.petshop.backend.security; + +import com.petshop.backend.entity.User; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.springframework.test.util.ReflectionTestUtils; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class JwtUtilTest { + + private JwtUtil jwtUtil; + + @BeforeEach + void setUp() { + jwtUtil = new JwtUtil(); + ReflectionTestUtils.setField(jwtUtil, "secret", "change_me_please_make_this_at_least_32_characters_long_for_security"); + ReflectionTestUtils.setField(jwtUtil, "expiration", 86400000L); + } + + @Test + void generatedTokenContainsIdentityClaims() { + User user = buildUser(); + + String token = jwtUtil.generateToken(user); + + assertEquals(42L, jwtUtil.extractUserId(token)); + assertEquals("staff-user", jwtUtil.extractUsername(token)); + assertEquals("STAFF", jwtUtil.extractRole(token)); + assertEquals(7, jwtUtil.extractTokenVersion(token)); + assertTrue(jwtUtil.validateToken(token, user)); + } + + @Test + void validateTokenRejectsChangedRoleOrTokenVersion() { + User user = buildUser(); + String token = jwtUtil.generateToken(user); + + user.setRole(User.Role.ADMIN); + assertFalse(jwtUtil.validateToken(token, user)); + + user.setRole(User.Role.STAFF); + user.setTokenVersion(8); + assertFalse(jwtUtil.validateToken(token, user)); + } + + private User buildUser() { + User user = new User(); + user.setId(42L); + user.setUsername("staff-user"); + user.setPassword("encoded"); + user.setRole(User.Role.STAFF); + user.setActive(true); + user.setTokenVersion(7); + return user; + } +}