222 lines
6.7 KiB
Python
222 lines
6.7 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import asdict, dataclass
|
|
import re
|
|
from typing import Any
|
|
|
|
|
|
ACCOUNT_TOKEN_RE = re.compile(r"\b\d{2}(?:\.\d{2})?\b")
|
|
|
|
|
|
@dataclass
|
|
class RouteDecisionFlags:
|
|
needs_exact_object_trace: bool
|
|
needs_causal_chain: bool
|
|
needs_cross_entity_join: bool
|
|
needs_full_period_aggregation: bool
|
|
needs_ranking: bool
|
|
needs_anomaly_summary: bool
|
|
needs_runtime_truth: bool
|
|
freshness_sensitive: bool
|
|
ambiguous_object_scope: bool
|
|
store_sufficiency_confident: bool
|
|
precomputed_aggregate_available: bool
|
|
|
|
def to_dict(self) -> dict[str, Any]:
|
|
return asdict(self)
|
|
|
|
|
|
def _norm(text: str) -> str:
|
|
return text.lower().strip()
|
|
|
|
|
|
def _contains_any(text: str, tokens: list[str]) -> bool:
|
|
return any(token in text for token in tokens)
|
|
|
|
|
|
def _has_account_token(text: str) -> bool:
|
|
return bool(ACCOUNT_TOKEN_RE.search(text))
|
|
|
|
|
|
def _has_iso_date(text: str) -> bool:
|
|
return bool(re.search(r'\b\d{4}-\d{2}-\d{2}\b', text))
|
|
|
|
|
|
def _aggregate_available_for_shape(
|
|
*,
|
|
available: set[str],
|
|
needs_ranking: bool,
|
|
needs_anomaly_summary: bool,
|
|
needs_full_period_aggregation: bool,
|
|
text: str,
|
|
) -> bool:
|
|
if needs_ranking:
|
|
ranking_tokens = {
|
|
"risk_account_ranking",
|
|
"risk_counterparty_ranking",
|
|
"risk_ranking",
|
|
}
|
|
return bool(available.intersection(ranking_tokens))
|
|
|
|
if needs_anomaly_summary:
|
|
anomaly_tokens = {
|
|
"company_anomaly_summary",
|
|
}
|
|
return bool(available.intersection(anomaly_tokens))
|
|
|
|
if needs_full_period_aggregation:
|
|
if "baseline" in text:
|
|
return "baseline_period_summary" in available
|
|
return "full_period_aggregation" in available
|
|
|
|
return False
|
|
|
|
|
|
def classify_query_for_route(
|
|
question_text: str,
|
|
parsed_intent: dict[str, Any],
|
|
store_metadata: dict[str, Any],
|
|
) -> RouteDecisionFlags:
|
|
text = _norm(question_text)
|
|
question_class = str(parsed_intent.get("question_class", "")).strip().lower()
|
|
|
|
exact_markers = [
|
|
"document by number",
|
|
"source-of-record",
|
|
"source",
|
|
"chain",
|
|
"why",
|
|
"subconto3",
|
|
"supplier",
|
|
"поставщик",
|
|
"buyer",
|
|
"покупатель",
|
|
"purchase",
|
|
"закуп",
|
|
"document",
|
|
"документ",
|
|
"date",
|
|
"дата",
|
|
]
|
|
inventory_markers = [
|
|
"склад",
|
|
"остаток",
|
|
"поставщик",
|
|
"закуплен",
|
|
"куплен",
|
|
"продан",
|
|
"inventory",
|
|
"stock",
|
|
"supplier",
|
|
"purchase",
|
|
"sale",
|
|
]
|
|
causal_markers = [
|
|
"свяжи",
|
|
"цепочка",
|
|
"через",
|
|
"объясни",
|
|
"почему",
|
|
"источник",
|
|
"регистр",
|
|
"первич",
|
|
]
|
|
cross_markers = [
|
|
"свяжи",
|
|
"документ",
|
|
"провод",
|
|
"контрагент",
|
|
"договор",
|
|
"регистр",
|
|
]
|
|
ranking_markers = ["рейтинг", "ranking", "топ", "top"]
|
|
anomaly_markers = ["аномал", "summary", "срез", "risk-slice", "риск-срез"]
|
|
|
|
needs_exact_object_trace = _contains_any(text, exact_markers) and (
|
|
question_class in {"drilldown_explain", "simple_factual", "cross_entity"}
|
|
)
|
|
if _contains_any(text, inventory_markers) and question_class in {
|
|
"drilldown_explain",
|
|
"simple_factual",
|
|
"cross_entity",
|
|
}:
|
|
needs_exact_object_trace = True
|
|
if question_class in {"drilldown_explain", "simple_factual", "cross_entity"} and (
|
|
("41" in text and _contains_any(text, ["остаток", "balance", "stock"]))
|
|
or (
|
|
_contains_any(text, ["supplier", "поставщик", "buyer", "покупатель"])
|
|
and _contains_any(text, ["purchase", "закуп", "document", "документ"])
|
|
)
|
|
):
|
|
needs_exact_object_trace = True
|
|
if question_class == "simple_factual" and "document by number" in text:
|
|
needs_exact_object_trace = True
|
|
|
|
if question_class == "unknown" and "41" in text and _has_iso_date(text):
|
|
needs_exact_object_trace = True
|
|
|
|
needs_causal_chain = _contains_any(text, causal_markers) and question_class in {
|
|
"drilldown_explain",
|
|
"cross_entity",
|
|
}
|
|
if _contains_any(text, inventory_markers) and "???????" in text:
|
|
needs_causal_chain = True
|
|
needs_cross_entity_join = (
|
|
question_class == "cross_entity"
|
|
or (_contains_any(text, cross_markers) and " ? " in text and "->" not in text)
|
|
)
|
|
|
|
needs_ranking = _contains_any(text, ranking_markers) and question_class in {
|
|
"heavy_analytical",
|
|
"period_trend",
|
|
"anomaly_control",
|
|
}
|
|
needs_anomaly_summary = _contains_any(text, anomaly_markers)
|
|
is_heavy = question_class == "heavy_analytical"
|
|
is_baseline_heavy = is_heavy and "baseline" in text
|
|
needs_full_period_aggregation = is_heavy and not is_baseline_heavy
|
|
|
|
needs_runtime_truth = needs_exact_object_trace or _contains_any(
|
|
text, ["runtime", "source-of-record", "???????????????? ????????????????"]
|
|
)
|
|
freshness_sensitive = question_class in {
|
|
"period_trend",
|
|
"anomaly_control",
|
|
"heavy_analytical",
|
|
}
|
|
ambiguous_object_scope = question_class == "ambiguous_fuzzy"
|
|
if ambiguous_object_scope and _has_account_token(text):
|
|
needs_runtime_truth = True
|
|
|
|
available_aggregates = {
|
|
str(item).strip().lower() for item in store_metadata.get("precomputed_aggregates", [])
|
|
}
|
|
precomputed_aggregate_available = _aggregate_available_for_shape(
|
|
available=available_aggregates,
|
|
needs_ranking=needs_ranking,
|
|
needs_anomaly_summary=needs_anomaly_summary,
|
|
needs_full_period_aggregation=needs_full_period_aggregation,
|
|
text=text,
|
|
)
|
|
|
|
store_sufficiency_confident = (
|
|
question_class == "simple_factual"
|
|
and not needs_runtime_truth
|
|
and not needs_causal_chain
|
|
and not needs_cross_entity_join
|
|
)
|
|
|
|
return RouteDecisionFlags(
|
|
needs_exact_object_trace=needs_exact_object_trace,
|
|
needs_causal_chain=needs_causal_chain,
|
|
needs_cross_entity_join=needs_cross_entity_join,
|
|
needs_full_period_aggregation=needs_full_period_aggregation,
|
|
needs_ranking=needs_ranking,
|
|
needs_anomaly_summary=needs_anomaly_summary,
|
|
needs_runtime_truth=needs_runtime_truth,
|
|
freshness_sensitive=freshness_sensitive,
|
|
ambiguous_object_scope=ambiguous_object_scope,
|
|
store_sufficiency_confident=store_sufficiency_confident,
|
|
precomputed_aggregate_available=precomputed_aggregate_available,
|
|
)
|