116 lines
4.2 KiB
Python
116 lines
4.2 KiB
Python
|
|
import hashlib
|
||
|
|
import json
|
||
|
|
import re
|
||
|
|
from dataclasses import dataclass
|
||
|
|
from typing import Any, Awaitable, Callable, Dict, List, Optional, Set
|
||
|
|
|
||
|
|
|
||
|
|
DEFAULT_QUERY_STOPWORDS: Set[str] = {
|
||
|
|
"the", "and", "for", "with", "that", "this", "from", "have", "has", "had", "are", "was", "were", "will",
|
||
|
|
"would", "should", "could", "can", "you", "your", "about", "what", "when", "where", "which", "who", "whom",
|
||
|
|
"why", "how", "tomorrow", "today", "please", "any", "there", "need", "want", "know", "does", "did", "done",
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass(frozen=True)
|
||
|
|
class AssistantRetrievalConfig:
|
||
|
|
min_token_overlap: int = 1
|
||
|
|
source_candidate_multiplier: int = 4
|
||
|
|
source_min_coverage: float = 0.6
|
||
|
|
query_stopwords: Set[str] = frozenset(DEFAULT_QUERY_STOPWORDS)
|
||
|
|
|
||
|
|
|
||
|
|
def query_tokens(text: str, stopwords: Set[str]) -> Set[str]:
|
||
|
|
return {
|
||
|
|
t for t in re.findall(r"[a-z0-9]{3,}", (text or "").lower())
|
||
|
|
if t not in stopwords
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
def source_text_for_match(src: Dict[str, Any]) -> str:
|
||
|
|
return " ".join(
|
||
|
|
[
|
||
|
|
str(src.get("text") or ""),
|
||
|
|
str(src.get("description") or ""),
|
||
|
|
str(src.get("summary") or ""),
|
||
|
|
str(src.get("display_name") or ""),
|
||
|
|
str(src.get("canonical_name") or ""),
|
||
|
|
]
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def is_strong_source_match(query: str, src: Dict[str, Any], cfg: AssistantRetrievalConfig) -> bool:
|
||
|
|
q_tokens = query_tokens(query, cfg.query_stopwords)
|
||
|
|
if not q_tokens:
|
||
|
|
return True
|
||
|
|
s_tokens = query_tokens(source_text_for_match(src), cfg.query_stopwords)
|
||
|
|
overlap = len(q_tokens.intersection(s_tokens))
|
||
|
|
q_len = max(1, len(q_tokens))
|
||
|
|
coverage = overlap / q_len
|
||
|
|
min_overlap = cfg.min_token_overlap
|
||
|
|
if q_len >= 2:
|
||
|
|
min_overlap = max(min_overlap, 2)
|
||
|
|
return overlap >= min_overlap and coverage >= cfg.source_min_coverage
|
||
|
|
|
||
|
|
|
||
|
|
async def retrieve_sources_two_stage(
|
||
|
|
query: str,
|
||
|
|
release_name: Optional[str],
|
||
|
|
max_sources: int,
|
||
|
|
include_release_recent_fallback: bool,
|
||
|
|
cfg: AssistantRetrievalConfig,
|
||
|
|
es_search_hits: Callable[[str, int, Optional[str]], Awaitable[List[Dict[str, Any]]]],
|
||
|
|
es_recent_by_release: Callable[[str, int], Awaitable[List[Dict[str, Any]]]],
|
||
|
|
) -> List[Dict[str, Any]]:
|
||
|
|
candidate_size = max(max_sources, max_sources * max(2, cfg.source_candidate_multiplier))
|
||
|
|
seen_keys: set[str] = set()
|
||
|
|
candidates: List[Dict[str, Any]] = []
|
||
|
|
|
||
|
|
def add_hits(hs: List[Dict[str, Any]]) -> None:
|
||
|
|
for h in hs:
|
||
|
|
src = h.get("_source", {}) or {}
|
||
|
|
key = str(src.get("concept_id") or src.get("source_pk") or "")
|
||
|
|
if not key:
|
||
|
|
key = hashlib.sha256(
|
||
|
|
json.dumps(src, ensure_ascii=False, sort_keys=True).encode("utf-8")
|
||
|
|
).hexdigest()[:20]
|
||
|
|
if key in seen_keys:
|
||
|
|
continue
|
||
|
|
seen_keys.add(key)
|
||
|
|
candidates.append(h)
|
||
|
|
|
||
|
|
try:
|
||
|
|
add_hits(await es_search_hits(query, candidate_size, release_name))
|
||
|
|
except Exception as e:
|
||
|
|
print(f"[WARN] stage1 release search failed: {e}")
|
||
|
|
if len(candidates) < max_sources:
|
||
|
|
try:
|
||
|
|
add_hits(await es_search_hits(query, candidate_size, None))
|
||
|
|
except Exception as e:
|
||
|
|
print(f"[WARN] stage1 global search failed: {e}")
|
||
|
|
if len(candidates) < max_sources and include_release_recent_fallback and release_name:
|
||
|
|
try:
|
||
|
|
add_hits(await es_recent_by_release(release_name, candidate_size))
|
||
|
|
except Exception as e:
|
||
|
|
print(f"[WARN] stage1 release-recent fallback failed: {e}")
|
||
|
|
|
||
|
|
q_tokens = query_tokens(query, cfg.query_stopwords)
|
||
|
|
ranked: List[Dict[str, Any]] = []
|
||
|
|
for h in candidates:
|
||
|
|
src = h.get("_source", {}) or {}
|
||
|
|
s_tokens = query_tokens(source_text_for_match(src), cfg.query_stopwords)
|
||
|
|
overlap = len(q_tokens.intersection(s_tokens)) if q_tokens else 0
|
||
|
|
base_score = float(h.get("_score") or 0.0)
|
||
|
|
ranked.append({"hit": h, "overlap": overlap, "base_score": base_score})
|
||
|
|
|
||
|
|
ranked.sort(key=lambda x: (x["overlap"], x["base_score"]), reverse=True)
|
||
|
|
relevant = []
|
||
|
|
for x in ranked:
|
||
|
|
src = x["hit"].get("_source", {}) or {}
|
||
|
|
if is_strong_source_match(query, src, cfg):
|
||
|
|
relevant.append(x)
|
||
|
|
if relevant:
|
||
|
|
return [x["hit"] for x in relevant[:max_sources]]
|
||
|
|
return []
|
||
|
|
|