初始提交:识流 AI 助手项目
微信自动回复机器人,基于截图+OCR识别消息,支持关键词规则和 AI(OpenAI/DeepSeek/Dify)自动回复。 技术栈:PySide6 + Flask + Vue3 + RapidOCR + SQLite 注:OCR大模型文件(.onnx / .pdiparams)不纳入版本控制,需单独下载。 🤖 Generated with [Qoder][https://qoder.com]
This commit is contained in:
346
app/infrastructure/service/wechat/ocr.py
Normal file
346
app/infrastructure/service/wechat/ocr.py
Normal file
@@ -0,0 +1,346 @@
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
from app.infrastructure.service.logging.log_service import log_event, new_trace_id
|
||||
from app.infrastructure.service.wechat.config import (
|
||||
BAIDU_API_KEY,
|
||||
BAIDU_SECRET_KEY,
|
||||
OCR_PROVIDER,
|
||||
RAPID_OCR_CLS_MODEL_PATH,
|
||||
RAPID_OCR_DET_MODEL_PATH,
|
||||
RAPID_OCR_REC_MODEL_PATH,
|
||||
SESSION_NAME_OCR_EXTRA_SCALE,
|
||||
SESSION_NAME_OCR_SCALE,
|
||||
)
|
||||
|
||||
|
||||
BAIDU_FALLBACK_ERROR_CODES = {17, 18, 110, 111}
|
||||
|
||||
|
||||
def _runtime_roots() -> list[Path]:
|
||||
roots: list[Path] = []
|
||||
|
||||
meipass = getattr(__import__("sys"), "_MEIPASS", None)
|
||||
if meipass:
|
||||
roots.append(Path(meipass))
|
||||
|
||||
file_root = Path(__file__).resolve().parents[4]
|
||||
roots.append(file_root)
|
||||
|
||||
cwd = Path.cwd().resolve()
|
||||
roots.append(cwd)
|
||||
roots.append(cwd / "resources")
|
||||
roots.append(cwd / "app")
|
||||
|
||||
try:
|
||||
exe_parent = Path(__import__("sys").executable).resolve().parent
|
||||
roots.append(exe_parent)
|
||||
roots.append(exe_parent / "resources")
|
||||
roots.append(exe_parent / "app")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
unique_roots: list[Path] = []
|
||||
seen = set()
|
||||
for root in roots:
|
||||
key = str(root)
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
unique_roots.append(root)
|
||||
return unique_roots
|
||||
|
||||
|
||||
def _resolve_project_path(path_str: str) -> str:
|
||||
path = Path(path_str)
|
||||
if path.is_absolute():
|
||||
return str(path)
|
||||
|
||||
candidates = [(root / path).resolve() for root in _runtime_roots()]
|
||||
for candidate in candidates:
|
||||
if candidate.exists():
|
||||
return str(candidate)
|
||||
return str(candidates[0])
|
||||
|
||||
|
||||
class OCRBase:
|
||||
provider_name = "base"
|
||||
|
||||
def recognize(self, image_data, scene="generic", mode="generic"):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class BaiduOCR(OCRBase):
|
||||
provider_name = "baidu"
|
||||
|
||||
def __init__(self, api_key, secret_key):
|
||||
self.api_key = api_key
|
||||
self.secret_key = secret_key
|
||||
self.access_token = None
|
||||
self.last_error_code = None
|
||||
self.last_error_msg = ""
|
||||
self.get_access_token()
|
||||
|
||||
def get_access_token(self):
|
||||
trace_id = new_trace_id("ocr")
|
||||
if not self.api_key or not self.secret_key:
|
||||
log_event("WARNING", "ocr", "ocr.baidu.token", trace_id, "token", "failed", "百度OCR凭据缺失", reason="credential_missing")
|
||||
return
|
||||
url = "https://aip.baidubce.com/oauth/2.0/token"
|
||||
params = {"grant_type": "client_credentials", "client_id": self.api_key, "client_secret": self.secret_key}
|
||||
try:
|
||||
response = requests.post(url, params=params, timeout=10)
|
||||
if response.status_code == 200:
|
||||
self.access_token = response.json().get("access_token")
|
||||
if self.access_token:
|
||||
log_event("INFO", "ocr", "ocr.baidu.token", trace_id, "token", "ok", "百度OCR token获取成功")
|
||||
else:
|
||||
log_event("WARNING", "ocr", "ocr.baidu.token", trace_id, "token", "failed", "百度OCR token为空", reason="token_empty")
|
||||
else:
|
||||
log_event("WARNING", "ocr", "ocr.baidu.token", trace_id, "token", "failed", "百度OCR token获取失败", reason="http_error", extra={"status_code": response.status_code})
|
||||
except Exception as e:
|
||||
log_event("ERROR", "ocr", "ocr.baidu.token", trace_id, "token", "failed", "百度OCR token请求异常", reason="request_error", extra={"error": str(e)})
|
||||
|
||||
def _reset_last_error(self):
|
||||
self.last_error_code = None
|
||||
self.last_error_msg = ""
|
||||
|
||||
def should_fallback_to_rapid(self):
|
||||
return self.last_error_code in BAIDU_FALLBACK_ERROR_CODES
|
||||
|
||||
def recognize(self, image_data, scene="generic", mode="generic"):
|
||||
trace_id = new_trace_id("ocr")
|
||||
self._reset_last_error()
|
||||
if not self.access_token:
|
||||
self.last_error_msg = "no_access_token"
|
||||
log_event("WARNING", "ocr", "ocr.baidu.recognize", trace_id, "recognize", "failed", "百度OCR无可用token", reason="no_access_token", extra={"scene": scene})
|
||||
return []
|
||||
url = f"https://aip.baidubce.com/rest/2.0/ocr/v1/general_basic?access_token={self.access_token}"
|
||||
payload = {"image": base64.b64encode(image_data).decode(), "language_type": "CHN_ENG", "detect_direction": "true", "probability": "true"}
|
||||
try:
|
||||
response = requests.post(url, data=payload, timeout=10)
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
if "error_code" in result:
|
||||
self.last_error_code = result.get("error_code")
|
||||
self.last_error_msg = result.get("error_msg") or ""
|
||||
log_event("WARNING", "ocr", "ocr.baidu.recognize", trace_id, "recognize", "failed", "百度OCR返回错误码", reason="baidu_error", extra={"scene": scene, "error_code": self.last_error_code, "error_msg": self.last_error_msg})
|
||||
return []
|
||||
if "words_result" in result:
|
||||
lines = []
|
||||
for item in result["words_result"]:
|
||||
text = item.get("words", "")
|
||||
prob = item.get("probability", {}).get("average", 0.9)
|
||||
if text and prob > 0.6:
|
||||
lines.append(text)
|
||||
log_event("INFO", "ocr", "ocr.baidu.recognize", trace_id, "recognize", "ok", "百度OCR识别完成", extra={"scene": scene, "line_count": len(lines)})
|
||||
return lines
|
||||
else:
|
||||
self.last_error_msg = f"http_{response.status_code}"
|
||||
log_event("WARNING", "ocr", "ocr.baidu.recognize", trace_id, "recognize", "failed", "百度OCR请求失败", reason="http_error", extra={"scene": scene, "status_code": response.status_code})
|
||||
except Exception as e:
|
||||
self.last_error_msg = str(e)
|
||||
log_event("ERROR", "ocr", "ocr.baidu.recognize", trace_id, "recognize", "failed", "百度OCR请求异常", reason="request_error", extra={"scene": scene, "error": str(e)})
|
||||
return []
|
||||
|
||||
|
||||
class RapidLocalOCR(OCRBase):
|
||||
provider_name = "rapid"
|
||||
|
||||
def __init__(self):
|
||||
self.ready = False
|
||||
self.engine = None
|
||||
self._init_engine()
|
||||
|
||||
def ensure_ready(self):
|
||||
return self.ready and self.engine is not None
|
||||
|
||||
def _init_engine(self):
|
||||
trace_id = new_trace_id("ocr")
|
||||
try:
|
||||
from rapidocr_onnxruntime import RapidOCR
|
||||
|
||||
model_paths = {
|
||||
"det_model_path": _resolve_project_path(RAPID_OCR_DET_MODEL_PATH),
|
||||
"rec_model_path": _resolve_project_path(RAPID_OCR_REC_MODEL_PATH),
|
||||
"cls_model_path": _resolve_project_path(RAPID_OCR_CLS_MODEL_PATH),
|
||||
}
|
||||
existing_model_paths = {key: value for key, value in model_paths.items() if Path(value).exists()}
|
||||
if len(existing_model_paths) == len(model_paths):
|
||||
self.engine = RapidOCR(**existing_model_paths)
|
||||
log_extra = existing_model_paths
|
||||
else:
|
||||
self.engine = RapidOCR()
|
||||
log_extra = {**model_paths, "missing_models": [value for value in model_paths.values() if not Path(value).exists()]}
|
||||
self.ready = True
|
||||
log_event("INFO", "ocr", "ocr.rapid.init", trace_id, "init", "ok", "RapidOCR初始化成功", extra=log_extra)
|
||||
except Exception as e:
|
||||
self.ready = False
|
||||
log_event("WARNING", "ocr", "ocr.rapid.init", trace_id, "init", "failed", "RapidOCR初始化失败", reason="init_error", extra={"error": str(e)})
|
||||
|
||||
def recognize(self, image_data, scene="generic", mode="generic"):
|
||||
trace_id = new_trace_id("ocr")
|
||||
if not self.ready or self.engine is None:
|
||||
log_event("WARNING", "ocr", "ocr.rapid.recognize", trace_id, "recognize", "failed", "RapidOCR未就绪", reason="not_ready", extra={"scene": scene})
|
||||
return []
|
||||
try:
|
||||
img_np = np.frombuffer(image_data, dtype=np.uint8)
|
||||
img = cv2.imdecode(img_np, cv2.IMREAD_COLOR)
|
||||
if img is None:
|
||||
log_event("WARNING", "ocr", "ocr.rapid.recognize", trace_id, "recognize", "failed", "RapidOCR图像解码失败", reason="decode_failed", extra={"scene": scene})
|
||||
return []
|
||||
result = self.engine(img)
|
||||
if not result or len(result) < 1:
|
||||
log_event("INFO", "ocr", "ocr.rapid.recognize", trace_id, "recognize", "ok", "RapidOCR识别结果为空", reason="empty_result", extra={"scene": scene})
|
||||
return []
|
||||
rec_res = result[0] or []
|
||||
lines = []
|
||||
for item in rec_res:
|
||||
if not item or len(item) < 2:
|
||||
continue
|
||||
text = str(item[1]).strip()
|
||||
if text:
|
||||
lines.append(text)
|
||||
log_event("INFO", "ocr", "ocr.rapid.recognize", trace_id, "recognize", "ok", "RapidOCR识别完成", extra={"scene": scene, "line_count": len(lines)})
|
||||
return lines
|
||||
except Exception as e:
|
||||
log_event("ERROR", "ocr", "ocr.rapid.recognize", trace_id, "recognize", "failed", "RapidOCR识别异常", reason="recognize_error", extra={"scene": scene, "error": str(e)})
|
||||
return []
|
||||
|
||||
|
||||
class OCRService(OCRBase):
|
||||
provider_name = "service"
|
||||
|
||||
def __init__(self, provider=None):
|
||||
self.provider_requested = (provider or OCR_PROVIDER or "baidu").strip().lower()
|
||||
self.baidu_provider = BaiduOCR(BAIDU_API_KEY, BAIDU_SECRET_KEY)
|
||||
self.rapid_provider = RapidLocalOCR()
|
||||
self.provider = self._build_provider(self.provider_requested)
|
||||
|
||||
def _build_provider(self, provider_name: str):
|
||||
if provider_name in {"rapid", "rapidocr"}:
|
||||
return self.rapid_provider
|
||||
if provider_name in {"baidu", "baiduocr"}:
|
||||
return self.baidu_provider
|
||||
if provider_name == "auto":
|
||||
if self.baidu_provider.access_token:
|
||||
return self.baidu_provider
|
||||
if self.rapid_provider.ensure_ready():
|
||||
return self.rapid_provider
|
||||
return self.baidu_provider
|
||||
return self.baidu_provider
|
||||
|
||||
def _provider_recognize(self, image_data, scene):
|
||||
trace_id = new_trace_id("ocr")
|
||||
lines = self.provider.recognize(image_data, scene=scene)
|
||||
if self.provider.provider_name != "baidu":
|
||||
return lines
|
||||
if lines:
|
||||
return lines
|
||||
no_token_fallback = self.baidu_provider.last_error_msg == "no_access_token"
|
||||
should_fallback = self.baidu_provider.should_fallback_to_rapid() or no_token_fallback
|
||||
if not should_fallback:
|
||||
return lines
|
||||
if not self.rapid_provider.ensure_ready():
|
||||
log_event("WARNING", "ocr", "ocr.fallback", trace_id, "fallback", "failed", "触发Rapid回退但引擎未就绪", reason="rapid_not_ready", extra={"scene": scene, "baidu_error": self.baidu_provider.last_error_msg or ""})
|
||||
return lines
|
||||
rapid_lines = self.rapid_provider.recognize(image_data, scene=f"{scene}_rapid_fallback")
|
||||
if rapid_lines:
|
||||
log_event("INFO", "ocr", "ocr.fallback", trace_id, "fallback", "ok", "百度OCR回退Rapid成功", reason="fallback_success", extra={"scene": scene, "line_count": len(rapid_lines)})
|
||||
else:
|
||||
log_event("WARNING", "ocr", "ocr.fallback", trace_id, "fallback", "failed", "百度OCR回退Rapid失败", reason="fallback_empty", extra={"scene": scene})
|
||||
return rapid_lines
|
||||
|
||||
def _encode_image(self, image_obj):
|
||||
buf = BytesIO()
|
||||
image_obj.save(buf, format="PNG")
|
||||
return buf.getvalue()
|
||||
|
||||
def _normalize_lines(self, lines, min_len=1, exclude=None):
|
||||
exclude = set(exclude or [])
|
||||
normalized = []
|
||||
for line in lines or []:
|
||||
text = str(line).strip()
|
||||
if not text:
|
||||
continue
|
||||
if len(text) < min_len:
|
||||
continue
|
||||
if text in exclude:
|
||||
continue
|
||||
normalized.append(text)
|
||||
return normalized
|
||||
|
||||
def _build_session_name_variants(self, image_data):
|
||||
image = Image.open(BytesIO(image_data)).convert("RGB")
|
||||
gray = image.convert("L")
|
||||
base_scale = max(2, int(SESSION_NAME_OCR_SCALE))
|
||||
extra_scale = max(base_scale, int(SESSION_NAME_OCR_EXTRA_SCALE))
|
||||
|
||||
enlarged = gray.resize(
|
||||
(gray.width * base_scale, gray.height * base_scale),
|
||||
resample=Image.Resampling.LANCZOS,
|
||||
)
|
||||
contrast = cv2.equalizeHist(np.array(enlarged))
|
||||
binary = cv2.threshold(contrast, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1]
|
||||
binary_inv = cv2.threshold(contrast, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]
|
||||
|
||||
extra_enlarged = gray.resize(
|
||||
(gray.width * extra_scale, gray.height * extra_scale),
|
||||
resample=Image.Resampling.LANCZOS,
|
||||
)
|
||||
extra_contrast = cv2.equalizeHist(np.array(extra_enlarged))
|
||||
|
||||
return [
|
||||
("name_crop", image_data),
|
||||
(f"name_orig_{base_scale}x", self._encode_image(enlarged)),
|
||||
(f"name_eq_{base_scale}x", self._encode_image(Image.fromarray(contrast))),
|
||||
(f"name_bin_{base_scale}x", self._encode_image(Image.fromarray(binary))),
|
||||
(f"name_bin_inv_{base_scale}x", self._encode_image(Image.fromarray(binary_inv))),
|
||||
(f"name_orig_{extra_scale}x", self._encode_image(extra_enlarged)),
|
||||
(f"name_eq_{extra_scale}x", self._encode_image(Image.fromarray(extra_contrast))),
|
||||
]
|
||||
|
||||
def _recognize_session_name(self, image_data, scene):
|
||||
trace_id = new_trace_id("ocr")
|
||||
for variant_name, variant_bytes in self._build_session_name_variants(image_data):
|
||||
lines = self._normalize_lines(
|
||||
self._provider_recognize(variant_bytes, scene=f"{scene}_{variant_name}"),
|
||||
min_len=1,
|
||||
)
|
||||
if lines:
|
||||
log_event("INFO", "ocr", "ocr.session_name", trace_id, "recognize", "ok", "会话名识别成功", extra={"scene": scene, "variant": variant_name, "line_count": len(lines)})
|
||||
return lines
|
||||
log_event("INFO", "ocr", "ocr.session_name", trace_id, "recognize", "failed", "会话名识别为空", reason="empty_result", extra={"scene": scene})
|
||||
return []
|
||||
|
||||
def _recognize_session_title(self, image_data, scene):
|
||||
trace_id = new_trace_id("ocr")
|
||||
lines = self._provider_recognize(image_data, scene=scene)
|
||||
normalized = self._normalize_lines(lines, min_len=1)
|
||||
if normalized:
|
||||
log_event("INFO", "ocr", "ocr.session_title", trace_id, "recognize", "ok", "会话标题识别成功", extra={"scene": scene, "line_count": len(normalized)})
|
||||
else:
|
||||
log_event("INFO", "ocr", "ocr.session_title", trace_id, "recognize", "failed", "会话标题识别为空", reason="empty_result", extra={"scene": scene})
|
||||
return normalized
|
||||
|
||||
def recognize_session_name(self, image_data, scene="session_name"):
|
||||
return self._recognize_session_name(image_data, scene=scene)
|
||||
|
||||
def recognize_session_title(self, image_data, scene="session_title"):
|
||||
return self._recognize_session_title(image_data, scene=scene)
|
||||
|
||||
def recognize(self, image_data, scene="generic", mode="generic"):
|
||||
trace_id = new_trace_id("ocr")
|
||||
if mode == "session_name":
|
||||
return self.recognize_session_name(image_data, scene=scene)
|
||||
if mode == "session_title":
|
||||
return self.recognize_session_title(image_data, scene=scene)
|
||||
lines = self._normalize_lines(self._provider_recognize(image_data, scene=scene), min_len=1)
|
||||
log_event("INFO", "ocr", "ocr.generic", trace_id, "recognize", "ok", "通用OCR识别完成", extra={"scene": scene, "line_count": len(lines)})
|
||||
return lines
|
||||
Reference in New Issue
Block a user