jecio/ingest_messages_batch.py
2026-02-14 21:10:26 +01:00

140 lines
4.8 KiB
Python

import argparse
import base64
import json
from datetime import datetime, timezone
from typing import Any, Dict, List
from pyspark.sql import SparkSession, types as T
def now_iso() -> str:
return datetime.now(timezone.utc).isoformat()
def decode_payload(payload_b64: str) -> List[Dict[str, Any]]:
raw = base64.b64decode(payload_b64.encode("ascii")).decode("utf-8")
data = json.loads(raw)
if not isinstance(data, list):
raise ValueError("Payload must decode to a JSON array")
out: List[Dict[str, Any]] = []
for i, row in enumerate(data):
if not isinstance(row, dict):
raise ValueError(f"Row {i} must be a JSON object")
out.append(row)
return out
def normalize_rows(rows: List[Dict[str, Any]]) -> List[tuple]:
norm: List[tuple] = []
for i, r in enumerate(rows):
thread_id = str(r.get("thread_id") or "").strip()
message_id = str(r.get("message_id") or "").strip()
sender = str(r.get("sender") or "").strip()
channel = str(r.get("channel") or "").strip()
body = str(r.get("body") or "").strip()
if not thread_id or not message_id or not sender or not channel or not body:
raise ValueError(
f"Row {i} missing required fields. "
"Required: thread_id, message_id, sender, channel, body"
)
sent_at_raw = r.get("sent_at")
sent_at = str(sent_at_raw).strip() if sent_at_raw is not None else ""
metadata = r.get("metadata", {})
if not isinstance(metadata, dict):
metadata = {}
metadata_json = json.dumps(metadata, ensure_ascii=False, sort_keys=True)
norm.append((thread_id, message_id, sender, channel, sent_at, body, metadata_json))
return norm
def main() -> None:
p = argparse.ArgumentParser(description="Batch ingest messages into Iceberg table")
p.add_argument("--table", required=True)
p.add_argument(
"--dedupe-mode",
choices=["none", "message_id", "thread_message"],
default="none",
help="Optional dedupe strategy against existing target rows",
)
p.add_argument("--payload-b64")
p.add_argument("--payload-file")
args = p.parse_args()
if not args.payload_b64 and not args.payload_file:
raise ValueError("Provide either --payload-b64 or --payload-file")
if args.payload_b64 and args.payload_file:
raise ValueError("Provide only one of --payload-b64 or --payload-file")
if args.payload_file:
with open(args.payload_file, "r", encoding="utf-8") as f:
file_data = json.load(f)
if not isinstance(file_data, list):
raise ValueError("--payload-file must contain a JSON array")
rows = normalize_rows(file_data)
else:
rows = normalize_rows(decode_payload(args.payload_b64 or ""))
if not rows:
print("[INFO] No rows supplied; nothing to ingest.")
return
spark = SparkSession.builder.appName("ingest-messages-batch").getOrCreate()
schema = T.StructType(
[
T.StructField("thread_id", T.StringType(), False),
T.StructField("message_id", T.StringType(), False),
T.StructField("sender", T.StringType(), False),
T.StructField("channel", T.StringType(), False),
T.StructField("sent_at_raw", T.StringType(), True),
T.StructField("body", T.StringType(), False),
T.StructField("metadata_json", T.StringType(), False),
]
)
df = spark.createDataFrame(rows, schema=schema)
df.createOrReplaceTempView("_batch_messages")
base_select = """
SELECT
b.thread_id,
b.message_id,
b.sender,
b.channel,
CASE
WHEN b.sent_at_raw IS NULL OR TRIM(b.sent_at_raw) = '' THEN current_timestamp()
ELSE CAST(b.sent_at_raw AS TIMESTAMP)
END AS sent_at,
b.body,
b.metadata_json
FROM _batch_messages b
"""
if args.dedupe_mode == "none":
insert_select = base_select
elif args.dedupe_mode == "message_id":
insert_select = (
base_select
+ f" LEFT ANTI JOIN {args.table} t ON b.message_id = t.message_id"
)
else:
insert_select = (
base_select
+ f" LEFT ANTI JOIN {args.table} t ON b.thread_id = t.thread_id AND b.message_id = t.message_id"
)
spark.sql(
f"""
INSERT INTO {args.table} (thread_id, message_id, sender, channel, sent_at, body, metadata_json)
{insert_select}
"""
)
print(f"[INFO] rows_in={len(rows)}")
print(f"[INFO] dedupe_mode={args.dedupe_mode}")
print(f"[INFO] table={args.table}")
print(f"[INFO] ingested_at_utc={now_iso()}")
print(f"[DONE] Batch ingest finished for {args.table}")
if __name__ == "__main__":
main()