import hashlib import json import re from dataclasses import dataclass from typing import Any, Awaitable, Callable, Dict, List, Optional, Set DEFAULT_QUERY_STOPWORDS: Set[str] = { "the", "and", "for", "with", "that", "this", "from", "have", "has", "had", "are", "was", "were", "will", "would", "should", "could", "can", "you", "your", "about", "what", "when", "where", "which", "who", "whom", "why", "how", "tomorrow", "today", "please", "any", "there", "need", "want", "know", "does", "did", "done", } @dataclass(frozen=True) class AssistantRetrievalConfig: min_token_overlap: int = 1 source_candidate_multiplier: int = 4 source_min_coverage: float = 0.6 query_stopwords: Set[str] = frozenset(DEFAULT_QUERY_STOPWORDS) def query_tokens(text: str, stopwords: Set[str]) -> Set[str]: return { t for t in re.findall(r"[a-z0-9]{3,}", (text or "").lower()) if t not in stopwords } def source_text_for_match(src: Dict[str, Any]) -> str: return " ".join( [ str(src.get("text") or ""), str(src.get("description") or ""), str(src.get("summary") or ""), str(src.get("display_name") or ""), str(src.get("canonical_name") or ""), ] ) def is_strong_source_match(query: str, src: Dict[str, Any], cfg: AssistantRetrievalConfig) -> bool: q_tokens = query_tokens(query, cfg.query_stopwords) if not q_tokens: return True s_tokens = query_tokens(source_text_for_match(src), cfg.query_stopwords) overlap = len(q_tokens.intersection(s_tokens)) q_len = max(1, len(q_tokens)) coverage = overlap / q_len min_overlap = cfg.min_token_overlap if q_len >= 2: min_overlap = max(min_overlap, 2) return overlap >= min_overlap and coverage >= cfg.source_min_coverage async def retrieve_sources_two_stage( query: str, release_name: Optional[str], max_sources: int, include_release_recent_fallback: bool, cfg: AssistantRetrievalConfig, es_search_hits: Callable[[str, int, Optional[str]], Awaitable[List[Dict[str, Any]]]], es_recent_by_release: Callable[[str, int], Awaitable[List[Dict[str, Any]]]], ) -> List[Dict[str, Any]]: candidate_size = max(max_sources, max_sources * max(2, cfg.source_candidate_multiplier)) seen_keys: set[str] = set() candidates: List[Dict[str, Any]] = [] def add_hits(hs: List[Dict[str, Any]]) -> None: for h in hs: src = h.get("_source", {}) or {} key = str(src.get("concept_id") or src.get("source_pk") or "") if not key: key = hashlib.sha256( json.dumps(src, ensure_ascii=False, sort_keys=True).encode("utf-8") ).hexdigest()[:20] if key in seen_keys: continue seen_keys.add(key) candidates.append(h) try: add_hits(await es_search_hits(query, candidate_size, release_name)) except Exception as e: print(f"[WARN] stage1 release search failed: {e}") if len(candidates) < max_sources: try: add_hits(await es_search_hits(query, candidate_size, None)) except Exception as e: print(f"[WARN] stage1 global search failed: {e}") if len(candidates) < max_sources and include_release_recent_fallback and release_name: try: add_hits(await es_recent_by_release(release_name, candidate_size)) except Exception as e: print(f"[WARN] stage1 release-recent fallback failed: {e}") q_tokens = query_tokens(query, cfg.query_stopwords) ranked: List[Dict[str, Any]] = [] for h in candidates: src = h.get("_source", {}) or {} s_tokens = query_tokens(source_text_for_match(src), cfg.query_stopwords) overlap = len(q_tokens.intersection(s_tokens)) if q_tokens else 0 base_score = float(h.get("_score") or 0.0) ranked.append({"hit": h, "overlap": overlap, "base_score": base_score}) ranked.sort(key=lambda x: (x["overlap"], x["base_score"]), reverse=True) relevant = [] for x in ranked: src = x["hit"].get("_source", {}) or {} if is_strong_source_match(query, src, cfg): relevant.append(x) if relevant: return [x["hit"] for x in relevant[:max_sources]] return []