From d43942fb76d9b89610b8970507f757ed450231a6 Mon Sep 17 00:00:00 2001 From: Harkamal Randhawa Date: Tue, 14 Apr 2026 15:23:26 -0600 Subject: [PATCH] add rate limiting --- .../backend/security/RateLimitFilter.java | 62 +++++++++++++++++++ .../backend/security/RateLimiterService.java | 45 ++++++++++++++ .../backend/security/SecurityConfig.java | 4 ++ 3 files changed, 111 insertions(+) create mode 100644 backend/src/main/java/com/petshop/backend/security/RateLimitFilter.java create mode 100644 backend/src/main/java/com/petshop/backend/security/RateLimiterService.java diff --git a/backend/src/main/java/com/petshop/backend/security/RateLimitFilter.java b/backend/src/main/java/com/petshop/backend/security/RateLimitFilter.java new file mode 100644 index 00000000..567d4219 --- /dev/null +++ b/backend/src/main/java/com/petshop/backend/security/RateLimitFilter.java @@ -0,0 +1,62 @@ +package com.petshop.backend.security; + +import com.petshop.backend.exception.ApiErrorResponder; +import jakarta.servlet.FilterChain; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import org.springframework.http.HttpStatus; +import org.springframework.lang.NonNull; +import org.springframework.stereotype.Component; +import org.springframework.web.filter.OncePerRequestFilter; + +import java.io.IOException; +import java.time.Duration; +import java.util.Map; + +@Component +public class RateLimitFilter extends OncePerRequestFilter { + + private static final Map RULES = Map.of( + "/api/v1/auth/login", new int[]{10, 15}, + "/api/v1/auth/register", new int[]{5, 60}, + "/api/v1/auth/forgot-password", new int[]{3, 10}, + "/api/v1/auth/reset-password", new int[]{10, 15} + ); + + private final RateLimiterService rateLimiterService; + private final ApiErrorResponder apiErrorResponder; + + public RateLimitFilter(RateLimiterService rateLimiterService, ApiErrorResponder apiErrorResponder) { + this.rateLimiterService = rateLimiterService; + this.apiErrorResponder = apiErrorResponder; + } + + @Override + protected void doFilterInternal(@NonNull HttpServletRequest request, + @NonNull HttpServletResponse response, + @NonNull FilterChain filterChain) throws ServletException, IOException { + String path = request.getRequestURI(); + int[] rule = RULES.get(path); + + if (rule != null) { + String ip = extractIp(request); + String key = path + ":" + ip; + if (!rateLimiterService.isAllowed(key, rule[0], Duration.ofMinutes(rule[1]))) { + apiErrorResponder.write(response, HttpStatus.TOO_MANY_REQUESTS, + "Too many requests. Please try again later.", null, path); + return; + } + } + + filterChain.doFilter(request, response); + } + + private String extractIp(HttpServletRequest request) { + String forwarded = request.getHeader("X-Forwarded-For"); + if (forwarded != null && !forwarded.isBlank()) { + return forwarded.split(",")[0].trim(); + } + return request.getRemoteAddr(); + } +} diff --git a/backend/src/main/java/com/petshop/backend/security/RateLimiterService.java b/backend/src/main/java/com/petshop/backend/security/RateLimiterService.java new file mode 100644 index 00000000..4f6eb94f --- /dev/null +++ b/backend/src/main/java/com/petshop/backend/security/RateLimiterService.java @@ -0,0 +1,45 @@ +package com.petshop.backend.security; + +import org.springframework.scheduling.annotation.Scheduled; +import org.springframework.stereotype.Service; + +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +@Service +public class RateLimiterService { + + private final Map> buckets = new ConcurrentHashMap<>(); + + public boolean isAllowed(String key, int maxRequests, Duration window) { + Instant now = Instant.now(); + Instant windowStart = now.minus(window); + + Deque timestamps = buckets.computeIfAbsent(key, k -> new ArrayDeque<>()); + synchronized (timestamps) { + while (!timestamps.isEmpty() && timestamps.peekFirst().isBefore(windowStart)) { + timestamps.pollFirst(); + } + if (timestamps.size() >= maxRequests) { + return false; + } + timestamps.addLast(now); + return true; + } + } + + @Scheduled(fixedDelay = 300_000) + public void evictStale() { + Instant cutoff = Instant.now().minus(Duration.ofHours(2)); + buckets.entrySet().removeIf(entry -> { + Deque timestamps = entry.getValue(); + synchronized (timestamps) { + return timestamps.isEmpty() || timestamps.peekLast().isBefore(cutoff); + } + }); + } +} diff --git a/backend/src/main/java/com/petshop/backend/security/SecurityConfig.java b/backend/src/main/java/com/petshop/backend/security/SecurityConfig.java index b15d4a96..12f784fb 100644 --- a/backend/src/main/java/com/petshop/backend/security/SecurityConfig.java +++ b/backend/src/main/java/com/petshop/backend/security/SecurityConfig.java @@ -31,15 +31,18 @@ import java.util.List; public class SecurityConfig { private final JwtAuthenticationFilter jwtAuthFilter; + private final RateLimitFilter rateLimitFilter; private final UserDetailsService userDetailsService; private final RestAuthenticationEntryPoint restAuthenticationEntryPoint; private final RestAccessDeniedHandler restAccessDeniedHandler; public SecurityConfig(JwtAuthenticationFilter jwtAuthFilter, + RateLimitFilter rateLimitFilter, UserDetailsService userDetailsService, RestAuthenticationEntryPoint restAuthenticationEntryPoint, RestAccessDeniedHandler restAccessDeniedHandler) { this.jwtAuthFilter = jwtAuthFilter; + this.rateLimitFilter = rateLimitFilter; this.userDetailsService = userDetailsService; this.restAuthenticationEntryPoint = restAuthenticationEntryPoint; this.restAccessDeniedHandler = restAccessDeniedHandler; @@ -75,6 +78,7 @@ public class SecurityConfig { .sessionManagement(session -> session.sessionCreationPolicy(SessionCreationPolicy.STATELESS)) .authenticationProvider(daoAuthenticationProvider()) .addFilterBefore(jwtAuthFilter, UsernamePasswordAuthenticationFilter.class); + http.addFilterBefore(rateLimitFilter, JwtAuthenticationFilter.class); http.addFilterAfter(activityLoggingFilter, JwtAuthenticationFilter.class); return http.build();