初始提交:识流 AI 助手项目
微信自动回复机器人,基于截图+OCR识别消息,支持关键词规则和 AI(OpenAI/DeepSeek/Dify)自动回复。 技术栈:PySide6 + Flask + Vue3 + RapidOCR + SQLite 注:OCR大模型文件(.onnx / .pdiparams)不纳入版本控制,需单独下载。 🤖 Generated with [Qoder][https://qoder.com]
This commit is contained in:
288
app/infrastructure/service/wechat/session_service.py
Normal file
288
app/infrastructure/service/wechat/session_service.py
Normal file
@@ -0,0 +1,288 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from io import BytesIO
|
||||
import json
|
||||
import os
|
||||
from typing import Callable
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from app.infrastructure.service.logging.log_service import log_event, new_trace_id
|
||||
from app.infrastructure.service.wechat.chat_snapshot_analyzer import analyze_pil_image
|
||||
from app.infrastructure.service.wechat.unread_session_analyzer import UnreadSessionAnalyzer
|
||||
from app.infrastructure.service.wechat.config import (
|
||||
BLOCKED_SESSION_KEYWORDS,
|
||||
CONTACT_ROW_HEIGHT,
|
||||
OCR_SAVE_DIR,
|
||||
OCR_SAVE_IMAGES,
|
||||
SESSION_NAME_HEIGHT,
|
||||
SESSION_NAME_LEFT_OFFSET,
|
||||
SESSION_NAME_TOP_OFFSET,
|
||||
SESSION_NAME_WIDTH,
|
||||
UI_NOISE_KEYWORDS,
|
||||
)
|
||||
|
||||
|
||||
|
||||
# 会话扫描结果数据类,包含所有会话列表和未读会话列表
|
||||
@dataclass
|
||||
class SessionScanResult:
|
||||
sessions: list[dict]
|
||||
unread_sessions: list[dict]
|
||||
|
||||
|
||||
# 聊天快照分析结果数据类
|
||||
@dataclass
|
||||
class ChatAnalyzeResult:
|
||||
ok: bool
|
||||
file_name: str
|
||||
latest_text: str
|
||||
confidence: str | float
|
||||
bubble_side: str
|
||||
screenshot_path: str
|
||||
|
||||
|
||||
# 微信会话服务类,处理会话列表扫描、红点检测和聊天截图分析
|
||||
class WechatSessionService:
|
||||
def __init__(self, screenshot_service, ocr_service, save_debug_image: Callable | None = None):
|
||||
self.screenshot = screenshot_service
|
||||
self.ocr = ocr_service
|
||||
self.save_debug_image = save_debug_image
|
||||
self._session_title_cache = {"value": "", "ts": 0.0}
|
||||
self.unread_analyzer = UnreadSessionAnalyzer()
|
||||
log_event("INFO", "bot", "bot.session_service.init", new_trace_id("bot"), "init", "ok", "会话服务初始化完成")
|
||||
|
||||
# 根据窗口矩形计算会话列表区域的位置
|
||||
def get_contact_list_rect(self, window_rect):
|
||||
box = self.screenshot.get_contact_list_box(window_rect)
|
||||
return {
|
||||
'left': box.left,
|
||||
'top': box.top,
|
||||
'right': box.right,
|
||||
'bottom': box.bottom,
|
||||
}
|
||||
|
||||
# 从会话行图片中裁剪出会话名称区域
|
||||
def extract_session_name_image(self, row_img):
|
||||
return self.screenshot.crop_session_name(row_img)
|
||||
|
||||
# 检测会话列表中的所有红点位置(红色圆点表示未读消息)
|
||||
def detect_red_dots(self, window_rect):
|
||||
contact_rect = self.get_contact_list_rect(window_rect)
|
||||
screenshot = self.screenshot.capture_contact_list_default(window_rect)
|
||||
return self.unread_analyzer.detect_red_dots(contact_rect, screenshot)
|
||||
|
||||
# 检测单行会话图片中是否有未读红点标记(严格模式)
|
||||
def row_has_red_dot(self, row_img, relaxed=False):
|
||||
return self.unread_analyzer.row_has_red_dot(row_img, relaxed=relaxed)
|
||||
|
||||
# 检测单行会话图片中是否有未读红点标记(宽松模式)
|
||||
def row_has_red_dot_weak(self, row_img):
|
||||
return self.unread_analyzer.row_has_red_dot_weak(row_img)
|
||||
|
||||
# 扫描所有会话行,识别哪些有未读消息标记
|
||||
def get_all_sessions_with_unread(self, window_rect, round_count):
|
||||
trace_id = new_trace_id("bot")
|
||||
contact_rect = self.get_contact_list_rect(window_rect)
|
||||
screenshot = self.screenshot.capture_contact_list_default(window_rect)
|
||||
|
||||
sessions, unread_sessions = self.unread_analyzer.get_all_sessions_with_unread(
|
||||
contact_rect=contact_rect,
|
||||
screenshot=screenshot,
|
||||
round_count=round_count,
|
||||
save_debug_image=lambda image_obj, filename: self._save_debug_image(image_obj, filename),
|
||||
)
|
||||
self._save_session_scan_debug(round_count=round_count, sessions=sessions, unread_sessions=unread_sessions, contact_rect=contact_rect)
|
||||
log_event("INFO", "bot", "bot.session_scan", trace_id, "scan", "ok", "会话扫描完成", extra={"round": int(round_count), "total": len(sessions), "unread": len(unread_sessions)})
|
||||
return SessionScanResult(sessions=sessions, unread_sessions=unread_sessions)
|
||||
|
||||
# 标准化文本用于匹配:去除空格并转为小写
|
||||
def normalize_match_text(self, text):
|
||||
if not text:
|
||||
return ""
|
||||
text = str(text).strip().lower()
|
||||
return "".join(ch for ch in text if not ch.isspace())
|
||||
|
||||
# 生成会话屏蔽关键字的唯一标识key,用于缓存比对
|
||||
def make_block_key(self, text):
|
||||
normalized = self.normalize_match_text(text)
|
||||
if not normalized:
|
||||
return ""
|
||||
return f"title:{normalized}"
|
||||
|
||||
# 重置当前会话标题缓存
|
||||
def reset_session_title_cache(self):
|
||||
self._session_title_cache = {"value": "", "ts": 0.0}
|
||||
|
||||
# 通过OCR识别当前会话窗口的标题文字
|
||||
def get_session_title_by_ocr(self, window_rect):
|
||||
trace_id = new_trace_id("bot")
|
||||
try:
|
||||
if not window_rect:
|
||||
return ""
|
||||
area_name = "main"
|
||||
screenshot = self.screenshot.capture_session_title(window_rect)
|
||||
img_bytes = BytesIO()
|
||||
screenshot.save(img_bytes, format='PNG')
|
||||
valid = self.ocr.recognize_session_title(img_bytes.getvalue(), scene=f"session_title_{area_name}")
|
||||
if valid:
|
||||
title = valid[0]
|
||||
log_event("INFO", "bot", "bot.session_title", trace_id, "ocr", "ok", "会话标题识别成功", extra={"title": title})
|
||||
return title
|
||||
log_event("INFO", "bot", "bot.session_title", trace_id, "ocr", "failed", "会话标题识别为空", reason="empty_result")
|
||||
return ""
|
||||
except Exception as e:
|
||||
log_event("ERROR", "bot", "bot.session_title", trace_id, "ocr", "failed", "会话标题识别异常", reason="ocr_error", extra={"error": str(e)})
|
||||
return ""
|
||||
|
||||
# 获取当前会话标题,优先使用缓存避免频繁OCR调用
|
||||
def get_current_session_title(self, window_rect):
|
||||
try:
|
||||
import time
|
||||
|
||||
now_ts = time.time()
|
||||
cached_title = (self._session_title_cache.get("value") or "").strip()
|
||||
cached_ts = float(self._session_title_cache.get("ts") or 0.0)
|
||||
if cached_title and now_ts - cached_ts <= 1.2:
|
||||
return cached_title
|
||||
|
||||
title = (self.get_session_title_by_ocr(window_rect) or "").strip()
|
||||
if title and title not in UI_NOISE_KEYWORDS:
|
||||
self._session_title_cache = {"value": title, "ts": now_ts}
|
||||
return title
|
||||
except Exception as e:
|
||||
return ""
|
||||
|
||||
# 判断当前选中的会话是否应被跳过(点击后标题检查阶段)
|
||||
def should_skip_current_session(self, window_rect, session, blocked_row_cache, save_blocked_row_cache: Callable):
|
||||
title = self.get_current_session_title(window_rect)
|
||||
block_key = self.make_block_key(title)
|
||||
if block_key and block_key in blocked_row_cache:
|
||||
return True
|
||||
normalized_title = self.normalize_match_text(title)
|
||||
for keyword in BLOCKED_SESSION_KEYWORDS:
|
||||
if self.normalize_match_text(keyword) in normalized_title:
|
||||
if block_key:
|
||||
blocked_row_cache[block_key] = title or keyword
|
||||
save_blocked_row_cache()
|
||||
return True
|
||||
return False
|
||||
|
||||
# 比较两个会话名称是否匹配(考虑模糊匹配和大小写)
|
||||
def is_same_session(self, expected_session, current_session):
|
||||
expected = self.normalize_match_text(expected_session)
|
||||
current = self.normalize_match_text(current_session)
|
||||
if not expected or not current:
|
||||
return False
|
||||
return expected in current or current in expected
|
||||
|
||||
# 根据OCR识别结果判断会话列表中的会话是否应被跳过
|
||||
def should_skip_session_by_ocr(self, session, blocked_row_cache, save_blocked_row_cache: Callable):
|
||||
image_obj = session.get('row_img')
|
||||
if image_obj is None:
|
||||
return False
|
||||
try:
|
||||
name_img = self.extract_session_name_image(image_obj)
|
||||
if name_img is None:
|
||||
return False
|
||||
|
||||
crop_box = {
|
||||
'left': SESSION_NAME_LEFT_OFFSET,
|
||||
'top': SESSION_NAME_TOP_OFFSET,
|
||||
'width': SESSION_NAME_WIDTH,
|
||||
'height': SESSION_NAME_HEIGHT,
|
||||
'row_w': image_obj.size[0],
|
||||
'row_h': image_obj.size[1],
|
||||
'crop_w': name_img.size[0],
|
||||
'crop_h': name_img.size[1],
|
||||
}
|
||||
if OCR_SAVE_IMAGES:
|
||||
file_name = f"row_{int(session.get('row_idx', 0)):03d}_name_raw.png"
|
||||
self._save_debug_image(name_img, os.path.join('sessions', 'name_ocr', file_name))
|
||||
|
||||
img_bytes = BytesIO()
|
||||
name_img.save(img_bytes, format='PNG')
|
||||
lines = self.ocr.recognize_session_name(img_bytes.getvalue(), scene=f"session_row_{session.get('row_idx')}")
|
||||
line_text = ' '.join(lines)
|
||||
session['list_ocr_title'] = line_text
|
||||
normalized_text = self.normalize_match_text(line_text)
|
||||
block_key = self.make_block_key(line_text)
|
||||
if block_key and block_key in blocked_row_cache:
|
||||
return True
|
||||
for keyword in BLOCKED_SESSION_KEYWORDS:
|
||||
if self.normalize_match_text(keyword) in normalized_text:
|
||||
if block_key:
|
||||
blocked_row_cache[block_key] = line_text or keyword
|
||||
save_blocked_row_cache()
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
return False
|
||||
|
||||
# 分析点击后的聊天区域截图,提取最新消息文本并返回分析结果
|
||||
def analyze_clicked_session(self, window_rect, round_count, row_idx):
|
||||
trace_id = new_trace_id("bot")
|
||||
chat_box = self.screenshot.get_chat_capture_box(window_rect)
|
||||
if not self.screenshot.is_valid_box(chat_box):
|
||||
log_event("WARNING", "bot", "bot.chat_analyze", trace_id, "capture", "failed", "聊天区截图区域无效", reason="invalid_box")
|
||||
return ChatAnalyzeResult(ok=False, file_name='', latest_text='', confidence='', bubble_side='', screenshot_path='')
|
||||
|
||||
screenshot = self.screenshot.capture_chat_area(window_rect)
|
||||
file_name = f"round_{round_count:04d}_row_{row_idx:03d}_chat.png"
|
||||
rel_path = os.path.join('sessions', 'clicked', file_name)
|
||||
self._save_debug_image(screenshot, rel_path)
|
||||
result = analyze_pil_image(screenshot, stem=os.path.splitext(file_name)[0], file_name=file_name)
|
||||
latest_text = (getattr(result, 'latest_text', None) or '').strip()
|
||||
confidence = getattr(result, 'confidence', '')
|
||||
bubble_side = getattr(result, 'bubble_side', '')
|
||||
log_event("INFO", "bot", "bot.chat_analyze", trace_id, "analyze", "ok", "聊天截图分析完成", extra={"round": int(round_count), "row_idx": int(row_idx), "has_text": bool(latest_text), "bubble_side": bubble_side or "", "confidence": confidence})
|
||||
return ChatAnalyzeResult(
|
||||
ok=bool(latest_text),
|
||||
file_name=file_name,
|
||||
latest_text=latest_text,
|
||||
confidence=confidence,
|
||||
bubble_side=bubble_side,
|
||||
screenshot_path=rel_path,
|
||||
)
|
||||
|
||||
# 保存会话扫描调试数据(类似聊天分析输出 result.json)
|
||||
def _save_session_scan_debug(self, round_count: int, sessions: list[dict], unread_sessions: list[dict], contact_rect: dict):
|
||||
if not OCR_SAVE_IMAGES:
|
||||
return
|
||||
try:
|
||||
debug_dir = os.path.join(OCR_SAVE_DIR, 'sessions', 'scan_debug')
|
||||
os.makedirs(debug_dir, exist_ok=True)
|
||||
file_name = f"round_{round_count:04d}_scan.json"
|
||||
file_path = os.path.join(debug_dir, file_name)
|
||||
rows = []
|
||||
for session in sessions:
|
||||
rows.append({
|
||||
'row_idx': session.get('row_idx'),
|
||||
'has_red_dot': bool(session.get('has_red_dot')),
|
||||
'has_red_by_global': bool(session.get('has_red_by_global')),
|
||||
'has_red_by_row': bool(session.get('has_red_by_row')),
|
||||
'has_red_by_row_weak': bool(session.get('has_red_by_row_weak')),
|
||||
'click_x': session.get('click_x'),
|
||||
'click_y': session.get('click_y'),
|
||||
'list_ocr_title': session.get('list_ocr_title', ''),
|
||||
})
|
||||
payload = {
|
||||
'round': int(round_count),
|
||||
'contact_rect': contact_rect,
|
||||
'total_sessions': len(sessions),
|
||||
'unread_count': len(unread_sessions),
|
||||
'unread_rows': [s.get('row_idx') for s in unread_sessions],
|
||||
'rows': rows,
|
||||
}
|
||||
with open(file_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(payload, f, ensure_ascii=False, indent=2)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
# 保存调试图片的内部方法
|
||||
def _save_debug_image(self, image_obj, filename):
|
||||
if not self.save_debug_image:
|
||||
return
|
||||
self.save_debug_image(image_obj, filename)
|
||||
Reference in New Issue
Block a user