Sorry, your browser cannot access this site
This page requires browser support (enable) JavaScript
Learn more >
场景说明

需求:经常一些接口会有一些校验的问题,比如权限,重复访问等,这种需求不是很想写在业务逻辑里面,那就可以用注解+AOP 的防止进行实现。下面用拦截重复点击的例子,做下实现。

工程结构

image-20240115103016227

功能实现
注解实现
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
/**
* @Author bestrookie
* @Date 2024/1/9 16:27
* @Desc
*/
@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.TYPE, ElementType.METHOD})
public @interface AccessInterceptor {
//用哪个字段作为拦截标识,未配置则默认走全部
String key() default "all";

//限制频次(每秒请求次数)
double permitsPerSecond();

//黑名单拦截(多少次限制后加入黑名单)0不限制
double blacklistCount() default 0;
//拦截后的执行方法
String fallbackMethod();

}

自定义切面注解,提供了拦截的 key,限制频次,黑名单处理、拦截后的回调方法。

切面拦截
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
/**
* @Author bestrookie
* @Date 2024/1/9 17:09
* @Desc
*/
@Slf4j
@Aspect
public class RateLimiterAOP {
/**
* 个人限频记录 1 分钟
*/
private final Cache<String, RateLimiter> loginRecord = CacheBuilder.newBuilder().
expireAfterWrite(1, TimeUnit.MINUTES).
build();

/**
* 个人限频黑名单 24h
*/
private final Cache<String, Long> blackList = CacheBuilder.newBuilder()
.expireAfterWrite(24, TimeUnit.HOURS)
.build();

@Pointcut("@annotation(com.bestrookie.annotation.AccessInterceptor)")
public void aopPoint(){
}

@Around("aopPoint() && @annotation(accessInterceptor)")
public Object doRouter(ProceedingJoinPoint jp, AccessInterceptor accessInterceptor) throws Throwable{
String key = accessInterceptor.key();
if (StringUtils.isBlank(key)){
throw new RuntimeException("annotation RateLimiter uId is null");
}
//获取拦截字段
String keyAttr = getAttrValue(key, jp.getArgs());
log.info("aop attr {}", keyAttr);
//黑名单拦截
if (!"all".equals(keyAttr) && accessInterceptor.blacklistCount() != 0
&& null != blackList.getIfPresent(keyAttr) &&
blackList.getIfPresent(keyAttr) > accessInterceptor.blacklistCount()){
log.info("限流-黑名单拦截 24h:{}",keyAttr);
return fallbackMethodResult(jp, accessInterceptor.fallbackMethod());
}
//获取限流 ->Guava 缓存一分钟
RateLimiter rateLimiter = loginRecord.getIfPresent(keyAttr);
if (null == rateLimiter){
rateLimiter = RateLimiter.create(accessInterceptor.permitsPerSecond());
loginRecord.put(keyAttr, rateLimiter);
}
// 限流拦截
if (!rateLimiter.tryAcquire()){
if (accessInterceptor.blacklistCount() != 0){
if (null == blackList.getIfPresent(keyAttr)){
blackList.put(keyAttr, 1L);
}else {
blackList.put(keyAttr, blackList.getIfPresent(keyAttr) + 1L);
}
}
log.info("限流-超频次拦截:{}", keyAttr);
return fallbackMethodResult(jp, accessInterceptor.fallbackMethod());
}
return jp.proceed();

}

private Object fallbackMethodResult(JoinPoint jp, String fallbackMethod) throws NoSuchMethodException, InvocationTargetException, IllegalAccessException {
Signature sig = jp.getSignature();
MethodSignature methodSignature = (MethodSignature) sig;
Method method = jp.getTarget().getClass().getMethod(fallbackMethod, methodSignature.getParameterTypes());
return method.invoke(jp.getThis(), jp.getArgs());
}

private Method getMethod(JoinPoint jp) throws NoSuchMethodException {
Signature sig = jp.getSignature();
MethodSignature methodSignature = (MethodSignature) sig;
return jp.getTarget().getClass().getMethod(methodSignature.getName(), methodSignature.getParameterTypes());
}

/**
* 根据自身业务调整,主要是为了获取某个值做拦截
*/
public String getAttrValue(String attr, Object[] args){
if (args[0] instanceof String){
return args[0].toString();
}
String filedValue = null;
for (Object arg : args) {
try {
if (StringUtils.isNotBlank(filedValue)){
break;
}
filedValue = String.valueOf(this.getValueByName(arg, attr));
}catch (Exception e){
log.error("获取路由属性值失败 attr:{}", attr,e);
}
}
return filedValue;
}


/**
* 获取对象的特定属性值
* @param item 对象
* @param name 属性名
* @return 属性值
*/
private Object getValueByName(Object item, String name){
try {
Field field = getFieldByName(item, name);
if (field == null){
return null;
}
field.setAccessible(true);
Object o = field.get(item);
field.setAccessible(false);
return o;

}catch (IllegalAccessException e){
return null;
}
}


/**
* 根据名称获取方法,该方法同时兼顾继承类获取父类的属性
* @param item item 对象
* @param name 属性名
* @return 该属性对应方法
*/
private Field getFieldByName(Object item, String name){
try {
Field field;
try {
field = item.getClass().getDeclaredField(name);
}catch (NoSuchFieldException e){
field = item.getClass().getSuperclass().getDeclaredField(name);
}
return field;
} catch (NoSuchFieldException e){
return null;
}
}
}

通过@Pointcut切入配置了定义注解的接口方法,通过自定义注解中配置的拦截字段,获取对应的值。这里的作为用户的标识使用,只对这个用户,当然也可以根据业务来定义

测试接口
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
@SpringBootApplication
@Slf4j
@Configurable
@RestController()
@RequestMapping("/api/ratelimiter/")
public class Application {
public static void main(String[] args) {
SpringApplication.run(Application.class);
}

@AccessInterceptor(key = "fingerprint", fallbackMethod = "loginErr", permitsPerSecond = 1.0d, blacklistCount = 10)
@RequestMapping(value = "login", method = RequestMethod.GET)
public String login(String fingerprint, String uId, String token) {
log.info("模拟登录 fingerprint:{}", fingerprint);
return "模拟登录:登录成功 " + uId;
}

public String loginErr(String fingerprint, String uId, String token) {
return "频次限制,请勿恶意访问!";
}

}
  • key:以用户 id 作为拦截标识限制用户访问
  • fallbackMethod:失败后的回调方法,方法出入参保持一致
  • permitsPerSecond:每秒的访问频次限制,1 秒一次
  • blacklistCount:超过十次限制后,加入黑名单

评论