from starlette.middleware.base import BaseHTTPMiddleware import fnmatch from starlette.responses import JSONResponse from linxyun.utils.logger import get_logger from linxyun.utils.result import Result, SysCodes import re logger = get_logger(__name__) # 鉴权中间件 class SecurityMiddleware(BaseHTTPMiddleware): def __init__(self, app, linxyun): logger.info(f"添加鉴权中间件 SecurityMiddleware") super().__init__(app) self.linxyun = linxyun self.pattern = re.compile("LoginID_\\d{14}_\\d{6}") async def dispatch(self, request, call_next): # 获取请求路径 path = request.url.path method = request.method if method == "OPTIONS": return await call_next(request) logger.info(f"鉴权拦截: {method} {path}") # 获取请求头 Token token = request.headers.get("token") if not token: return JSONResponse(content=Result.error(SysCodes.USER_NOT_LOGIN).to_dict()) if token.startswith("Session"): search_resp = self.pattern.search(token) if search_resp: token = search_resp.group() user_auth_result:dict = self.linxyun.get_user_auth(token).to_dict() if user_auth_result.get("success") is False: logger.info(f"用户登录信息失效: {user_auth_result}") return JSONResponse(content=user_auth_result) user_auth = user_auth_result.get("data") user_role = user_auth.get("UserRoles") if user_role is None: logger.info(f"用户角色为空: {user_auth}") return JSONResponse(content=Result.error(SysCodes.USER_NO_AUTHORITY).to_dict()) if self.linxyun.config.entCode != user_auth.get("EntCode"): logger.info(f"用户企业编码错误: {user_auth.get('EntCode')}") return JSONResponse(content=Result.error(SysCodes.LOGIN_ERROR).to_dict()) role_map = self.linxyun.config.role if user_role not in role_map: logger.info(f"用户权限未在系统权限中: {user_role}") return JSONResponse(content=Result.error(SysCodes.LOGIN_ERROR).to_dict()) # 是否有权限访问该路径 path_list = role_map.get(user_role) if self.is_uri_authorized(path, path_list) is False: logger.info(f"用户没有权限访问该路径: {path}") return JSONResponse(content=Result.error(SysCodes.USER_NO_AUTHORITY).to_dict()) logger.info(f"鉴权通过: {path} {token}") return await call_next(request) @staticmethod def is_uri_authorized(path: str, path_list: list[str]) -> bool: # 遍历所有授权路径 for authorized_path in path_list: # 如果授权路径包含 **,使用 fnmatch 模块进行通配符匹配 if authorized_path.endswith("/**"): # 将 ** 转换为通配符 *,然后进行匹配 pattern = authorized_path.replace("/**", "/*") if fnmatch.fnmatch(path, pattern): return True elif path == authorized_path: # 完全匹配授权路径 return True return False