Files
ai-shiliu/app/infrastructure/service/wechat/session_service.py
figmar 81115dc23d 初始提交:识流 AI 助手项目
微信自动回复机器人,基于截图+OCR识别消息,支持关键词规则和 AI(OpenAI/DeepSeek/Dify)自动回复。
技术栈:PySide6 + Flask + Vue3 + RapidOCR + SQLite

注:OCR大模型文件(.onnx / .pdiparams)不纳入版本控制,需单独下载。

🤖 Generated with [Qoder][https://qoder.com]
2026-05-30 15:09:40 +08:00

289 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
from dataclasses import dataclass
from io import BytesIO
import json
import os
from typing import Callable
import cv2
import numpy as np
from app.infrastructure.service.logging.log_service import log_event, new_trace_id
from app.infrastructure.service.wechat.chat_snapshot_analyzer import analyze_pil_image
from app.infrastructure.service.wechat.unread_session_analyzer import UnreadSessionAnalyzer
from app.infrastructure.service.wechat.config import (
BLOCKED_SESSION_KEYWORDS,
CONTACT_ROW_HEIGHT,
OCR_SAVE_DIR,
OCR_SAVE_IMAGES,
SESSION_NAME_HEIGHT,
SESSION_NAME_LEFT_OFFSET,
SESSION_NAME_TOP_OFFSET,
SESSION_NAME_WIDTH,
UI_NOISE_KEYWORDS,
)
# 会话扫描结果数据类,包含所有会话列表和未读会话列表
@dataclass
class SessionScanResult:
sessions: list[dict]
unread_sessions: list[dict]
# 聊天快照分析结果数据类
@dataclass
class ChatAnalyzeResult:
ok: bool
file_name: str
latest_text: str
confidence: str | float
bubble_side: str
screenshot_path: str
# 微信会话服务类,处理会话列表扫描、红点检测和聊天截图分析
class WechatSessionService:
def __init__(self, screenshot_service, ocr_service, save_debug_image: Callable | None = None):
self.screenshot = screenshot_service
self.ocr = ocr_service
self.save_debug_image = save_debug_image
self._session_title_cache = {"value": "", "ts": 0.0}
self.unread_analyzer = UnreadSessionAnalyzer()
log_event("INFO", "bot", "bot.session_service.init", new_trace_id("bot"), "init", "ok", "会话服务初始化完成")
# 根据窗口矩形计算会话列表区域的位置
def get_contact_list_rect(self, window_rect):
box = self.screenshot.get_contact_list_box(window_rect)
return {
'left': box.left,
'top': box.top,
'right': box.right,
'bottom': box.bottom,
}
# 从会话行图片中裁剪出会话名称区域
def extract_session_name_image(self, row_img):
return self.screenshot.crop_session_name(row_img)
# 检测会话列表中的所有红点位置(红色圆点表示未读消息)
def detect_red_dots(self, window_rect):
contact_rect = self.get_contact_list_rect(window_rect)
screenshot = self.screenshot.capture_contact_list_default(window_rect)
return self.unread_analyzer.detect_red_dots(contact_rect, screenshot)
# 检测单行会话图片中是否有未读红点标记(严格模式)
def row_has_red_dot(self, row_img, relaxed=False):
return self.unread_analyzer.row_has_red_dot(row_img, relaxed=relaxed)
# 检测单行会话图片中是否有未读红点标记(宽松模式)
def row_has_red_dot_weak(self, row_img):
return self.unread_analyzer.row_has_red_dot_weak(row_img)
# 扫描所有会话行,识别哪些有未读消息标记
def get_all_sessions_with_unread(self, window_rect, round_count):
trace_id = new_trace_id("bot")
contact_rect = self.get_contact_list_rect(window_rect)
screenshot = self.screenshot.capture_contact_list_default(window_rect)
sessions, unread_sessions = self.unread_analyzer.get_all_sessions_with_unread(
contact_rect=contact_rect,
screenshot=screenshot,
round_count=round_count,
save_debug_image=lambda image_obj, filename: self._save_debug_image(image_obj, filename),
)
self._save_session_scan_debug(round_count=round_count, sessions=sessions, unread_sessions=unread_sessions, contact_rect=contact_rect)
log_event("INFO", "bot", "bot.session_scan", trace_id, "scan", "ok", "会话扫描完成", extra={"round": int(round_count), "total": len(sessions), "unread": len(unread_sessions)})
return SessionScanResult(sessions=sessions, unread_sessions=unread_sessions)
# 标准化文本用于匹配:去除空格并转为小写
def normalize_match_text(self, text):
if not text:
return ""
text = str(text).strip().lower()
return "".join(ch for ch in text if not ch.isspace())
# 生成会话屏蔽关键字的唯一标识key用于缓存比对
def make_block_key(self, text):
normalized = self.normalize_match_text(text)
if not normalized:
return ""
return f"title:{normalized}"
# 重置当前会话标题缓存
def reset_session_title_cache(self):
self._session_title_cache = {"value": "", "ts": 0.0}
# 通过OCR识别当前会话窗口的标题文字
def get_session_title_by_ocr(self, window_rect):
trace_id = new_trace_id("bot")
try:
if not window_rect:
return ""
area_name = "main"
screenshot = self.screenshot.capture_session_title(window_rect)
img_bytes = BytesIO()
screenshot.save(img_bytes, format='PNG')
valid = self.ocr.recognize_session_title(img_bytes.getvalue(), scene=f"session_title_{area_name}")
if valid:
title = valid[0]
log_event("INFO", "bot", "bot.session_title", trace_id, "ocr", "ok", "会话标题识别成功", extra={"title": title})
return title
log_event("INFO", "bot", "bot.session_title", trace_id, "ocr", "failed", "会话标题识别为空", reason="empty_result")
return ""
except Exception as e:
log_event("ERROR", "bot", "bot.session_title", trace_id, "ocr", "failed", "会话标题识别异常", reason="ocr_error", extra={"error": str(e)})
return ""
# 获取当前会话标题优先使用缓存避免频繁OCR调用
def get_current_session_title(self, window_rect):
try:
import time
now_ts = time.time()
cached_title = (self._session_title_cache.get("value") or "").strip()
cached_ts = float(self._session_title_cache.get("ts") or 0.0)
if cached_title and now_ts - cached_ts <= 1.2:
return cached_title
title = (self.get_session_title_by_ocr(window_rect) or "").strip()
if title and title not in UI_NOISE_KEYWORDS:
self._session_title_cache = {"value": title, "ts": now_ts}
return title
except Exception as e:
return ""
# 判断当前选中的会话是否应被跳过(点击后标题检查阶段)
def should_skip_current_session(self, window_rect, session, blocked_row_cache, save_blocked_row_cache: Callable):
title = self.get_current_session_title(window_rect)
block_key = self.make_block_key(title)
if block_key and block_key in blocked_row_cache:
return True
normalized_title = self.normalize_match_text(title)
for keyword in BLOCKED_SESSION_KEYWORDS:
if self.normalize_match_text(keyword) in normalized_title:
if block_key:
blocked_row_cache[block_key] = title or keyword
save_blocked_row_cache()
return True
return False
# 比较两个会话名称是否匹配(考虑模糊匹配和大小写)
def is_same_session(self, expected_session, current_session):
expected = self.normalize_match_text(expected_session)
current = self.normalize_match_text(current_session)
if not expected or not current:
return False
return expected in current or current in expected
# 根据OCR识别结果判断会话列表中的会话是否应被跳过
def should_skip_session_by_ocr(self, session, blocked_row_cache, save_blocked_row_cache: Callable):
image_obj = session.get('row_img')
if image_obj is None:
return False
try:
name_img = self.extract_session_name_image(image_obj)
if name_img is None:
return False
crop_box = {
'left': SESSION_NAME_LEFT_OFFSET,
'top': SESSION_NAME_TOP_OFFSET,
'width': SESSION_NAME_WIDTH,
'height': SESSION_NAME_HEIGHT,
'row_w': image_obj.size[0],
'row_h': image_obj.size[1],
'crop_w': name_img.size[0],
'crop_h': name_img.size[1],
}
if OCR_SAVE_IMAGES:
file_name = f"row_{int(session.get('row_idx', 0)):03d}_name_raw.png"
self._save_debug_image(name_img, os.path.join('sessions', 'name_ocr', file_name))
img_bytes = BytesIO()
name_img.save(img_bytes, format='PNG')
lines = self.ocr.recognize_session_name(img_bytes.getvalue(), scene=f"session_row_{session.get('row_idx')}")
line_text = ' '.join(lines)
session['list_ocr_title'] = line_text
normalized_text = self.normalize_match_text(line_text)
block_key = self.make_block_key(line_text)
if block_key and block_key in blocked_row_cache:
return True
for keyword in BLOCKED_SESSION_KEYWORDS:
if self.normalize_match_text(keyword) in normalized_text:
if block_key:
blocked_row_cache[block_key] = line_text or keyword
save_blocked_row_cache()
return True
return False
except Exception as e:
return False
# 分析点击后的聊天区域截图,提取最新消息文本并返回分析结果
def analyze_clicked_session(self, window_rect, round_count, row_idx):
trace_id = new_trace_id("bot")
chat_box = self.screenshot.get_chat_capture_box(window_rect)
if not self.screenshot.is_valid_box(chat_box):
log_event("WARNING", "bot", "bot.chat_analyze", trace_id, "capture", "failed", "聊天区截图区域无效", reason="invalid_box")
return ChatAnalyzeResult(ok=False, file_name='', latest_text='', confidence='', bubble_side='', screenshot_path='')
screenshot = self.screenshot.capture_chat_area(window_rect)
file_name = f"round_{round_count:04d}_row_{row_idx:03d}_chat.png"
rel_path = os.path.join('sessions', 'clicked', file_name)
self._save_debug_image(screenshot, rel_path)
result = analyze_pil_image(screenshot, stem=os.path.splitext(file_name)[0], file_name=file_name)
latest_text = (getattr(result, 'latest_text', None) or '').strip()
confidence = getattr(result, 'confidence', '')
bubble_side = getattr(result, 'bubble_side', '')
log_event("INFO", "bot", "bot.chat_analyze", trace_id, "analyze", "ok", "聊天截图分析完成", extra={"round": int(round_count), "row_idx": int(row_idx), "has_text": bool(latest_text), "bubble_side": bubble_side or "", "confidence": confidence})
return ChatAnalyzeResult(
ok=bool(latest_text),
file_name=file_name,
latest_text=latest_text,
confidence=confidence,
bubble_side=bubble_side,
screenshot_path=rel_path,
)
# 保存会话扫描调试数据(类似聊天分析输出 result.json
def _save_session_scan_debug(self, round_count: int, sessions: list[dict], unread_sessions: list[dict], contact_rect: dict):
if not OCR_SAVE_IMAGES:
return
try:
debug_dir = os.path.join(OCR_SAVE_DIR, 'sessions', 'scan_debug')
os.makedirs(debug_dir, exist_ok=True)
file_name = f"round_{round_count:04d}_scan.json"
file_path = os.path.join(debug_dir, file_name)
rows = []
for session in sessions:
rows.append({
'row_idx': session.get('row_idx'),
'has_red_dot': bool(session.get('has_red_dot')),
'has_red_by_global': bool(session.get('has_red_by_global')),
'has_red_by_row': bool(session.get('has_red_by_row')),
'has_red_by_row_weak': bool(session.get('has_red_by_row_weak')),
'click_x': session.get('click_x'),
'click_y': session.get('click_y'),
'list_ocr_title': session.get('list_ocr_title', ''),
})
payload = {
'round': int(round_count),
'contact_rect': contact_rect,
'total_sessions': len(sessions),
'unread_count': len(unread_sessions),
'unread_rows': [s.get('row_idx') for s in unread_sessions],
'rows': rows,
}
with open(file_path, 'w', encoding='utf-8') as f:
json.dump(payload, f, ensure_ascii=False, indent=2)
except Exception as e:
pass
# 保存调试图片的内部方法
def _save_debug_image(self, image_obj, filename):
if not self.save_debug_image:
return
self.save_debug_image(image_obj, filename)