Files
ai-shiliu/app/infrastructure/service/wechat/session_service.py

289 lines
13 KiB
Python
Raw Normal View History

from __future__ import annotations
from dataclasses import dataclass
from io import BytesIO
import json
import os
from typing import Callable
import cv2
import numpy as np
from app.infrastructure.service.logging.log_service import log_event, new_trace_id
from app.infrastructure.service.wechat.chat_snapshot_analyzer import analyze_pil_image
from app.infrastructure.service.wechat.unread_session_analyzer import UnreadSessionAnalyzer
from app.infrastructure.service.wechat.config import (
BLOCKED_SESSION_KEYWORDS,
CONTACT_ROW_HEIGHT,
OCR_SAVE_DIR,
OCR_SAVE_IMAGES,
SESSION_NAME_HEIGHT,
SESSION_NAME_LEFT_OFFSET,
SESSION_NAME_TOP_OFFSET,
SESSION_NAME_WIDTH,
UI_NOISE_KEYWORDS,
)
# 会话扫描结果数据类,包含所有会话列表和未读会话列表
@dataclass
class SessionScanResult:
sessions: list[dict]
unread_sessions: list[dict]
# 聊天快照分析结果数据类
@dataclass
class ChatAnalyzeResult:
ok: bool
file_name: str
latest_text: str
confidence: str | float
bubble_side: str
screenshot_path: str
# 微信会话服务类,处理会话列表扫描、红点检测和聊天截图分析
class WechatSessionService:
def __init__(self, screenshot_service, ocr_service, save_debug_image: Callable | None = None):
self.screenshot = screenshot_service
self.ocr = ocr_service
self.save_debug_image = save_debug_image
self._session_title_cache = {"value": "", "ts": 0.0}
self.unread_analyzer = UnreadSessionAnalyzer()
log_event("INFO", "bot", "bot.session_service.init", new_trace_id("bot"), "init", "ok", "会话服务初始化完成")
# 根据窗口矩形计算会话列表区域的位置
def get_contact_list_rect(self, window_rect):
box = self.screenshot.get_contact_list_box(window_rect)
return {
'left': box.left,
'top': box.top,
'right': box.right,
'bottom': box.bottom,
}
# 从会话行图片中裁剪出会话名称区域
def extract_session_name_image(self, row_img):
return self.screenshot.crop_session_name(row_img)
# 检测会话列表中的所有红点位置(红色圆点表示未读消息)
def detect_red_dots(self, window_rect):
contact_rect = self.get_contact_list_rect(window_rect)
screenshot = self.screenshot.capture_contact_list_default(window_rect)
return self.unread_analyzer.detect_red_dots(contact_rect, screenshot)
# 检测单行会话图片中是否有未读红点标记(严格模式)
def row_has_red_dot(self, row_img, relaxed=False):
return self.unread_analyzer.row_has_red_dot(row_img, relaxed=relaxed)
# 检测单行会话图片中是否有未读红点标记(宽松模式)
def row_has_red_dot_weak(self, row_img):
return self.unread_analyzer.row_has_red_dot_weak(row_img)
# 扫描所有会话行,识别哪些有未读消息标记
def get_all_sessions_with_unread(self, window_rect, round_count):
trace_id = new_trace_id("bot")
contact_rect = self.get_contact_list_rect(window_rect)
screenshot = self.screenshot.capture_contact_list_default(window_rect)
sessions, unread_sessions = self.unread_analyzer.get_all_sessions_with_unread(
contact_rect=contact_rect,
screenshot=screenshot,
round_count=round_count,
save_debug_image=lambda image_obj, filename: self._save_debug_image(image_obj, filename),
)
self._save_session_scan_debug(round_count=round_count, sessions=sessions, unread_sessions=unread_sessions, contact_rect=contact_rect)
log_event("INFO", "bot", "bot.session_scan", trace_id, "scan", "ok", "会话扫描完成", extra={"round": int(round_count), "total": len(sessions), "unread": len(unread_sessions)})
return SessionScanResult(sessions=sessions, unread_sessions=unread_sessions)
# 标准化文本用于匹配:去除空格并转为小写
def normalize_match_text(self, text):
if not text:
return ""
text = str(text).strip().lower()
return "".join(ch for ch in text if not ch.isspace())
# 生成会话屏蔽关键字的唯一标识key用于缓存比对
def make_block_key(self, text):
normalized = self.normalize_match_text(text)
if not normalized:
return ""
return f"title:{normalized}"
# 重置当前会话标题缓存
def reset_session_title_cache(self):
self._session_title_cache = {"value": "", "ts": 0.0}
# 通过OCR识别当前会话窗口的标题文字
def get_session_title_by_ocr(self, window_rect):
trace_id = new_trace_id("bot")
try:
if not window_rect:
return ""
area_name = "main"
screenshot = self.screenshot.capture_session_title(window_rect)
img_bytes = BytesIO()
screenshot.save(img_bytes, format='PNG')
valid = self.ocr.recognize_session_title(img_bytes.getvalue(), scene=f"session_title_{area_name}")
if valid:
title = valid[0]
log_event("INFO", "bot", "bot.session_title", trace_id, "ocr", "ok", "会话标题识别成功", extra={"title": title})
return title
log_event("INFO", "bot", "bot.session_title", trace_id, "ocr", "failed", "会话标题识别为空", reason="empty_result")
return ""
except Exception as e:
log_event("ERROR", "bot", "bot.session_title", trace_id, "ocr", "failed", "会话标题识别异常", reason="ocr_error", extra={"error": str(e)})
return ""
# 获取当前会话标题优先使用缓存避免频繁OCR调用
def get_current_session_title(self, window_rect):
try:
import time
now_ts = time.time()
cached_title = (self._session_title_cache.get("value") or "").strip()
cached_ts = float(self._session_title_cache.get("ts") or 0.0)
if cached_title and now_ts - cached_ts <= 1.2:
return cached_title
title = (self.get_session_title_by_ocr(window_rect) or "").strip()
if title and title not in UI_NOISE_KEYWORDS:
self._session_title_cache = {"value": title, "ts": now_ts}
return title
except Exception as e:
return ""
# 判断当前选中的会话是否应被跳过(点击后标题检查阶段)
def should_skip_current_session(self, window_rect, session, blocked_row_cache, save_blocked_row_cache: Callable):
title = self.get_current_session_title(window_rect)
block_key = self.make_block_key(title)
if block_key and block_key in blocked_row_cache:
return True
normalized_title = self.normalize_match_text(title)
for keyword in BLOCKED_SESSION_KEYWORDS:
if self.normalize_match_text(keyword) in normalized_title:
if block_key:
blocked_row_cache[block_key] = title or keyword
save_blocked_row_cache()
return True
return False
# 比较两个会话名称是否匹配(考虑模糊匹配和大小写)
def is_same_session(self, expected_session, current_session):
expected = self.normalize_match_text(expected_session)
current = self.normalize_match_text(current_session)
if not expected or not current:
return False
return expected in current or current in expected
# 根据OCR识别结果判断会话列表中的会话是否应被跳过
def should_skip_session_by_ocr(self, session, blocked_row_cache, save_blocked_row_cache: Callable):
image_obj = session.get('row_img')
if image_obj is None:
return False
try:
name_img = self.extract_session_name_image(image_obj)
if name_img is None:
return False
crop_box = {
'left': SESSION_NAME_LEFT_OFFSET,
'top': SESSION_NAME_TOP_OFFSET,
'width': SESSION_NAME_WIDTH,
'height': SESSION_NAME_HEIGHT,
'row_w': image_obj.size[0],
'row_h': image_obj.size[1],
'crop_w': name_img.size[0],
'crop_h': name_img.size[1],
}
if OCR_SAVE_IMAGES:
file_name = f"row_{int(session.get('row_idx', 0)):03d}_name_raw.png"
self._save_debug_image(name_img, os.path.join('sessions', 'name_ocr', file_name))
img_bytes = BytesIO()
name_img.save(img_bytes, format='PNG')
lines = self.ocr.recognize_session_name(img_bytes.getvalue(), scene=f"session_row_{session.get('row_idx')}")
line_text = ' '.join(lines)
session['list_ocr_title'] = line_text
normalized_text = self.normalize_match_text(line_text)
block_key = self.make_block_key(line_text)
if block_key and block_key in blocked_row_cache:
return True
for keyword in BLOCKED_SESSION_KEYWORDS:
if self.normalize_match_text(keyword) in normalized_text:
if block_key:
blocked_row_cache[block_key] = line_text or keyword
save_blocked_row_cache()
return True
return False
except Exception as e:
return False
# 分析点击后的聊天区域截图,提取最新消息文本并返回分析结果
def analyze_clicked_session(self, window_rect, round_count, row_idx):
trace_id = new_trace_id("bot")
chat_box = self.screenshot.get_chat_capture_box(window_rect)
if not self.screenshot.is_valid_box(chat_box):
log_event("WARNING", "bot", "bot.chat_analyze", trace_id, "capture", "failed", "聊天区截图区域无效", reason="invalid_box")
return ChatAnalyzeResult(ok=False, file_name='', latest_text='', confidence='', bubble_side='', screenshot_path='')
screenshot = self.screenshot.capture_chat_area(window_rect)
file_name = f"round_{round_count:04d}_row_{row_idx:03d}_chat.png"
rel_path = os.path.join('sessions', 'clicked', file_name)
self._save_debug_image(screenshot, rel_path)
result = analyze_pil_image(screenshot, stem=os.path.splitext(file_name)[0], file_name=file_name)
latest_text = (getattr(result, 'latest_text', None) or '').strip()
confidence = getattr(result, 'confidence', '')
bubble_side = getattr(result, 'bubble_side', '')
log_event("INFO", "bot", "bot.chat_analyze", trace_id, "analyze", "ok", "聊天截图分析完成", extra={"round": int(round_count), "row_idx": int(row_idx), "has_text": bool(latest_text), "bubble_side": bubble_side or "", "confidence": confidence})
return ChatAnalyzeResult(
ok=bool(latest_text),
file_name=file_name,
latest_text=latest_text,
confidence=confidence,
bubble_side=bubble_side,
screenshot_path=rel_path,
)
# 保存会话扫描调试数据(类似聊天分析输出 result.json
def _save_session_scan_debug(self, round_count: int, sessions: list[dict], unread_sessions: list[dict], contact_rect: dict):
if not OCR_SAVE_IMAGES:
return
try:
debug_dir = os.path.join(OCR_SAVE_DIR, 'sessions', 'scan_debug')
os.makedirs(debug_dir, exist_ok=True)
file_name = f"round_{round_count:04d}_scan.json"
file_path = os.path.join(debug_dir, file_name)
rows = []
for session in sessions:
rows.append({
'row_idx': session.get('row_idx'),
'has_red_dot': bool(session.get('has_red_dot')),
'has_red_by_global': bool(session.get('has_red_by_global')),
'has_red_by_row': bool(session.get('has_red_by_row')),
'has_red_by_row_weak': bool(session.get('has_red_by_row_weak')),
'click_x': session.get('click_x'),
'click_y': session.get('click_y'),
'list_ocr_title': session.get('list_ocr_title', ''),
})
payload = {
'round': int(round_count),
'contact_rect': contact_rect,
'total_sessions': len(sessions),
'unread_count': len(unread_sessions),
'unread_rows': [s.get('row_idx') for s in unread_sessions],
'rows': rows,
}
with open(file_path, 'w', encoding='utf-8') as f:
json.dump(payload, f, ensure_ascii=False, indent=2)
except Exception as e:
pass
# 保存调试图片的内部方法
def _save_debug_image(self, image_obj, filename):
if not self.save_debug_image:
return
self.save_debug_image(image_obj, filename)