NODEDC_1C/canonical_layer/store.py

745 lines
28 KiB
Python

from __future__ import annotations
from datetime import datetime, timezone
import json
from pathlib import Path
from typing import Any
from uuid import uuid4
from sqlalchemy import create_engine, delete, func, select
from sqlalchemy.orm import Session, sessionmaker
from canonical_layer.models import CanonicalEntity
from canonical_layer.store_models import (
AnomalySignalRow,
Base,
CanonicalEntityRow,
CanonicalLinkRow,
FeatureMetricRow,
FeatureRunRow,
RefreshCheckpointRow,
RefreshRunRow,
RiskPatternRow,
RiskRunRow,
)
def _utc_now() -> datetime:
return datetime.now(timezone.utc)
def _dump_json(payload: Any) -> str:
return json.dumps(payload, ensure_ascii=False)
def _load_json(payload: str, default: Any) -> Any:
if not payload:
return default
try:
return json.loads(payload)
except json.JSONDecodeError:
return default
def _dt_to_iso(value: datetime | None) -> str | None:
if value is None:
return None
if value.tzinfo is None:
value = value.replace(tzinfo=timezone.utc)
return value.astimezone(timezone.utc).isoformat()
class CanonicalStore:
def __init__(self, db_url: str) -> None:
self.db_url = db_url
self.engine = create_engine(db_url, future=True)
self.session_factory = sessionmaker(bind=self.engine, autoflush=False, expire_on_commit=False, future=True)
def _ensure_sqlite_path(self) -> None:
if not self.db_url.startswith("sqlite:///"):
return
db_path_raw = self.db_url.replace("sqlite:///", "", 1)
db_path = Path(db_path_raw)
db_path.parent.mkdir(parents=True, exist_ok=True)
def ensure_created(self) -> None:
self._ensure_sqlite_path()
Base.metadata.create_all(self.engine)
def _session(self) -> Session:
return self.session_factory()
def start_refresh_run(
self,
*,
mode: str,
requested_entity_sets: list[str],
date_from: str | None,
date_to: str | None,
limit_per_set: int,
) -> str:
run_id = uuid4().hex
with self._session() as session, session.begin():
session.add(
RefreshRunRow(
id=run_id,
mode=mode,
status="running",
started_at=_utc_now(),
requested_entity_sets_json=_dump_json(requested_entity_sets),
date_from=date_from,
date_to=date_to,
limit_per_set=limit_per_set,
)
)
return run_id
def finish_refresh_run(
self,
*,
run_id: str,
status: str,
records_read: int,
entities_written: int,
links_written: int,
checkpoints_updated: int,
details: dict[str, Any] | None = None,
error_message: str | None = None,
) -> None:
with self._session() as session, session.begin():
run = session.get(RefreshRunRow, run_id)
if run is None:
return
run.status = status
run.records_read = records_read
run.entities_written = entities_written
run.links_written = links_written
run.checkpoints_updated = checkpoints_updated
run.details_json = _dump_json(details or {})
run.error_message = error_message
run.finished_at = _utc_now()
def upsert_entities(self, *, run_id: str, entities: list[CanonicalEntity]) -> tuple[int, int]:
entities_written = 0
links_written = 0
now = _utc_now()
with self._session() as session, session.begin():
for entity in entities:
row = session.execute(
select(CanonicalEntityRow).where(
CanonicalEntityRow.source_entity == entity.source_entity,
CanonicalEntityRow.source_id == entity.source_id,
)
).scalar_one_or_none()
if row is None:
row = CanonicalEntityRow(
source_entity=entity.source_entity,
source_id=entity.source_id,
display_name=entity.display_name,
attributes_json=_dump_json(entity.attributes),
first_seen_at=now,
updated_at=now,
last_refresh_run_id=run_id,
)
session.add(row)
else:
row.display_name = entity.display_name
row.attributes_json = _dump_json(entity.attributes)
row.updated_at = now
row.last_refresh_run_id = run_id
entities_written += 1
session.execute(
delete(CanonicalLinkRow).where(
CanonicalLinkRow.source_entity == entity.source_entity,
CanonicalLinkRow.source_id == entity.source_id,
)
)
for link in entity.links:
if not link.target_id:
continue
session.add(
CanonicalLinkRow(
source_entity=entity.source_entity,
source_id=entity.source_id,
relation=link.relation,
target_entity=link.target_entity,
target_id=link.target_id,
source_field=link.source_field,
updated_at=now,
last_refresh_run_id=run_id,
)
)
links_written += 1
return entities_written, links_written
def update_checkpoints(
self,
*,
run_id: str,
entity_sets: list[str],
date_from: str | None,
date_to: str | None,
) -> int:
if not entity_sets:
return 0
now = _utc_now()
updated = 0
with self._session() as session, session.begin():
for entity_set in entity_sets:
row = session.get(RefreshCheckpointRow, entity_set)
if row is None:
row = RefreshCheckpointRow(
entity_set=entity_set,
last_success_at=now,
last_refresh_run_id=run_id,
last_date_from=date_from,
last_date_to=date_to,
)
session.add(row)
else:
row.last_success_at = now
row.last_refresh_run_id = run_id
row.last_date_from = date_from
row.last_date_to = date_to
updated += 1
return updated
def list_recent_runs(self, limit: int = 20) -> list[dict[str, Any]]:
safe_limit = max(1, min(limit, 200))
with self._session() as session:
rows = (
session.execute(
select(RefreshRunRow)
.order_by(RefreshRunRow.started_at.desc())
.limit(safe_limit)
)
.scalars()
.all()
)
output: list[dict[str, Any]] = []
for row in rows:
output.append(
{
"run_id": row.id,
"mode": row.mode,
"status": row.status,
"started_at": _dt_to_iso(row.started_at),
"finished_at": _dt_to_iso(row.finished_at),
"requested_entity_sets": _load_json(row.requested_entity_sets_json, []),
"date_from": row.date_from,
"date_to": row.date_to,
"limit_per_set": row.limit_per_set,
"records_read": row.records_read,
"entities_written": row.entities_written,
"links_written": row.links_written,
"checkpoints_updated": row.checkpoints_updated,
"details": _load_json(row.details_json, {}),
"error_message": row.error_message,
}
)
return output
def store_stats(self) -> dict[str, Any]:
with self._session() as session:
entities_total = session.execute(select(func.count(CanonicalEntityRow.id))).scalar_one()
links_total = session.execute(select(func.count(CanonicalLinkRow.id))).scalar_one()
checkpoints_total = session.execute(select(func.count(RefreshCheckpointRow.entity_set))).scalar_one()
latest_run = (
session.execute(select(RefreshRunRow).order_by(RefreshRunRow.started_at.desc()).limit(1))
.scalars()
.first()
)
latest_run_payload: dict[str, Any] | None = None
if latest_run is not None:
latest_run_payload = {
"run_id": latest_run.id,
"mode": latest_run.mode,
"status": latest_run.status,
"started_at": _dt_to_iso(latest_run.started_at),
"finished_at": _dt_to_iso(latest_run.finished_at),
}
return {
"db_url": self.db_url,
"entities_total": int(entities_total),
"links_total": int(links_total),
"checkpoints_total": int(checkpoints_total),
"latest_run": latest_run_payload,
}
def start_feature_run(self, *, baseline_window_hours: int) -> str:
run_id = uuid4().hex
with self._session() as session, session.begin():
session.add(
FeatureRunRow(
id=run_id,
status="running",
started_at=_utc_now(),
baseline_window_hours=baseline_window_hours,
)
)
return run_id
def replace_feature_results(
self,
*,
run_id: str,
metrics: list[dict[str, Any]],
anomalies: list[dict[str, Any]],
) -> tuple[int, int]:
now = _utc_now()
with self._session() as session, session.begin():
session.execute(delete(FeatureMetricRow).where(FeatureMetricRow.feature_run_id == run_id))
session.execute(delete(AnomalySignalRow).where(AnomalySignalRow.feature_run_id == run_id))
# Deactivate previously active anomalies before writing a new active snapshot.
previous_anomalies = session.execute(
select(AnomalySignalRow).where(AnomalySignalRow.is_active == 1)
).scalars().all()
for item in previous_anomalies:
item.is_active = 0
for metric in metrics:
session.add(
FeatureMetricRow(
feature_run_id=run_id,
metric_key=str(metric.get("metric_key", "")),
scope=str(metric.get("scope", "global")),
scope_id=str(metric.get("scope_id", "")),
metric_type=str(metric.get("metric_type", "gauge")),
metric_value=float(metric.get("metric_value", 0.0)),
attributes_json=_dump_json(metric.get("attributes", {})),
computed_at=now,
)
)
for anomaly in anomalies:
session.add(
AnomalySignalRow(
feature_run_id=run_id,
signal_type=str(anomaly.get("signal_type", "unknown_signal")),
severity=str(anomaly.get("severity", "medium")),
scope=str(anomaly.get("scope", "global")),
scope_id=str(anomaly.get("scope_id", "")),
score=float(anomaly.get("score", 0.0)),
details_json=_dump_json(anomaly.get("details", {})),
detected_at=now,
is_active=1,
)
)
return len(metrics), len(anomalies)
def finish_feature_run(
self,
*,
run_id: str,
status: str,
entities_total: int,
metrics_written: int,
anomalies_written: int,
details: dict[str, Any] | None = None,
error_message: str | None = None,
) -> None:
with self._session() as session, session.begin():
row = session.get(FeatureRunRow, run_id)
if row is None:
return
row.status = status
row.entities_total = entities_total
row.metrics_written = metrics_written
row.anomalies_written = anomalies_written
row.details_json = _dump_json(details or {})
row.error_message = error_message
row.finished_at = _utc_now()
def list_recent_feature_runs(self, limit: int = 20) -> list[dict[str, Any]]:
safe_limit = max(1, min(limit, 200))
with self._session() as session:
rows = (
session.execute(
select(FeatureRunRow)
.order_by(FeatureRunRow.started_at.desc())
.limit(safe_limit)
)
.scalars()
.all()
)
output: list[dict[str, Any]] = []
for row in rows:
output.append(
{
"run_id": row.id,
"status": row.status,
"started_at": _dt_to_iso(row.started_at),
"finished_at": _dt_to_iso(row.finished_at),
"baseline_window_hours": row.baseline_window_hours,
"entities_total": row.entities_total,
"metrics_written": row.metrics_written,
"anomalies_written": row.anomalies_written,
"details": _load_json(row.details_json, {}),
"error_message": row.error_message,
}
)
return output
def list_feature_metrics(
self,
*,
limit: int = 200,
metric_key: str | None = None,
scope: str | None = None,
run_id: str | None = None,
) -> list[dict[str, Any]]:
safe_limit = max(1, min(limit, 2000))
with self._session() as session:
stmt = select(FeatureMetricRow).order_by(FeatureMetricRow.computed_at.desc(), FeatureMetricRow.id.desc())
if metric_key:
stmt = stmt.where(FeatureMetricRow.metric_key == metric_key)
if scope:
stmt = stmt.where(FeatureMetricRow.scope == scope)
if run_id:
stmt = stmt.where(FeatureMetricRow.feature_run_id == run_id)
rows = session.execute(stmt.limit(safe_limit)).scalars().all()
output: list[dict[str, Any]] = []
for row in rows:
output.append(
{
"id": row.id,
"feature_run_id": row.feature_run_id,
"metric_key": row.metric_key,
"scope": row.scope,
"scope_id": row.scope_id,
"metric_type": row.metric_type,
"metric_value": row.metric_value,
"attributes": _load_json(row.attributes_json, {}),
"computed_at": _dt_to_iso(row.computed_at),
}
)
return output
def list_anomaly_signals(
self,
*,
limit: int = 200,
severity: str | None = None,
active_only: bool = True,
run_id: str | None = None,
) -> list[dict[str, Any]]:
safe_limit = max(1, min(limit, 2000))
with self._session() as session:
stmt = select(AnomalySignalRow).order_by(AnomalySignalRow.detected_at.desc(), AnomalySignalRow.id.desc())
if active_only:
stmt = stmt.where(AnomalySignalRow.is_active == 1)
if severity:
stmt = stmt.where(AnomalySignalRow.severity == severity)
if run_id:
stmt = stmt.where(AnomalySignalRow.feature_run_id == run_id)
rows = session.execute(stmt.limit(safe_limit)).scalars().all()
output: list[dict[str, Any]] = []
for row in rows:
output.append(
{
"id": row.id,
"feature_run_id": row.feature_run_id,
"signal_type": row.signal_type,
"severity": row.severity,
"scope": row.scope,
"scope_id": row.scope_id,
"score": row.score,
"details": _load_json(row.details_json, {}),
"detected_at": _dt_to_iso(row.detected_at),
"is_active": bool(row.is_active),
}
)
return output
def feature_store_stats(self) -> dict[str, Any]:
with self._session() as session:
metrics_total = session.execute(select(func.count(FeatureMetricRow.id))).scalar_one()
anomalies_total = session.execute(select(func.count(AnomalySignalRow.id))).scalar_one()
active_anomalies_total = session.execute(
select(func.count(AnomalySignalRow.id)).where(AnomalySignalRow.is_active == 1)
).scalar_one()
latest_feature_run = (
session.execute(select(FeatureRunRow).order_by(FeatureRunRow.started_at.desc()).limit(1))
.scalars()
.first()
)
latest_payload: dict[str, Any] | None = None
if latest_feature_run is not None:
latest_payload = {
"run_id": latest_feature_run.id,
"status": latest_feature_run.status,
"started_at": _dt_to_iso(latest_feature_run.started_at),
"finished_at": _dt_to_iso(latest_feature_run.finished_at),
"metrics_written": latest_feature_run.metrics_written,
"anomalies_written": latest_feature_run.anomalies_written,
}
return {
"metrics_total": int(metrics_total),
"anomalies_total": int(anomalies_total),
"active_anomalies_total": int(active_anomalies_total),
"latest_feature_run": latest_payload,
}
def latest_successful_feature_run(self) -> dict[str, Any] | None:
with self._session() as session:
row = (
session.execute(
select(FeatureRunRow)
.where(FeatureRunRow.status == "success")
.order_by(FeatureRunRow.finished_at.desc(), FeatureRunRow.started_at.desc())
.limit(1)
)
.scalars()
.first()
)
if row is None:
return None
return {
"run_id": row.id,
"status": row.status,
"started_at": _dt_to_iso(row.started_at),
"finished_at": _dt_to_iso(row.finished_at),
"metrics_written": row.metrics_written,
"anomalies_written": row.anomalies_written,
}
def latest_refresh_finished_at(self) -> datetime | None:
with self._session() as session:
row = (
session.execute(
select(RefreshRunRow)
.where(RefreshRunRow.status.in_(("success", "partial_success")))
.order_by(RefreshRunRow.finished_at.desc())
.limit(1)
)
.scalars()
.first()
)
if row is None:
return None
return row.finished_at
def iter_entities_for_features(self, *, limit: int = 200000) -> list[dict[str, Any]]:
safe_limit = max(1, min(limit, 200000))
with self._session() as session:
rows = (
session.execute(
select(CanonicalEntityRow)
.order_by(CanonicalEntityRow.updated_at.desc())
.limit(safe_limit)
)
.scalars()
.all()
)
output: list[dict[str, Any]] = []
for row in rows:
output.append(
{
"source_entity": row.source_entity,
"source_id": row.source_id,
"display_name": row.display_name,
"attributes": _load_json(row.attributes_json, {}),
"updated_at": row.updated_at,
}
)
return output
def link_counts_by_source(self) -> dict[tuple[str, str], int]:
with self._session() as session:
rows = session.execute(
select(
CanonicalLinkRow.source_entity,
CanonicalLinkRow.source_id,
func.count(CanonicalLinkRow.id),
)
.group_by(CanonicalLinkRow.source_entity, CanonicalLinkRow.source_id)
).all()
output: dict[tuple[str, str], int] = {}
for source_entity, source_id, count in rows:
output[(str(source_entity), str(source_id))] = int(count)
return output
def start_risk_run(self, *, source_feature_run_id: str | None) -> str:
run_id = uuid4().hex
with self._session() as session, session.begin():
session.add(
RiskRunRow(
id=run_id,
status="running",
started_at=_utc_now(),
source_feature_run_id=source_feature_run_id,
)
)
return run_id
def replace_risk_patterns(self, *, run_id: str, patterns: list[dict[str, Any]]) -> int:
now = _utc_now()
with self._session() as session, session.begin():
session.execute(delete(RiskPatternRow).where(RiskPatternRow.risk_run_id == run_id))
previous_active = session.execute(
select(RiskPatternRow).where(RiskPatternRow.is_active == 1)
).scalars().all()
for item in previous_active:
item.is_active = 0
for pattern in patterns:
session.add(
RiskPatternRow(
risk_run_id=run_id,
pattern_key=str(pattern.get("pattern_key", "unknown_pattern")),
severity=str(pattern.get("severity", "low")),
scope=str(pattern.get("scope", "global")),
scope_id=str(pattern.get("scope_id", "")),
score=float(pattern.get("score", 0.0)),
confidence=float(pattern.get("confidence", 0.0)),
details_json=_dump_json(pattern.get("details", {})),
detected_at=now,
is_active=1,
)
)
return len(patterns)
def finish_risk_run(
self,
*,
run_id: str,
status: str,
patterns_written: int,
global_score: float,
details: dict[str, Any] | None = None,
error_message: str | None = None,
) -> None:
with self._session() as session, session.begin():
row = session.get(RiskRunRow, run_id)
if row is None:
return
row.status = status
row.patterns_written = patterns_written
row.global_score = global_score
row.details_json = _dump_json(details or {})
row.error_message = error_message
row.finished_at = _utc_now()
def list_recent_risk_runs(self, limit: int = 20) -> list[dict[str, Any]]:
safe_limit = max(1, min(limit, 200))
with self._session() as session:
rows = (
session.execute(
select(RiskRunRow)
.order_by(RiskRunRow.started_at.desc())
.limit(safe_limit)
)
.scalars()
.all()
)
output: list[dict[str, Any]] = []
for row in rows:
output.append(
{
"run_id": row.id,
"status": row.status,
"started_at": _dt_to_iso(row.started_at),
"finished_at": _dt_to_iso(row.finished_at),
"source_feature_run_id": row.source_feature_run_id,
"patterns_written": row.patterns_written,
"global_score": row.global_score,
"details": _load_json(row.details_json, {}),
"error_message": row.error_message,
}
)
return output
def list_risk_patterns(
self,
*,
limit: int = 200,
severity: str | None = None,
active_only: bool = True,
run_id: str | None = None,
pattern_key: str | None = None,
scope: str | None = None,
) -> list[dict[str, Any]]:
safe_limit = max(1, min(limit, 2000))
with self._session() as session:
stmt = select(RiskPatternRow).order_by(RiskPatternRow.detected_at.desc(), RiskPatternRow.id.desc())
if active_only:
stmt = stmt.where(RiskPatternRow.is_active == 1)
if severity:
stmt = stmt.where(RiskPatternRow.severity == severity)
if run_id:
stmt = stmt.where(RiskPatternRow.risk_run_id == run_id)
if pattern_key:
stmt = stmt.where(RiskPatternRow.pattern_key == pattern_key)
if scope:
stmt = stmt.where(RiskPatternRow.scope == scope)
rows = session.execute(stmt.limit(safe_limit)).scalars().all()
output: list[dict[str, Any]] = []
for row in rows:
output.append(
{
"id": row.id,
"risk_run_id": row.risk_run_id,
"pattern_key": row.pattern_key,
"severity": row.severity,
"scope": row.scope,
"scope_id": row.scope_id,
"score": row.score,
"confidence": row.confidence,
"details": _load_json(row.details_json, {}),
"detected_at": _dt_to_iso(row.detected_at),
"is_active": bool(row.is_active),
}
)
return output
def risk_store_stats(self) -> dict[str, Any]:
with self._session() as session:
patterns_total = session.execute(select(func.count(RiskPatternRow.id))).scalar_one()
active_patterns_total = session.execute(
select(func.count(RiskPatternRow.id)).where(RiskPatternRow.is_active == 1)
).scalar_one()
latest_run = (
session.execute(select(RiskRunRow).order_by(RiskRunRow.started_at.desc()).limit(1))
.scalars()
.first()
)
latest_payload: dict[str, Any] | None = None
if latest_run is not None:
latest_payload = {
"run_id": latest_run.id,
"status": latest_run.status,
"started_at": _dt_to_iso(latest_run.started_at),
"finished_at": _dt_to_iso(latest_run.finished_at),
"patterns_written": latest_run.patterns_written,
"global_score": latest_run.global_score,
}
return {
"patterns_total": int(patterns_total),
"active_patterns_total": int(active_patterns_total),
"latest_risk_run": latest_payload,
}