最近,笔者收到个业务系统的需求,需要对系统里部分模块的数据按部门进行查看数据。即用户只能看到本部门内的用户数据。由于笔者公司已经进行了大量程序的开发,再对原程序进行调整将会导致大量的额外工作。因此笔者决定利用mybatis Plugins 对sql的查询语句进行修改,进而实现权限过滤的功能。
目前支持全部数据,个人数据,无权限以及按部门数据的sql修改。
一.添加权限注解
因业务的需要,通常表的设计会存在不同的字段进行权限控制。因此,采用注解的方式对mapper的sql语句进行配置。
单表权限控制,即sql语句中只有一个表需要权限控制的情况:
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
public @interface AccessTable {
String module();//模块名称
String resource();//资源名称
String table();//表名
String deptColumn() default "dept_id";//表中部门列,多个用,分隔
String accountColumn() default "account";//表中用户帐号列,多个用,分隔
}
多表权限控制,即sql语句中有多个表需要权限控制的情况:
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
public @interface Access {
AccessTable[] accessTable();//实际是单表注解的集合
}
二.使用druid parse对数据查询语句进行修改
2.1 引用druid sql parse
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>druid-spring-boot-starter</artifactId>
<version>1.1.21</version>
</dependency>
2.2 对sql语句进行分析及修改
@Slf4j
public class AccessSqlParse {
private final String dbType=JdbcConstants.MYSQL;
private User user;
private AccessTable accessTable;
public AccessSqlParse() {
super();
}
public static void main(String[] args) {
AccessSqlParse parse=new AccessSqlParse();
parse.process("SELECT * FROM dept WHERE status=1 ");
}
public AccessSqlParse(User user, AccessTable accessTable) {
super();
this.user = user;
this.accessTable = accessTable;
}
public String process(String sql) {
MySqlStatementParser parse=new MySqlStatementParser(sql);
SQLStatement stmt = parse.parseStatement();
if (stmt instanceof SQLSelectStatement) {//只对查询语句进行处理
SQLSelect sqlSelect = ((SQLSelectStatement) stmt).getSelect();
SQLSelectQuery sqlSelectQuery = sqlSelect.getQuery();
parseSelect(sqlSelectQuery);
return stmt.toString();
}
return "";//非查询语句不处理
}
private void select(SQLSelectQueryBlock sqlSelectQueryBlock) {
// 获取表
SQLTableSource table = sqlSelectQueryBlock.getFrom();
List<SQLExprTableSource> tables = new ArrayList<>();
tableParse(table, tables);
Set<String> tableNames=new HashSet<String>();
for(SQLExprTableSource t:tables) {
if(t.getName().toString().equals(accessTable.table())) {
if(t.getAlias()!=null) {
tableNames.add(t.getAlias());
}else {
tableNames.add(t.getName().toString());
}
}
}
String module = accessTable.module();
String resource = accessTable.resource();
SQLExpr where = sqlSelectQueryBlock.getWhere();
log.info(where.toString());
SQLExpr newWhere=null;
Map<String, Set<Integer>> resources = user.getResources();
if (resources != null && !resources.isEmpty()) {
Set<Integer> deptIds = resources.get(module + "#34; + resource);
if (deptIds != null && !deptIds.isEmpty()) {
if (deptIds.contains(-3)) {// 无权
newWhere=buildNotAccess();
} else if (deptIds.contains(-2)) {// 个人
newWhere=buildPersonWhere(tableNames);
} else if(deptIds.contains(-1)) {//全部,不需要添加权限处理
}else {
newWhere=buildDeptWhere(tableNames, deptIds);
}
}
}
if(newWhere!=null) {
where=buildNewCondition(SQLBinaryOperator.BooleanAnd, newWhere, false, where);
}
sqlSelectQueryBlock.setWhere(where);
}
private SQLExpr buildNotAccess() {
SQLBinaryOpExpr newWhere=new SQLBinaryOpExpr(dbType);
newWhere.setOperator(SQLBinaryOperator.Equality);
newWhere.setLeft(new SQLIntegerExpr(1));
newWhere.setRight(new SQLIntegerExpr(0));
return newWhere;
}
private SQLExpr buildDeptWhere(Set<String> tableNames,Set<Integer> deptIds) {
String columnStr=accessTable.deptColumn();
String[] columns=columnStr.split(",");
SQLExpr where=null;
for(String tableName:tableNames) {
SQLExpr accountWhere=null;
for(String column:columns) {
SQLInListExpr newWhere=new SQLInListExpr();
newWhere.setExpr(new SQLIdentifierExpr(tableName+"."+column));
List<SQLExpr> list=new ArrayList<>();
for(Integer deptId:deptIds) {
list.add(new SQLIntegerExpr(deptId));
}
newWhere.setTargetList(list);
accountWhere= buildNewCondition(SQLBinaryOperator.BooleanOr, newWhere, false, accountWhere);
}
where= buildNewCondition(SQLBinaryOperator.BooleanAnd, accountWhere, false, where);
}
return where;
}
private SQLExpr buildPersonWhere(Set<String> tableNames) {
String columnStr=accessTable.accountColumn();
String[] columns=columnStr.split(",");
String account=user.getAccount();
SQLBinaryOpExpr where=null;
for(String tableName:tableNames) {
SQLBinaryOpExpr accountWhere=null;
for(String column:columns) {
SQLBinaryOpExpr newWhere=new SQLBinaryOpExpr(dbType);
newWhere.setOperator(SQLBinaryOperator.Equality);
newWhere.setLeft(new SQLIdentifierExpr(tableName+"."+column));
newWhere.setRight(new SQLCharExpr(account));
accountWhere=(SQLBinaryOpExpr) buildNewCondition(SQLBinaryOperator.BooleanOr, newWhere, false, accountWhere);
}
where=(SQLBinaryOpExpr) buildNewCondition(SQLBinaryOperator.BooleanAnd, accountWhere, false, where);
}
return where;
}
private void tableParse(SQLTableSource table, List<SQLExprTableSource> tables) {
if (table instanceof SQLExprTableSource) {// 普通单表
SQLExprTableSource tableSource = (SQLExprTableSource) table;
tables.add(tableSource);
} else if (table instanceof SQLJoinTableSource) { // join多表
join(table, tables);
} else if (table instanceof SQLSubqueryTableSource) {// 子查询作为表
child((SQLSubqueryTableSource) table);
}
}
private void child(SQLSubqueryTableSource table) {
SQLSelect sqlSelect = table.getSelect();
SQLSelectQuery sqlSelectQuery = sqlSelect.getQuery();
// log.info(sqlSelect.getQuery().getClass().getName());
parseSelect(sqlSelectQuery);
}
private void join(SQLTableSource table, List<SQLExprTableSource> tables) {
SQLJoinTableSource joinTable = (SQLJoinTableSource) table;
SQLTableSource left = joinTable.getLeft();
SQLTableSource right = joinTable.getRight();
tableParse(left, tables);
tableParse(right, tables);
}
private SQLExpr buildNewCondition(SQLBinaryOperator op, SQLExpr condition, boolean left, SQLExpr where) {
if (where == null) {
return condition;
}
SQLBinaryOpExpr newCondition;
if (left) {
newCondition = new SQLBinaryOpExpr(condition, op, where);
} else {
newCondition = new SQLBinaryOpExpr(where, op, condition);
}
return newCondition;
}
private void union(SQLUnionQuery unionQuery) {
SQLSelectQuery left=unionQuery.getLeft();
SQLSelectQuery right=unionQuery.getRight();
parseSelect(left);
parseSelect(right);
}
private void parseSelect(SQLSelectQuery sqlSelectQuery) {
if (sqlSelectQuery instanceof SQLUnionQuery) {
SQLUnionQuery unionQuery = (SQLUnionQuery) sqlSelectQuery;
union(unionQuery);
} else {
SQLSelectQueryBlock sqlSelectQueryBlock = (SQLSelectQueryBlock) sqlSelectQuery;
select(sqlSelectQueryBlock);
}
}
public User getUser() {
return user;
}
public void setUser(User user) {
this.user = user;
}
public AccessTable getAccessTable() {
return accessTable;
}
public void setAccessTable(AccessTable accessTable) {
this.accessTable = accessTable;
}
}
三.实现mybatis plugin Interceptor接口,实现对sql语句的拦截及权限处理
mybatis plugin支持对Executor,ParameterHandler,ResultSetHandler,StatementHandler等目标的拦截。由于本次主要的目标是对sql语句进行修改,因此主要用到StatementHandler.prepare进行拦截。具体如下:
@Intercepts({@Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class,Integer.class})})
@Slf4j
public class MapperInterceptor implements Interceptor {
private Properties properties;
private Map<String, String> moduleMapping = new ConcurrentHashMap<String, String>();
@Override
public Object intercept(Invocation invocation) throws Throwable {
if (SecurityContextHolder.getContext().getAuthentication() == null) {// 后台任务,通常是定时任务
return invocation.proceed();
}
Object prince = SecurityContextHolder.getContext().getAuthentication().getPrincipal();
if (prince instanceof String) {//未登陆
return invocation.proceed();
}
User user = (User) SecurityContextHolder.getContext().getAuthentication().getPrincipal();
if (user.getAccount().equals("root")) {//超级管理员帐号
return invocation.proceed();
}
log.debug("sql权限控制");
StatementHandler statementHandler = (StatementHandler) invocation.getTarget();
MetaObject metaObject = MetaObject.forObject(statementHandler, SystemMetaObject.DEFAULT_OBJECT_FACTORY,
SystemMetaObject.DEFAULT_OBJECT_WRAPPER_FACTORY, new DefaultReflectorFactory());
// 先拦截到RoutingStatementHandler,里面有个StatementHandler类型的delegate变量,其实现类是BaseStatementHandler,然后就到BaseStatementHandler的成员变量mappedStatement
MappedStatement mappedStatement = (MappedStatement) metaObject.getValue("delegate.mappedStatement");
// id为执行的mapper方法的全路径名,如com.uv.dao.UserMapper.insertUser
String id = mappedStatement.getId();
// sql语句类型 select、delete、insert、update
String sqlCommandType = mappedStatement.getSqlCommandType().toString();
BoundSql boundSql = statementHandler.getBoundSql();
log.debug("method:"+id);
// 获取到原始sql语句
String sql = boundSql.getSql();
//log.info(sql);
log.debug("转换前的sql:"+sql);
String[] permissionsValue = null;
// 注解逻辑判断 添加注解了才拦截
Class<?> classType = Class
.forName(mappedStatement.getId().substring(0, mappedStatement.getId().lastIndexOf(".")));
String mName = mappedStatement.getId().substring(mappedStatement.getId().lastIndexOf(".") + 1,
mappedStatement.getId().length());
String newSql="";
for (Method method : classType.getMethods()) {
if (method.getName().equals(mName)) {
if (method.isAnnotationPresent(AccessTable.class)) {
AccessTable accessTable = method.getAnnotation(AccessTable.class);
log.debug(accessTable.toString());
newSql=processSql(sql, user, accessTable);
} else if (method.isAnnotationPresent(Access.class)) {
Access access = method.getAnnotation(Access.class);
AccessTable[] accessTables = access.accessTable();
newSql=sql;
for (AccessTable accessTable : accessTables) {
newSql=processSql(newSql, user, accessTable);
}
}
}
}
if(StringUtils.isEmpty(newSql)) {
log.info("没有转换:");
}else {
log.info("转换后的sql:"+newSql);
Field field = boundSql.getClass().getDeclaredField("sql");
field.setAccessible(true);
field.set(boundSql, newSql);
}
//metaObject.setValue("delegate.boundSql.sql", sql);
return invocation.proceed();
}
private String processSql(String sql, User user, AccessTable access) {
String module = access.module();
String resource = access.resource();
String table = access.table();
if (module == null || module.equals("")) {
throw new RuntimeException("@AccessTable注解中未设置模块字段module");
}
if (resource == null || resource.equals("")) {
throw new RuntimeException("@AccessTable注解中未设置资源字段resource");
}
if (table == null || table.equals("")) {
throw new RuntimeException("@AccessTable注解中未设置表字段table");
}
AccessSqlParse pares=new AccessSqlParse(user, access);
sql=pares.process(sql);
return sql;
}
public static class BoundSqlSqlSource implements SqlSource {
private BoundSql boundSql;
public BoundSqlSqlSource(BoundSql boundSql) {
this.boundSql = boundSql;
}
public BoundSql getBoundSql(Object parameterObject) {
return boundSql;
}
}
@Override
public Object plugin(Object target) {
if (target instanceof StatementHandler) {
return Plugin.wrap(target, this);
} else {
return target;
}
}
@Override
public void setProperties(Properties properties0) {
this.properties = properties0;
}
}
创作不易,支持请关注
本文暂时没有评论,来添加一个吧(●'◡'●)