jecio/services/remote_ops.py

556 lines
22 KiB
Python

import asyncio
import base64
import json
import os
import shlex
import tempfile
import uuid
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from fastapi import HTTPException
@dataclass(frozen=True)
class RemoteOpsConfig:
ssh_host: str
remote_dir: str
ssh_bin: str
ssh_opts: str
scp_bin: str
scp_opts: str
timeout_sec: int
projector_remote_script: str
ingest_message_remote_script: str
ingest_messages_batch_remote_script: str
assistant_feedback_remote_script: str
assistant_feedback_query_remote_script: str
assistant_metrics_query_remote_script: str
assistant_action_remote_script: str
assistant_actions_query_remote_script: str
assistant_proposals_remote_script: str
assistant_proposals_query_remote_script: str
runs_remote_script: str
run_events_remote_script: str
imap_checkpoint_remote_script: str
create_messages_release_remote_script: str
def _tail(text: str, max_chars: int = 8000) -> str:
if len(text) <= max_chars:
return text
return text[-max_chars:]
def _b64(s: str) -> str:
return base64.b64encode(s.encode("utf-8")).decode("ascii")
def _extract_json_array_from_text(text: str) -> List[Dict[str, Any]]:
start = text.find("[")
end = text.rfind("]")
if start == -1 or end == -1 or end < start:
raise ValueError("No JSON array found in output")
candidate = text[start : end + 1]
obj = json.loads(candidate)
if not isinstance(obj, list):
raise ValueError("Parsed value is not a JSON array")
out: List[Dict[str, Any]] = []
for item in obj:
if isinstance(item, dict):
out.append(item)
return out
def _extract_json_object_from_text(text: str) -> Dict[str, Any]:
start = text.find("{")
end = text.rfind("}")
if start == -1 or end == -1 or end < start:
raise ValueError("No JSON object found in output")
candidate = text[start : end + 1]
obj = json.loads(candidate)
if not isinstance(obj, dict):
raise ValueError("Parsed value is not a JSON object")
return obj
class RemoteOps:
def __init__(self, cfg: RemoteOpsConfig):
self.cfg = cfg
def _ssh_args(self, command: str) -> List[str]:
return [self.cfg.ssh_bin, *shlex.split(self.cfg.ssh_opts), self.cfg.ssh_host, command]
async def _run_ssh(self, parts: List[str], timeout_error: str) -> Dict[str, Any]:
command = f"cd {shlex.quote(self.cfg.remote_dir)} && {' '.join(shlex.quote(p) for p in parts)}"
proc = await asyncio.create_subprocess_exec(
*self._ssh_args(command),
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
try:
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=self.cfg.timeout_sec)
except asyncio.TimeoutError:
proc.kill()
await proc.wait()
raise HTTPException(status_code=504, detail=timeout_error)
out = stdout.decode("utf-8", errors="replace")
err = stderr.decode("utf-8", errors="replace")
return {"code": proc.returncode, "out": out, "err": err}
def _error_detail(self, code: int, out: str, err: str) -> Dict[str, Any]:
return {
"host": self.cfg.ssh_host,
"remote_dir": self.cfg.remote_dir,
"exit_code": code,
"stdout_tail": _tail(out),
"stderr_tail": _tail(err),
}
async def run_remote_query_imap_checkpoint(
self,
host: str,
mailbox: str,
username: str,
table: str,
) -> Optional[int]:
res = await self._run_ssh(
[self.cfg.imap_checkpoint_remote_script, host, mailbox, username, table],
"IMAP checkpoint query timed out",
)
if res["code"] != 0:
raise HTTPException(status_code=502, detail=self._error_detail(res["code"], res["out"], res["err"]))
try:
obj = _extract_json_object_from_text(res["out"])
val = obj.get("max_uid")
if val is None:
return None
return int(val)
except Exception as e:
raise HTTPException(
status_code=502,
detail={
"message": f"Unable to parse IMAP checkpoint output: {e}",
"stdout_tail": _tail(res["out"]),
"stderr_tail": _tail(res["err"]),
},
)
async def run_remote_create_messages_release(self, release_name: str) -> Dict[str, Any]:
res = await self._run_ssh(
[self.cfg.create_messages_release_remote_script, release_name],
"Create messages release timed out",
)
result = {
**self._error_detail(res["code"], res["out"], res["err"]),
"release_name": release_name,
}
if res["code"] != 0:
raise HTTPException(status_code=502, detail=result)
return result
async def run_remote_projector(self, payload: Any) -> Dict[str, Any]:
parts = [
self.cfg.projector_remote_script,
"--release-name", str(getattr(payload, "release_name", "")),
"--targets", str(getattr(payload, "targets", "both")),
]
if getattr(payload, "concept_table", None):
parts.extend(["--concept-table", str(getattr(payload, "concept_table"))])
if bool(getattr(payload, "dry_run", False)):
parts.append("--dry-run")
res = await self._run_ssh(parts, "Projector execution timed out")
result = {
**self._error_detail(res["code"], res["out"], res["err"]),
"spark_read_done": "[STEP] spark_read_done" in res["out"],
"projection_done": "[STEP] projection_done" in res["out"],
}
if res["code"] != 0:
raise HTTPException(status_code=502, detail=result)
return result
async def run_remote_ingest_message(self, payload: Any) -> Dict[str, Any]:
parts = [
self.cfg.ingest_message_remote_script,
str(getattr(payload, "table")),
str(getattr(payload, "thread_id")),
str(getattr(payload, "message_id")),
str(getattr(payload, "sender")),
str(getattr(payload, "channel")),
str(getattr(payload, "sent_at") or ""),
_b64(str(getattr(payload, "body") or "")),
_b64(json.dumps(getattr(payload, "metadata", {}) or {}, ensure_ascii=False)),
]
res = await self._run_ssh(parts, "Message ingest execution timed out")
result = self._error_detail(res["code"], res["out"], res["err"])
if res["code"] != 0:
raise HTTPException(status_code=502, detail=result)
return result
async def run_remote_ingest_messages_batch(self, payload: Any) -> Dict[str, Any]:
rows = []
for m in list(getattr(payload, "messages", []) or []):
rows.append(
{
"thread_id": getattr(m, "thread_id"),
"message_id": getattr(m, "message_id"),
"sender": getattr(m, "sender"),
"channel": getattr(m, "channel"),
"sent_at": getattr(m, "sent_at"),
"body": getattr(m, "body"),
"metadata": getattr(m, "metadata"),
}
)
if not rows:
return {
"host": self.cfg.ssh_host,
"remote_dir": self.cfg.remote_dir,
"exit_code": 0,
"rows": 0,
"stdout_tail": "[INFO] No rows to ingest",
"stderr_tail": "",
}
local_tmp = tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False, encoding="utf-8")
remote_tmp = f"{self.cfg.remote_dir}/.ingest_messages_{uuid.uuid4().hex}.json"
try:
json.dump(rows, local_tmp, ensure_ascii=False)
local_tmp.flush()
local_tmp.close()
scp_target = f"{self.cfg.ssh_host}:{remote_tmp}"
scp_args = [self.cfg.scp_bin, *shlex.split(self.cfg.scp_opts), local_tmp.name, scp_target]
scp_proc = await asyncio.create_subprocess_exec(
*scp_args,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
try:
scp_stdout, scp_stderr = await asyncio.wait_for(scp_proc.communicate(), timeout=self.cfg.timeout_sec)
except asyncio.TimeoutError:
scp_proc.kill()
await scp_proc.wait()
raise HTTPException(status_code=504, detail="Batch payload upload timed out")
if scp_proc.returncode != 0:
raise HTTPException(
status_code=502,
detail={
"host": self.cfg.ssh_host,
"remote_dir": self.cfg.remote_dir,
"exit_code": scp_proc.returncode,
"stdout_tail": _tail(scp_stdout.decode("utf-8", errors="replace")),
"stderr_tail": _tail(scp_stderr.decode("utf-8", errors="replace")),
},
)
payload_arg = f"@{remote_tmp}"
parts = [
self.cfg.ingest_messages_batch_remote_script,
str(getattr(payload, "table")),
str(getattr(payload, "dedupe_mode")),
payload_arg,
]
batch_cmd = " ".join(shlex.quote(p) for p in parts)
command = (
f"cd {shlex.quote(self.cfg.remote_dir)} && "
f"({batch_cmd}); rc=$?; rm -f {shlex.quote(remote_tmp)}; exit $rc"
)
proc = await asyncio.create_subprocess_exec(
*self._ssh_args(command),
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
try:
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=self.cfg.timeout_sec)
except asyncio.TimeoutError:
proc.kill()
await proc.wait()
raise HTTPException(status_code=504, detail="Batch message ingest execution timed out")
finally:
try:
os.unlink(local_tmp.name)
except Exception:
pass
out = stdout.decode("utf-8", errors="replace")
err = stderr.decode("utf-8", errors="replace")
result = {
"host": self.cfg.ssh_host,
"remote_dir": self.cfg.remote_dir,
"exit_code": proc.returncode,
"rows": len(rows),
"stdout_tail": _tail(out),
"stderr_tail": _tail(err),
}
if proc.returncode != 0:
raise HTTPException(status_code=502, detail=result)
return result
async def run_remote_assistant_feedback(self, feedback_id: str, payload: Any, created_at_utc: str) -> Dict[str, Any]:
confidence = getattr(payload, "confidence", None)
conf = confidence if confidence is not None else 0.0
sources = [s.model_dump() for s in list(getattr(payload, "sources", []) or [])]
parts = [
self.cfg.assistant_feedback_remote_script,
feedback_id,
created_at_utc,
str(getattr(payload, "outcome")),
str(getattr(payload, "task_type")),
str(getattr(payload, "release_name") or ""),
f"{conf}",
"true" if bool(getattr(payload, "needs_review", False)) else "false",
_b64(str(getattr(payload, "goal") or "")),
_b64(str(getattr(payload, "draft") or "")),
_b64(str(getattr(payload, "final_text") or "")),
_b64(json.dumps(sources, ensure_ascii=False)),
_b64(str(getattr(payload, "notes") or "")),
]
res = await self._run_ssh(parts, "Assistant feedback execution timed out")
result = self._error_detail(res["code"], res["out"], res["err"])
if res["code"] != 0:
raise HTTPException(status_code=502, detail=result)
return result
async def run_remote_query_assistant_feedback(
self, outcome: Optional[str], task_type: Optional[str], release_name: Optional[str], limit: int
) -> Dict[str, Any]:
parts = [
self.cfg.assistant_feedback_query_remote_script,
outcome or "",
task_type or "",
release_name or "",
str(limit),
]
res = await self._run_ssh(parts, "Assistant feedback query timed out")
if res["code"] != 0:
raise HTTPException(status_code=502, detail=self._error_detail(res["code"], res["out"], res["err"]))
try:
rows = _extract_json_array_from_text(res["out"])
except Exception as e:
raise HTTPException(
status_code=502,
detail={
"message": f"Unable to parse feedback query output: {e}",
"stdout_tail": _tail(res["out"]),
"stderr_tail": _tail(res["err"]),
},
)
return {"host": self.cfg.ssh_host, "remote_dir": self.cfg.remote_dir, "rows": rows}
async def run_remote_query_assistant_metrics(
self, task_type: Optional[str], release_name: Optional[str], outcome: Optional[str], group_by: str, limit: int
) -> Dict[str, Any]:
parts = [
self.cfg.assistant_metrics_query_remote_script,
task_type or "",
release_name or "",
outcome or "",
group_by,
str(limit),
]
res = await self._run_ssh(parts, "Assistant metrics query timed out")
if res["code"] != 0:
raise HTTPException(status_code=502, detail=self._error_detail(res["code"], res["out"], res["err"]))
try:
rows = _extract_json_array_from_text(res["out"])
except Exception as e:
raise HTTPException(
status_code=502,
detail={
"message": f"Unable to parse metrics query output: {e}",
"stdout_tail": _tail(res["out"]),
"stderr_tail": _tail(res["err"]),
},
)
return {"host": self.cfg.ssh_host, "remote_dir": self.cfg.remote_dir, "rows": rows}
async def run_remote_assistant_action(
self, action_id: str, payload: Any, step: Any, status: str, output_json: Dict[str, Any], error_text: Optional[str], created_at_utc: str
) -> Dict[str, Any]:
parts = [
self.cfg.assistant_action_remote_script,
action_id,
created_at_utc,
str(getattr(payload, "task_type")),
str(getattr(payload, "release_name") or ""),
_b64(str(getattr(payload, "objective") or "")),
str(getattr(step, "step_id")),
_b64(str(getattr(step, "title") or "")),
str(getattr(step, "action_type")),
"true" if bool(getattr(step, "requires_approval", False)) else "false",
"true" if bool(getattr(payload, "approved", False)) else "false",
status,
_b64(json.dumps(output_json, ensure_ascii=False)),
_b64(error_text or ""),
]
res = await self._run_ssh(parts, "Assistant action logging timed out")
result = self._error_detail(res["code"], res["out"], res["err"])
if res["code"] != 0:
raise HTTPException(status_code=502, detail=result)
return result
async def run_remote_query_assistant_actions(
self,
status: Optional[str],
task_type: Optional[str],
release_name: Optional[str],
step_id: Optional[str],
action_type: Optional[str],
limit: int,
) -> Dict[str, Any]:
parts = [
self.cfg.assistant_actions_query_remote_script,
status or "",
task_type or "",
release_name or "",
step_id or "",
action_type or "",
str(limit),
]
res = await self._run_ssh(parts, "Assistant actions query timed out")
if res["code"] != 0:
raise HTTPException(status_code=502, detail=self._error_detail(res["code"], res["out"], res["err"]))
try:
rows = _extract_json_array_from_text(res["out"])
except Exception as e:
raise HTTPException(
status_code=502,
detail={
"message": f"Unable to parse actions query output: {e}",
"stdout_tail": _tail(res["out"]),
"stderr_tail": _tail(res["err"]),
},
)
return {"host": self.cfg.ssh_host, "remote_dir": self.cfg.remote_dir, "rows": rows}
async def run_remote_record_assistant_proposals(
self,
proposal_set_id: str,
created_at_utc: str,
objective: str,
release_name: Optional[str],
summary: str,
signals: Dict[str, Any],
proposals: List[Dict[str, Any]],
) -> Dict[str, Any]:
parts = [
self.cfg.assistant_proposals_remote_script,
proposal_set_id,
created_at_utc,
_b64(objective or ""),
release_name or "",
_b64(summary or ""),
_b64(json.dumps(signals or {}, ensure_ascii=False)),
_b64(json.dumps(proposals or [], ensure_ascii=False)),
]
res = await self._run_ssh(parts, "Assistant proposals logging timed out")
result = self._error_detail(res["code"], res["out"], res["err"])
if res["code"] != 0:
raise HTTPException(status_code=502, detail=result)
return result
async def run_remote_query_assistant_proposals(
self,
release_name: Optional[str],
proposal_set_id: Optional[str],
limit: int,
) -> Dict[str, Any]:
parts = [
self.cfg.assistant_proposals_query_remote_script,
release_name or "",
proposal_set_id or "",
str(limit),
]
res = await self._run_ssh(parts, "Assistant proposals query timed out")
if res["code"] != 0:
raise HTTPException(status_code=502, detail=self._error_detail(res["code"], res["out"], res["err"]))
try:
rows = _extract_json_array_from_text(res["out"])
except Exception as e:
raise HTTPException(
status_code=502,
detail={
"message": f"Unable to parse proposals query output: {e}",
"stdout_tail": _tail(res["out"]),
"stderr_tail": _tail(res["err"]),
},
)
return {"host": self.cfg.ssh_host, "remote_dir": self.cfg.remote_dir, "rows": rows}
async def run_remote_record_run(
self,
run_id: str,
run_type: str,
status: str,
started_at_utc: str,
finished_at_utc: str,
actor: str,
input_json: Dict[str, Any],
output_json: Optional[Dict[str, Any]],
error_text: Optional[str],
) -> None:
parts = [
self.cfg.runs_remote_script,
run_id,
run_type,
status,
started_at_utc,
finished_at_utc,
actor,
_b64(json.dumps(input_json, ensure_ascii=False)),
_b64(json.dumps(output_json, ensure_ascii=False) if output_json is not None else ""),
_b64(error_text or ""),
]
command = f"cd {shlex.quote(self.cfg.remote_dir)} && {' '.join(shlex.quote(p) for p in parts)}"
proc = await asyncio.create_subprocess_exec(
*self._ssh_args(command),
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=self.cfg.timeout_sec)
if proc.returncode != 0:
raise HTTPException(
status_code=502,
detail={
"message": "Failed to record run in Iceberg",
"host": self.cfg.ssh_host,
"exit_code": proc.returncode,
"stdout_tail": _tail(stdout.decode("utf-8", errors="replace")),
"stderr_tail": _tail(stderr.decode("utf-8", errors="replace")),
},
)
async def run_remote_record_event(self, run_id: str, event_type: str, detail_json: Dict[str, Any], created_at_utc: str) -> None:
parts = [
self.cfg.run_events_remote_script,
run_id,
event_type,
created_at_utc,
_b64(json.dumps(detail_json, ensure_ascii=False)),
]
command = f"cd {shlex.quote(self.cfg.remote_dir)} && {' '.join(shlex.quote(p) for p in parts)}"
proc = await asyncio.create_subprocess_exec(
*self._ssh_args(command),
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=self.cfg.timeout_sec)
if proc.returncode != 0:
raise HTTPException(
status_code=502,
detail={
"message": "Failed to record run event in Iceberg",
"host": self.cfg.ssh_host,
"exit_code": proc.returncode,
"stdout_tail": _tail(stdout.decode("utf-8", errors="replace")),
"stderr_tail": _tail(stderr.decode("utf-8", errors="replace")),
},
)
async def record_event_best_effort(self, run_id: str, event_type: str, detail_json: Dict[str, Any], created_at_utc: str) -> None:
try:
await self.run_remote_record_event(run_id, event_type, detail_json, created_at_utc)
except Exception as e:
print(f"[WARN] run event logging failed: run_id={run_id} event={event_type} error={e}")