jecio/ingest_messages_batch.py

140 lines
4.8 KiB
Python
Raw Normal View History

import argparse
import base64
import json
from datetime import datetime, timezone
from typing import Any, Dict, List
from pyspark.sql import SparkSession, types as T
def now_iso() -> str:
return datetime.now(timezone.utc).isoformat()
def decode_payload(payload_b64: str) -> List[Dict[str, Any]]:
raw = base64.b64decode(payload_b64.encode("ascii")).decode("utf-8")
data = json.loads(raw)
if not isinstance(data, list):
raise ValueError("Payload must decode to a JSON array")
out: List[Dict[str, Any]] = []
for i, row in enumerate(data):
if not isinstance(row, dict):
raise ValueError(f"Row {i} must be a JSON object")
out.append(row)
return out
def normalize_rows(rows: List[Dict[str, Any]]) -> List[tuple]:
norm: List[tuple] = []
for i, r in enumerate(rows):
thread_id = str(r.get("thread_id") or "").strip()
message_id = str(r.get("message_id") or "").strip()
sender = str(r.get("sender") or "").strip()
channel = str(r.get("channel") or "").strip()
body = str(r.get("body") or "").strip()
if not thread_id or not message_id or not sender or not channel or not body:
raise ValueError(
f"Row {i} missing required fields. "
"Required: thread_id, message_id, sender, channel, body"
)
sent_at_raw = r.get("sent_at")
sent_at = str(sent_at_raw).strip() if sent_at_raw is not None else ""
metadata = r.get("metadata", {})
if not isinstance(metadata, dict):
metadata = {}
metadata_json = json.dumps(metadata, ensure_ascii=False, sort_keys=True)
norm.append((thread_id, message_id, sender, channel, sent_at, body, metadata_json))
return norm
def main() -> None:
p = argparse.ArgumentParser(description="Batch ingest messages into Iceberg table")
p.add_argument("--table", required=True)
p.add_argument(
"--dedupe-mode",
choices=["none", "message_id", "thread_message"],
default="none",
help="Optional dedupe strategy against existing target rows",
)
p.add_argument("--payload-b64")
p.add_argument("--payload-file")
args = p.parse_args()
if not args.payload_b64 and not args.payload_file:
raise ValueError("Provide either --payload-b64 or --payload-file")
if args.payload_b64 and args.payload_file:
raise ValueError("Provide only one of --payload-b64 or --payload-file")
if args.payload_file:
with open(args.payload_file, "r", encoding="utf-8") as f:
file_data = json.load(f)
if not isinstance(file_data, list):
raise ValueError("--payload-file must contain a JSON array")
rows = normalize_rows(file_data)
else:
rows = normalize_rows(decode_payload(args.payload_b64 or ""))
if not rows:
print("[INFO] No rows supplied; nothing to ingest.")
return
spark = SparkSession.builder.appName("ingest-messages-batch").getOrCreate()
schema = T.StructType(
[
T.StructField("thread_id", T.StringType(), False),
T.StructField("message_id", T.StringType(), False),
T.StructField("sender", T.StringType(), False),
T.StructField("channel", T.StringType(), False),
T.StructField("sent_at_raw", T.StringType(), True),
T.StructField("body", T.StringType(), False),
T.StructField("metadata_json", T.StringType(), False),
]
)
df = spark.createDataFrame(rows, schema=schema)
df.createOrReplaceTempView("_batch_messages")
base_select = """
SELECT
b.thread_id,
b.message_id,
b.sender,
b.channel,
CASE
WHEN b.sent_at_raw IS NULL OR TRIM(b.sent_at_raw) = '' THEN current_timestamp()
ELSE CAST(b.sent_at_raw AS TIMESTAMP)
END AS sent_at,
b.body,
b.metadata_json
FROM _batch_messages b
"""
if args.dedupe_mode == "none":
insert_select = base_select
elif args.dedupe_mode == "message_id":
insert_select = (
base_select
+ f" LEFT ANTI JOIN {args.table} t ON b.message_id = t.message_id"
)
else:
insert_select = (
base_select
+ f" LEFT ANTI JOIN {args.table} t ON b.thread_id = t.thread_id AND b.message_id = t.message_id"
)
spark.sql(
f"""
INSERT INTO {args.table} (thread_id, message_id, sender, channel, sent_at, body, metadata_json)
{insert_select}
"""
)
print(f"[INFO] rows_in={len(rows)}")
print(f"[INFO] dedupe_mode={args.dedupe_mode}")
print(f"[INFO] table={args.table}")
print(f"[INFO] ingested_at_utc={now_iso()}")
print(f"[DONE] Batch ingest finished for {args.table}")
if __name__ == "__main__":
main()