556 lines
22 KiB
Python
556 lines
22 KiB
Python
import asyncio
|
|
import base64
|
|
import json
|
|
import os
|
|
import shlex
|
|
import tempfile
|
|
import uuid
|
|
from dataclasses import dataclass
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from fastapi import HTTPException
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class RemoteOpsConfig:
|
|
ssh_host: str
|
|
remote_dir: str
|
|
ssh_bin: str
|
|
ssh_opts: str
|
|
scp_bin: str
|
|
scp_opts: str
|
|
timeout_sec: int
|
|
projector_remote_script: str
|
|
ingest_message_remote_script: str
|
|
ingest_messages_batch_remote_script: str
|
|
assistant_feedback_remote_script: str
|
|
assistant_feedback_query_remote_script: str
|
|
assistant_metrics_query_remote_script: str
|
|
assistant_action_remote_script: str
|
|
assistant_actions_query_remote_script: str
|
|
assistant_proposals_remote_script: str
|
|
assistant_proposals_query_remote_script: str
|
|
runs_remote_script: str
|
|
run_events_remote_script: str
|
|
imap_checkpoint_remote_script: str
|
|
create_messages_release_remote_script: str
|
|
|
|
|
|
def _tail(text: str, max_chars: int = 8000) -> str:
|
|
if len(text) <= max_chars:
|
|
return text
|
|
return text[-max_chars:]
|
|
|
|
|
|
def _b64(s: str) -> str:
|
|
return base64.b64encode(s.encode("utf-8")).decode("ascii")
|
|
|
|
|
|
def _extract_json_array_from_text(text: str) -> List[Dict[str, Any]]:
|
|
start = text.find("[")
|
|
end = text.rfind("]")
|
|
if start == -1 or end == -1 or end < start:
|
|
raise ValueError("No JSON array found in output")
|
|
candidate = text[start : end + 1]
|
|
obj = json.loads(candidate)
|
|
if not isinstance(obj, list):
|
|
raise ValueError("Parsed value is not a JSON array")
|
|
out: List[Dict[str, Any]] = []
|
|
for item in obj:
|
|
if isinstance(item, dict):
|
|
out.append(item)
|
|
return out
|
|
|
|
|
|
def _extract_json_object_from_text(text: str) -> Dict[str, Any]:
|
|
start = text.find("{")
|
|
end = text.rfind("}")
|
|
if start == -1 or end == -1 or end < start:
|
|
raise ValueError("No JSON object found in output")
|
|
candidate = text[start : end + 1]
|
|
obj = json.loads(candidate)
|
|
if not isinstance(obj, dict):
|
|
raise ValueError("Parsed value is not a JSON object")
|
|
return obj
|
|
|
|
|
|
class RemoteOps:
|
|
def __init__(self, cfg: RemoteOpsConfig):
|
|
self.cfg = cfg
|
|
|
|
def _ssh_args(self, command: str) -> List[str]:
|
|
return [self.cfg.ssh_bin, *shlex.split(self.cfg.ssh_opts), self.cfg.ssh_host, command]
|
|
|
|
async def _run_ssh(self, parts: List[str], timeout_error: str) -> Dict[str, Any]:
|
|
command = f"cd {shlex.quote(self.cfg.remote_dir)} && {' '.join(shlex.quote(p) for p in parts)}"
|
|
proc = await asyncio.create_subprocess_exec(
|
|
*self._ssh_args(command),
|
|
stdout=asyncio.subprocess.PIPE,
|
|
stderr=asyncio.subprocess.PIPE,
|
|
)
|
|
try:
|
|
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=self.cfg.timeout_sec)
|
|
except asyncio.TimeoutError:
|
|
proc.kill()
|
|
await proc.wait()
|
|
raise HTTPException(status_code=504, detail=timeout_error)
|
|
out = stdout.decode("utf-8", errors="replace")
|
|
err = stderr.decode("utf-8", errors="replace")
|
|
return {"code": proc.returncode, "out": out, "err": err}
|
|
|
|
def _error_detail(self, code: int, out: str, err: str) -> Dict[str, Any]:
|
|
return {
|
|
"host": self.cfg.ssh_host,
|
|
"remote_dir": self.cfg.remote_dir,
|
|
"exit_code": code,
|
|
"stdout_tail": _tail(out),
|
|
"stderr_tail": _tail(err),
|
|
}
|
|
|
|
async def run_remote_query_imap_checkpoint(
|
|
self,
|
|
host: str,
|
|
mailbox: str,
|
|
username: str,
|
|
table: str,
|
|
) -> Optional[int]:
|
|
res = await self._run_ssh(
|
|
[self.cfg.imap_checkpoint_remote_script, host, mailbox, username, table],
|
|
"IMAP checkpoint query timed out",
|
|
)
|
|
if res["code"] != 0:
|
|
raise HTTPException(status_code=502, detail=self._error_detail(res["code"], res["out"], res["err"]))
|
|
try:
|
|
obj = _extract_json_object_from_text(res["out"])
|
|
val = obj.get("max_uid")
|
|
if val is None:
|
|
return None
|
|
return int(val)
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
status_code=502,
|
|
detail={
|
|
"message": f"Unable to parse IMAP checkpoint output: {e}",
|
|
"stdout_tail": _tail(res["out"]),
|
|
"stderr_tail": _tail(res["err"]),
|
|
},
|
|
)
|
|
|
|
async def run_remote_create_messages_release(self, release_name: str) -> Dict[str, Any]:
|
|
res = await self._run_ssh(
|
|
[self.cfg.create_messages_release_remote_script, release_name],
|
|
"Create messages release timed out",
|
|
)
|
|
result = {
|
|
**self._error_detail(res["code"], res["out"], res["err"]),
|
|
"release_name": release_name,
|
|
}
|
|
if res["code"] != 0:
|
|
raise HTTPException(status_code=502, detail=result)
|
|
return result
|
|
|
|
async def run_remote_projector(self, payload: Any) -> Dict[str, Any]:
|
|
parts = [
|
|
self.cfg.projector_remote_script,
|
|
"--release-name", str(getattr(payload, "release_name", "")),
|
|
"--targets", str(getattr(payload, "targets", "both")),
|
|
]
|
|
if getattr(payload, "concept_table", None):
|
|
parts.extend(["--concept-table", str(getattr(payload, "concept_table"))])
|
|
if bool(getattr(payload, "dry_run", False)):
|
|
parts.append("--dry-run")
|
|
|
|
res = await self._run_ssh(parts, "Projector execution timed out")
|
|
result = {
|
|
**self._error_detail(res["code"], res["out"], res["err"]),
|
|
"spark_read_done": "[STEP] spark_read_done" in res["out"],
|
|
"projection_done": "[STEP] projection_done" in res["out"],
|
|
}
|
|
if res["code"] != 0:
|
|
raise HTTPException(status_code=502, detail=result)
|
|
return result
|
|
|
|
async def run_remote_ingest_message(self, payload: Any) -> Dict[str, Any]:
|
|
parts = [
|
|
self.cfg.ingest_message_remote_script,
|
|
str(getattr(payload, "table")),
|
|
str(getattr(payload, "thread_id")),
|
|
str(getattr(payload, "message_id")),
|
|
str(getattr(payload, "sender")),
|
|
str(getattr(payload, "channel")),
|
|
str(getattr(payload, "sent_at") or ""),
|
|
_b64(str(getattr(payload, "body") or "")),
|
|
_b64(json.dumps(getattr(payload, "metadata", {}) or {}, ensure_ascii=False)),
|
|
]
|
|
res = await self._run_ssh(parts, "Message ingest execution timed out")
|
|
result = self._error_detail(res["code"], res["out"], res["err"])
|
|
if res["code"] != 0:
|
|
raise HTTPException(status_code=502, detail=result)
|
|
return result
|
|
|
|
async def run_remote_ingest_messages_batch(self, payload: Any) -> Dict[str, Any]:
|
|
rows = []
|
|
for m in list(getattr(payload, "messages", []) or []):
|
|
rows.append(
|
|
{
|
|
"thread_id": getattr(m, "thread_id"),
|
|
"message_id": getattr(m, "message_id"),
|
|
"sender": getattr(m, "sender"),
|
|
"channel": getattr(m, "channel"),
|
|
"sent_at": getattr(m, "sent_at"),
|
|
"body": getattr(m, "body"),
|
|
"metadata": getattr(m, "metadata"),
|
|
}
|
|
)
|
|
if not rows:
|
|
return {
|
|
"host": self.cfg.ssh_host,
|
|
"remote_dir": self.cfg.remote_dir,
|
|
"exit_code": 0,
|
|
"rows": 0,
|
|
"stdout_tail": "[INFO] No rows to ingest",
|
|
"stderr_tail": "",
|
|
}
|
|
|
|
local_tmp = tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False, encoding="utf-8")
|
|
remote_tmp = f"{self.cfg.remote_dir}/.ingest_messages_{uuid.uuid4().hex}.json"
|
|
try:
|
|
json.dump(rows, local_tmp, ensure_ascii=False)
|
|
local_tmp.flush()
|
|
local_tmp.close()
|
|
|
|
scp_target = f"{self.cfg.ssh_host}:{remote_tmp}"
|
|
scp_args = [self.cfg.scp_bin, *shlex.split(self.cfg.scp_opts), local_tmp.name, scp_target]
|
|
scp_proc = await asyncio.create_subprocess_exec(
|
|
*scp_args,
|
|
stdout=asyncio.subprocess.PIPE,
|
|
stderr=asyncio.subprocess.PIPE,
|
|
)
|
|
try:
|
|
scp_stdout, scp_stderr = await asyncio.wait_for(scp_proc.communicate(), timeout=self.cfg.timeout_sec)
|
|
except asyncio.TimeoutError:
|
|
scp_proc.kill()
|
|
await scp_proc.wait()
|
|
raise HTTPException(status_code=504, detail="Batch payload upload timed out")
|
|
if scp_proc.returncode != 0:
|
|
raise HTTPException(
|
|
status_code=502,
|
|
detail={
|
|
"host": self.cfg.ssh_host,
|
|
"remote_dir": self.cfg.remote_dir,
|
|
"exit_code": scp_proc.returncode,
|
|
"stdout_tail": _tail(scp_stdout.decode("utf-8", errors="replace")),
|
|
"stderr_tail": _tail(scp_stderr.decode("utf-8", errors="replace")),
|
|
},
|
|
)
|
|
|
|
payload_arg = f"@{remote_tmp}"
|
|
parts = [
|
|
self.cfg.ingest_messages_batch_remote_script,
|
|
str(getattr(payload, "table")),
|
|
str(getattr(payload, "dedupe_mode")),
|
|
payload_arg,
|
|
]
|
|
batch_cmd = " ".join(shlex.quote(p) for p in parts)
|
|
command = (
|
|
f"cd {shlex.quote(self.cfg.remote_dir)} && "
|
|
f"({batch_cmd}); rc=$?; rm -f {shlex.quote(remote_tmp)}; exit $rc"
|
|
)
|
|
proc = await asyncio.create_subprocess_exec(
|
|
*self._ssh_args(command),
|
|
stdout=asyncio.subprocess.PIPE,
|
|
stderr=asyncio.subprocess.PIPE,
|
|
)
|
|
try:
|
|
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=self.cfg.timeout_sec)
|
|
except asyncio.TimeoutError:
|
|
proc.kill()
|
|
await proc.wait()
|
|
raise HTTPException(status_code=504, detail="Batch message ingest execution timed out")
|
|
finally:
|
|
try:
|
|
os.unlink(local_tmp.name)
|
|
except Exception:
|
|
pass
|
|
|
|
out = stdout.decode("utf-8", errors="replace")
|
|
err = stderr.decode("utf-8", errors="replace")
|
|
result = {
|
|
"host": self.cfg.ssh_host,
|
|
"remote_dir": self.cfg.remote_dir,
|
|
"exit_code": proc.returncode,
|
|
"rows": len(rows),
|
|
"stdout_tail": _tail(out),
|
|
"stderr_tail": _tail(err),
|
|
}
|
|
if proc.returncode != 0:
|
|
raise HTTPException(status_code=502, detail=result)
|
|
return result
|
|
|
|
async def run_remote_assistant_feedback(self, feedback_id: str, payload: Any, created_at_utc: str) -> Dict[str, Any]:
|
|
confidence = getattr(payload, "confidence", None)
|
|
conf = confidence if confidence is not None else 0.0
|
|
sources = [s.model_dump() for s in list(getattr(payload, "sources", []) or [])]
|
|
parts = [
|
|
self.cfg.assistant_feedback_remote_script,
|
|
feedback_id,
|
|
created_at_utc,
|
|
str(getattr(payload, "outcome")),
|
|
str(getattr(payload, "task_type")),
|
|
str(getattr(payload, "release_name") or ""),
|
|
f"{conf}",
|
|
"true" if bool(getattr(payload, "needs_review", False)) else "false",
|
|
_b64(str(getattr(payload, "goal") or "")),
|
|
_b64(str(getattr(payload, "draft") or "")),
|
|
_b64(str(getattr(payload, "final_text") or "")),
|
|
_b64(json.dumps(sources, ensure_ascii=False)),
|
|
_b64(str(getattr(payload, "notes") or "")),
|
|
]
|
|
res = await self._run_ssh(parts, "Assistant feedback execution timed out")
|
|
result = self._error_detail(res["code"], res["out"], res["err"])
|
|
if res["code"] != 0:
|
|
raise HTTPException(status_code=502, detail=result)
|
|
return result
|
|
|
|
async def run_remote_query_assistant_feedback(
|
|
self, outcome: Optional[str], task_type: Optional[str], release_name: Optional[str], limit: int
|
|
) -> Dict[str, Any]:
|
|
parts = [
|
|
self.cfg.assistant_feedback_query_remote_script,
|
|
outcome or "",
|
|
task_type or "",
|
|
release_name or "",
|
|
str(limit),
|
|
]
|
|
res = await self._run_ssh(parts, "Assistant feedback query timed out")
|
|
if res["code"] != 0:
|
|
raise HTTPException(status_code=502, detail=self._error_detail(res["code"], res["out"], res["err"]))
|
|
try:
|
|
rows = _extract_json_array_from_text(res["out"])
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
status_code=502,
|
|
detail={
|
|
"message": f"Unable to parse feedback query output: {e}",
|
|
"stdout_tail": _tail(res["out"]),
|
|
"stderr_tail": _tail(res["err"]),
|
|
},
|
|
)
|
|
return {"host": self.cfg.ssh_host, "remote_dir": self.cfg.remote_dir, "rows": rows}
|
|
|
|
async def run_remote_query_assistant_metrics(
|
|
self, task_type: Optional[str], release_name: Optional[str], outcome: Optional[str], group_by: str, limit: int
|
|
) -> Dict[str, Any]:
|
|
parts = [
|
|
self.cfg.assistant_metrics_query_remote_script,
|
|
task_type or "",
|
|
release_name or "",
|
|
outcome or "",
|
|
group_by,
|
|
str(limit),
|
|
]
|
|
res = await self._run_ssh(parts, "Assistant metrics query timed out")
|
|
if res["code"] != 0:
|
|
raise HTTPException(status_code=502, detail=self._error_detail(res["code"], res["out"], res["err"]))
|
|
try:
|
|
rows = _extract_json_array_from_text(res["out"])
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
status_code=502,
|
|
detail={
|
|
"message": f"Unable to parse metrics query output: {e}",
|
|
"stdout_tail": _tail(res["out"]),
|
|
"stderr_tail": _tail(res["err"]),
|
|
},
|
|
)
|
|
return {"host": self.cfg.ssh_host, "remote_dir": self.cfg.remote_dir, "rows": rows}
|
|
|
|
async def run_remote_assistant_action(
|
|
self, action_id: str, payload: Any, step: Any, status: str, output_json: Dict[str, Any], error_text: Optional[str], created_at_utc: str
|
|
) -> Dict[str, Any]:
|
|
parts = [
|
|
self.cfg.assistant_action_remote_script,
|
|
action_id,
|
|
created_at_utc,
|
|
str(getattr(payload, "task_type")),
|
|
str(getattr(payload, "release_name") or ""),
|
|
_b64(str(getattr(payload, "objective") or "")),
|
|
str(getattr(step, "step_id")),
|
|
_b64(str(getattr(step, "title") or "")),
|
|
str(getattr(step, "action_type")),
|
|
"true" if bool(getattr(step, "requires_approval", False)) else "false",
|
|
"true" if bool(getattr(payload, "approved", False)) else "false",
|
|
status,
|
|
_b64(json.dumps(output_json, ensure_ascii=False)),
|
|
_b64(error_text or ""),
|
|
]
|
|
res = await self._run_ssh(parts, "Assistant action logging timed out")
|
|
result = self._error_detail(res["code"], res["out"], res["err"])
|
|
if res["code"] != 0:
|
|
raise HTTPException(status_code=502, detail=result)
|
|
return result
|
|
|
|
async def run_remote_query_assistant_actions(
|
|
self,
|
|
status: Optional[str],
|
|
task_type: Optional[str],
|
|
release_name: Optional[str],
|
|
step_id: Optional[str],
|
|
action_type: Optional[str],
|
|
limit: int,
|
|
) -> Dict[str, Any]:
|
|
parts = [
|
|
self.cfg.assistant_actions_query_remote_script,
|
|
status or "",
|
|
task_type or "",
|
|
release_name or "",
|
|
step_id or "",
|
|
action_type or "",
|
|
str(limit),
|
|
]
|
|
res = await self._run_ssh(parts, "Assistant actions query timed out")
|
|
if res["code"] != 0:
|
|
raise HTTPException(status_code=502, detail=self._error_detail(res["code"], res["out"], res["err"]))
|
|
try:
|
|
rows = _extract_json_array_from_text(res["out"])
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
status_code=502,
|
|
detail={
|
|
"message": f"Unable to parse actions query output: {e}",
|
|
"stdout_tail": _tail(res["out"]),
|
|
"stderr_tail": _tail(res["err"]),
|
|
},
|
|
)
|
|
return {"host": self.cfg.ssh_host, "remote_dir": self.cfg.remote_dir, "rows": rows}
|
|
|
|
async def run_remote_record_assistant_proposals(
|
|
self,
|
|
proposal_set_id: str,
|
|
created_at_utc: str,
|
|
objective: str,
|
|
release_name: Optional[str],
|
|
summary: str,
|
|
signals: Dict[str, Any],
|
|
proposals: List[Dict[str, Any]],
|
|
) -> Dict[str, Any]:
|
|
parts = [
|
|
self.cfg.assistant_proposals_remote_script,
|
|
proposal_set_id,
|
|
created_at_utc,
|
|
_b64(objective or ""),
|
|
release_name or "",
|
|
_b64(summary or ""),
|
|
_b64(json.dumps(signals or {}, ensure_ascii=False)),
|
|
_b64(json.dumps(proposals or [], ensure_ascii=False)),
|
|
]
|
|
res = await self._run_ssh(parts, "Assistant proposals logging timed out")
|
|
result = self._error_detail(res["code"], res["out"], res["err"])
|
|
if res["code"] != 0:
|
|
raise HTTPException(status_code=502, detail=result)
|
|
return result
|
|
|
|
async def run_remote_query_assistant_proposals(
|
|
self,
|
|
release_name: Optional[str],
|
|
proposal_set_id: Optional[str],
|
|
limit: int,
|
|
) -> Dict[str, Any]:
|
|
parts = [
|
|
self.cfg.assistant_proposals_query_remote_script,
|
|
release_name or "",
|
|
proposal_set_id or "",
|
|
str(limit),
|
|
]
|
|
res = await self._run_ssh(parts, "Assistant proposals query timed out")
|
|
if res["code"] != 0:
|
|
raise HTTPException(status_code=502, detail=self._error_detail(res["code"], res["out"], res["err"]))
|
|
try:
|
|
rows = _extract_json_array_from_text(res["out"])
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
status_code=502,
|
|
detail={
|
|
"message": f"Unable to parse proposals query output: {e}",
|
|
"stdout_tail": _tail(res["out"]),
|
|
"stderr_tail": _tail(res["err"]),
|
|
},
|
|
)
|
|
return {"host": self.cfg.ssh_host, "remote_dir": self.cfg.remote_dir, "rows": rows}
|
|
|
|
async def run_remote_record_run(
|
|
self,
|
|
run_id: str,
|
|
run_type: str,
|
|
status: str,
|
|
started_at_utc: str,
|
|
finished_at_utc: str,
|
|
actor: str,
|
|
input_json: Dict[str, Any],
|
|
output_json: Optional[Dict[str, Any]],
|
|
error_text: Optional[str],
|
|
) -> None:
|
|
parts = [
|
|
self.cfg.runs_remote_script,
|
|
run_id,
|
|
run_type,
|
|
status,
|
|
started_at_utc,
|
|
finished_at_utc,
|
|
actor,
|
|
_b64(json.dumps(input_json, ensure_ascii=False)),
|
|
_b64(json.dumps(output_json, ensure_ascii=False) if output_json is not None else ""),
|
|
_b64(error_text or ""),
|
|
]
|
|
command = f"cd {shlex.quote(self.cfg.remote_dir)} && {' '.join(shlex.quote(p) for p in parts)}"
|
|
proc = await asyncio.create_subprocess_exec(
|
|
*self._ssh_args(command),
|
|
stdout=asyncio.subprocess.PIPE,
|
|
stderr=asyncio.subprocess.PIPE,
|
|
)
|
|
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=self.cfg.timeout_sec)
|
|
if proc.returncode != 0:
|
|
raise HTTPException(
|
|
status_code=502,
|
|
detail={
|
|
"message": "Failed to record run in Iceberg",
|
|
"host": self.cfg.ssh_host,
|
|
"exit_code": proc.returncode,
|
|
"stdout_tail": _tail(stdout.decode("utf-8", errors="replace")),
|
|
"stderr_tail": _tail(stderr.decode("utf-8", errors="replace")),
|
|
},
|
|
)
|
|
|
|
async def run_remote_record_event(self, run_id: str, event_type: str, detail_json: Dict[str, Any], created_at_utc: str) -> None:
|
|
parts = [
|
|
self.cfg.run_events_remote_script,
|
|
run_id,
|
|
event_type,
|
|
created_at_utc,
|
|
_b64(json.dumps(detail_json, ensure_ascii=False)),
|
|
]
|
|
command = f"cd {shlex.quote(self.cfg.remote_dir)} && {' '.join(shlex.quote(p) for p in parts)}"
|
|
proc = await asyncio.create_subprocess_exec(
|
|
*self._ssh_args(command),
|
|
stdout=asyncio.subprocess.PIPE,
|
|
stderr=asyncio.subprocess.PIPE,
|
|
)
|
|
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=self.cfg.timeout_sec)
|
|
if proc.returncode != 0:
|
|
raise HTTPException(
|
|
status_code=502,
|
|
detail={
|
|
"message": "Failed to record run event in Iceberg",
|
|
"host": self.cfg.ssh_host,
|
|
"exit_code": proc.returncode,
|
|
"stdout_tail": _tail(stdout.decode("utf-8", errors="replace")),
|
|
"stderr_tail": _tail(stderr.decode("utf-8", errors="replace")),
|
|
},
|
|
)
|
|
|
|
async def record_event_best_effort(self, run_id: str, event_type: str, detail_json: Dict[str, Any], created_at_utc: str) -> None:
|
|
try:
|
|
await self.run_remote_record_event(run_id, event_type, detail_json, created_at_utc)
|
|
except Exception as e:
|
|
print(f"[WARN] run event logging failed: run_id={run_id} event={event_type} error={e}")
|