import argparse import json import os from pyspark.sql import SparkSession from pyspark.sql import functions as F def main() -> None: p = argparse.ArgumentParser(description="Query assistant feedback metrics") p.add_argument("--table", default=os.getenv("FEEDBACK_TABLE", "lake.db1.assistant_feedback")) p.add_argument("--task-type", default="") p.add_argument("--release-name", default="") p.add_argument("--outcome", default="") p.add_argument("--group-by", choices=["task_type", "release_name", "both"], default="both") p.add_argument("--limit", type=int, default=100) args = p.parse_args() spark = SparkSession.builder.appName("query-assistant-metrics").getOrCreate() df = spark.table(args.table) if args.task_type: df = df.where(F.col("task_type") == args.task_type) if args.release_name: df = df.where(F.col("release_name") == args.release_name) if args.outcome: df = df.where(F.col("outcome") == args.outcome) if args.group_by == "task_type": group_cols = [F.col("task_type")] elif args.group_by == "release_name": group_cols = [F.col("release_name")] else: group_cols = [F.col("task_type"), F.col("release_name")] agg = ( df.groupBy(*group_cols) .agg( F.count(F.lit(1)).alias("total"), F.sum(F.when(F.col("outcome") == "accepted", F.lit(1)).otherwise(F.lit(0))).alias("accepted"), F.sum(F.when(F.col("outcome") == "edited", F.lit(1)).otherwise(F.lit(0))).alias("edited"), F.sum(F.when(F.col("outcome") == "rejected", F.lit(1)).otherwise(F.lit(0))).alias("rejected"), F.avg(F.col("confidence")).alias("avg_confidence"), ) .withColumn("accept_rate", F.when(F.col("total") > 0, F.col("accepted") / F.col("total")).otherwise(F.lit(0.0))) .withColumn("edit_rate", F.when(F.col("total") > 0, F.col("edited") / F.col("total")).otherwise(F.lit(0.0))) .withColumn("reject_rate", F.when(F.col("total") > 0, F.col("rejected") / F.col("total")).otherwise(F.lit(0.0))) .orderBy(F.col("total").desc(), *[c.asc() for c in group_cols]) .limit(max(1, min(args.limit, 1000))) ) rows = [r.asDict(recursive=True) for r in agg.collect()] print(json.dumps(rows, ensure_ascii=False)) if __name__ == "__main__": main()