from __future__ import annotations from dataclasses import dataclass from io import BytesIO import json import os from typing import Callable import cv2 import numpy as np from app.infrastructure.service.logging.log_service import log_event, new_trace_id from app.infrastructure.service.wechat.chat_snapshot_analyzer import analyze_pil_image from app.infrastructure.service.wechat.unread_session_analyzer import UnreadSessionAnalyzer from app.infrastructure.service.wechat.config import ( BLOCKED_SESSION_KEYWORDS, CONTACT_ROW_HEIGHT, OCR_SAVE_DIR, OCR_SAVE_IMAGES, SESSION_NAME_HEIGHT, SESSION_NAME_LEFT_OFFSET, SESSION_NAME_TOP_OFFSET, SESSION_NAME_WIDTH, UI_NOISE_KEYWORDS, ) # 会话扫描结果数据类,包含所有会话列表和未读会话列表 @dataclass class SessionScanResult: sessions: list[dict] unread_sessions: list[dict] # 聊天快照分析结果数据类 @dataclass class ChatAnalyzeResult: ok: bool file_name: str latest_text: str confidence: str | float bubble_side: str screenshot_path: str # 微信会话服务类,处理会话列表扫描、红点检测和聊天截图分析 class WechatSessionService: def __init__(self, screenshot_service, ocr_service, save_debug_image: Callable | None = None): self.screenshot = screenshot_service self.ocr = ocr_service self.save_debug_image = save_debug_image self._session_title_cache = {"value": "", "ts": 0.0} self.unread_analyzer = UnreadSessionAnalyzer() log_event("INFO", "bot", "bot.session_service.init", new_trace_id("bot"), "init", "ok", "会话服务初始化完成") # 根据窗口矩形计算会话列表区域的位置 def get_contact_list_rect(self, window_rect): box = self.screenshot.get_contact_list_box(window_rect) return { 'left': box.left, 'top': box.top, 'right': box.right, 'bottom': box.bottom, } # 从会话行图片中裁剪出会话名称区域 def extract_session_name_image(self, row_img): return self.screenshot.crop_session_name(row_img) # 检测会话列表中的所有红点位置(红色圆点表示未读消息) def detect_red_dots(self, window_rect): contact_rect = self.get_contact_list_rect(window_rect) screenshot = self.screenshot.capture_contact_list_default(window_rect) return self.unread_analyzer.detect_red_dots(contact_rect, screenshot) # 检测单行会话图片中是否有未读红点标记(严格模式) def row_has_red_dot(self, row_img, relaxed=False): return self.unread_analyzer.row_has_red_dot(row_img, relaxed=relaxed) # 检测单行会话图片中是否有未读红点标记(宽松模式) def row_has_red_dot_weak(self, row_img): return self.unread_analyzer.row_has_red_dot_weak(row_img) # 扫描所有会话行,识别哪些有未读消息标记 def get_all_sessions_with_unread(self, window_rect, round_count): trace_id = new_trace_id("bot") contact_rect = self.get_contact_list_rect(window_rect) screenshot = self.screenshot.capture_contact_list_default(window_rect) sessions, unread_sessions = self.unread_analyzer.get_all_sessions_with_unread( contact_rect=contact_rect, screenshot=screenshot, round_count=round_count, save_debug_image=lambda image_obj, filename: self._save_debug_image(image_obj, filename), ) self._save_session_scan_debug(round_count=round_count, sessions=sessions, unread_sessions=unread_sessions, contact_rect=contact_rect) log_event("INFO", "bot", "bot.session_scan", trace_id, "scan", "ok", "会话扫描完成", extra={"round": int(round_count), "total": len(sessions), "unread": len(unread_sessions)}) return SessionScanResult(sessions=sessions, unread_sessions=unread_sessions) # 标准化文本用于匹配:去除空格并转为小写 def normalize_match_text(self, text): if not text: return "" text = str(text).strip().lower() return "".join(ch for ch in text if not ch.isspace()) # 生成会话屏蔽关键字的唯一标识key,用于缓存比对 def make_block_key(self, text): normalized = self.normalize_match_text(text) if not normalized: return "" return f"title:{normalized}" # 重置当前会话标题缓存 def reset_session_title_cache(self): self._session_title_cache = {"value": "", "ts": 0.0} # 通过OCR识别当前会话窗口的标题文字 def get_session_title_by_ocr(self, window_rect): trace_id = new_trace_id("bot") try: if not window_rect: return "" area_name = "main" screenshot = self.screenshot.capture_session_title(window_rect) img_bytes = BytesIO() screenshot.save(img_bytes, format='PNG') valid = self.ocr.recognize_session_title(img_bytes.getvalue(), scene=f"session_title_{area_name}") if valid: title = valid[0] log_event("INFO", "bot", "bot.session_title", trace_id, "ocr", "ok", "会话标题识别成功", extra={"title": title}) return title log_event("INFO", "bot", "bot.session_title", trace_id, "ocr", "failed", "会话标题识别为空", reason="empty_result") return "" except Exception as e: log_event("ERROR", "bot", "bot.session_title", trace_id, "ocr", "failed", "会话标题识别异常", reason="ocr_error", extra={"error": str(e)}) return "" # 获取当前会话标题,优先使用缓存避免频繁OCR调用 def get_current_session_title(self, window_rect): try: import time now_ts = time.time() cached_title = (self._session_title_cache.get("value") or "").strip() cached_ts = float(self._session_title_cache.get("ts") or 0.0) if cached_title and now_ts - cached_ts <= 1.2: return cached_title title = (self.get_session_title_by_ocr(window_rect) or "").strip() if title and title not in UI_NOISE_KEYWORDS: self._session_title_cache = {"value": title, "ts": now_ts} return title except Exception as e: return "" # 判断当前选中的会话是否应被跳过(点击后标题检查阶段) def should_skip_current_session(self, window_rect, session, blocked_row_cache, save_blocked_row_cache: Callable): title = self.get_current_session_title(window_rect) block_key = self.make_block_key(title) if block_key and block_key in blocked_row_cache: return True normalized_title = self.normalize_match_text(title) for keyword in BLOCKED_SESSION_KEYWORDS: if self.normalize_match_text(keyword) in normalized_title: if block_key: blocked_row_cache[block_key] = title or keyword save_blocked_row_cache() return True return False # 比较两个会话名称是否匹配(考虑模糊匹配和大小写) def is_same_session(self, expected_session, current_session): expected = self.normalize_match_text(expected_session) current = self.normalize_match_text(current_session) if not expected or not current: return False return expected in current or current in expected # 根据OCR识别结果判断会话列表中的会话是否应被跳过 def should_skip_session_by_ocr(self, session, blocked_row_cache, save_blocked_row_cache: Callable): image_obj = session.get('row_img') if image_obj is None: return False try: name_img = self.extract_session_name_image(image_obj) if name_img is None: return False crop_box = { 'left': SESSION_NAME_LEFT_OFFSET, 'top': SESSION_NAME_TOP_OFFSET, 'width': SESSION_NAME_WIDTH, 'height': SESSION_NAME_HEIGHT, 'row_w': image_obj.size[0], 'row_h': image_obj.size[1], 'crop_w': name_img.size[0], 'crop_h': name_img.size[1], } if OCR_SAVE_IMAGES: file_name = f"row_{int(session.get('row_idx', 0)):03d}_name_raw.png" self._save_debug_image(name_img, os.path.join('sessions', 'name_ocr', file_name)) img_bytes = BytesIO() name_img.save(img_bytes, format='PNG') lines = self.ocr.recognize_session_name(img_bytes.getvalue(), scene=f"session_row_{session.get('row_idx')}") line_text = ' '.join(lines) session['list_ocr_title'] = line_text normalized_text = self.normalize_match_text(line_text) block_key = self.make_block_key(line_text) if block_key and block_key in blocked_row_cache: return True for keyword in BLOCKED_SESSION_KEYWORDS: if self.normalize_match_text(keyword) in normalized_text: if block_key: blocked_row_cache[block_key] = line_text or keyword save_blocked_row_cache() return True return False except Exception as e: return False # 分析点击后的聊天区域截图,提取最新消息文本并返回分析结果 def analyze_clicked_session(self, window_rect, round_count, row_idx): trace_id = new_trace_id("bot") chat_box = self.screenshot.get_chat_capture_box(window_rect) if not self.screenshot.is_valid_box(chat_box): log_event("WARNING", "bot", "bot.chat_analyze", trace_id, "capture", "failed", "聊天区截图区域无效", reason="invalid_box") return ChatAnalyzeResult(ok=False, file_name='', latest_text='', confidence='', bubble_side='', screenshot_path='') screenshot = self.screenshot.capture_chat_area(window_rect) file_name = f"round_{round_count:04d}_row_{row_idx:03d}_chat.png" rel_path = os.path.join('sessions', 'clicked', file_name) self._save_debug_image(screenshot, rel_path) result = analyze_pil_image(screenshot, stem=os.path.splitext(file_name)[0], file_name=file_name) latest_text = (getattr(result, 'latest_text', None) or '').strip() confidence = getattr(result, 'confidence', '') bubble_side = getattr(result, 'bubble_side', '') log_event("INFO", "bot", "bot.chat_analyze", trace_id, "analyze", "ok", "聊天截图分析完成", extra={"round": int(round_count), "row_idx": int(row_idx), "has_text": bool(latest_text), "bubble_side": bubble_side or "", "confidence": confidence}) return ChatAnalyzeResult( ok=bool(latest_text), file_name=file_name, latest_text=latest_text, confidence=confidence, bubble_side=bubble_side, screenshot_path=rel_path, ) # 保存会话扫描调试数据(类似聊天分析输出 result.json) def _save_session_scan_debug(self, round_count: int, sessions: list[dict], unread_sessions: list[dict], contact_rect: dict): if not OCR_SAVE_IMAGES: return try: debug_dir = os.path.join(OCR_SAVE_DIR, 'sessions', 'scan_debug') os.makedirs(debug_dir, exist_ok=True) file_name = f"round_{round_count:04d}_scan.json" file_path = os.path.join(debug_dir, file_name) rows = [] for session in sessions: rows.append({ 'row_idx': session.get('row_idx'), 'has_red_dot': bool(session.get('has_red_dot')), 'has_red_by_global': bool(session.get('has_red_by_global')), 'has_red_by_row': bool(session.get('has_red_by_row')), 'has_red_by_row_weak': bool(session.get('has_red_by_row_weak')), 'click_x': session.get('click_x'), 'click_y': session.get('click_y'), 'list_ocr_title': session.get('list_ocr_title', ''), }) payload = { 'round': int(round_count), 'contact_rect': contact_rect, 'total_sessions': len(sessions), 'unread_count': len(unread_sessions), 'unread_rows': [s.get('row_idx') for s in unread_sessions], 'rows': rows, } with open(file_path, 'w', encoding='utf-8') as f: json.dump(payload, f, ensure_ascii=False, indent=2) except Exception as e: pass # 保存调试图片的内部方法 def _save_debug_image(self, image_obj, filename): if not self.save_debug_image: return self.save_debug_image(image_obj, filename)