414 lines
18 KiB
Python
414 lines
18 KiB
Python
#!/usr/bin/env python3
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import statistics
|
|
import time
|
|
import urllib.error
|
|
import urllib.request
|
|
from collections import Counter
|
|
from dataclasses import dataclass
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
|
|
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
|
|
|
|
|
@dataclass
|
|
class QuestionCase:
|
|
id: str
|
|
text: str
|
|
expected_intent: str | None
|
|
expected_mode: str | None
|
|
expected_reply_type: str | None
|
|
session: str | None
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(
|
|
description="Run slang stress live-batch against /api/assistant/message and produce summary artifacts."
|
|
)
|
|
parser.add_argument(
|
|
"--questions-file",
|
|
required=True,
|
|
help="Path to JSON questions file (list of strings or objects with id/text/expected_intent/session).",
|
|
)
|
|
parser.add_argument(
|
|
"--backend-url",
|
|
default="http://127.0.0.1:8787/api/assistant/message",
|
|
help="Assistant endpoint URL.",
|
|
)
|
|
parser.add_argument("--prompt-version", default="address_query_runtime_v1")
|
|
parser.add_argument("--llm-provider", default="local")
|
|
parser.add_argument("--llm-model", default="qwen2.5-14b-instruct-1m")
|
|
parser.add_argument("--llm-base-url", default="http://127.0.0.1:1234")
|
|
parser.add_argument("--temperature", type=float, default=0.0)
|
|
parser.add_argument("--max-output-tokens", type=int, default=900)
|
|
parser.add_argument("--timeout-sec", type=int, default=120)
|
|
parser.add_argument("--run-id", default="")
|
|
parser.add_argument(
|
|
"--strict-policy",
|
|
default="route",
|
|
choices=["semantic", "route", "factual"],
|
|
help="Pass policy: semantic=intent/mode only, route=semantic+non-blocked route, factual=semantic+factual reply",
|
|
)
|
|
parser.add_argument(
|
|
"--output-root",
|
|
default=str(PROJECT_ROOT / "docs" / "ADDRESS" / "runs"),
|
|
help="Root directory where run folder will be created.",
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
def now_stamp() -> str:
|
|
return datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
|
|
|
|
|
def load_cases(path: Path) -> list[QuestionCase]:
|
|
raw = json.loads(path.read_text(encoding="utf-8-sig"))
|
|
if not isinstance(raw, list):
|
|
raise ValueError("questions-file must contain JSON array")
|
|
|
|
cases: list[QuestionCase] = []
|
|
for idx, item in enumerate(raw, start=1):
|
|
if isinstance(item, str):
|
|
text = item.strip()
|
|
if not text:
|
|
continue
|
|
cases.append(
|
|
QuestionCase(
|
|
id=f"Q{idx:03d}",
|
|
text=text,
|
|
expected_intent=None,
|
|
expected_mode="address_query",
|
|
expected_reply_type=None,
|
|
session=None,
|
|
)
|
|
)
|
|
continue
|
|
|
|
if not isinstance(item, dict):
|
|
raise ValueError(f"questions-file element #{idx} must be string or object")
|
|
|
|
text = str(item.get("text", "")).strip()
|
|
if not text:
|
|
continue
|
|
case_id = str(item.get("id", f"Q{idx:03d}")).strip() or f"Q{idx:03d}"
|
|
expected_intent = item.get("expected_intent")
|
|
expected_mode = item.get("expected_mode", "address_query")
|
|
expected_reply_type = item.get("expected_reply_type")
|
|
session = item.get("session")
|
|
cases.append(
|
|
QuestionCase(
|
|
id=case_id,
|
|
text=text,
|
|
expected_intent=str(expected_intent).strip() if expected_intent else None,
|
|
expected_mode=str(expected_mode).strip() if expected_mode else None,
|
|
expected_reply_type=str(expected_reply_type).strip() if expected_reply_type else None,
|
|
session=str(session).strip() if session else None,
|
|
)
|
|
)
|
|
|
|
if not cases:
|
|
raise ValueError("questions-file has no non-empty cases")
|
|
return cases
|
|
|
|
|
|
def post_json(url: str, payload: dict[str, Any], timeout_sec: int) -> tuple[int, dict[str, Any]]:
|
|
data = json.dumps(payload, ensure_ascii=False).encode("utf-8")
|
|
req = urllib.request.Request(url, data=data, headers={"Content-Type": "application/json"}, method="POST")
|
|
try:
|
|
with urllib.request.urlopen(req, timeout=timeout_sec) as response:
|
|
status = int(response.getcode())
|
|
body = json.loads(response.read().decode("utf-8"))
|
|
return status, body
|
|
except urllib.error.HTTPError as error:
|
|
status = int(error.code)
|
|
raw = error.read().decode("utf-8", errors="replace")
|
|
try:
|
|
return status, json.loads(raw)
|
|
except json.JSONDecodeError:
|
|
return status, {"ok": False, "error": {"code": "HTTP_ERROR", "message": raw}}
|
|
|
|
|
|
def first_line(text: str | None) -> str:
|
|
value = str(text or "").strip()
|
|
if not value:
|
|
return ""
|
|
return value.splitlines()[0].strip()
|
|
|
|
|
|
def classify_route_health(
|
|
*,
|
|
status_code: int,
|
|
ok_flag: bool,
|
|
reply_type: str | None,
|
|
limited_reason_category: str | None,
|
|
mcp_call_status: str | None,
|
|
) -> str:
|
|
if status_code != 200 or not ok_flag:
|
|
return "http_or_backend_error"
|
|
if reply_type == "clarification_required":
|
|
return "blocked_clarification"
|
|
if reply_type == "backend_error":
|
|
return "blocked_backend_error"
|
|
if reply_type != "partial_coverage":
|
|
return "ok_or_factual"
|
|
if limited_reason_category == "missing_anchor":
|
|
return "blocked_missing_anchor"
|
|
if limited_reason_category == "unsupported":
|
|
return "blocked_unsupported"
|
|
if limited_reason_category == "recipe_visibility_gap":
|
|
return "blocked_recipe_visibility_gap"
|
|
if limited_reason_category == "execution_error":
|
|
return "blocked_execution_error"
|
|
if mcp_call_status in {"skipped", "materialized_but_not_anchor_matched", "materialized_but_filtered_out_by_recipe"}:
|
|
return "likely_blocked_route"
|
|
return "partial_non_blocking"
|
|
|
|
|
|
def main() -> None:
|
|
args = parse_args()
|
|
questions_path = Path(args.questions_file).resolve()
|
|
output_root = Path(args.output_root).resolve()
|
|
cases = load_cases(questions_path)
|
|
|
|
run_id = args.run_id.strip() or f"{datetime.now().date().isoformat()}_Address_Slang_Live_Stress_{now_stamp()}"
|
|
run_dir = output_root / run_id
|
|
run_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
session_map: dict[str, str] = {}
|
|
rows: list[dict[str, Any]] = []
|
|
elapsed_values: list[int] = []
|
|
|
|
for index, case in enumerate(cases, start=1):
|
|
if case.session:
|
|
session_id = session_map.get(case.session)
|
|
if not session_id:
|
|
session_id = f"asst-{run_id}-{case.session}"
|
|
session_map[case.session] = session_id
|
|
else:
|
|
session_id = f"asst-{run_id}-{case.id.lower()}"
|
|
|
|
payload = {
|
|
"session_id": session_id,
|
|
"user_message": case.text,
|
|
"mode": "assistant",
|
|
"promptVersion": args.prompt_version,
|
|
"llmProvider": args.llm_provider,
|
|
"model": args.llm_model,
|
|
"baseUrl": args.llm_base_url,
|
|
"temperature": args.temperature,
|
|
"maxOutputTokens": args.max_output_tokens,
|
|
"useMock": False,
|
|
}
|
|
|
|
started = time.perf_counter()
|
|
status_code, body = post_json(args.backend_url, payload, args.timeout_sec)
|
|
elapsed_ms = int((time.perf_counter() - started) * 1000)
|
|
elapsed_values.append(elapsed_ms)
|
|
|
|
ok_flag = bool(body.get("ok")) if isinstance(body, dict) else False
|
|
debug = {}
|
|
if isinstance(body, dict):
|
|
debug = body.get("debug") or body.get("conversation_item", {}).get("debug") or {}
|
|
if not isinstance(debug, dict):
|
|
debug = {}
|
|
|
|
actual_intent = debug.get("detected_intent")
|
|
actual_mode = debug.get("detected_mode")
|
|
reply_type = body.get("reply_type") if isinstance(body, dict) else None
|
|
trace_id = debug.get("trace_id") or body.get("trace_id")
|
|
|
|
intent_match = case.expected_intent is None or actual_intent == case.expected_intent
|
|
mode_match = case.expected_mode is None or actual_mode == case.expected_mode
|
|
reply_match = case.expected_reply_type is None or reply_type == case.expected_reply_type
|
|
semantic_pass = bool(intent_match and mode_match and status_code == 200 and ok_flag)
|
|
route_health = classify_route_health(
|
|
status_code=status_code,
|
|
ok_flag=ok_flag,
|
|
reply_type=str(reply_type) if reply_type is not None else None,
|
|
limited_reason_category=debug.get("limited_reason_category"),
|
|
mcp_call_status=debug.get("mcp_call_status"),
|
|
)
|
|
route_pass = bool(semantic_pass and not route_health.startswith("blocked") and route_health != "likely_blocked_route")
|
|
if args.strict_policy == "semantic":
|
|
policy_pass = semantic_pass
|
|
elif args.strict_policy == "factual":
|
|
policy_pass = semantic_pass if case.expected_reply_type is not None else bool(semantic_pass and reply_type == "factual")
|
|
else:
|
|
policy_pass = route_pass
|
|
strict_pass = bool(policy_pass and reply_match)
|
|
|
|
row = {
|
|
"index": index,
|
|
"id": case.id,
|
|
"question": case.text,
|
|
"session": case.session,
|
|
"session_id": session_id,
|
|
"status_code": status_code,
|
|
"ok": ok_flag,
|
|
"elapsed_ms": elapsed_ms,
|
|
"reply_type": reply_type,
|
|
"trace_id": trace_id,
|
|
"assistant_reply": body.get("assistant_reply") if isinstance(body, dict) else None,
|
|
"assistant_reply_first_line": first_line(body.get("assistant_reply") if isinstance(body, dict) else None),
|
|
"expected_intent": case.expected_intent,
|
|
"actual_intent": actual_intent,
|
|
"intent_match": intent_match,
|
|
"expected_mode": case.expected_mode,
|
|
"actual_mode": actual_mode,
|
|
"mode_match": mode_match,
|
|
"expected_reply_type": case.expected_reply_type,
|
|
"reply_match": reply_match,
|
|
"semantic_pass": semantic_pass,
|
|
"route_pass": route_pass,
|
|
"route_health": route_health,
|
|
"strict_policy": args.strict_policy,
|
|
"strict_pass": strict_pass,
|
|
"selected_recipe": debug.get("selected_recipe"),
|
|
"missing_required_filters": debug.get("missing_required_filters"),
|
|
"match_failure_stage": debug.get("match_failure_stage"),
|
|
"match_failure_reason": debug.get("match_failure_reason"),
|
|
"rows_fetched": debug.get("rows_fetched"),
|
|
"rows_matched": debug.get("rows_matched"),
|
|
"mcp_call_status": debug.get("mcp_call_status"),
|
|
"limited_reason_category": debug.get("limited_reason_category"),
|
|
"llm_decomposition_applied": debug.get("llm_decomposition_applied"),
|
|
"llm_decomposition_reason": debug.get("llm_decomposition_reason"),
|
|
"fallback_rule_hit": debug.get("fallback_rule_hit"),
|
|
"debug_payload": debug,
|
|
"error_code": body.get("error", {}).get("code") if isinstance(body, dict) and isinstance(body.get("error"), dict) else None,
|
|
"error_message": body.get("error", {}).get("message") if isinstance(body, dict) and isinstance(body.get("error"), dict) else None,
|
|
}
|
|
rows.append(row)
|
|
|
|
print(
|
|
f"[{index:03d}/{len(cases):03d}] {case.id} | status={status_code} reply={reply_type} "
|
|
f"intent={actual_intent} mode={actual_mode} semantic={semantic_pass} route={route_pass} strict={strict_pass} health={route_health}"
|
|
)
|
|
|
|
reply_counter = Counter(str(r.get("reply_type")) for r in rows)
|
|
intent_counter = Counter(str(r.get("actual_intent")) for r in rows)
|
|
mode_counter = Counter(str(r.get("actual_mode")) for r in rows)
|
|
mcp_counter = Counter(str(r.get("mcp_call_status")) for r in rows)
|
|
limited_counter = Counter(str(r.get("limited_reason_category")) for r in rows if r.get("limited_reason_category") is not None)
|
|
route_health_counter = Counter(str(r.get("route_health")) for r in rows)
|
|
|
|
semantic_pass_count = sum(1 for r in rows if r.get("semantic_pass"))
|
|
route_pass_count = sum(1 for r in rows if r.get("route_pass"))
|
|
strict_pass_count = sum(1 for r in rows if r.get("strict_pass"))
|
|
factual_count = sum(1 for r in rows if r.get("reply_type") == "factual")
|
|
ok_200_count = sum(1 for r in rows if r.get("status_code") == 200 and r.get("ok"))
|
|
llm_decomposition_applied_count = sum(1 for r in rows if r.get("llm_decomposition_applied") is True)
|
|
avg_elapsed = round(statistics.mean(elapsed_values), 1) if elapsed_values else 0.0
|
|
|
|
summary = {
|
|
"run_id": run_id,
|
|
"generated_at": datetime.now().isoformat(timespec="seconds"),
|
|
"source_questions_file": str(questions_path),
|
|
"backend_url": args.backend_url,
|
|
"llm_provider": args.llm_provider,
|
|
"llm_model": args.llm_model,
|
|
"llm_base_url": args.llm_base_url,
|
|
"strict_policy": args.strict_policy,
|
|
"totals": {
|
|
"questions_total": len(rows),
|
|
"ok_200_count": ok_200_count,
|
|
"semantic_pass_count": semantic_pass_count,
|
|
"semantic_pass_rate": round(semantic_pass_count / len(rows), 4) if rows else 0.0,
|
|
"route_pass_count": route_pass_count,
|
|
"route_pass_rate": round(route_pass_count / len(rows), 4) if rows else 0.0,
|
|
"strict_pass_count": strict_pass_count,
|
|
"strict_pass_rate": round(strict_pass_count / len(rows), 4) if rows else 0.0,
|
|
"factual_count": factual_count,
|
|
"partial_coverage_count": sum(1 for r in rows if r.get("reply_type") == "partial_coverage"),
|
|
"clarification_required_count": sum(1 for r in rows if r.get("reply_type") == "clarification_required"),
|
|
"http_error_count": sum(1 for r in rows if r.get("status_code") != 200),
|
|
"llm_decomposition_applied_count": llm_decomposition_applied_count,
|
|
"avg_elapsed_ms": avg_elapsed,
|
|
},
|
|
"distributions": {
|
|
"reply_type": dict(reply_counter),
|
|
"actual_intent": dict(intent_counter),
|
|
"actual_mode": dict(mode_counter),
|
|
"mcp_call_status": dict(mcp_counter),
|
|
"limited_reason_category": dict(limited_counter),
|
|
"route_health": dict(route_health_counter),
|
|
},
|
|
}
|
|
|
|
failures = [
|
|
r
|
|
for r in rows
|
|
if not r.get("strict_pass")
|
|
or r.get("status_code") != 200
|
|
or r.get("reply_type") in {"clarification_required", "backend_error"}
|
|
]
|
|
|
|
(run_dir / "run_summary.json").write_text(json.dumps(summary, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
|
|
(run_dir / "full_live_results.json").write_text(
|
|
json.dumps({"run_id": run_id, "generated_at": summary["generated_at"], "summary": summary, "rows": rows}, ensure_ascii=False, indent=2)
|
|
+ "\n",
|
|
encoding="utf-8",
|
|
)
|
|
(run_dir / "failures_only.json").write_text(json.dumps(failures, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
|
|
|
|
lines = [
|
|
f"# {run_id}",
|
|
"",
|
|
f"Generated at: {summary['generated_at']}",
|
|
f"Questions file: {questions_path}",
|
|
f"Backend URL: {args.backend_url}",
|
|
f"LLM: {args.llm_provider} / {args.llm_model} @ {args.llm_base_url}",
|
|
f"Strict policy: {args.strict_policy}",
|
|
"",
|
|
"## Totals",
|
|
f"- questions_total: {summary['totals']['questions_total']}",
|
|
f"- ok_200_count: {summary['totals']['ok_200_count']}",
|
|
f"- semantic_pass_count: {summary['totals']['semantic_pass_count']}",
|
|
f"- semantic_pass_rate: {summary['totals']['semantic_pass_rate']}",
|
|
f"- route_pass_count: {summary['totals']['route_pass_count']}",
|
|
f"- route_pass_rate: {summary['totals']['route_pass_rate']}",
|
|
f"- strict_pass_count: {summary['totals']['strict_pass_count']}",
|
|
f"- strict_pass_rate: {summary['totals']['strict_pass_rate']}",
|
|
f"- factual_count: {summary['totals']['factual_count']}",
|
|
f"- partial_coverage_count: {summary['totals']['partial_coverage_count']}",
|
|
f"- clarification_required_count: {summary['totals']['clarification_required_count']}",
|
|
f"- http_error_count: {summary['totals']['http_error_count']}",
|
|
f"- llm_decomposition_applied_count: {summary['totals']['llm_decomposition_applied_count']}",
|
|
f"- avg_elapsed_ms: {summary['totals']['avg_elapsed_ms']}",
|
|
"",
|
|
"## Files",
|
|
"- run_summary.json",
|
|
"- full_live_results.json",
|
|
"- failures_only.json",
|
|
]
|
|
(run_dir / "README.md").write_text("\n".join(lines) + "\n", encoding="utf-8")
|
|
|
|
audit_lines = [
|
|
f"# Response Audit: {run_id}",
|
|
"",
|
|
"| id | strict | route_health | reply_type | intent | limited_reason | question | assistant_first_line |",
|
|
"|---|---|---|---|---|---|---|---|",
|
|
]
|
|
for row in rows:
|
|
audit_lines.append(
|
|
f"| {row.get('id')} | {row.get('strict_pass')} | {row.get('route_health')} | {row.get('reply_type')} | "
|
|
f"{row.get('actual_intent')} | {row.get('limited_reason_category')} | "
|
|
f"{str(row.get('question', '')).replace('|', '/')} | {str(row.get('assistant_reply_first_line', '')).replace('|', '/')} |"
|
|
)
|
|
(run_dir / "response_audit.md").write_text("\n".join(audit_lines) + "\n", encoding="utf-8")
|
|
|
|
print(f"\nRun directory: {run_dir}")
|
|
print(f"Semantic pass: {semantic_pass_count}/{len(rows)}")
|
|
print(f"Route pass: {route_pass_count}/{len(rows)}")
|
|
print(f"Strict pass ({args.strict_policy}): {strict_pass_count}/{len(rows)}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|