import argparse import base64 import json from datetime import datetime, timezone from typing import Any, Dict, List from pyspark.sql import SparkSession, types as T def now_iso() -> str: return datetime.now(timezone.utc).isoformat() def decode_payload(payload_b64: str) -> List[Dict[str, Any]]: raw = base64.b64decode(payload_b64.encode("ascii")).decode("utf-8") data = json.loads(raw) if not isinstance(data, list): raise ValueError("Payload must decode to a JSON array") out: List[Dict[str, Any]] = [] for i, row in enumerate(data): if not isinstance(row, dict): raise ValueError(f"Row {i} must be a JSON object") out.append(row) return out def normalize_rows(rows: List[Dict[str, Any]]) -> List[tuple]: norm: List[tuple] = [] for i, r in enumerate(rows): thread_id = str(r.get("thread_id") or "").strip() message_id = str(r.get("message_id") or "").strip() sender = str(r.get("sender") or "").strip() channel = str(r.get("channel") or "").strip() body = str(r.get("body") or "").strip() if not thread_id or not message_id or not sender or not channel or not body: raise ValueError( f"Row {i} missing required fields. " "Required: thread_id, message_id, sender, channel, body" ) sent_at_raw = r.get("sent_at") sent_at = str(sent_at_raw).strip() if sent_at_raw is not None else "" metadata = r.get("metadata", {}) if not isinstance(metadata, dict): metadata = {} metadata_json = json.dumps(metadata, ensure_ascii=False, sort_keys=True) norm.append((thread_id, message_id, sender, channel, sent_at, body, metadata_json)) return norm def main() -> None: p = argparse.ArgumentParser(description="Batch ingest messages into Iceberg table") p.add_argument("--table", required=True) p.add_argument( "--dedupe-mode", choices=["none", "message_id", "thread_message"], default="none", help="Optional dedupe strategy against existing target rows", ) p.add_argument("--payload-b64") p.add_argument("--payload-file") args = p.parse_args() if not args.payload_b64 and not args.payload_file: raise ValueError("Provide either --payload-b64 or --payload-file") if args.payload_b64 and args.payload_file: raise ValueError("Provide only one of --payload-b64 or --payload-file") if args.payload_file: with open(args.payload_file, "r", encoding="utf-8") as f: file_data = json.load(f) if not isinstance(file_data, list): raise ValueError("--payload-file must contain a JSON array") rows = normalize_rows(file_data) else: rows = normalize_rows(decode_payload(args.payload_b64 or "")) if not rows: print("[INFO] No rows supplied; nothing to ingest.") return spark = SparkSession.builder.appName("ingest-messages-batch").getOrCreate() schema = T.StructType( [ T.StructField("thread_id", T.StringType(), False), T.StructField("message_id", T.StringType(), False), T.StructField("sender", T.StringType(), False), T.StructField("channel", T.StringType(), False), T.StructField("sent_at_raw", T.StringType(), True), T.StructField("body", T.StringType(), False), T.StructField("metadata_json", T.StringType(), False), ] ) df = spark.createDataFrame(rows, schema=schema) df.createOrReplaceTempView("_batch_messages") base_select = """ SELECT b.thread_id, b.message_id, b.sender, b.channel, CASE WHEN b.sent_at_raw IS NULL OR TRIM(b.sent_at_raw) = '' THEN current_timestamp() ELSE CAST(b.sent_at_raw AS TIMESTAMP) END AS sent_at, b.body, b.metadata_json FROM _batch_messages b """ if args.dedupe_mode == "none": insert_select = base_select elif args.dedupe_mode == "message_id": insert_select = ( base_select + f" LEFT ANTI JOIN {args.table} t ON b.message_id = t.message_id" ) else: insert_select = ( base_select + f" LEFT ANTI JOIN {args.table} t ON b.thread_id = t.thread_id AND b.message_id = t.message_id" ) spark.sql( f""" INSERT INTO {args.table} (thread_id, message_id, sender, channel, sent_at, body, metadata_json) {insert_select} """ ) print(f"[INFO] rows_in={len(rows)}") print(f"[INFO] dedupe_mode={args.dedupe_mode}") print(f"[INFO] table={args.table}") print(f"[INFO] ingested_at_utc={now_iso()}") print(f"[DONE] Batch ingest finished for {args.table}") if __name__ == "__main__": main()