Files
figmar 81115dc23d 初始提交:识流 AI 助手项目
微信自动回复机器人,基于截图+OCR识别消息,支持关键词规则和 AI(OpenAI/DeepSeek/Dify)自动回复。
技术栈:PySide6 + Flask + Vue3 + RapidOCR + SQLite

注:OCR大模型文件(.onnx / .pdiparams)不纳入版本控制,需单独下载。

🤖 Generated with [Qoder][https://qoder.com]
2026-05-30 15:09:40 +08:00

347 lines
16 KiB
Python

import base64
from io import BytesIO
from pathlib import Path
import cv2
import numpy as np
import requests
from PIL import Image
from app.infrastructure.service.logging.log_service import log_event, new_trace_id
from app.infrastructure.service.wechat.config import (
BAIDU_API_KEY,
BAIDU_SECRET_KEY,
OCR_PROVIDER,
RAPID_OCR_CLS_MODEL_PATH,
RAPID_OCR_DET_MODEL_PATH,
RAPID_OCR_REC_MODEL_PATH,
SESSION_NAME_OCR_EXTRA_SCALE,
SESSION_NAME_OCR_SCALE,
)
BAIDU_FALLBACK_ERROR_CODES = {17, 18, 110, 111}
def _runtime_roots() -> list[Path]:
roots: list[Path] = []
meipass = getattr(__import__("sys"), "_MEIPASS", None)
if meipass:
roots.append(Path(meipass))
file_root = Path(__file__).resolve().parents[4]
roots.append(file_root)
cwd = Path.cwd().resolve()
roots.append(cwd)
roots.append(cwd / "resources")
roots.append(cwd / "app")
try:
exe_parent = Path(__import__("sys").executable).resolve().parent
roots.append(exe_parent)
roots.append(exe_parent / "resources")
roots.append(exe_parent / "app")
except Exception:
pass
unique_roots: list[Path] = []
seen = set()
for root in roots:
key = str(root)
if key in seen:
continue
seen.add(key)
unique_roots.append(root)
return unique_roots
def _resolve_project_path(path_str: str) -> str:
path = Path(path_str)
if path.is_absolute():
return str(path)
candidates = [(root / path).resolve() for root in _runtime_roots()]
for candidate in candidates:
if candidate.exists():
return str(candidate)
return str(candidates[0])
class OCRBase:
provider_name = "base"
def recognize(self, image_data, scene="generic", mode="generic"):
raise NotImplementedError
class BaiduOCR(OCRBase):
provider_name = "baidu"
def __init__(self, api_key, secret_key):
self.api_key = api_key
self.secret_key = secret_key
self.access_token = None
self.last_error_code = None
self.last_error_msg = ""
self.get_access_token()
def get_access_token(self):
trace_id = new_trace_id("ocr")
if not self.api_key or not self.secret_key:
log_event("WARNING", "ocr", "ocr.baidu.token", trace_id, "token", "failed", "百度OCR凭据缺失", reason="credential_missing")
return
url = "https://aip.baidubce.com/oauth/2.0/token"
params = {"grant_type": "client_credentials", "client_id": self.api_key, "client_secret": self.secret_key}
try:
response = requests.post(url, params=params, timeout=10)
if response.status_code == 200:
self.access_token = response.json().get("access_token")
if self.access_token:
log_event("INFO", "ocr", "ocr.baidu.token", trace_id, "token", "ok", "百度OCR token获取成功")
else:
log_event("WARNING", "ocr", "ocr.baidu.token", trace_id, "token", "failed", "百度OCR token为空", reason="token_empty")
else:
log_event("WARNING", "ocr", "ocr.baidu.token", trace_id, "token", "failed", "百度OCR token获取失败", reason="http_error", extra={"status_code": response.status_code})
except Exception as e:
log_event("ERROR", "ocr", "ocr.baidu.token", trace_id, "token", "failed", "百度OCR token请求异常", reason="request_error", extra={"error": str(e)})
def _reset_last_error(self):
self.last_error_code = None
self.last_error_msg = ""
def should_fallback_to_rapid(self):
return self.last_error_code in BAIDU_FALLBACK_ERROR_CODES
def recognize(self, image_data, scene="generic", mode="generic"):
trace_id = new_trace_id("ocr")
self._reset_last_error()
if not self.access_token:
self.last_error_msg = "no_access_token"
log_event("WARNING", "ocr", "ocr.baidu.recognize", trace_id, "recognize", "failed", "百度OCR无可用token", reason="no_access_token", extra={"scene": scene})
return []
url = f"https://aip.baidubce.com/rest/2.0/ocr/v1/general_basic?access_token={self.access_token}"
payload = {"image": base64.b64encode(image_data).decode(), "language_type": "CHN_ENG", "detect_direction": "true", "probability": "true"}
try:
response = requests.post(url, data=payload, timeout=10)
if response.status_code == 200:
result = response.json()
if "error_code" in result:
self.last_error_code = result.get("error_code")
self.last_error_msg = result.get("error_msg") or ""
log_event("WARNING", "ocr", "ocr.baidu.recognize", trace_id, "recognize", "failed", "百度OCR返回错误码", reason="baidu_error", extra={"scene": scene, "error_code": self.last_error_code, "error_msg": self.last_error_msg})
return []
if "words_result" in result:
lines = []
for item in result["words_result"]:
text = item.get("words", "")
prob = item.get("probability", {}).get("average", 0.9)
if text and prob > 0.6:
lines.append(text)
log_event("INFO", "ocr", "ocr.baidu.recognize", trace_id, "recognize", "ok", "百度OCR识别完成", extra={"scene": scene, "line_count": len(lines)})
return lines
else:
self.last_error_msg = f"http_{response.status_code}"
log_event("WARNING", "ocr", "ocr.baidu.recognize", trace_id, "recognize", "failed", "百度OCR请求失败", reason="http_error", extra={"scene": scene, "status_code": response.status_code})
except Exception as e:
self.last_error_msg = str(e)
log_event("ERROR", "ocr", "ocr.baidu.recognize", trace_id, "recognize", "failed", "百度OCR请求异常", reason="request_error", extra={"scene": scene, "error": str(e)})
return []
class RapidLocalOCR(OCRBase):
provider_name = "rapid"
def __init__(self):
self.ready = False
self.engine = None
self._init_engine()
def ensure_ready(self):
return self.ready and self.engine is not None
def _init_engine(self):
trace_id = new_trace_id("ocr")
try:
from rapidocr_onnxruntime import RapidOCR
model_paths = {
"det_model_path": _resolve_project_path(RAPID_OCR_DET_MODEL_PATH),
"rec_model_path": _resolve_project_path(RAPID_OCR_REC_MODEL_PATH),
"cls_model_path": _resolve_project_path(RAPID_OCR_CLS_MODEL_PATH),
}
existing_model_paths = {key: value for key, value in model_paths.items() if Path(value).exists()}
if len(existing_model_paths) == len(model_paths):
self.engine = RapidOCR(**existing_model_paths)
log_extra = existing_model_paths
else:
self.engine = RapidOCR()
log_extra = {**model_paths, "missing_models": [value for value in model_paths.values() if not Path(value).exists()]}
self.ready = True
log_event("INFO", "ocr", "ocr.rapid.init", trace_id, "init", "ok", "RapidOCR初始化成功", extra=log_extra)
except Exception as e:
self.ready = False
log_event("WARNING", "ocr", "ocr.rapid.init", trace_id, "init", "failed", "RapidOCR初始化失败", reason="init_error", extra={"error": str(e)})
def recognize(self, image_data, scene="generic", mode="generic"):
trace_id = new_trace_id("ocr")
if not self.ready or self.engine is None:
log_event("WARNING", "ocr", "ocr.rapid.recognize", trace_id, "recognize", "failed", "RapidOCR未就绪", reason="not_ready", extra={"scene": scene})
return []
try:
img_np = np.frombuffer(image_data, dtype=np.uint8)
img = cv2.imdecode(img_np, cv2.IMREAD_COLOR)
if img is None:
log_event("WARNING", "ocr", "ocr.rapid.recognize", trace_id, "recognize", "failed", "RapidOCR图像解码失败", reason="decode_failed", extra={"scene": scene})
return []
result = self.engine(img)
if not result or len(result) < 1:
log_event("INFO", "ocr", "ocr.rapid.recognize", trace_id, "recognize", "ok", "RapidOCR识别结果为空", reason="empty_result", extra={"scene": scene})
return []
rec_res = result[0] or []
lines = []
for item in rec_res:
if not item or len(item) < 2:
continue
text = str(item[1]).strip()
if text:
lines.append(text)
log_event("INFO", "ocr", "ocr.rapid.recognize", trace_id, "recognize", "ok", "RapidOCR识别完成", extra={"scene": scene, "line_count": len(lines)})
return lines
except Exception as e:
log_event("ERROR", "ocr", "ocr.rapid.recognize", trace_id, "recognize", "failed", "RapidOCR识别异常", reason="recognize_error", extra={"scene": scene, "error": str(e)})
return []
class OCRService(OCRBase):
provider_name = "service"
def __init__(self, provider=None):
self.provider_requested = (provider or OCR_PROVIDER or "baidu").strip().lower()
self.baidu_provider = BaiduOCR(BAIDU_API_KEY, BAIDU_SECRET_KEY)
self.rapid_provider = RapidLocalOCR()
self.provider = self._build_provider(self.provider_requested)
def _build_provider(self, provider_name: str):
if provider_name in {"rapid", "rapidocr"}:
return self.rapid_provider
if provider_name in {"baidu", "baiduocr"}:
return self.baidu_provider
if provider_name == "auto":
if self.baidu_provider.access_token:
return self.baidu_provider
if self.rapid_provider.ensure_ready():
return self.rapid_provider
return self.baidu_provider
return self.baidu_provider
def _provider_recognize(self, image_data, scene):
trace_id = new_trace_id("ocr")
lines = self.provider.recognize(image_data, scene=scene)
if self.provider.provider_name != "baidu":
return lines
if lines:
return lines
no_token_fallback = self.baidu_provider.last_error_msg == "no_access_token"
should_fallback = self.baidu_provider.should_fallback_to_rapid() or no_token_fallback
if not should_fallback:
return lines
if not self.rapid_provider.ensure_ready():
log_event("WARNING", "ocr", "ocr.fallback", trace_id, "fallback", "failed", "触发Rapid回退但引擎未就绪", reason="rapid_not_ready", extra={"scene": scene, "baidu_error": self.baidu_provider.last_error_msg or ""})
return lines
rapid_lines = self.rapid_provider.recognize(image_data, scene=f"{scene}_rapid_fallback")
if rapid_lines:
log_event("INFO", "ocr", "ocr.fallback", trace_id, "fallback", "ok", "百度OCR回退Rapid成功", reason="fallback_success", extra={"scene": scene, "line_count": len(rapid_lines)})
else:
log_event("WARNING", "ocr", "ocr.fallback", trace_id, "fallback", "failed", "百度OCR回退Rapid失败", reason="fallback_empty", extra={"scene": scene})
return rapid_lines
def _encode_image(self, image_obj):
buf = BytesIO()
image_obj.save(buf, format="PNG")
return buf.getvalue()
def _normalize_lines(self, lines, min_len=1, exclude=None):
exclude = set(exclude or [])
normalized = []
for line in lines or []:
text = str(line).strip()
if not text:
continue
if len(text) < min_len:
continue
if text in exclude:
continue
normalized.append(text)
return normalized
def _build_session_name_variants(self, image_data):
image = Image.open(BytesIO(image_data)).convert("RGB")
gray = image.convert("L")
base_scale = max(2, int(SESSION_NAME_OCR_SCALE))
extra_scale = max(base_scale, int(SESSION_NAME_OCR_EXTRA_SCALE))
enlarged = gray.resize(
(gray.width * base_scale, gray.height * base_scale),
resample=Image.Resampling.LANCZOS,
)
contrast = cv2.equalizeHist(np.array(enlarged))
binary = cv2.threshold(contrast, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1]
binary_inv = cv2.threshold(contrast, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]
extra_enlarged = gray.resize(
(gray.width * extra_scale, gray.height * extra_scale),
resample=Image.Resampling.LANCZOS,
)
extra_contrast = cv2.equalizeHist(np.array(extra_enlarged))
return [
("name_crop", image_data),
(f"name_orig_{base_scale}x", self._encode_image(enlarged)),
(f"name_eq_{base_scale}x", self._encode_image(Image.fromarray(contrast))),
(f"name_bin_{base_scale}x", self._encode_image(Image.fromarray(binary))),
(f"name_bin_inv_{base_scale}x", self._encode_image(Image.fromarray(binary_inv))),
(f"name_orig_{extra_scale}x", self._encode_image(extra_enlarged)),
(f"name_eq_{extra_scale}x", self._encode_image(Image.fromarray(extra_contrast))),
]
def _recognize_session_name(self, image_data, scene):
trace_id = new_trace_id("ocr")
for variant_name, variant_bytes in self._build_session_name_variants(image_data):
lines = self._normalize_lines(
self._provider_recognize(variant_bytes, scene=f"{scene}_{variant_name}"),
min_len=1,
)
if lines:
log_event("INFO", "ocr", "ocr.session_name", trace_id, "recognize", "ok", "会话名识别成功", extra={"scene": scene, "variant": variant_name, "line_count": len(lines)})
return lines
log_event("INFO", "ocr", "ocr.session_name", trace_id, "recognize", "failed", "会话名识别为空", reason="empty_result", extra={"scene": scene})
return []
def _recognize_session_title(self, image_data, scene):
trace_id = new_trace_id("ocr")
lines = self._provider_recognize(image_data, scene=scene)
normalized = self._normalize_lines(lines, min_len=1)
if normalized:
log_event("INFO", "ocr", "ocr.session_title", trace_id, "recognize", "ok", "会话标题识别成功", extra={"scene": scene, "line_count": len(normalized)})
else:
log_event("INFO", "ocr", "ocr.session_title", trace_id, "recognize", "failed", "会话标题识别为空", reason="empty_result", extra={"scene": scene})
return normalized
def recognize_session_name(self, image_data, scene="session_name"):
return self._recognize_session_name(image_data, scene=scene)
def recognize_session_title(self, image_data, scene="session_title"):
return self._recognize_session_title(image_data, scene=scene)
def recognize(self, image_data, scene="generic", mode="generic"):
trace_id = new_trace_id("ocr")
if mode == "session_name":
return self.recognize_session_name(image_data, scene=scene)
if mode == "session_title":
return self.recognize_session_title(image_data, scene=scene)
lines = self._normalize_lines(self._provider_recognize(image_data, scene=scene), min_len=1)
log_event("INFO", "ocr", "ocr.generic", trace_id, "recognize", "ok", "通用OCR识别完成", extra={"scene": scene, "line_count": len(lines)})
return lines