通过AOP+自定义注解+Redis实现的限制ip访问接口次数

/ JavaSpring / 1080浏览

默认是全部接口做限制,如果接口没有@RequestLimit注解,则RequestLimitAspect会按照application-dev.yml中 request-limit.amount,request-limit.time中设置的值做限制; 也可通过@RequestLimit对单个接口做定制操作,RequestLimitAspect类会以@RequestLimit为准。

yml文件

  #请求限制参数
 request-limit:
    amount: 100
    time: 30000

通过@ConfigurationProperties读取yml文件的配置

@ConfigurationProperties(prefix = "request-limit")
@Component
@Data
public class RequestLimitConfig {
    /**
     * 允许访问的数量
     */
    public int amount;
    /**
     * 时间段
     */
    public long time;
}

自定义注解

@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
@Order(Ordered.HIGHEST_PRECEDENCE)
public @interface RequestLimit {

    /**
     * 允许访问的数量,默认200
     */
    int amount() default 200;

    /**
     * 时间段,单位为毫秒,默认一分钟
     */
    long time() default 60000;
}

切面类

import lombok.extern.slf4j.Slf4j;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;

import javax.servlet.http.HttpServletRequest;
import java.lang.reflect.Method;
import java.util.concurrent.TimeUnit;

/**
 * 请求限制切面实现
 *
 * @author songhaozhi
 * @since 2020/2/2
 */
@Aspect
@Component
@Slf4j
public class RequestLimitAspect {

    @Autowired
    private RedisUtil redisUtil;

    @Autowired
    private RequestLimitConfig requestLimitConfig;

    private final String POINT = "execution(* app.xxx.api.*.*..*.*(..))";

    @Pointcut(POINT)
    public void pointcut() {

    }

    /**
     * 方法前执行
     */
    @Around("pointcut()")
    public Object around(ProceedingJoinPoint point) throws Throwable {
        ServletRequestAttributes attribute = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
        HttpServletRequest request = attribute.getRequest();
        //获取IP
        String ip = IpUtil.getIpAddress(request);
        //获取请求路径
        String url = request.getRequestURL().toString();
        String key = RedisKey.REQUEST_LIMIT.concat(ip).concat(url);
        Method currentMethod = AspectUtil.INSTANCE.getMethod(point);
        //查看接口是否有RequestLimit注解,如果没有则按yml的值全局验证
        if (currentMethod.isAnnotationPresent(RequestLimit.class)) {
            //获取注解
            RequestLimit requestLimit = currentMethod.getAnnotation(RequestLimit.class);
            boolean checkResult = checkWithRedis(requestLimit.amount(),requestLimit.time(), key);
            if (checkResult) {
                log.info("requestLimited," + "[用户ip:{}],[访问地址:{}]超过了限定的次数[{}]次", ip, url, requestLimit.amount());
                return ApiResult.requestTooFast(null);
            }
            return point.proceed();
        }
        boolean checkResult = checkWithRedis(requestLimitConfig.getAmount(),requestLimitConfig.getTime(), key);
        if (checkResult) {
            log.info("requestLimited," + "[用户ip:{}],[访问地址:{}]超过了限定的次数[{}]次", ip, url, requestLimitConfig.getAmount());
            return ApiResult.requestTooFast(null);
        }
        return point.proceed();
    }

    /**
     * 以redis实现请求记录
     *
     * @param amount 请求次数
     * @param time   时间段
     * @param key
     * @return
     */
    private boolean checkWithRedis(int amount, long time, String key) {
        long count = redisUtil.incrBy(key, 1);
        if (count == 1) {
            redisUtil.expire(key, time, TimeUnit.MILLISECONDS);
        }
        if (count <= amount) {
            return false;
        }
        return true;
    }
}

其中RedisUtil是这个项目中的 https://github.com/whvcse/RedisUtil
获取IP方法

/**
     * 获取用户真实IP地址,不使用request.getRemoteAddr()的原因是有可能用户使用了代理软件方式避免真实IP地址,
     * 可是,如果通过了多级反向代理的话,X-Forwarded-For的值并不止一个,而是一串IP值
     *
     * @return ip
     */
    public static String getIpAddress(HttpServletRequest request) {
        String ip = request.getHeader("x-forwarded-for");
        if (ip != null && ip.length() != 0 && !Constants.UNKNOWN.equalsIgnoreCase(ip)) {
            // 多次反向代理后会有多个ip值,第一个ip才是真实ip
            if (ip.indexOf(StringPool.COMMA) != -1) {
                ip = ip.split(StringPool.COMMA)[0];
            }
        }
        if (ip == null || ip.length() == 0 || Constants.UNKNOWN.equalsIgnoreCase(ip)) {
            ip = request.getHeader("Proxy-Client-IP");
        }
        if (ip == null || ip.length() == 0 || Constants.UNKNOWN.equalsIgnoreCase(ip)) {
            ip = request.getHeader("WL-Proxy-Client-IP");
        }
        if (ip == null || ip.length() == 0 || Constants.UNKNOWN.equalsIgnoreCase(ip)) {
            ip = request.getHeader("HTTP_CLIENT_IP");
        }
        if (ip == null || ip.length() == 0 || Constants.UNKNOWN.equalsIgnoreCase(ip)) {
            ip = request.getHeader("HTTP_X_FORWARDED_FOR");
        }
        if (ip == null || ip.length() == 0 || Constants.UNKNOWN.equalsIgnoreCase(ip)) {
            ip = request.getHeader("X-Real-IP");
        }
        if (ip == null || ip.length() == 0 || Constants.UNKNOWN.equalsIgnoreCase(ip)) {
            ip = request.getRemoteAddr();
        }
        return ip;
    }

切面工具类

public enum AspectUtil {
    /**
     * 单例对象
     */
    INSTANCE;

    /**
     * 获取当前切面执行的方法的方法名
     *
     * @param point 当前切面执行的方法
     */
    public Method getMethod(JoinPoint point) throws NoSuchMethodException {
        Signature sig = point.getSignature();
        MethodSignature msig = (MethodSignature) sig;
        Object target = point.getTarget();
        return target.getClass()
                .getMethod( msig.getName(), msig.getParameterTypes() );
    }

}