NODEDC_1C/router/query_classifier.py

178 lines
5.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from __future__ import annotations
from dataclasses import asdict, dataclass
import re
from typing import Any
ACCOUNT_TOKEN_RE = re.compile(r"\b\d{2}(?:\.\d{2})?\b")
@dataclass
class RouteDecisionFlags:
needs_exact_object_trace: bool
needs_causal_chain: bool
needs_cross_entity_join: bool
needs_full_period_aggregation: bool
needs_ranking: bool
needs_anomaly_summary: bool
needs_runtime_truth: bool
freshness_sensitive: bool
ambiguous_object_scope: bool
store_sufficiency_confident: bool
precomputed_aggregate_available: bool
def to_dict(self) -> dict[str, Any]:
return asdict(self)
def _norm(text: str) -> str:
return text.lower().strip()
def _contains_any(text: str, tokens: list[str]) -> bool:
return any(token in text for token in tokens)
def _has_account_token(text: str) -> bool:
return bool(ACCOUNT_TOKEN_RE.search(text))
def _aggregate_available_for_shape(
*,
available: set[str],
needs_ranking: bool,
needs_anomaly_summary: bool,
needs_full_period_aggregation: bool,
text: str,
) -> bool:
if needs_ranking:
ranking_tokens = {
"risk_account_ranking",
"risk_counterparty_ranking",
"risk_ranking",
}
return bool(available.intersection(ranking_tokens))
if needs_anomaly_summary:
anomaly_tokens = {
"company_anomaly_summary",
}
return bool(available.intersection(anomaly_tokens))
if needs_full_period_aggregation:
if "baseline" in text:
return "baseline_period_summary" in available
return "full_period_aggregation" in available
return False
def classify_query_for_route(
question_text: str,
parsed_intent: dict[str, Any],
store_metadata: dict[str, Any],
) -> RouteDecisionFlags:
text = _norm(question_text)
question_class = str(parsed_intent.get("question_class", "")).strip().lower()
exact_markers = [
"документ по номеру",
"source-of-record",
"источник",
"цепочка",
"почему",
"subconto3",
"субконто3",
]
causal_markers = [
"свяжи",
"цепочка",
"через",
"объясни",
"почему",
"источник",
"регистр",
"первич",
]
cross_markers = [
"свяжи",
"документ",
"провод",
"контрагент",
"договор",
"регистр",
]
ranking_markers = ["рейтинг", "ranking", "топ", "top"]
anomaly_markers = ["аномал", "summary", "срез", "risk-slice", "риск-срез"]
needs_exact_object_trace = _contains_any(text, exact_markers) and (
question_class in {"drilldown_explain", "simple_factual", "cross_entity"}
)
if question_class == "simple_factual" and "документ по номеру" in text:
needs_exact_object_trace = True
needs_causal_chain = _contains_any(text, causal_markers) and question_class in {
"drilldown_explain",
"cross_entity",
}
needs_cross_entity_join = (
question_class == "cross_entity"
or (_contains_any(text, cross_markers) and " и " in text and "->" not in text)
)
needs_ranking = _contains_any(text, ranking_markers) and question_class in {
"heavy_analytical",
"period_trend",
"anomaly_control",
}
needs_anomaly_summary = _contains_any(text, anomaly_markers)
is_heavy = question_class == "heavy_analytical"
is_baseline_heavy = is_heavy and "baseline" in text
needs_full_period_aggregation = is_heavy and not is_baseline_heavy
needs_runtime_truth = needs_exact_object_trace or _contains_any(
text, ["runtime", "source-of-record", "источник регистра"]
)
freshness_sensitive = question_class in {
"period_trend",
"anomaly_control",
"heavy_analytical",
}
ambiguous_object_scope = question_class == "ambiguous_fuzzy"
if ambiguous_object_scope and _has_account_token(text):
# Ambiguous account prompts should avoid hard downcast into canonical-only answers.
needs_runtime_truth = True
available_aggregates = {
str(item).strip().lower() for item in store_metadata.get("precomputed_aggregates", [])
}
precomputed_aggregate_available = _aggregate_available_for_shape(
available=available_aggregates,
needs_ranking=needs_ranking,
needs_anomaly_summary=needs_anomaly_summary,
needs_full_period_aggregation=needs_full_period_aggregation,
text=text,
)
store_sufficiency_confident = (
question_class == "simple_factual"
and not needs_runtime_truth
and not needs_causal_chain
and not needs_cross_entity_join
)
return RouteDecisionFlags(
needs_exact_object_trace=needs_exact_object_trace,
needs_causal_chain=needs_causal_chain,
needs_cross_entity_join=needs_cross_entity_join,
needs_full_period_aggregation=needs_full_period_aggregation,
needs_ranking=needs_ranking,
needs_anomaly_summary=needs_anomaly_summary,
needs_runtime_truth=needs_runtime_truth,
freshness_sensitive=freshness_sensitive,
ambiguous_object_scope=ambiguous_object_scope,
store_sufficiency_confident=store_sufficiency_confident,
precomputed_aggregate_available=precomputed_aggregate_available,
)