140 lines
4.8 KiB
Python
140 lines
4.8 KiB
Python
import argparse
|
|
import base64
|
|
import json
|
|
from datetime import datetime, timezone
|
|
from typing import Any, Dict, List
|
|
|
|
from pyspark.sql import SparkSession, types as T
|
|
|
|
|
|
def now_iso() -> str:
|
|
return datetime.now(timezone.utc).isoformat()
|
|
|
|
|
|
def decode_payload(payload_b64: str) -> List[Dict[str, Any]]:
|
|
raw = base64.b64decode(payload_b64.encode("ascii")).decode("utf-8")
|
|
data = json.loads(raw)
|
|
if not isinstance(data, list):
|
|
raise ValueError("Payload must decode to a JSON array")
|
|
out: List[Dict[str, Any]] = []
|
|
for i, row in enumerate(data):
|
|
if not isinstance(row, dict):
|
|
raise ValueError(f"Row {i} must be a JSON object")
|
|
out.append(row)
|
|
return out
|
|
|
|
|
|
def normalize_rows(rows: List[Dict[str, Any]]) -> List[tuple]:
|
|
norm: List[tuple] = []
|
|
for i, r in enumerate(rows):
|
|
thread_id = str(r.get("thread_id") or "").strip()
|
|
message_id = str(r.get("message_id") or "").strip()
|
|
sender = str(r.get("sender") or "").strip()
|
|
channel = str(r.get("channel") or "").strip()
|
|
body = str(r.get("body") or "").strip()
|
|
if not thread_id or not message_id or not sender or not channel or not body:
|
|
raise ValueError(
|
|
f"Row {i} missing required fields. "
|
|
"Required: thread_id, message_id, sender, channel, body"
|
|
)
|
|
|
|
sent_at_raw = r.get("sent_at")
|
|
sent_at = str(sent_at_raw).strip() if sent_at_raw is not None else ""
|
|
metadata = r.get("metadata", {})
|
|
if not isinstance(metadata, dict):
|
|
metadata = {}
|
|
metadata_json = json.dumps(metadata, ensure_ascii=False, sort_keys=True)
|
|
norm.append((thread_id, message_id, sender, channel, sent_at, body, metadata_json))
|
|
return norm
|
|
|
|
|
|
def main() -> None:
|
|
p = argparse.ArgumentParser(description="Batch ingest messages into Iceberg table")
|
|
p.add_argument("--table", required=True)
|
|
p.add_argument(
|
|
"--dedupe-mode",
|
|
choices=["none", "message_id", "thread_message"],
|
|
default="none",
|
|
help="Optional dedupe strategy against existing target rows",
|
|
)
|
|
p.add_argument("--payload-b64")
|
|
p.add_argument("--payload-file")
|
|
args = p.parse_args()
|
|
|
|
if not args.payload_b64 and not args.payload_file:
|
|
raise ValueError("Provide either --payload-b64 or --payload-file")
|
|
if args.payload_b64 and args.payload_file:
|
|
raise ValueError("Provide only one of --payload-b64 or --payload-file")
|
|
|
|
if args.payload_file:
|
|
with open(args.payload_file, "r", encoding="utf-8") as f:
|
|
file_data = json.load(f)
|
|
if not isinstance(file_data, list):
|
|
raise ValueError("--payload-file must contain a JSON array")
|
|
rows = normalize_rows(file_data)
|
|
else:
|
|
rows = normalize_rows(decode_payload(args.payload_b64 or ""))
|
|
if not rows:
|
|
print("[INFO] No rows supplied; nothing to ingest.")
|
|
return
|
|
|
|
spark = SparkSession.builder.appName("ingest-messages-batch").getOrCreate()
|
|
|
|
schema = T.StructType(
|
|
[
|
|
T.StructField("thread_id", T.StringType(), False),
|
|
T.StructField("message_id", T.StringType(), False),
|
|
T.StructField("sender", T.StringType(), False),
|
|
T.StructField("channel", T.StringType(), False),
|
|
T.StructField("sent_at_raw", T.StringType(), True),
|
|
T.StructField("body", T.StringType(), False),
|
|
T.StructField("metadata_json", T.StringType(), False),
|
|
]
|
|
)
|
|
df = spark.createDataFrame(rows, schema=schema)
|
|
df.createOrReplaceTempView("_batch_messages")
|
|
|
|
base_select = """
|
|
SELECT
|
|
b.thread_id,
|
|
b.message_id,
|
|
b.sender,
|
|
b.channel,
|
|
CASE
|
|
WHEN b.sent_at_raw IS NULL OR TRIM(b.sent_at_raw) = '' THEN current_timestamp()
|
|
ELSE CAST(b.sent_at_raw AS TIMESTAMP)
|
|
END AS sent_at,
|
|
b.body,
|
|
b.metadata_json
|
|
FROM _batch_messages b
|
|
"""
|
|
if args.dedupe_mode == "none":
|
|
insert_select = base_select
|
|
elif args.dedupe_mode == "message_id":
|
|
insert_select = (
|
|
base_select
|
|
+ f" LEFT ANTI JOIN {args.table} t ON b.message_id = t.message_id"
|
|
)
|
|
else:
|
|
insert_select = (
|
|
base_select
|
|
+ f" LEFT ANTI JOIN {args.table} t ON b.thread_id = t.thread_id AND b.message_id = t.message_id"
|
|
)
|
|
|
|
spark.sql(
|
|
f"""
|
|
INSERT INTO {args.table} (thread_id, message_id, sender, channel, sent_at, body, metadata_json)
|
|
{insert_select}
|
|
"""
|
|
)
|
|
|
|
print(f"[INFO] rows_in={len(rows)}")
|
|
print(f"[INFO] dedupe_mode={args.dedupe_mode}")
|
|
print(f"[INFO] table={args.table}")
|
|
print(f"[INFO] ingested_at_utc={now_iso()}")
|
|
print(f"[DONE] Batch ingest finished for {args.table}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|