NODEDC_1C/router/query_classifier.py

222 lines
6.7 KiB
Python

from __future__ import annotations
from dataclasses import asdict, dataclass
import re
from typing import Any
ACCOUNT_TOKEN_RE = re.compile(r"\b\d{2}(?:\.\d{2})?\b")
@dataclass
class RouteDecisionFlags:
needs_exact_object_trace: bool
needs_causal_chain: bool
needs_cross_entity_join: bool
needs_full_period_aggregation: bool
needs_ranking: bool
needs_anomaly_summary: bool
needs_runtime_truth: bool
freshness_sensitive: bool
ambiguous_object_scope: bool
store_sufficiency_confident: bool
precomputed_aggregate_available: bool
def to_dict(self) -> dict[str, Any]:
return asdict(self)
def _norm(text: str) -> str:
return text.lower().strip()
def _contains_any(text: str, tokens: list[str]) -> bool:
return any(token in text for token in tokens)
def _has_account_token(text: str) -> bool:
return bool(ACCOUNT_TOKEN_RE.search(text))
def _has_iso_date(text: str) -> bool:
return bool(re.search(r'\b\d{4}-\d{2}-\d{2}\b', text))
def _aggregate_available_for_shape(
*,
available: set[str],
needs_ranking: bool,
needs_anomaly_summary: bool,
needs_full_period_aggregation: bool,
text: str,
) -> bool:
if needs_ranking:
ranking_tokens = {
"risk_account_ranking",
"risk_counterparty_ranking",
"risk_ranking",
}
return bool(available.intersection(ranking_tokens))
if needs_anomaly_summary:
anomaly_tokens = {
"company_anomaly_summary",
}
return bool(available.intersection(anomaly_tokens))
if needs_full_period_aggregation:
if "baseline" in text:
return "baseline_period_summary" in available
return "full_period_aggregation" in available
return False
def classify_query_for_route(
question_text: str,
parsed_intent: dict[str, Any],
store_metadata: dict[str, Any],
) -> RouteDecisionFlags:
text = _norm(question_text)
question_class = str(parsed_intent.get("question_class", "")).strip().lower()
exact_markers = [
"document by number",
"source-of-record",
"source",
"chain",
"why",
"subconto3",
"supplier",
"поставщик",
"buyer",
"покупатель",
"purchase",
"закуп",
"document",
"документ",
"date",
"дата",
]
inventory_markers = [
"склад",
"остаток",
"поставщик",
"закуплен",
"куплен",
"продан",
"inventory",
"stock",
"supplier",
"purchase",
"sale",
]
causal_markers = [
"свяжи",
"цепочка",
"через",
"объясни",
"почему",
"источник",
"регистр",
"первич",
]
cross_markers = [
"свяжи",
"документ",
"провод",
"контрагент",
"договор",
"регистр",
]
ranking_markers = ["рейтинг", "ranking", "топ", "top"]
anomaly_markers = ["аномал", "summary", "срез", "risk-slice", "риск-срез"]
needs_exact_object_trace = _contains_any(text, exact_markers) and (
question_class in {"drilldown_explain", "simple_factual", "cross_entity"}
)
if _contains_any(text, inventory_markers) and question_class in {
"drilldown_explain",
"simple_factual",
"cross_entity",
}:
needs_exact_object_trace = True
if question_class in {"drilldown_explain", "simple_factual", "cross_entity"} and (
("41" in text and _contains_any(text, ["остаток", "balance", "stock"]))
or (
_contains_any(text, ["supplier", "поставщик", "buyer", "покупатель"])
and _contains_any(text, ["purchase", "закуп", "document", "документ"])
)
):
needs_exact_object_trace = True
if question_class == "simple_factual" and "document by number" in text:
needs_exact_object_trace = True
if question_class == "unknown" and "41" in text and _has_iso_date(text):
needs_exact_object_trace = True
needs_causal_chain = _contains_any(text, causal_markers) and question_class in {
"drilldown_explain",
"cross_entity",
}
if _contains_any(text, inventory_markers) and "???????" in text:
needs_causal_chain = True
needs_cross_entity_join = (
question_class == "cross_entity"
or (_contains_any(text, cross_markers) and " ? " in text and "->" not in text)
)
needs_ranking = _contains_any(text, ranking_markers) and question_class in {
"heavy_analytical",
"period_trend",
"anomaly_control",
}
needs_anomaly_summary = _contains_any(text, anomaly_markers)
is_heavy = question_class == "heavy_analytical"
is_baseline_heavy = is_heavy and "baseline" in text
needs_full_period_aggregation = is_heavy and not is_baseline_heavy
needs_runtime_truth = needs_exact_object_trace or _contains_any(
text, ["runtime", "source-of-record", "???????????????? ????????????????"]
)
freshness_sensitive = question_class in {
"period_trend",
"anomaly_control",
"heavy_analytical",
}
ambiguous_object_scope = question_class == "ambiguous_fuzzy"
if ambiguous_object_scope and _has_account_token(text):
needs_runtime_truth = True
available_aggregates = {
str(item).strip().lower() for item in store_metadata.get("precomputed_aggregates", [])
}
precomputed_aggregate_available = _aggregate_available_for_shape(
available=available_aggregates,
needs_ranking=needs_ranking,
needs_anomaly_summary=needs_anomaly_summary,
needs_full_period_aggregation=needs_full_period_aggregation,
text=text,
)
store_sufficiency_confident = (
question_class == "simple_factual"
and not needs_runtime_truth
and not needs_causal_chain
and not needs_cross_entity_join
)
return RouteDecisionFlags(
needs_exact_object_trace=needs_exact_object_trace,
needs_causal_chain=needs_causal_chain,
needs_cross_entity_join=needs_cross_entity_join,
needs_full_period_aggregation=needs_full_period_aggregation,
needs_ranking=needs_ranking,
needs_anomaly_summary=needs_anomaly_summary,
needs_runtime_truth=needs_runtime_truth,
freshness_sensitive=freshness_sensitive,
ambiguous_object_scope=ambiguous_object_scope,
store_sufficiency_confident=store_sufficiency_confident,
precomputed_aggregate_available=precomputed_aggregate_available,
)