289 lines
13 KiB
Python
289 lines
13 KiB
Python
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
from dataclasses import dataclass
|
|||
|
|
from io import BytesIO
|
|||
|
|
import json
|
|||
|
|
import os
|
|||
|
|
from typing import Callable
|
|||
|
|
|
|||
|
|
import cv2
|
|||
|
|
import numpy as np
|
|||
|
|
|
|||
|
|
from app.infrastructure.service.logging.log_service import log_event, new_trace_id
|
|||
|
|
from app.infrastructure.service.wechat.chat_snapshot_analyzer import analyze_pil_image
|
|||
|
|
from app.infrastructure.service.wechat.unread_session_analyzer import UnreadSessionAnalyzer
|
|||
|
|
from app.infrastructure.service.wechat.config import (
|
|||
|
|
BLOCKED_SESSION_KEYWORDS,
|
|||
|
|
CONTACT_ROW_HEIGHT,
|
|||
|
|
OCR_SAVE_DIR,
|
|||
|
|
OCR_SAVE_IMAGES,
|
|||
|
|
SESSION_NAME_HEIGHT,
|
|||
|
|
SESSION_NAME_LEFT_OFFSET,
|
|||
|
|
SESSION_NAME_TOP_OFFSET,
|
|||
|
|
SESSION_NAME_WIDTH,
|
|||
|
|
UI_NOISE_KEYWORDS,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 会话扫描结果数据类,包含所有会话列表和未读会话列表
|
|||
|
|
@dataclass
|
|||
|
|
class SessionScanResult:
|
|||
|
|
sessions: list[dict]
|
|||
|
|
unread_sessions: list[dict]
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 聊天快照分析结果数据类
|
|||
|
|
@dataclass
|
|||
|
|
class ChatAnalyzeResult:
|
|||
|
|
ok: bool
|
|||
|
|
file_name: str
|
|||
|
|
latest_text: str
|
|||
|
|
confidence: str | float
|
|||
|
|
bubble_side: str
|
|||
|
|
screenshot_path: str
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 微信会话服务类,处理会话列表扫描、红点检测和聊天截图分析
|
|||
|
|
class WechatSessionService:
|
|||
|
|
def __init__(self, screenshot_service, ocr_service, save_debug_image: Callable | None = None):
|
|||
|
|
self.screenshot = screenshot_service
|
|||
|
|
self.ocr = ocr_service
|
|||
|
|
self.save_debug_image = save_debug_image
|
|||
|
|
self._session_title_cache = {"value": "", "ts": 0.0}
|
|||
|
|
self.unread_analyzer = UnreadSessionAnalyzer()
|
|||
|
|
log_event("INFO", "bot", "bot.session_service.init", new_trace_id("bot"), "init", "ok", "会话服务初始化完成")
|
|||
|
|
|
|||
|
|
# 根据窗口矩形计算会话列表区域的位置
|
|||
|
|
def get_contact_list_rect(self, window_rect):
|
|||
|
|
box = self.screenshot.get_contact_list_box(window_rect)
|
|||
|
|
return {
|
|||
|
|
'left': box.left,
|
|||
|
|
'top': box.top,
|
|||
|
|
'right': box.right,
|
|||
|
|
'bottom': box.bottom,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# 从会话行图片中裁剪出会话名称区域
|
|||
|
|
def extract_session_name_image(self, row_img):
|
|||
|
|
return self.screenshot.crop_session_name(row_img)
|
|||
|
|
|
|||
|
|
# 检测会话列表中的所有红点位置(红色圆点表示未读消息)
|
|||
|
|
def detect_red_dots(self, window_rect):
|
|||
|
|
contact_rect = self.get_contact_list_rect(window_rect)
|
|||
|
|
screenshot = self.screenshot.capture_contact_list_default(window_rect)
|
|||
|
|
return self.unread_analyzer.detect_red_dots(contact_rect, screenshot)
|
|||
|
|
|
|||
|
|
# 检测单行会话图片中是否有未读红点标记(严格模式)
|
|||
|
|
def row_has_red_dot(self, row_img, relaxed=False):
|
|||
|
|
return self.unread_analyzer.row_has_red_dot(row_img, relaxed=relaxed)
|
|||
|
|
|
|||
|
|
# 检测单行会话图片中是否有未读红点标记(宽松模式)
|
|||
|
|
def row_has_red_dot_weak(self, row_img):
|
|||
|
|
return self.unread_analyzer.row_has_red_dot_weak(row_img)
|
|||
|
|
|
|||
|
|
# 扫描所有会话行,识别哪些有未读消息标记
|
|||
|
|
def get_all_sessions_with_unread(self, window_rect, round_count):
|
|||
|
|
trace_id = new_trace_id("bot")
|
|||
|
|
contact_rect = self.get_contact_list_rect(window_rect)
|
|||
|
|
screenshot = self.screenshot.capture_contact_list_default(window_rect)
|
|||
|
|
|
|||
|
|
sessions, unread_sessions = self.unread_analyzer.get_all_sessions_with_unread(
|
|||
|
|
contact_rect=contact_rect,
|
|||
|
|
screenshot=screenshot,
|
|||
|
|
round_count=round_count,
|
|||
|
|
save_debug_image=lambda image_obj, filename: self._save_debug_image(image_obj, filename),
|
|||
|
|
)
|
|||
|
|
self._save_session_scan_debug(round_count=round_count, sessions=sessions, unread_sessions=unread_sessions, contact_rect=contact_rect)
|
|||
|
|
log_event("INFO", "bot", "bot.session_scan", trace_id, "scan", "ok", "会话扫描完成", extra={"round": int(round_count), "total": len(sessions), "unread": len(unread_sessions)})
|
|||
|
|
return SessionScanResult(sessions=sessions, unread_sessions=unread_sessions)
|
|||
|
|
|
|||
|
|
# 标准化文本用于匹配:去除空格并转为小写
|
|||
|
|
def normalize_match_text(self, text):
|
|||
|
|
if not text:
|
|||
|
|
return ""
|
|||
|
|
text = str(text).strip().lower()
|
|||
|
|
return "".join(ch for ch in text if not ch.isspace())
|
|||
|
|
|
|||
|
|
# 生成会话屏蔽关键字的唯一标识key,用于缓存比对
|
|||
|
|
def make_block_key(self, text):
|
|||
|
|
normalized = self.normalize_match_text(text)
|
|||
|
|
if not normalized:
|
|||
|
|
return ""
|
|||
|
|
return f"title:{normalized}"
|
|||
|
|
|
|||
|
|
# 重置当前会话标题缓存
|
|||
|
|
def reset_session_title_cache(self):
|
|||
|
|
self._session_title_cache = {"value": "", "ts": 0.0}
|
|||
|
|
|
|||
|
|
# 通过OCR识别当前会话窗口的标题文字
|
|||
|
|
def get_session_title_by_ocr(self, window_rect):
|
|||
|
|
trace_id = new_trace_id("bot")
|
|||
|
|
try:
|
|||
|
|
if not window_rect:
|
|||
|
|
return ""
|
|||
|
|
area_name = "main"
|
|||
|
|
screenshot = self.screenshot.capture_session_title(window_rect)
|
|||
|
|
img_bytes = BytesIO()
|
|||
|
|
screenshot.save(img_bytes, format='PNG')
|
|||
|
|
valid = self.ocr.recognize_session_title(img_bytes.getvalue(), scene=f"session_title_{area_name}")
|
|||
|
|
if valid:
|
|||
|
|
title = valid[0]
|
|||
|
|
log_event("INFO", "bot", "bot.session_title", trace_id, "ocr", "ok", "会话标题识别成功", extra={"title": title})
|
|||
|
|
return title
|
|||
|
|
log_event("INFO", "bot", "bot.session_title", trace_id, "ocr", "failed", "会话标题识别为空", reason="empty_result")
|
|||
|
|
return ""
|
|||
|
|
except Exception as e:
|
|||
|
|
log_event("ERROR", "bot", "bot.session_title", trace_id, "ocr", "failed", "会话标题识别异常", reason="ocr_error", extra={"error": str(e)})
|
|||
|
|
return ""
|
|||
|
|
|
|||
|
|
# 获取当前会话标题,优先使用缓存避免频繁OCR调用
|
|||
|
|
def get_current_session_title(self, window_rect):
|
|||
|
|
try:
|
|||
|
|
import time
|
|||
|
|
|
|||
|
|
now_ts = time.time()
|
|||
|
|
cached_title = (self._session_title_cache.get("value") or "").strip()
|
|||
|
|
cached_ts = float(self._session_title_cache.get("ts") or 0.0)
|
|||
|
|
if cached_title and now_ts - cached_ts <= 1.2:
|
|||
|
|
return cached_title
|
|||
|
|
|
|||
|
|
title = (self.get_session_title_by_ocr(window_rect) or "").strip()
|
|||
|
|
if title and title not in UI_NOISE_KEYWORDS:
|
|||
|
|
self._session_title_cache = {"value": title, "ts": now_ts}
|
|||
|
|
return title
|
|||
|
|
except Exception as e:
|
|||
|
|
return ""
|
|||
|
|
|
|||
|
|
# 判断当前选中的会话是否应被跳过(点击后标题检查阶段)
|
|||
|
|
def should_skip_current_session(self, window_rect, session, blocked_row_cache, save_blocked_row_cache: Callable):
|
|||
|
|
title = self.get_current_session_title(window_rect)
|
|||
|
|
block_key = self.make_block_key(title)
|
|||
|
|
if block_key and block_key in blocked_row_cache:
|
|||
|
|
return True
|
|||
|
|
normalized_title = self.normalize_match_text(title)
|
|||
|
|
for keyword in BLOCKED_SESSION_KEYWORDS:
|
|||
|
|
if self.normalize_match_text(keyword) in normalized_title:
|
|||
|
|
if block_key:
|
|||
|
|
blocked_row_cache[block_key] = title or keyword
|
|||
|
|
save_blocked_row_cache()
|
|||
|
|
return True
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
# 比较两个会话名称是否匹配(考虑模糊匹配和大小写)
|
|||
|
|
def is_same_session(self, expected_session, current_session):
|
|||
|
|
expected = self.normalize_match_text(expected_session)
|
|||
|
|
current = self.normalize_match_text(current_session)
|
|||
|
|
if not expected or not current:
|
|||
|
|
return False
|
|||
|
|
return expected in current or current in expected
|
|||
|
|
|
|||
|
|
# 根据OCR识别结果判断会话列表中的会话是否应被跳过
|
|||
|
|
def should_skip_session_by_ocr(self, session, blocked_row_cache, save_blocked_row_cache: Callable):
|
|||
|
|
image_obj = session.get('row_img')
|
|||
|
|
if image_obj is None:
|
|||
|
|
return False
|
|||
|
|
try:
|
|||
|
|
name_img = self.extract_session_name_image(image_obj)
|
|||
|
|
if name_img is None:
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
crop_box = {
|
|||
|
|
'left': SESSION_NAME_LEFT_OFFSET,
|
|||
|
|
'top': SESSION_NAME_TOP_OFFSET,
|
|||
|
|
'width': SESSION_NAME_WIDTH,
|
|||
|
|
'height': SESSION_NAME_HEIGHT,
|
|||
|
|
'row_w': image_obj.size[0],
|
|||
|
|
'row_h': image_obj.size[1],
|
|||
|
|
'crop_w': name_img.size[0],
|
|||
|
|
'crop_h': name_img.size[1],
|
|||
|
|
}
|
|||
|
|
if OCR_SAVE_IMAGES:
|
|||
|
|
file_name = f"row_{int(session.get('row_idx', 0)):03d}_name_raw.png"
|
|||
|
|
self._save_debug_image(name_img, os.path.join('sessions', 'name_ocr', file_name))
|
|||
|
|
|
|||
|
|
img_bytes = BytesIO()
|
|||
|
|
name_img.save(img_bytes, format='PNG')
|
|||
|
|
lines = self.ocr.recognize_session_name(img_bytes.getvalue(), scene=f"session_row_{session.get('row_idx')}")
|
|||
|
|
line_text = ' '.join(lines)
|
|||
|
|
session['list_ocr_title'] = line_text
|
|||
|
|
normalized_text = self.normalize_match_text(line_text)
|
|||
|
|
block_key = self.make_block_key(line_text)
|
|||
|
|
if block_key and block_key in blocked_row_cache:
|
|||
|
|
return True
|
|||
|
|
for keyword in BLOCKED_SESSION_KEYWORDS:
|
|||
|
|
if self.normalize_match_text(keyword) in normalized_text:
|
|||
|
|
if block_key:
|
|||
|
|
blocked_row_cache[block_key] = line_text or keyword
|
|||
|
|
save_blocked_row_cache()
|
|||
|
|
return True
|
|||
|
|
return False
|
|||
|
|
except Exception as e:
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
# 分析点击后的聊天区域截图,提取最新消息文本并返回分析结果
|
|||
|
|
def analyze_clicked_session(self, window_rect, round_count, row_idx):
|
|||
|
|
trace_id = new_trace_id("bot")
|
|||
|
|
chat_box = self.screenshot.get_chat_capture_box(window_rect)
|
|||
|
|
if not self.screenshot.is_valid_box(chat_box):
|
|||
|
|
log_event("WARNING", "bot", "bot.chat_analyze", trace_id, "capture", "failed", "聊天区截图区域无效", reason="invalid_box")
|
|||
|
|
return ChatAnalyzeResult(ok=False, file_name='', latest_text='', confidence='', bubble_side='', screenshot_path='')
|
|||
|
|
|
|||
|
|
screenshot = self.screenshot.capture_chat_area(window_rect)
|
|||
|
|
file_name = f"round_{round_count:04d}_row_{row_idx:03d}_chat.png"
|
|||
|
|
rel_path = os.path.join('sessions', 'clicked', file_name)
|
|||
|
|
self._save_debug_image(screenshot, rel_path)
|
|||
|
|
result = analyze_pil_image(screenshot, stem=os.path.splitext(file_name)[0], file_name=file_name)
|
|||
|
|
latest_text = (getattr(result, 'latest_text', None) or '').strip()
|
|||
|
|
confidence = getattr(result, 'confidence', '')
|
|||
|
|
bubble_side = getattr(result, 'bubble_side', '')
|
|||
|
|
log_event("INFO", "bot", "bot.chat_analyze", trace_id, "analyze", "ok", "聊天截图分析完成", extra={"round": int(round_count), "row_idx": int(row_idx), "has_text": bool(latest_text), "bubble_side": bubble_side or "", "confidence": confidence})
|
|||
|
|
return ChatAnalyzeResult(
|
|||
|
|
ok=bool(latest_text),
|
|||
|
|
file_name=file_name,
|
|||
|
|
latest_text=latest_text,
|
|||
|
|
confidence=confidence,
|
|||
|
|
bubble_side=bubble_side,
|
|||
|
|
screenshot_path=rel_path,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 保存会话扫描调试数据(类似聊天分析输出 result.json)
|
|||
|
|
def _save_session_scan_debug(self, round_count: int, sessions: list[dict], unread_sessions: list[dict], contact_rect: dict):
|
|||
|
|
if not OCR_SAVE_IMAGES:
|
|||
|
|
return
|
|||
|
|
try:
|
|||
|
|
debug_dir = os.path.join(OCR_SAVE_DIR, 'sessions', 'scan_debug')
|
|||
|
|
os.makedirs(debug_dir, exist_ok=True)
|
|||
|
|
file_name = f"round_{round_count:04d}_scan.json"
|
|||
|
|
file_path = os.path.join(debug_dir, file_name)
|
|||
|
|
rows = []
|
|||
|
|
for session in sessions:
|
|||
|
|
rows.append({
|
|||
|
|
'row_idx': session.get('row_idx'),
|
|||
|
|
'has_red_dot': bool(session.get('has_red_dot')),
|
|||
|
|
'has_red_by_global': bool(session.get('has_red_by_global')),
|
|||
|
|
'has_red_by_row': bool(session.get('has_red_by_row')),
|
|||
|
|
'has_red_by_row_weak': bool(session.get('has_red_by_row_weak')),
|
|||
|
|
'click_x': session.get('click_x'),
|
|||
|
|
'click_y': session.get('click_y'),
|
|||
|
|
'list_ocr_title': session.get('list_ocr_title', ''),
|
|||
|
|
})
|
|||
|
|
payload = {
|
|||
|
|
'round': int(round_count),
|
|||
|
|
'contact_rect': contact_rect,
|
|||
|
|
'total_sessions': len(sessions),
|
|||
|
|
'unread_count': len(unread_sessions),
|
|||
|
|
'unread_rows': [s.get('row_idx') for s in unread_sessions],
|
|||
|
|
'rows': rows,
|
|||
|
|
}
|
|||
|
|
with open(file_path, 'w', encoding='utf-8') as f:
|
|||
|
|
json.dump(payload, f, ensure_ascii=False, indent=2)
|
|||
|
|
except Exception as e:
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
# 保存调试图片的内部方法
|
|||
|
|
def _save_debug_image(self, image_obj, filename):
|
|||
|
|
if not self.save_debug_image:
|
|||
|
|
return
|
|||
|
|
self.save_debug_image(image_obj, filename)
|