58 lines
2.3 KiB
Python
58 lines
2.3 KiB
Python
|
|
import argparse
|
||
|
|
import json
|
||
|
|
import os
|
||
|
|
|
||
|
|
from pyspark.sql import SparkSession
|
||
|
|
from pyspark.sql import functions as F
|
||
|
|
|
||
|
|
|
||
|
|
def main() -> None:
|
||
|
|
p = argparse.ArgumentParser(description="Query assistant feedback metrics")
|
||
|
|
p.add_argument("--table", default=os.getenv("FEEDBACK_TABLE", "lake.db1.assistant_feedback"))
|
||
|
|
p.add_argument("--task-type", default="")
|
||
|
|
p.add_argument("--release-name", default="")
|
||
|
|
p.add_argument("--outcome", default="")
|
||
|
|
p.add_argument("--group-by", choices=["task_type", "release_name", "both"], default="both")
|
||
|
|
p.add_argument("--limit", type=int, default=100)
|
||
|
|
args = p.parse_args()
|
||
|
|
|
||
|
|
spark = SparkSession.builder.appName("query-assistant-metrics").getOrCreate()
|
||
|
|
df = spark.table(args.table)
|
||
|
|
|
||
|
|
if args.task_type:
|
||
|
|
df = df.where(F.col("task_type") == args.task_type)
|
||
|
|
if args.release_name:
|
||
|
|
df = df.where(F.col("release_name") == args.release_name)
|
||
|
|
if args.outcome:
|
||
|
|
df = df.where(F.col("outcome") == args.outcome)
|
||
|
|
|
||
|
|
if args.group_by == "task_type":
|
||
|
|
group_cols = [F.col("task_type")]
|
||
|
|
elif args.group_by == "release_name":
|
||
|
|
group_cols = [F.col("release_name")]
|
||
|
|
else:
|
||
|
|
group_cols = [F.col("task_type"), F.col("release_name")]
|
||
|
|
|
||
|
|
agg = (
|
||
|
|
df.groupBy(*group_cols)
|
||
|
|
.agg(
|
||
|
|
F.count(F.lit(1)).alias("total"),
|
||
|
|
F.sum(F.when(F.col("outcome") == "accepted", F.lit(1)).otherwise(F.lit(0))).alias("accepted"),
|
||
|
|
F.sum(F.when(F.col("outcome") == "edited", F.lit(1)).otherwise(F.lit(0))).alias("edited"),
|
||
|
|
F.sum(F.when(F.col("outcome") == "rejected", F.lit(1)).otherwise(F.lit(0))).alias("rejected"),
|
||
|
|
F.avg(F.col("confidence")).alias("avg_confidence"),
|
||
|
|
)
|
||
|
|
.withColumn("accept_rate", F.when(F.col("total") > 0, F.col("accepted") / F.col("total")).otherwise(F.lit(0.0)))
|
||
|
|
.withColumn("edit_rate", F.when(F.col("total") > 0, F.col("edited") / F.col("total")).otherwise(F.lit(0.0)))
|
||
|
|
.withColumn("reject_rate", F.when(F.col("total") > 0, F.col("rejected") / F.col("total")).otherwise(F.lit(0.0)))
|
||
|
|
.orderBy(F.col("total").desc(), *[c.asc() for c in group_cols])
|
||
|
|
.limit(max(1, min(args.limit, 1000)))
|
||
|
|
)
|
||
|
|
|
||
|
|
rows = [r.asDict(recursive=True) for r in agg.collect()]
|
||
|
|
print(json.dumps(rows, ensure_ascii=False))
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|