steam-server/linxyun/sso/middleware/security.py
2024-12-16 14:31:28 +08:00

78 lines
3.1 KiB
Python

from starlette.middleware.base import BaseHTTPMiddleware
import fnmatch
from starlette.responses import JSONResponse
from linxyun.utils.logger import get_logger
from linxyun.utils.result import Result, SysCodes
import re
logger = get_logger(__name__)
# 鉴权中间件
class SecurityMiddleware(BaseHTTPMiddleware):
def __init__(self, app, linxyun):
logger.info(f"添加鉴权中间件 SecurityMiddleware")
super().__init__(app)
self.linxyun = linxyun
self.pattern = re.compile("LoginID_\\d{14}_\\d{6}")
async def dispatch(self, request, call_next):
# 获取请求路径
path = request.url.path
method = request.method
logger.info(f"鉴权拦截: {method} {path}")
# 获取请求头 Token
token = request.headers.get("token")
if not token:
return JSONResponse(content=Result.error(SysCodes.USER_NOT_LOGIN).to_dict())
if token.startswith("Session"):
search_resp = self.pattern.search(token)
if search_resp:
token = search_resp.group()
user_auth_result:dict = self.linxyun.get_user_auth(token).to_dict()
if user_auth_result.get("success") is False:
logger.info(f"用户登录信息失效: {user_auth_result}")
return JSONResponse(content=user_auth_result)
user_auth = user_auth_result.get("data")
user_role = user_auth.get("UserRoles")
if user_role is None:
logger.info(f"用户角色为空: {user_auth}")
return JSONResponse(content=Result.error(SysCodes.USER_NO_AUTHORITY).to_dict())
if self.linxyun.config.entCode != user_auth.get("EntCode"):
logger.info(f"用户企业编码错误: {user_auth.get('EntCode')}")
return JSONResponse(content=Result.error(SysCodes.LOGIN_ERROR).to_dict())
role_map = self.linxyun.config.role
if user_role not in role_map:
logger.info(f"用户权限未在系统权限中: {user_role}")
return JSONResponse(content=Result.error(SysCodes.LOGIN_ERROR).to_dict())
# 是否有权限访问该路径
path_list = role_map.get(user_role)
if self.is_uri_authorized(path, path_list) is False:
logger.info(f"用户没有权限访问该路径: {path}")
return JSONResponse(content=Result.error(SysCodes.USER_NO_AUTHORITY).to_dict())
logger.info(f"鉴权通过: {path} {token}")
return await call_next(request)
@staticmethod
def is_uri_authorized(path: str, path_list: list[str]) -> bool:
# 遍历所有授权路径
for authorized_path in path_list:
# 如果授权路径包含 **,使用 fnmatch 模块进行通配符匹配
if authorized_path.endswith("/**"):
# 将 ** 转换为通配符 *,然后进行匹配
pattern = authorized_path.replace("/**", "/*")
if fnmatch.fnmatch(path, pattern):
return True
elif path == authorized_path:
# 完全匹配授权路径
return True
return False