178 lines
5.5 KiB
Python
178 lines
5.5 KiB
Python
from __future__ import annotations
|
||
|
||
from dataclasses import asdict, dataclass
|
||
import re
|
||
from typing import Any
|
||
|
||
|
||
ACCOUNT_TOKEN_RE = re.compile(r"\b\d{2}(?:\.\d{2})?\b")
|
||
|
||
|
||
@dataclass
|
||
class RouteDecisionFlags:
|
||
needs_exact_object_trace: bool
|
||
needs_causal_chain: bool
|
||
needs_cross_entity_join: bool
|
||
needs_full_period_aggregation: bool
|
||
needs_ranking: bool
|
||
needs_anomaly_summary: bool
|
||
needs_runtime_truth: bool
|
||
freshness_sensitive: bool
|
||
ambiguous_object_scope: bool
|
||
store_sufficiency_confident: bool
|
||
precomputed_aggregate_available: bool
|
||
|
||
def to_dict(self) -> dict[str, Any]:
|
||
return asdict(self)
|
||
|
||
|
||
def _norm(text: str) -> str:
|
||
return text.lower().strip()
|
||
|
||
|
||
def _contains_any(text: str, tokens: list[str]) -> bool:
|
||
return any(token in text for token in tokens)
|
||
|
||
|
||
def _has_account_token(text: str) -> bool:
|
||
return bool(ACCOUNT_TOKEN_RE.search(text))
|
||
|
||
|
||
def _aggregate_available_for_shape(
|
||
*,
|
||
available: set[str],
|
||
needs_ranking: bool,
|
||
needs_anomaly_summary: bool,
|
||
needs_full_period_aggregation: bool,
|
||
text: str,
|
||
) -> bool:
|
||
if needs_ranking:
|
||
ranking_tokens = {
|
||
"risk_account_ranking",
|
||
"risk_counterparty_ranking",
|
||
"risk_ranking",
|
||
}
|
||
return bool(available.intersection(ranking_tokens))
|
||
|
||
if needs_anomaly_summary:
|
||
anomaly_tokens = {
|
||
"company_anomaly_summary",
|
||
}
|
||
return bool(available.intersection(anomaly_tokens))
|
||
|
||
if needs_full_period_aggregation:
|
||
if "baseline" in text:
|
||
return "baseline_period_summary" in available
|
||
return "full_period_aggregation" in available
|
||
|
||
return False
|
||
|
||
|
||
def classify_query_for_route(
|
||
question_text: str,
|
||
parsed_intent: dict[str, Any],
|
||
store_metadata: dict[str, Any],
|
||
) -> RouteDecisionFlags:
|
||
text = _norm(question_text)
|
||
question_class = str(parsed_intent.get("question_class", "")).strip().lower()
|
||
|
||
exact_markers = [
|
||
"документ по номеру",
|
||
"source-of-record",
|
||
"источник",
|
||
"цепочка",
|
||
"почему",
|
||
"subconto3",
|
||
"субконто3",
|
||
]
|
||
causal_markers = [
|
||
"свяжи",
|
||
"цепочка",
|
||
"через",
|
||
"объясни",
|
||
"почему",
|
||
"источник",
|
||
"регистр",
|
||
"первич",
|
||
]
|
||
cross_markers = [
|
||
"свяжи",
|
||
"документ",
|
||
"провод",
|
||
"контрагент",
|
||
"договор",
|
||
"регистр",
|
||
]
|
||
ranking_markers = ["рейтинг", "ranking", "топ", "top"]
|
||
anomaly_markers = ["аномал", "summary", "срез", "risk-slice", "риск-срез"]
|
||
|
||
needs_exact_object_trace = _contains_any(text, exact_markers) and (
|
||
question_class in {"drilldown_explain", "simple_factual", "cross_entity"}
|
||
)
|
||
if question_class == "simple_factual" and "документ по номеру" in text:
|
||
needs_exact_object_trace = True
|
||
|
||
needs_causal_chain = _contains_any(text, causal_markers) and question_class in {
|
||
"drilldown_explain",
|
||
"cross_entity",
|
||
}
|
||
needs_cross_entity_join = (
|
||
question_class == "cross_entity"
|
||
or (_contains_any(text, cross_markers) and " и " in text and "->" not in text)
|
||
)
|
||
|
||
needs_ranking = _contains_any(text, ranking_markers) and question_class in {
|
||
"heavy_analytical",
|
||
"period_trend",
|
||
"anomaly_control",
|
||
}
|
||
needs_anomaly_summary = _contains_any(text, anomaly_markers)
|
||
is_heavy = question_class == "heavy_analytical"
|
||
is_baseline_heavy = is_heavy and "baseline" in text
|
||
needs_full_period_aggregation = is_heavy and not is_baseline_heavy
|
||
|
||
needs_runtime_truth = needs_exact_object_trace or _contains_any(
|
||
text, ["runtime", "source-of-record", "источник регистра"]
|
||
)
|
||
freshness_sensitive = question_class in {
|
||
"period_trend",
|
||
"anomaly_control",
|
||
"heavy_analytical",
|
||
}
|
||
ambiguous_object_scope = question_class == "ambiguous_fuzzy"
|
||
if ambiguous_object_scope and _has_account_token(text):
|
||
# Ambiguous account prompts should avoid hard downcast into canonical-only answers.
|
||
needs_runtime_truth = True
|
||
|
||
available_aggregates = {
|
||
str(item).strip().lower() for item in store_metadata.get("precomputed_aggregates", [])
|
||
}
|
||
precomputed_aggregate_available = _aggregate_available_for_shape(
|
||
available=available_aggregates,
|
||
needs_ranking=needs_ranking,
|
||
needs_anomaly_summary=needs_anomaly_summary,
|
||
needs_full_period_aggregation=needs_full_period_aggregation,
|
||
text=text,
|
||
)
|
||
|
||
store_sufficiency_confident = (
|
||
question_class == "simple_factual"
|
||
and not needs_runtime_truth
|
||
and not needs_causal_chain
|
||
and not needs_cross_entity_join
|
||
)
|
||
|
||
return RouteDecisionFlags(
|
||
needs_exact_object_trace=needs_exact_object_trace,
|
||
needs_causal_chain=needs_causal_chain,
|
||
needs_cross_entity_join=needs_cross_entity_join,
|
||
needs_full_period_aggregation=needs_full_period_aggregation,
|
||
needs_ranking=needs_ranking,
|
||
needs_anomaly_summary=needs_anomaly_summary,
|
||
needs_runtime_truth=needs_runtime_truth,
|
||
freshness_sensitive=freshness_sensitive,
|
||
ambiguous_object_scope=ambiguous_object_scope,
|
||
store_sufficiency_confident=store_sufficiency_confident,
|
||
precomputed_aggregate_available=precomputed_aggregate_available,
|
||
)
|