Files
ai-shiliu/app/infrastructure/service/wechat/ocr.py

347 lines
16 KiB
Python
Raw Normal View History

import base64
from io import BytesIO
from pathlib import Path
import cv2
import numpy as np
import requests
from PIL import Image
from app.infrastructure.service.logging.log_service import log_event, new_trace_id
from app.infrastructure.service.wechat.config import (
BAIDU_API_KEY,
BAIDU_SECRET_KEY,
OCR_PROVIDER,
RAPID_OCR_CLS_MODEL_PATH,
RAPID_OCR_DET_MODEL_PATH,
RAPID_OCR_REC_MODEL_PATH,
SESSION_NAME_OCR_EXTRA_SCALE,
SESSION_NAME_OCR_SCALE,
)
BAIDU_FALLBACK_ERROR_CODES = {17, 18, 110, 111}
def _runtime_roots() -> list[Path]:
roots: list[Path] = []
meipass = getattr(__import__("sys"), "_MEIPASS", None)
if meipass:
roots.append(Path(meipass))
file_root = Path(__file__).resolve().parents[4]
roots.append(file_root)
cwd = Path.cwd().resolve()
roots.append(cwd)
roots.append(cwd / "resources")
roots.append(cwd / "app")
try:
exe_parent = Path(__import__("sys").executable).resolve().parent
roots.append(exe_parent)
roots.append(exe_parent / "resources")
roots.append(exe_parent / "app")
except Exception:
pass
unique_roots: list[Path] = []
seen = set()
for root in roots:
key = str(root)
if key in seen:
continue
seen.add(key)
unique_roots.append(root)
return unique_roots
def _resolve_project_path(path_str: str) -> str:
path = Path(path_str)
if path.is_absolute():
return str(path)
candidates = [(root / path).resolve() for root in _runtime_roots()]
for candidate in candidates:
if candidate.exists():
return str(candidate)
return str(candidates[0])
class OCRBase:
provider_name = "base"
def recognize(self, image_data, scene="generic", mode="generic"):
raise NotImplementedError
class BaiduOCR(OCRBase):
provider_name = "baidu"
def __init__(self, api_key, secret_key):
self.api_key = api_key
self.secret_key = secret_key
self.access_token = None
self.last_error_code = None
self.last_error_msg = ""
self.get_access_token()
def get_access_token(self):
trace_id = new_trace_id("ocr")
if not self.api_key or not self.secret_key:
log_event("WARNING", "ocr", "ocr.baidu.token", trace_id, "token", "failed", "百度OCR凭据缺失", reason="credential_missing")
return
url = "https://aip.baidubce.com/oauth/2.0/token"
params = {"grant_type": "client_credentials", "client_id": self.api_key, "client_secret": self.secret_key}
try:
response = requests.post(url, params=params, timeout=10)
if response.status_code == 200:
self.access_token = response.json().get("access_token")
if self.access_token:
log_event("INFO", "ocr", "ocr.baidu.token", trace_id, "token", "ok", "百度OCR token获取成功")
else:
log_event("WARNING", "ocr", "ocr.baidu.token", trace_id, "token", "failed", "百度OCR token为空", reason="token_empty")
else:
log_event("WARNING", "ocr", "ocr.baidu.token", trace_id, "token", "failed", "百度OCR token获取失败", reason="http_error", extra={"status_code": response.status_code})
except Exception as e:
log_event("ERROR", "ocr", "ocr.baidu.token", trace_id, "token", "failed", "百度OCR token请求异常", reason="request_error", extra={"error": str(e)})
def _reset_last_error(self):
self.last_error_code = None
self.last_error_msg = ""
def should_fallback_to_rapid(self):
return self.last_error_code in BAIDU_FALLBACK_ERROR_CODES
def recognize(self, image_data, scene="generic", mode="generic"):
trace_id = new_trace_id("ocr")
self._reset_last_error()
if not self.access_token:
self.last_error_msg = "no_access_token"
log_event("WARNING", "ocr", "ocr.baidu.recognize", trace_id, "recognize", "failed", "百度OCR无可用token", reason="no_access_token", extra={"scene": scene})
return []
url = f"https://aip.baidubce.com/rest/2.0/ocr/v1/general_basic?access_token={self.access_token}"
payload = {"image": base64.b64encode(image_data).decode(), "language_type": "CHN_ENG", "detect_direction": "true", "probability": "true"}
try:
response = requests.post(url, data=payload, timeout=10)
if response.status_code == 200:
result = response.json()
if "error_code" in result:
self.last_error_code = result.get("error_code")
self.last_error_msg = result.get("error_msg") or ""
log_event("WARNING", "ocr", "ocr.baidu.recognize", trace_id, "recognize", "failed", "百度OCR返回错误码", reason="baidu_error", extra={"scene": scene, "error_code": self.last_error_code, "error_msg": self.last_error_msg})
return []
if "words_result" in result:
lines = []
for item in result["words_result"]:
text = item.get("words", "")
prob = item.get("probability", {}).get("average", 0.9)
if text and prob > 0.6:
lines.append(text)
log_event("INFO", "ocr", "ocr.baidu.recognize", trace_id, "recognize", "ok", "百度OCR识别完成", extra={"scene": scene, "line_count": len(lines)})
return lines
else:
self.last_error_msg = f"http_{response.status_code}"
log_event("WARNING", "ocr", "ocr.baidu.recognize", trace_id, "recognize", "failed", "百度OCR请求失败", reason="http_error", extra={"scene": scene, "status_code": response.status_code})
except Exception as e:
self.last_error_msg = str(e)
log_event("ERROR", "ocr", "ocr.baidu.recognize", trace_id, "recognize", "failed", "百度OCR请求异常", reason="request_error", extra={"scene": scene, "error": str(e)})
return []
class RapidLocalOCR(OCRBase):
provider_name = "rapid"
def __init__(self):
self.ready = False
self.engine = None
self._init_engine()
def ensure_ready(self):
return self.ready and self.engine is not None
def _init_engine(self):
trace_id = new_trace_id("ocr")
try:
from rapidocr_onnxruntime import RapidOCR
model_paths = {
"det_model_path": _resolve_project_path(RAPID_OCR_DET_MODEL_PATH),
"rec_model_path": _resolve_project_path(RAPID_OCR_REC_MODEL_PATH),
"cls_model_path": _resolve_project_path(RAPID_OCR_CLS_MODEL_PATH),
}
existing_model_paths = {key: value for key, value in model_paths.items() if Path(value).exists()}
if len(existing_model_paths) == len(model_paths):
self.engine = RapidOCR(**existing_model_paths)
log_extra = existing_model_paths
else:
self.engine = RapidOCR()
log_extra = {**model_paths, "missing_models": [value for value in model_paths.values() if not Path(value).exists()]}
self.ready = True
log_event("INFO", "ocr", "ocr.rapid.init", trace_id, "init", "ok", "RapidOCR初始化成功", extra=log_extra)
except Exception as e:
self.ready = False
log_event("WARNING", "ocr", "ocr.rapid.init", trace_id, "init", "failed", "RapidOCR初始化失败", reason="init_error", extra={"error": str(e)})
def recognize(self, image_data, scene="generic", mode="generic"):
trace_id = new_trace_id("ocr")
if not self.ready or self.engine is None:
log_event("WARNING", "ocr", "ocr.rapid.recognize", trace_id, "recognize", "failed", "RapidOCR未就绪", reason="not_ready", extra={"scene": scene})
return []
try:
img_np = np.frombuffer(image_data, dtype=np.uint8)
img = cv2.imdecode(img_np, cv2.IMREAD_COLOR)
if img is None:
log_event("WARNING", "ocr", "ocr.rapid.recognize", trace_id, "recognize", "failed", "RapidOCR图像解码失败", reason="decode_failed", extra={"scene": scene})
return []
result = self.engine(img)
if not result or len(result) < 1:
log_event("INFO", "ocr", "ocr.rapid.recognize", trace_id, "recognize", "ok", "RapidOCR识别结果为空", reason="empty_result", extra={"scene": scene})
return []
rec_res = result[0] or []
lines = []
for item in rec_res:
if not item or len(item) < 2:
continue
text = str(item[1]).strip()
if text:
lines.append(text)
log_event("INFO", "ocr", "ocr.rapid.recognize", trace_id, "recognize", "ok", "RapidOCR识别完成", extra={"scene": scene, "line_count": len(lines)})
return lines
except Exception as e:
log_event("ERROR", "ocr", "ocr.rapid.recognize", trace_id, "recognize", "failed", "RapidOCR识别异常", reason="recognize_error", extra={"scene": scene, "error": str(e)})
return []
class OCRService(OCRBase):
provider_name = "service"
def __init__(self, provider=None):
self.provider_requested = (provider or OCR_PROVIDER or "baidu").strip().lower()
self.baidu_provider = BaiduOCR(BAIDU_API_KEY, BAIDU_SECRET_KEY)
self.rapid_provider = RapidLocalOCR()
self.provider = self._build_provider(self.provider_requested)
def _build_provider(self, provider_name: str):
if provider_name in {"rapid", "rapidocr"}:
return self.rapid_provider
if provider_name in {"baidu", "baiduocr"}:
return self.baidu_provider
if provider_name == "auto":
if self.baidu_provider.access_token:
return self.baidu_provider
if self.rapid_provider.ensure_ready():
return self.rapid_provider
return self.baidu_provider
return self.baidu_provider
def _provider_recognize(self, image_data, scene):
trace_id = new_trace_id("ocr")
lines = self.provider.recognize(image_data, scene=scene)
if self.provider.provider_name != "baidu":
return lines
if lines:
return lines
no_token_fallback = self.baidu_provider.last_error_msg == "no_access_token"
should_fallback = self.baidu_provider.should_fallback_to_rapid() or no_token_fallback
if not should_fallback:
return lines
if not self.rapid_provider.ensure_ready():
log_event("WARNING", "ocr", "ocr.fallback", trace_id, "fallback", "failed", "触发Rapid回退但引擎未就绪", reason="rapid_not_ready", extra={"scene": scene, "baidu_error": self.baidu_provider.last_error_msg or ""})
return lines
rapid_lines = self.rapid_provider.recognize(image_data, scene=f"{scene}_rapid_fallback")
if rapid_lines:
log_event("INFO", "ocr", "ocr.fallback", trace_id, "fallback", "ok", "百度OCR回退Rapid成功", reason="fallback_success", extra={"scene": scene, "line_count": len(rapid_lines)})
else:
log_event("WARNING", "ocr", "ocr.fallback", trace_id, "fallback", "failed", "百度OCR回退Rapid失败", reason="fallback_empty", extra={"scene": scene})
return rapid_lines
def _encode_image(self, image_obj):
buf = BytesIO()
image_obj.save(buf, format="PNG")
return buf.getvalue()
def _normalize_lines(self, lines, min_len=1, exclude=None):
exclude = set(exclude or [])
normalized = []
for line in lines or []:
text = str(line).strip()
if not text:
continue
if len(text) < min_len:
continue
if text in exclude:
continue
normalized.append(text)
return normalized
def _build_session_name_variants(self, image_data):
image = Image.open(BytesIO(image_data)).convert("RGB")
gray = image.convert("L")
base_scale = max(2, int(SESSION_NAME_OCR_SCALE))
extra_scale = max(base_scale, int(SESSION_NAME_OCR_EXTRA_SCALE))
enlarged = gray.resize(
(gray.width * base_scale, gray.height * base_scale),
resample=Image.Resampling.LANCZOS,
)
contrast = cv2.equalizeHist(np.array(enlarged))
binary = cv2.threshold(contrast, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1]
binary_inv = cv2.threshold(contrast, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]
extra_enlarged = gray.resize(
(gray.width * extra_scale, gray.height * extra_scale),
resample=Image.Resampling.LANCZOS,
)
extra_contrast = cv2.equalizeHist(np.array(extra_enlarged))
return [
("name_crop", image_data),
(f"name_orig_{base_scale}x", self._encode_image(enlarged)),
(f"name_eq_{base_scale}x", self._encode_image(Image.fromarray(contrast))),
(f"name_bin_{base_scale}x", self._encode_image(Image.fromarray(binary))),
(f"name_bin_inv_{base_scale}x", self._encode_image(Image.fromarray(binary_inv))),
(f"name_orig_{extra_scale}x", self._encode_image(extra_enlarged)),
(f"name_eq_{extra_scale}x", self._encode_image(Image.fromarray(extra_contrast))),
]
def _recognize_session_name(self, image_data, scene):
trace_id = new_trace_id("ocr")
for variant_name, variant_bytes in self._build_session_name_variants(image_data):
lines = self._normalize_lines(
self._provider_recognize(variant_bytes, scene=f"{scene}_{variant_name}"),
min_len=1,
)
if lines:
log_event("INFO", "ocr", "ocr.session_name", trace_id, "recognize", "ok", "会话名识别成功", extra={"scene": scene, "variant": variant_name, "line_count": len(lines)})
return lines
log_event("INFO", "ocr", "ocr.session_name", trace_id, "recognize", "failed", "会话名识别为空", reason="empty_result", extra={"scene": scene})
return []
def _recognize_session_title(self, image_data, scene):
trace_id = new_trace_id("ocr")
lines = self._provider_recognize(image_data, scene=scene)
normalized = self._normalize_lines(lines, min_len=1)
if normalized:
log_event("INFO", "ocr", "ocr.session_title", trace_id, "recognize", "ok", "会话标题识别成功", extra={"scene": scene, "line_count": len(normalized)})
else:
log_event("INFO", "ocr", "ocr.session_title", trace_id, "recognize", "failed", "会话标题识别为空", reason="empty_result", extra={"scene": scene})
return normalized
def recognize_session_name(self, image_data, scene="session_name"):
return self._recognize_session_name(image_data, scene=scene)
def recognize_session_title(self, image_data, scene="session_title"):
return self._recognize_session_title(image_data, scene=scene)
def recognize(self, image_data, scene="generic", mode="generic"):
trace_id = new_trace_id("ocr")
if mode == "session_name":
return self.recognize_session_name(image_data, scene=scene)
if mode == "session_title":
return self.recognize_session_title(image_data, scene=scene)
lines = self._normalize_lines(self._provider_recognize(image_data, scene=scene), min_len=1)
log_event("INFO", "ocr", "ocr.generic", trace_id, "recognize", "ok", "通用OCR识别完成", extra={"scene": scene, "line_count": len(lines)})
return lines