计算机系统应用教程网站

网站首页 > 技术文章 正文

mybatis plugin以及druid sql parse实现对Sql语句优化实现权限控制

btikc 2024-09-16 13:04:19 技术文章 23 ℃ 0 评论

最近,笔者收到个业务系统的需求,需要对系统里部分模块的数据按部门进行查看数据。即用户只能看到本部门内的用户数据。由于笔者公司已经进行了大量程序的开发,再对原程序进行调整将会导致大量的额外工作。因此笔者决定利用mybatis Plugins 对sql的查询语句进行修改,进而实现权限过滤的功能。

目前支持全部数据,个人数据,无权限以及按部门数据的sql修改。

一.添加权限注解

因业务的需要,通常表的设计会存在不同的字段进行权限控制。因此,采用注解的方式对mapper的sql语句进行配置。

单表权限控制,即sql语句中只有一个表需要权限控制的情况:

@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
public @interface AccessTable {
	String module();//模块名称
	String resource();//资源名称
	String table();//表名
	String deptColumn() default "dept_id";//表中部门列,多个用,分隔
	String accountColumn() default "account";//表中用户帐号列,多个用,分隔
}

多表权限控制,即sql语句中有多个表需要权限控制的情况:

@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
public @interface Access {
	AccessTable[] accessTable();//实际是单表注解的集合
}

二.使用druid parse对数据查询语句进行修改

2.1 引用druid sql parse

<dependency>
		    <groupId>com.alibaba</groupId>
		    <artifactId>druid-spring-boot-starter</artifactId>
		    <version>1.1.21</version>
		</dependency>

2.2 对sql语句进行分析及修改


@Slf4j
public class AccessSqlParse {
	private final String dbType=JdbcConstants.MYSQL;
	private User user;
	private AccessTable accessTable;

	public AccessSqlParse() {
		super();
	}
	public static void main(String[] args) {
		AccessSqlParse parse=new AccessSqlParse();
		parse.process("SELECT * FROM dept 	WHERE status=1 ");
	}
	public AccessSqlParse(User user, AccessTable accessTable) {
		super();
		this.user = user;
		this.accessTable = accessTable;
	}

	public String process(String sql) {
		MySqlStatementParser parse=new MySqlStatementParser(sql);
		SQLStatement stmt = parse.parseStatement();
		if (stmt instanceof SQLSelectStatement) {//只对查询语句进行处理
			SQLSelect sqlSelect = ((SQLSelectStatement) stmt).getSelect();
			SQLSelectQuery sqlSelectQuery = sqlSelect.getQuery();
			parseSelect(sqlSelectQuery);
			return stmt.toString();
		}
		return "";//非查询语句不处理
		
	}

	private void select(SQLSelectQueryBlock sqlSelectQueryBlock) {
		// 获取表
		SQLTableSource table = sqlSelectQueryBlock.getFrom();
		List<SQLExprTableSource> tables = new ArrayList<>();
		tableParse(table, tables);
		Set<String> tableNames=new HashSet<String>();
		for(SQLExprTableSource t:tables) {
			if(t.getName().toString().equals(accessTable.table())) {
				if(t.getAlias()!=null) {
					tableNames.add(t.getAlias());
				}else {
					tableNames.add(t.getName().toString());
				}
			}
			
		}
		String module = accessTable.module();
		String resource = accessTable.resource();
		
		SQLExpr where = sqlSelectQueryBlock.getWhere();
		log.info(where.toString());
		SQLExpr newWhere=null;
		Map<String, Set<Integer>> resources = user.getResources();
		
		if (resources != null && !resources.isEmpty()) {
			Set<Integer> deptIds = resources.get(module + "#34; + resource);
			if (deptIds != null && !deptIds.isEmpty()) {
				if (deptIds.contains(-3)) {// 无权
					
					newWhere=buildNotAccess();
				} else if (deptIds.contains(-2)) {// 个人
					newWhere=buildPersonWhere(tableNames);
				} else if(deptIds.contains(-1)) {//全部,不需要添加权限处理
					
				}else {
					newWhere=buildDeptWhere(tableNames, deptIds);
				}
			}
		}
		if(newWhere!=null) {
			where=buildNewCondition(SQLBinaryOperator.BooleanAnd, newWhere, false, where);
		}
		sqlSelectQueryBlock.setWhere(where);
	}
	private SQLExpr buildNotAccess() {
		SQLBinaryOpExpr newWhere=new SQLBinaryOpExpr(dbType);
		newWhere.setOperator(SQLBinaryOperator.Equality);
		newWhere.setLeft(new SQLIntegerExpr(1));
		newWhere.setRight(new SQLIntegerExpr(0));
		return newWhere;
	}
	private SQLExpr buildDeptWhere(Set<String> tableNames,Set<Integer> deptIds) {
		String columnStr=accessTable.deptColumn();
		String[] columns=columnStr.split(",");

		SQLExpr where=null;
		for(String tableName:tableNames) {
			SQLExpr accountWhere=null;
			for(String column:columns) {
				SQLInListExpr newWhere=new SQLInListExpr();
				newWhere.setExpr(new SQLIdentifierExpr(tableName+"."+column));
				List<SQLExpr> list=new ArrayList<>();
				for(Integer deptId:deptIds) {
					list.add(new SQLIntegerExpr(deptId));
				}
				newWhere.setTargetList(list);
				accountWhere= buildNewCondition(SQLBinaryOperator.BooleanOr, newWhere, false, accountWhere); 
			}
			where= buildNewCondition(SQLBinaryOperator.BooleanAnd, accountWhere, false, where);
		}
		
		return where;
		
	}
	private SQLExpr buildPersonWhere(Set<String> tableNames) {
		String columnStr=accessTable.accountColumn();
		String[] columns=columnStr.split(",");
		String account=user.getAccount();
		SQLBinaryOpExpr where=null;
		for(String tableName:tableNames) {
			SQLBinaryOpExpr accountWhere=null;
			for(String column:columns) {
				SQLBinaryOpExpr newWhere=new SQLBinaryOpExpr(dbType);
				newWhere.setOperator(SQLBinaryOperator.Equality);
				newWhere.setLeft(new SQLIdentifierExpr(tableName+"."+column));
				newWhere.setRight(new SQLCharExpr(account));
				accountWhere=(SQLBinaryOpExpr) buildNewCondition(SQLBinaryOperator.BooleanOr, newWhere, false, accountWhere); 
			}
			where=(SQLBinaryOpExpr) buildNewCondition(SQLBinaryOperator.BooleanAnd, accountWhere, false, where);
		}
		
		return where;
		
	}
	private void tableParse(SQLTableSource table, List<SQLExprTableSource> tables) {

		if (table instanceof SQLExprTableSource) {// 普通单表
			
			SQLExprTableSource tableSource = (SQLExprTableSource) table;
			tables.add(tableSource);
		} else if (table instanceof SQLJoinTableSource) { // join多表
			join(table, tables);
		} else if (table instanceof SQLSubqueryTableSource) {// 子查询作为表
			child((SQLSubqueryTableSource) table);
		}
	}

	private void child(SQLSubqueryTableSource table) {
		SQLSelect sqlSelect = table.getSelect();
		SQLSelectQuery sqlSelectQuery = sqlSelect.getQuery();
		// log.info(sqlSelect.getQuery().getClass().getName());
		parseSelect(sqlSelectQuery);
	}

	private void join(SQLTableSource table, List<SQLExprTableSource> tables) {
		SQLJoinTableSource joinTable = (SQLJoinTableSource) table;
		SQLTableSource left = joinTable.getLeft();
		SQLTableSource right = joinTable.getRight();
		tableParse(left, tables);
		tableParse(right, tables);

	}

	private SQLExpr buildNewCondition(SQLBinaryOperator op, SQLExpr condition, boolean left, SQLExpr where) {
		if (where == null) {
			return condition;
		}

		SQLBinaryOpExpr newCondition;
		if (left) {
			newCondition = new SQLBinaryOpExpr(condition, op, where);
		} else {
			newCondition = new SQLBinaryOpExpr(where, op, condition);
		}
		return newCondition;
	}

	private void union(SQLUnionQuery unionQuery) {
		SQLSelectQuery left=unionQuery.getLeft();
		SQLSelectQuery right=unionQuery.getRight();
		parseSelect(left);
		parseSelect(right);
	}
	private void parseSelect(SQLSelectQuery sqlSelectQuery) {
		if (sqlSelectQuery instanceof SQLUnionQuery) {
			SQLUnionQuery unionQuery = (SQLUnionQuery) sqlSelectQuery;
			union(unionQuery);
		} else {
			SQLSelectQueryBlock sqlSelectQueryBlock = (SQLSelectQueryBlock) sqlSelectQuery;
			select(sqlSelectQueryBlock);
		}
	}
	public User getUser() {
		return user;
	}

	public void setUser(User user) {
		this.user = user;
	}

	public AccessTable getAccessTable() {
		return accessTable;
	}

	public void setAccessTable(AccessTable accessTable) {
		this.accessTable = accessTable;
	}

}


三.实现mybatis plugin Interceptor接口,实现对sql语句的拦截及权限处理

mybatis plugin支持对Executor,ParameterHandler,ResultSetHandler,StatementHandler等目标的拦截。由于本次主要的目标是对sql语句进行修改,因此主要用到StatementHandler.prepare进行拦截。具体如下:

@Intercepts({@Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class,Integer.class})})
@Slf4j
public class MapperInterceptor implements Interceptor {

	private Properties properties;
	private Map<String, String> moduleMapping = new ConcurrentHashMap<String, String>();
	@Override
	public Object intercept(Invocation invocation) throws Throwable {
		if (SecurityContextHolder.getContext().getAuthentication() == null) {// 后台任务,通常是定时任务
			return invocation.proceed();
		}

		Object prince = SecurityContextHolder.getContext().getAuthentication().getPrincipal();
		if (prince instanceof String) {//未登陆
			return invocation.proceed();
		}
		User user = (User) SecurityContextHolder.getContext().getAuthentication().getPrincipal();
		if (user.getAccount().equals("root")) {//超级管理员帐号
			return invocation.proceed();
		}
		log.debug("sql权限控制");

		StatementHandler statementHandler = (StatementHandler) invocation.getTarget();
		MetaObject metaObject = MetaObject.forObject(statementHandler, SystemMetaObject.DEFAULT_OBJECT_FACTORY,
				SystemMetaObject.DEFAULT_OBJECT_WRAPPER_FACTORY, new DefaultReflectorFactory());
		// 先拦截到RoutingStatementHandler,里面有个StatementHandler类型的delegate变量,其实现类是BaseStatementHandler,然后就到BaseStatementHandler的成员变量mappedStatement
		MappedStatement mappedStatement = (MappedStatement) metaObject.getValue("delegate.mappedStatement");
		// id为执行的mapper方法的全路径名,如com.uv.dao.UserMapper.insertUser
		String id = mappedStatement.getId();
		// sql语句类型 select、delete、insert、update
		String sqlCommandType = mappedStatement.getSqlCommandType().toString();
		BoundSql boundSql = statementHandler.getBoundSql();
		log.debug("method:"+id);
		// 获取到原始sql语句
		String sql = boundSql.getSql();
		//log.info(sql);
		log.debug("转换前的sql:"+sql);
		String[] permissionsValue = null;

		// 注解逻辑判断 添加注解了才拦截
		Class<?> classType = Class
				.forName(mappedStatement.getId().substring(0, mappedStatement.getId().lastIndexOf(".")));
		String mName = mappedStatement.getId().substring(mappedStatement.getId().lastIndexOf(".") + 1,
				mappedStatement.getId().length());
		String newSql="";
		for (Method method : classType.getMethods()) {

			if (method.getName().equals(mName)) {

				if (method.isAnnotationPresent(AccessTable.class)) {
					
					AccessTable accessTable = method.getAnnotation(AccessTable.class);
					log.debug(accessTable.toString());
					newSql=processSql(sql, user, accessTable);
					
				} else if (method.isAnnotationPresent(Access.class)) {
					Access access = method.getAnnotation(Access.class);
					AccessTable[] accessTables = access.accessTable();
					newSql=sql;
					for (AccessTable accessTable : accessTables) {
						newSql=processSql(newSql, user, accessTable);
						
					}
				}

			}
		}
		if(StringUtils.isEmpty(newSql)) {
			log.info("没有转换:");
		}else {
			log.info("转换后的sql:"+newSql);
			Field field = boundSql.getClass().getDeclaredField("sql");
			field.setAccessible(true);
			field.set(boundSql, newSql);
		}
		
		
		//metaObject.setValue("delegate.boundSql.sql", sql);
		return invocation.proceed();

	}

	private String processSql(String sql, User user, AccessTable access) {
		String module = access.module();
		String resource = access.resource();
		String table = access.table();

		if (module == null || module.equals("")) {
			throw new RuntimeException("@AccessTable注解中未设置模块字段module");
		}
		if (resource == null || resource.equals("")) {
			throw new RuntimeException("@AccessTable注解中未设置资源字段resource");
		}
		if (table == null || table.equals("")) {
			throw new RuntimeException("@AccessTable注解中未设置表字段table");
		}
		AccessSqlParse pares=new AccessSqlParse(user, access);
		sql=pares.process(sql);
		return sql;
	}

	public static class BoundSqlSqlSource implements SqlSource {
		private BoundSql boundSql;

		public BoundSqlSqlSource(BoundSql boundSql) {
			this.boundSql = boundSql;
		}

		public BoundSql getBoundSql(Object parameterObject) {
			return boundSql;
		}
	}

	@Override
	public Object plugin(Object target) {
		if (target instanceof StatementHandler) {
			return Plugin.wrap(target, this);
		} else {
			return target;
		}
	}

	@Override
	public void setProperties(Properties properties0) {
		this.properties = properties0;
	}
}

创作不易,支持请关注

本文暂时没有评论,来添加一个吧(●'◡'●)

欢迎 发表评论:

最近发表
标签列表