229 lines
7.7 KiB
Python
229 lines
7.7 KiB
Python
#!/usr/bin/env python3
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
from dataclasses import dataclass
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
|
|
@dataclass
|
|
class MetricBundle:
|
|
summary_path: str
|
|
run_id: str
|
|
questions_total: int | None
|
|
strict_pass_rate: float
|
|
route_pass_rate: float
|
|
execution_error_count: int
|
|
false_factual_rate: float
|
|
notes: list[str]
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(
|
|
description="Compare ADDRESS run_summary against baseline and fail on gate regressions."
|
|
)
|
|
parser.add_argument("--baseline-summary", required=True, help="Path to baseline run_summary.json")
|
|
parser.add_argument("--candidate-summary", required=True, help="Path to candidate run_summary.json")
|
|
parser.add_argument(
|
|
"--report-json",
|
|
default="",
|
|
help="Optional path to write comparator report JSON.",
|
|
)
|
|
parser.add_argument(
|
|
"--epsilon",
|
|
type=float,
|
|
default=1e-9,
|
|
help="Tolerance for float comparison.",
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
def load_summary(path: Path) -> dict[str, Any]:
|
|
payload = json.loads(path.read_text(encoding="utf-8-sig"))
|
|
if not isinstance(payload, dict):
|
|
raise ValueError(f"{path} must contain JSON object")
|
|
return payload
|
|
|
|
|
|
def extract_metric_float(summary: dict[str, Any], key: str, *, default: float | None = None) -> float:
|
|
totals = summary.get("totals")
|
|
if isinstance(totals, dict) and key in totals:
|
|
return float(totals.get(key) or 0.0)
|
|
if key in summary:
|
|
return float(summary.get(key) or 0.0)
|
|
if default is not None:
|
|
return default
|
|
raise ValueError(f"missing required metric: {key}")
|
|
|
|
|
|
def extract_metric_int(summary: dict[str, Any], key: str, *, default: int | None = None) -> int:
|
|
totals = summary.get("totals")
|
|
if isinstance(totals, dict) and key in totals:
|
|
return int(totals.get(key) or 0)
|
|
if key in summary:
|
|
return int(summary.get(key) or 0)
|
|
if default is not None:
|
|
return default
|
|
raise ValueError(f"missing required metric: {key}")
|
|
|
|
|
|
def extract_limited_reason_count(summary: dict[str, Any], reason: str) -> int:
|
|
distributions = summary.get("distributions")
|
|
if not isinstance(distributions, dict):
|
|
return 0
|
|
limited = distributions.get("limited_reason_category")
|
|
if not isinstance(limited, dict):
|
|
return 0
|
|
return int(limited.get(reason) or 0)
|
|
|
|
|
|
def collect_metrics(path: Path) -> MetricBundle:
|
|
summary = load_summary(path)
|
|
notes: list[str] = []
|
|
|
|
run_id = str(summary.get("run_id", "")).strip() or "<unknown>"
|
|
|
|
questions_total: int | None = None
|
|
try:
|
|
questions_total = extract_metric_int(summary, "questions_total")
|
|
except ValueError:
|
|
notes.append("questions_total missing")
|
|
|
|
strict_pass_rate = extract_metric_float(summary, "strict_pass_rate", default=0.0)
|
|
route_pass_rate = extract_metric_float(summary, "route_pass_rate", default=0.0)
|
|
|
|
http_error_count = extract_metric_int(summary, "http_error_count", default=0)
|
|
explicit_execution_error_count = extract_metric_int(summary, "execution_error_count", default=-1)
|
|
limited_execution_error_count = extract_limited_reason_count(summary, "execution_error")
|
|
if explicit_execution_error_count >= 0:
|
|
execution_error_count = explicit_execution_error_count
|
|
else:
|
|
execution_error_count = http_error_count + limited_execution_error_count
|
|
notes.append("execution_error_count derived as http_error_count + limited_reason_category.execution_error")
|
|
|
|
explicit_false_factual_rate = extract_metric_float(summary, "false_factual_rate", default=-1.0)
|
|
if explicit_false_factual_rate >= 0:
|
|
false_factual_rate = explicit_false_factual_rate
|
|
else:
|
|
false_factual_count = extract_metric_int(summary, "false_factual_count", default=0)
|
|
if questions_total and questions_total > 0:
|
|
false_factual_rate = false_factual_count / questions_total
|
|
notes.append("false_factual_rate derived from false_factual_count/questions_total")
|
|
else:
|
|
false_factual_rate = 0.0
|
|
notes.append("false_factual_rate defaulted to 0.0 (no questions_total)")
|
|
|
|
return MetricBundle(
|
|
summary_path=str(path),
|
|
run_id=run_id,
|
|
questions_total=questions_total,
|
|
strict_pass_rate=strict_pass_rate,
|
|
route_pass_rate=route_pass_rate,
|
|
execution_error_count=execution_error_count,
|
|
false_factual_rate=false_factual_rate,
|
|
notes=notes,
|
|
)
|
|
|
|
|
|
def main() -> None:
|
|
args = parse_args()
|
|
baseline_path = Path(args.baseline_summary).resolve()
|
|
candidate_path = Path(args.candidate_summary).resolve()
|
|
|
|
baseline = collect_metrics(baseline_path)
|
|
candidate = collect_metrics(candidate_path)
|
|
epsilon = float(args.epsilon)
|
|
|
|
checks: list[dict[str, Any]] = []
|
|
|
|
def add_check(name: str, passed: bool, baseline_value: Any, candidate_value: Any, rule: str) -> None:
|
|
checks.append(
|
|
{
|
|
"metric": name,
|
|
"passed": passed,
|
|
"baseline": baseline_value,
|
|
"candidate": candidate_value,
|
|
"rule": rule,
|
|
}
|
|
)
|
|
|
|
add_check(
|
|
"strict_pass_rate",
|
|
candidate.strict_pass_rate + epsilon >= baseline.strict_pass_rate,
|
|
baseline.strict_pass_rate,
|
|
candidate.strict_pass_rate,
|
|
"candidate >= baseline",
|
|
)
|
|
add_check(
|
|
"route_pass_rate",
|
|
candidate.route_pass_rate + epsilon >= baseline.route_pass_rate,
|
|
baseline.route_pass_rate,
|
|
candidate.route_pass_rate,
|
|
"candidate >= baseline",
|
|
)
|
|
add_check(
|
|
"execution_error_count",
|
|
candidate.execution_error_count <= baseline.execution_error_count,
|
|
baseline.execution_error_count,
|
|
candidate.execution_error_count,
|
|
"candidate <= baseline",
|
|
)
|
|
add_check(
|
|
"false_factual_rate",
|
|
candidate.false_factual_rate <= baseline.false_factual_rate + epsilon,
|
|
baseline.false_factual_rate,
|
|
candidate.false_factual_rate,
|
|
"candidate <= baseline",
|
|
)
|
|
|
|
if baseline.questions_total is not None and candidate.questions_total is not None:
|
|
add_check(
|
|
"questions_total_match",
|
|
candidate.questions_total == baseline.questions_total,
|
|
baseline.questions_total,
|
|
candidate.questions_total,
|
|
"candidate == baseline",
|
|
)
|
|
|
|
overall_pass = all(bool(item["passed"]) for item in checks)
|
|
|
|
report = {
|
|
"generated_at": datetime.now().isoformat(timespec="seconds"),
|
|
"overall_pass": overall_pass,
|
|
"baseline": baseline.__dict__,
|
|
"candidate": candidate.__dict__,
|
|
"checks": checks,
|
|
}
|
|
|
|
if args.report_json:
|
|
output_path = Path(args.report_json).resolve()
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
output_path.write_text(json.dumps(report, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
|
|
|
|
print(f"Baseline: {baseline.summary_path} ({baseline.run_id})")
|
|
print(f"Candidate: {candidate.summary_path} ({candidate.run_id})")
|
|
for item in checks:
|
|
status = "PASS" if item["passed"] else "FAIL"
|
|
print(
|
|
f"[{status}] {item['metric']}: baseline={item['baseline']} candidate={item['candidate']} rule={item['rule']}"
|
|
)
|
|
|
|
if baseline.notes:
|
|
print("\nBaseline notes:")
|
|
for note in baseline.notes:
|
|
print(f"- {note}")
|
|
if candidate.notes:
|
|
print("\nCandidate notes:")
|
|
for note in candidate.notes:
|
|
print(f"- {note}")
|
|
|
|
if not overall_pass:
|
|
raise SystemExit(1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|