jecio/services/assistant_retrieval.py

116 lines
4.2 KiB
Python
Raw Permalink Normal View History

import hashlib
import json
import re
from dataclasses import dataclass
from typing import Any, Awaitable, Callable, Dict, List, Optional, Set
DEFAULT_QUERY_STOPWORDS: Set[str] = {
"the", "and", "for", "with", "that", "this", "from", "have", "has", "had", "are", "was", "were", "will",
"would", "should", "could", "can", "you", "your", "about", "what", "when", "where", "which", "who", "whom",
"why", "how", "tomorrow", "today", "please", "any", "there", "need", "want", "know", "does", "did", "done",
}
@dataclass(frozen=True)
class AssistantRetrievalConfig:
min_token_overlap: int = 1
source_candidate_multiplier: int = 4
source_min_coverage: float = 0.6
query_stopwords: Set[str] = frozenset(DEFAULT_QUERY_STOPWORDS)
def query_tokens(text: str, stopwords: Set[str]) -> Set[str]:
return {
t for t in re.findall(r"[a-z0-9]{3,}", (text or "").lower())
if t not in stopwords
}
def source_text_for_match(src: Dict[str, Any]) -> str:
return " ".join(
[
str(src.get("text") or ""),
str(src.get("description") or ""),
str(src.get("summary") or ""),
str(src.get("display_name") or ""),
str(src.get("canonical_name") or ""),
]
)
def is_strong_source_match(query: str, src: Dict[str, Any], cfg: AssistantRetrievalConfig) -> bool:
q_tokens = query_tokens(query, cfg.query_stopwords)
if not q_tokens:
return True
s_tokens = query_tokens(source_text_for_match(src), cfg.query_stopwords)
overlap = len(q_tokens.intersection(s_tokens))
q_len = max(1, len(q_tokens))
coverage = overlap / q_len
min_overlap = cfg.min_token_overlap
if q_len >= 2:
min_overlap = max(min_overlap, 2)
return overlap >= min_overlap and coverage >= cfg.source_min_coverage
async def retrieve_sources_two_stage(
query: str,
release_name: Optional[str],
max_sources: int,
include_release_recent_fallback: bool,
cfg: AssistantRetrievalConfig,
es_search_hits: Callable[[str, int, Optional[str]], Awaitable[List[Dict[str, Any]]]],
es_recent_by_release: Callable[[str, int], Awaitable[List[Dict[str, Any]]]],
) -> List[Dict[str, Any]]:
candidate_size = max(max_sources, max_sources * max(2, cfg.source_candidate_multiplier))
seen_keys: set[str] = set()
candidates: List[Dict[str, Any]] = []
def add_hits(hs: List[Dict[str, Any]]) -> None:
for h in hs:
src = h.get("_source", {}) or {}
key = str(src.get("concept_id") or src.get("source_pk") or "")
if not key:
key = hashlib.sha256(
json.dumps(src, ensure_ascii=False, sort_keys=True).encode("utf-8")
).hexdigest()[:20]
if key in seen_keys:
continue
seen_keys.add(key)
candidates.append(h)
try:
add_hits(await es_search_hits(query, candidate_size, release_name))
except Exception as e:
print(f"[WARN] stage1 release search failed: {e}")
if len(candidates) < max_sources:
try:
add_hits(await es_search_hits(query, candidate_size, None))
except Exception as e:
print(f"[WARN] stage1 global search failed: {e}")
if len(candidates) < max_sources and include_release_recent_fallback and release_name:
try:
add_hits(await es_recent_by_release(release_name, candidate_size))
except Exception as e:
print(f"[WARN] stage1 release-recent fallback failed: {e}")
q_tokens = query_tokens(query, cfg.query_stopwords)
ranked: List[Dict[str, Any]] = []
for h in candidates:
src = h.get("_source", {}) or {}
s_tokens = query_tokens(source_text_for_match(src), cfg.query_stopwords)
overlap = len(q_tokens.intersection(s_tokens)) if q_tokens else 0
base_score = float(h.get("_score") or 0.0)
ranked.append({"hit": h, "overlap": overlap, "base_score": base_score})
ranked.sort(key=lambda x: (x["overlap"], x["base_score"]), reverse=True)
relevant = []
for x in ranked:
src = x["hit"].get("_source", {}) or {}
if is_strong_source_match(query, src, cfg):
relevant.append(x)
if relevant:
return [x["hit"] for x in relevant[:max_sources]]
return []