219 lines
7.3 KiB
Python
219 lines
7.3 KiB
Python
|
|
import json
|
|||
|
|
import os
|
|||
|
|
import re
|
|||
|
|
import sqlite3
|
|||
|
|
from pathlib import Path
|
|||
|
|
|
|||
|
|
from app.infrastructure.service.backend.config import PROJECT_ROOT
|
|||
|
|
from app.infrastructure.service.logging.log_service import log_event, new_trace_id
|
|||
|
|
|
|||
|
|
_SETTING_CACHE = {}
|
|||
|
|
SETTINGS_FILE = PROJECT_ROOT / "logs" / "state" / "local_settings.json"
|
|||
|
|
LOCAL_APPDATA_DIR = Path(os.environ.get("LOCALAPPDATA", str(Path.home() / "AppData" / "Local")))
|
|||
|
|
SQLITE_DB_PATH = LOCAL_APPDATA_DIR / "com.shiliu.aiassistant" / "ai_shiliu.sqlite3"
|
|||
|
|
|
|||
|
|
|
|||
|
|
class _SQLiteCursor:
|
|||
|
|
def __init__(self, cursor):
|
|||
|
|
self._cursor = cursor
|
|||
|
|
|
|||
|
|
def __enter__(self):
|
|||
|
|
return self
|
|||
|
|
|
|||
|
|
def __exit__(self, exc_type, exc, tb):
|
|||
|
|
self._cursor.close()
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def _adapt_sql(sql: str) -> str:
|
|||
|
|
return sql.replace("%s", "?")
|
|||
|
|
|
|||
|
|
def execute(self, sql, params=None):
|
|||
|
|
sql = self._adapt_sql(sql)
|
|||
|
|
if params is None:
|
|||
|
|
self._cursor.execute(sql)
|
|||
|
|
else:
|
|||
|
|
self._cursor.execute(sql, params)
|
|||
|
|
return self
|
|||
|
|
|
|||
|
|
def fetchone(self):
|
|||
|
|
row = self._cursor.fetchone()
|
|||
|
|
return dict(row) if row is not None else None
|
|||
|
|
|
|||
|
|
def fetchall(self):
|
|||
|
|
return [dict(row) for row in self._cursor.fetchall()]
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def lastrowid(self):
|
|||
|
|
return self._cursor.lastrowid
|
|||
|
|
|
|||
|
|
|
|||
|
|
class _SQLiteConn:
|
|||
|
|
def __init__(self, conn):
|
|||
|
|
self._conn = conn
|
|||
|
|
|
|||
|
|
def cursor(self):
|
|||
|
|
return _SQLiteCursor(self._conn.cursor())
|
|||
|
|
|
|||
|
|
def commit(self):
|
|||
|
|
self._conn.commit()
|
|||
|
|
|
|||
|
|
def rollback(self):
|
|||
|
|
self._conn.rollback()
|
|||
|
|
|
|||
|
|
def close(self):
|
|||
|
|
self._conn.close()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _bootstrap_sqlite_file():
|
|||
|
|
SQLITE_DB_PATH.parent.mkdir(parents=True, exist_ok=True)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_conn(db_name=None):
|
|||
|
|
_bootstrap_sqlite_file()
|
|||
|
|
conn = sqlite3.connect(str(SQLITE_DB_PATH))
|
|||
|
|
conn.row_factory = sqlite3.Row
|
|||
|
|
conn.execute("PRAGMA foreign_keys = ON")
|
|||
|
|
return _SQLiteConn(conn)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def init_db():
|
|||
|
|
trace_id = new_trace_id("db")
|
|||
|
|
log_event("INFO", "db", "db.init", trace_id, "start", "ok", "初始化数据库开始", extra={"path": str(SQLITE_DB_PATH)})
|
|||
|
|
conn = get_conn()
|
|||
|
|
try:
|
|||
|
|
with conn.cursor() as cur:
|
|||
|
|
cur.execute(
|
|||
|
|
"""
|
|||
|
|
CREATE TABLE IF NOT EXISTS messages (
|
|||
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|||
|
|
wx_user_id TEXT NOT NULL DEFAULT '',
|
|||
|
|
wx_nickname TEXT NOT NULL DEFAULT '',
|
|||
|
|
direction TEXT NOT NULL DEFAULT 'in',
|
|||
|
|
content TEXT NOT NULL,
|
|||
|
|
is_ai_reply INTEGER NOT NULL DEFAULT 0,
|
|||
|
|
rule_id INTEGER NULL,
|
|||
|
|
is_friend_request INTEGER NOT NULL DEFAULT 0,
|
|||
|
|
reply_strategy TEXT NOT NULL DEFAULT '',
|
|||
|
|
reply_reason TEXT NOT NULL DEFAULT '',
|
|||
|
|
ocr_confidence TEXT NOT NULL DEFAULT '',
|
|||
|
|
ocr_bubble_side TEXT NOT NULL DEFAULT '',
|
|||
|
|
created_at TEXT NOT NULL DEFAULT (datetime('now', 'localtime'))
|
|||
|
|
)
|
|||
|
|
"""
|
|||
|
|
)
|
|||
|
|
cur.execute("CREATE INDEX IF NOT EXISTS idx_messages_user_time ON messages(wx_user_id, created_at)")
|
|||
|
|
cur.execute(
|
|||
|
|
"""
|
|||
|
|
CREATE TABLE IF NOT EXISTS auto_reply_rules (
|
|||
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|||
|
|
keyword TEXT NOT NULL,
|
|||
|
|
match_type TEXT NOT NULL DEFAULT 'contain',
|
|||
|
|
reply_text TEXT NOT NULL,
|
|||
|
|
is_active INTEGER NOT NULL DEFAULT 1,
|
|||
|
|
created_at TEXT NOT NULL DEFAULT (datetime('now', 'localtime')),
|
|||
|
|
updated_at TEXT NOT NULL DEFAULT (datetime('now', 'localtime'))
|
|||
|
|
)
|
|||
|
|
"""
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
cur.execute("PRAGMA table_info(messages)")
|
|||
|
|
cols = {str(x.get('name') or '') for x in (cur.fetchall() or [])}
|
|||
|
|
if "reply_strategy" not in cols:
|
|||
|
|
cur.execute("ALTER TABLE messages ADD COLUMN reply_strategy TEXT NOT NULL DEFAULT ''")
|
|||
|
|
if "reply_reason" not in cols:
|
|||
|
|
cur.execute("ALTER TABLE messages ADD COLUMN reply_reason TEXT NOT NULL DEFAULT ''")
|
|||
|
|
if "ocr_confidence" not in cols:
|
|||
|
|
cur.execute("ALTER TABLE messages ADD COLUMN ocr_confidence TEXT NOT NULL DEFAULT ''")
|
|||
|
|
if "ocr_bubble_side" not in cols:
|
|||
|
|
cur.execute("ALTER TABLE messages ADD COLUMN ocr_bubble_side TEXT NOT NULL DEFAULT ''")
|
|||
|
|
|
|||
|
|
conn.commit()
|
|||
|
|
log_event("INFO", "db", "db.init", trace_id, "done", "ok", "初始化数据库完成")
|
|||
|
|
except Exception as exc:
|
|||
|
|
conn.rollback()
|
|||
|
|
log_event("ERROR", "db", "db.init", trace_id, "done", "failed", "初始化数据库失败", reason="db_error", extra={"error": str(exc)})
|
|||
|
|
raise
|
|||
|
|
finally:
|
|||
|
|
conn.close()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _load_settings_file():
|
|||
|
|
if _SETTING_CACHE:
|
|||
|
|
return
|
|||
|
|
try:
|
|||
|
|
if SETTINGS_FILE.exists():
|
|||
|
|
data = json.loads(SETTINGS_FILE.read_text(encoding="utf-8"))
|
|||
|
|
if isinstance(data, dict):
|
|||
|
|
for k, v in data.items():
|
|||
|
|
_SETTING_CACHE[str(k)] = str(v)
|
|||
|
|
except Exception:
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _save_settings_file():
|
|||
|
|
SETTINGS_FILE.parent.mkdir(parents=True, exist_ok=True)
|
|||
|
|
SETTINGS_FILE.write_text(json.dumps(_SETTING_CACHE, ensure_ascii=False, indent=2), encoding="utf-8")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_setting(key, default=None):
|
|||
|
|
_load_settings_file()
|
|||
|
|
if key in _SETTING_CACHE:
|
|||
|
|
return _SETTING_CACHE[key]
|
|||
|
|
if default is None:
|
|||
|
|
return None
|
|||
|
|
val = str(default)
|
|||
|
|
_SETTING_CACHE[key] = val
|
|||
|
|
_save_settings_file()
|
|||
|
|
return val
|
|||
|
|
|
|||
|
|
|
|||
|
|
def set_setting(key, value):
|
|||
|
|
_load_settings_file()
|
|||
|
|
_SETTING_CACHE[str(key)] = str(value)
|
|||
|
|
_save_settings_file()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def normalize_text(text):
|
|||
|
|
t = (text or "").strip().lower()
|
|||
|
|
t = re.sub(r"\s+", "", t)
|
|||
|
|
t = t.replace(":", ":")
|
|||
|
|
return t
|
|||
|
|
|
|||
|
|
|
|||
|
|
def find_rule_reply(content):
|
|||
|
|
trace_id = new_trace_id("db")
|
|||
|
|
conn = get_conn()
|
|||
|
|
try:
|
|||
|
|
with conn.cursor() as cur:
|
|||
|
|
cur.execute("SELECT * FROM auto_reply_rules WHERE is_active = 1 ORDER BY id ASC")
|
|||
|
|
rules = cur.fetchall()
|
|||
|
|
except Exception as exc:
|
|||
|
|
log_event("ERROR", "db", "db.rule.query", trace_id, "query", "failed", "查询规则失败", reason="db_error", extra={"error": str(exc)})
|
|||
|
|
raise
|
|||
|
|
finally:
|
|||
|
|
conn.close()
|
|||
|
|
|
|||
|
|
raw_content = (content or "").strip()
|
|||
|
|
content_lower = raw_content.lower()
|
|||
|
|
content_norm = normalize_text(raw_content)
|
|||
|
|
|
|||
|
|
for rule in rules:
|
|||
|
|
kw = (rule.get("keyword") or "").strip()
|
|||
|
|
if not kw:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
kw_lower = kw.lower()
|
|||
|
|
kw_norm = normalize_text(kw)
|
|||
|
|
match_type = rule.get("match_type")
|
|||
|
|
|
|||
|
|
if match_type == "equal":
|
|||
|
|
if content_lower == kw_lower or content_norm == kw_norm:
|
|||
|
|
log_event("INFO", "db", "db.rule.match", trace_id, "match", "ok", "命中规则", reason="rule_hit", extra={"rule_id": rule.get("id"), "match_type": match_type})
|
|||
|
|
return rule
|
|||
|
|
else:
|
|||
|
|
if kw_lower in content_lower or kw_norm in content_norm:
|
|||
|
|
log_event("INFO", "db", "db.rule.match", trace_id, "match", "ok", "命中规则", reason="rule_hit", extra={"rule_id": rule.get("id"), "match_type": match_type or "contain"})
|
|||
|
|
return rule
|
|||
|
|
log_event("INFO", "db", "db.rule.match", trace_id, "match", "ok", "未命中规则", reason="rule_miss", extra={"rule_count": len(rules)})
|
|||
|
|
return None
|