import asyncio import base64 import json import os import shlex import tempfile import uuid from dataclasses import dataclass from typing import Any, Dict, List, Optional from fastapi import HTTPException @dataclass(frozen=True) class RemoteOpsConfig: ssh_host: str remote_dir: str ssh_bin: str ssh_opts: str scp_bin: str scp_opts: str timeout_sec: int projector_remote_script: str ingest_message_remote_script: str ingest_messages_batch_remote_script: str assistant_feedback_remote_script: str assistant_feedback_query_remote_script: str assistant_metrics_query_remote_script: str assistant_action_remote_script: str assistant_actions_query_remote_script: str assistant_proposals_remote_script: str assistant_proposals_query_remote_script: str runs_remote_script: str run_events_remote_script: str imap_checkpoint_remote_script: str create_messages_release_remote_script: str def _tail(text: str, max_chars: int = 8000) -> str: if len(text) <= max_chars: return text return text[-max_chars:] def _b64(s: str) -> str: return base64.b64encode(s.encode("utf-8")).decode("ascii") def _extract_json_array_from_text(text: str) -> List[Dict[str, Any]]: start = text.find("[") end = text.rfind("]") if start == -1 or end == -1 or end < start: raise ValueError("No JSON array found in output") candidate = text[start : end + 1] obj = json.loads(candidate) if not isinstance(obj, list): raise ValueError("Parsed value is not a JSON array") out: List[Dict[str, Any]] = [] for item in obj: if isinstance(item, dict): out.append(item) return out def _extract_json_object_from_text(text: str) -> Dict[str, Any]: start = text.find("{") end = text.rfind("}") if start == -1 or end == -1 or end < start: raise ValueError("No JSON object found in output") candidate = text[start : end + 1] obj = json.loads(candidate) if not isinstance(obj, dict): raise ValueError("Parsed value is not a JSON object") return obj class RemoteOps: def __init__(self, cfg: RemoteOpsConfig): self.cfg = cfg def _ssh_args(self, command: str) -> List[str]: return [self.cfg.ssh_bin, *shlex.split(self.cfg.ssh_opts), self.cfg.ssh_host, command] async def _run_ssh(self, parts: List[str], timeout_error: str) -> Dict[str, Any]: command = f"cd {shlex.quote(self.cfg.remote_dir)} && {' '.join(shlex.quote(p) for p in parts)}" proc = await asyncio.create_subprocess_exec( *self._ssh_args(command), stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, ) try: stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=self.cfg.timeout_sec) except asyncio.TimeoutError: proc.kill() await proc.wait() raise HTTPException(status_code=504, detail=timeout_error) out = stdout.decode("utf-8", errors="replace") err = stderr.decode("utf-8", errors="replace") return {"code": proc.returncode, "out": out, "err": err} def _error_detail(self, code: int, out: str, err: str) -> Dict[str, Any]: return { "host": self.cfg.ssh_host, "remote_dir": self.cfg.remote_dir, "exit_code": code, "stdout_tail": _tail(out), "stderr_tail": _tail(err), } async def run_remote_query_imap_checkpoint( self, host: str, mailbox: str, username: str, table: str, ) -> Optional[int]: res = await self._run_ssh( [self.cfg.imap_checkpoint_remote_script, host, mailbox, username, table], "IMAP checkpoint query timed out", ) if res["code"] != 0: raise HTTPException(status_code=502, detail=self._error_detail(res["code"], res["out"], res["err"])) try: obj = _extract_json_object_from_text(res["out"]) val = obj.get("max_uid") if val is None: return None return int(val) except Exception as e: raise HTTPException( status_code=502, detail={ "message": f"Unable to parse IMAP checkpoint output: {e}", "stdout_tail": _tail(res["out"]), "stderr_tail": _tail(res["err"]), }, ) async def run_remote_create_messages_release(self, release_name: str) -> Dict[str, Any]: res = await self._run_ssh( [self.cfg.create_messages_release_remote_script, release_name], "Create messages release timed out", ) result = { **self._error_detail(res["code"], res["out"], res["err"]), "release_name": release_name, } if res["code"] != 0: raise HTTPException(status_code=502, detail=result) return result async def run_remote_projector(self, payload: Any) -> Dict[str, Any]: parts = [ self.cfg.projector_remote_script, "--release-name", str(getattr(payload, "release_name", "")), "--targets", str(getattr(payload, "targets", "both")), ] if getattr(payload, "concept_table", None): parts.extend(["--concept-table", str(getattr(payload, "concept_table"))]) if bool(getattr(payload, "dry_run", False)): parts.append("--dry-run") res = await self._run_ssh(parts, "Projector execution timed out") result = { **self._error_detail(res["code"], res["out"], res["err"]), "spark_read_done": "[STEP] spark_read_done" in res["out"], "projection_done": "[STEP] projection_done" in res["out"], } if res["code"] != 0: raise HTTPException(status_code=502, detail=result) return result async def run_remote_ingest_message(self, payload: Any) -> Dict[str, Any]: parts = [ self.cfg.ingest_message_remote_script, str(getattr(payload, "table")), str(getattr(payload, "thread_id")), str(getattr(payload, "message_id")), str(getattr(payload, "sender")), str(getattr(payload, "channel")), str(getattr(payload, "sent_at") or ""), _b64(str(getattr(payload, "body") or "")), _b64(json.dumps(getattr(payload, "metadata", {}) or {}, ensure_ascii=False)), ] res = await self._run_ssh(parts, "Message ingest execution timed out") result = self._error_detail(res["code"], res["out"], res["err"]) if res["code"] != 0: raise HTTPException(status_code=502, detail=result) return result async def run_remote_ingest_messages_batch(self, payload: Any) -> Dict[str, Any]: rows = [] for m in list(getattr(payload, "messages", []) or []): rows.append( { "thread_id": getattr(m, "thread_id"), "message_id": getattr(m, "message_id"), "sender": getattr(m, "sender"), "channel": getattr(m, "channel"), "sent_at": getattr(m, "sent_at"), "body": getattr(m, "body"), "metadata": getattr(m, "metadata"), } ) if not rows: return { "host": self.cfg.ssh_host, "remote_dir": self.cfg.remote_dir, "exit_code": 0, "rows": 0, "stdout_tail": "[INFO] No rows to ingest", "stderr_tail": "", } local_tmp = tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False, encoding="utf-8") remote_tmp = f"{self.cfg.remote_dir}/.ingest_messages_{uuid.uuid4().hex}.json" try: json.dump(rows, local_tmp, ensure_ascii=False) local_tmp.flush() local_tmp.close() scp_target = f"{self.cfg.ssh_host}:{remote_tmp}" scp_args = [self.cfg.scp_bin, *shlex.split(self.cfg.scp_opts), local_tmp.name, scp_target] scp_proc = await asyncio.create_subprocess_exec( *scp_args, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, ) try: scp_stdout, scp_stderr = await asyncio.wait_for(scp_proc.communicate(), timeout=self.cfg.timeout_sec) except asyncio.TimeoutError: scp_proc.kill() await scp_proc.wait() raise HTTPException(status_code=504, detail="Batch payload upload timed out") if scp_proc.returncode != 0: raise HTTPException( status_code=502, detail={ "host": self.cfg.ssh_host, "remote_dir": self.cfg.remote_dir, "exit_code": scp_proc.returncode, "stdout_tail": _tail(scp_stdout.decode("utf-8", errors="replace")), "stderr_tail": _tail(scp_stderr.decode("utf-8", errors="replace")), }, ) payload_arg = f"@{remote_tmp}" parts = [ self.cfg.ingest_messages_batch_remote_script, str(getattr(payload, "table")), str(getattr(payload, "dedupe_mode")), payload_arg, ] batch_cmd = " ".join(shlex.quote(p) for p in parts) command = ( f"cd {shlex.quote(self.cfg.remote_dir)} && " f"({batch_cmd}); rc=$?; rm -f {shlex.quote(remote_tmp)}; exit $rc" ) proc = await asyncio.create_subprocess_exec( *self._ssh_args(command), stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, ) try: stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=self.cfg.timeout_sec) except asyncio.TimeoutError: proc.kill() await proc.wait() raise HTTPException(status_code=504, detail="Batch message ingest execution timed out") finally: try: os.unlink(local_tmp.name) except Exception: pass out = stdout.decode("utf-8", errors="replace") err = stderr.decode("utf-8", errors="replace") result = { "host": self.cfg.ssh_host, "remote_dir": self.cfg.remote_dir, "exit_code": proc.returncode, "rows": len(rows), "stdout_tail": _tail(out), "stderr_tail": _tail(err), } if proc.returncode != 0: raise HTTPException(status_code=502, detail=result) return result async def run_remote_assistant_feedback(self, feedback_id: str, payload: Any, created_at_utc: str) -> Dict[str, Any]: confidence = getattr(payload, "confidence", None) conf = confidence if confidence is not None else 0.0 sources = [s.model_dump() for s in list(getattr(payload, "sources", []) or [])] parts = [ self.cfg.assistant_feedback_remote_script, feedback_id, created_at_utc, str(getattr(payload, "outcome")), str(getattr(payload, "task_type")), str(getattr(payload, "release_name") or ""), f"{conf}", "true" if bool(getattr(payload, "needs_review", False)) else "false", _b64(str(getattr(payload, "goal") or "")), _b64(str(getattr(payload, "draft") or "")), _b64(str(getattr(payload, "final_text") or "")), _b64(json.dumps(sources, ensure_ascii=False)), _b64(str(getattr(payload, "notes") or "")), ] res = await self._run_ssh(parts, "Assistant feedback execution timed out") result = self._error_detail(res["code"], res["out"], res["err"]) if res["code"] != 0: raise HTTPException(status_code=502, detail=result) return result async def run_remote_query_assistant_feedback( self, outcome: Optional[str], task_type: Optional[str], release_name: Optional[str], limit: int ) -> Dict[str, Any]: parts = [ self.cfg.assistant_feedback_query_remote_script, outcome or "", task_type or "", release_name or "", str(limit), ] res = await self._run_ssh(parts, "Assistant feedback query timed out") if res["code"] != 0: raise HTTPException(status_code=502, detail=self._error_detail(res["code"], res["out"], res["err"])) try: rows = _extract_json_array_from_text(res["out"]) except Exception as e: raise HTTPException( status_code=502, detail={ "message": f"Unable to parse feedback query output: {e}", "stdout_tail": _tail(res["out"]), "stderr_tail": _tail(res["err"]), }, ) return {"host": self.cfg.ssh_host, "remote_dir": self.cfg.remote_dir, "rows": rows} async def run_remote_query_assistant_metrics( self, task_type: Optional[str], release_name: Optional[str], outcome: Optional[str], group_by: str, limit: int ) -> Dict[str, Any]: parts = [ self.cfg.assistant_metrics_query_remote_script, task_type or "", release_name or "", outcome or "", group_by, str(limit), ] res = await self._run_ssh(parts, "Assistant metrics query timed out") if res["code"] != 0: raise HTTPException(status_code=502, detail=self._error_detail(res["code"], res["out"], res["err"])) try: rows = _extract_json_array_from_text(res["out"]) except Exception as e: raise HTTPException( status_code=502, detail={ "message": f"Unable to parse metrics query output: {e}", "stdout_tail": _tail(res["out"]), "stderr_tail": _tail(res["err"]), }, ) return {"host": self.cfg.ssh_host, "remote_dir": self.cfg.remote_dir, "rows": rows} async def run_remote_assistant_action( self, action_id: str, payload: Any, step: Any, status: str, output_json: Dict[str, Any], error_text: Optional[str], created_at_utc: str ) -> Dict[str, Any]: parts = [ self.cfg.assistant_action_remote_script, action_id, created_at_utc, str(getattr(payload, "task_type")), str(getattr(payload, "release_name") or ""), _b64(str(getattr(payload, "objective") or "")), str(getattr(step, "step_id")), _b64(str(getattr(step, "title") or "")), str(getattr(step, "action_type")), "true" if bool(getattr(step, "requires_approval", False)) else "false", "true" if bool(getattr(payload, "approved", False)) else "false", status, _b64(json.dumps(output_json, ensure_ascii=False)), _b64(error_text or ""), ] res = await self._run_ssh(parts, "Assistant action logging timed out") result = self._error_detail(res["code"], res["out"], res["err"]) if res["code"] != 0: raise HTTPException(status_code=502, detail=result) return result async def run_remote_query_assistant_actions( self, status: Optional[str], task_type: Optional[str], release_name: Optional[str], step_id: Optional[str], action_type: Optional[str], limit: int, ) -> Dict[str, Any]: parts = [ self.cfg.assistant_actions_query_remote_script, status or "", task_type or "", release_name or "", step_id or "", action_type or "", str(limit), ] res = await self._run_ssh(parts, "Assistant actions query timed out") if res["code"] != 0: raise HTTPException(status_code=502, detail=self._error_detail(res["code"], res["out"], res["err"])) try: rows = _extract_json_array_from_text(res["out"]) except Exception as e: raise HTTPException( status_code=502, detail={ "message": f"Unable to parse actions query output: {e}", "stdout_tail": _tail(res["out"]), "stderr_tail": _tail(res["err"]), }, ) return {"host": self.cfg.ssh_host, "remote_dir": self.cfg.remote_dir, "rows": rows} async def run_remote_record_assistant_proposals( self, proposal_set_id: str, created_at_utc: str, objective: str, release_name: Optional[str], summary: str, signals: Dict[str, Any], proposals: List[Dict[str, Any]], ) -> Dict[str, Any]: parts = [ self.cfg.assistant_proposals_remote_script, proposal_set_id, created_at_utc, _b64(objective or ""), release_name or "", _b64(summary or ""), _b64(json.dumps(signals or {}, ensure_ascii=False)), _b64(json.dumps(proposals or [], ensure_ascii=False)), ] res = await self._run_ssh(parts, "Assistant proposals logging timed out") result = self._error_detail(res["code"], res["out"], res["err"]) if res["code"] != 0: raise HTTPException(status_code=502, detail=result) return result async def run_remote_query_assistant_proposals( self, release_name: Optional[str], proposal_set_id: Optional[str], limit: int, ) -> Dict[str, Any]: parts = [ self.cfg.assistant_proposals_query_remote_script, release_name or "", proposal_set_id or "", str(limit), ] res = await self._run_ssh(parts, "Assistant proposals query timed out") if res["code"] != 0: raise HTTPException(status_code=502, detail=self._error_detail(res["code"], res["out"], res["err"])) try: rows = _extract_json_array_from_text(res["out"]) except Exception as e: raise HTTPException( status_code=502, detail={ "message": f"Unable to parse proposals query output: {e}", "stdout_tail": _tail(res["out"]), "stderr_tail": _tail(res["err"]), }, ) return {"host": self.cfg.ssh_host, "remote_dir": self.cfg.remote_dir, "rows": rows} async def run_remote_record_run( self, run_id: str, run_type: str, status: str, started_at_utc: str, finished_at_utc: str, actor: str, input_json: Dict[str, Any], output_json: Optional[Dict[str, Any]], error_text: Optional[str], ) -> None: parts = [ self.cfg.runs_remote_script, run_id, run_type, status, started_at_utc, finished_at_utc, actor, _b64(json.dumps(input_json, ensure_ascii=False)), _b64(json.dumps(output_json, ensure_ascii=False) if output_json is not None else ""), _b64(error_text or ""), ] command = f"cd {shlex.quote(self.cfg.remote_dir)} && {' '.join(shlex.quote(p) for p in parts)}" proc = await asyncio.create_subprocess_exec( *self._ssh_args(command), stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, ) stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=self.cfg.timeout_sec) if proc.returncode != 0: raise HTTPException( status_code=502, detail={ "message": "Failed to record run in Iceberg", "host": self.cfg.ssh_host, "exit_code": proc.returncode, "stdout_tail": _tail(stdout.decode("utf-8", errors="replace")), "stderr_tail": _tail(stderr.decode("utf-8", errors="replace")), }, ) async def run_remote_record_event(self, run_id: str, event_type: str, detail_json: Dict[str, Any], created_at_utc: str) -> None: parts = [ self.cfg.run_events_remote_script, run_id, event_type, created_at_utc, _b64(json.dumps(detail_json, ensure_ascii=False)), ] command = f"cd {shlex.quote(self.cfg.remote_dir)} && {' '.join(shlex.quote(p) for p in parts)}" proc = await asyncio.create_subprocess_exec( *self._ssh_args(command), stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, ) stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=self.cfg.timeout_sec) if proc.returncode != 0: raise HTTPException( status_code=502, detail={ "message": "Failed to record run event in Iceberg", "host": self.cfg.ssh_host, "exit_code": proc.returncode, "stdout_tail": _tail(stdout.decode("utf-8", errors="replace")), "stderr_tail": _tail(stderr.decode("utf-8", errors="replace")), }, ) async def record_event_best_effort(self, run_id: str, event_type: str, detail_json: Dict[str, Any], created_at_utc: str) -> None: try: await self.run_remote_record_event(run_id, event_type, detail_json, created_at_utc) except Exception as e: print(f"[WARN] run event logging failed: run_id={run_id} event={event_type} error={e}")