from __future__ import annotations from datetime import datetime, timezone import json from pathlib import Path from typing import Any from uuid import uuid4 from sqlalchemy import create_engine, delete, func, select from sqlalchemy.orm import Session, sessionmaker from canonical_layer.models import CanonicalEntity from canonical_layer.store_models import ( AnomalySignalRow, Base, CanonicalEntityRow, CanonicalLinkRow, FeatureMetricRow, FeatureRunRow, RefreshCheckpointRow, RefreshRunRow, RiskPatternRow, RiskRunRow, ) def _utc_now() -> datetime: return datetime.now(timezone.utc) def _dump_json(payload: Any) -> str: return json.dumps(payload, ensure_ascii=False) def _load_json(payload: str, default: Any) -> Any: if not payload: return default try: return json.loads(payload) except json.JSONDecodeError: return default def _dt_to_iso(value: datetime | None) -> str | None: if value is None: return None if value.tzinfo is None: value = value.replace(tzinfo=timezone.utc) return value.astimezone(timezone.utc).isoformat() class CanonicalStore: def __init__(self, db_url: str) -> None: self.db_url = db_url self.engine = create_engine(db_url, future=True) self.session_factory = sessionmaker(bind=self.engine, autoflush=False, expire_on_commit=False, future=True) def _ensure_sqlite_path(self) -> None: if not self.db_url.startswith("sqlite:///"): return db_path_raw = self.db_url.replace("sqlite:///", "", 1) db_path = Path(db_path_raw) db_path.parent.mkdir(parents=True, exist_ok=True) def ensure_created(self) -> None: self._ensure_sqlite_path() Base.metadata.create_all(self.engine) def _session(self) -> Session: return self.session_factory() def start_refresh_run( self, *, mode: str, requested_entity_sets: list[str], date_from: str | None, date_to: str | None, limit_per_set: int, ) -> str: run_id = uuid4().hex with self._session() as session, session.begin(): session.add( RefreshRunRow( id=run_id, mode=mode, status="running", started_at=_utc_now(), requested_entity_sets_json=_dump_json(requested_entity_sets), date_from=date_from, date_to=date_to, limit_per_set=limit_per_set, ) ) return run_id def finish_refresh_run( self, *, run_id: str, status: str, records_read: int, entities_written: int, links_written: int, checkpoints_updated: int, details: dict[str, Any] | None = None, error_message: str | None = None, ) -> None: with self._session() as session, session.begin(): run = session.get(RefreshRunRow, run_id) if run is None: return run.status = status run.records_read = records_read run.entities_written = entities_written run.links_written = links_written run.checkpoints_updated = checkpoints_updated run.details_json = _dump_json(details or {}) run.error_message = error_message run.finished_at = _utc_now() def upsert_entities(self, *, run_id: str, entities: list[CanonicalEntity]) -> tuple[int, int]: entities_written = 0 links_written = 0 now = _utc_now() with self._session() as session, session.begin(): for entity in entities: row = session.execute( select(CanonicalEntityRow).where( CanonicalEntityRow.source_entity == entity.source_entity, CanonicalEntityRow.source_id == entity.source_id, ) ).scalar_one_or_none() if row is None: row = CanonicalEntityRow( source_entity=entity.source_entity, source_id=entity.source_id, display_name=entity.display_name, attributes_json=_dump_json(entity.attributes), first_seen_at=now, updated_at=now, last_refresh_run_id=run_id, ) session.add(row) else: row.display_name = entity.display_name row.attributes_json = _dump_json(entity.attributes) row.updated_at = now row.last_refresh_run_id = run_id entities_written += 1 session.execute( delete(CanonicalLinkRow).where( CanonicalLinkRow.source_entity == entity.source_entity, CanonicalLinkRow.source_id == entity.source_id, ) ) for link in entity.links: if not link.target_id: continue session.add( CanonicalLinkRow( source_entity=entity.source_entity, source_id=entity.source_id, relation=link.relation, target_entity=link.target_entity, target_id=link.target_id, source_field=link.source_field, updated_at=now, last_refresh_run_id=run_id, ) ) links_written += 1 return entities_written, links_written def update_checkpoints( self, *, run_id: str, entity_sets: list[str], date_from: str | None, date_to: str | None, ) -> int: if not entity_sets: return 0 now = _utc_now() updated = 0 with self._session() as session, session.begin(): for entity_set in entity_sets: row = session.get(RefreshCheckpointRow, entity_set) if row is None: row = RefreshCheckpointRow( entity_set=entity_set, last_success_at=now, last_refresh_run_id=run_id, last_date_from=date_from, last_date_to=date_to, ) session.add(row) else: row.last_success_at = now row.last_refresh_run_id = run_id row.last_date_from = date_from row.last_date_to = date_to updated += 1 return updated def list_recent_runs(self, limit: int = 20) -> list[dict[str, Any]]: safe_limit = max(1, min(limit, 200)) with self._session() as session: rows = ( session.execute( select(RefreshRunRow) .order_by(RefreshRunRow.started_at.desc()) .limit(safe_limit) ) .scalars() .all() ) output: list[dict[str, Any]] = [] for row in rows: output.append( { "run_id": row.id, "mode": row.mode, "status": row.status, "started_at": _dt_to_iso(row.started_at), "finished_at": _dt_to_iso(row.finished_at), "requested_entity_sets": _load_json(row.requested_entity_sets_json, []), "date_from": row.date_from, "date_to": row.date_to, "limit_per_set": row.limit_per_set, "records_read": row.records_read, "entities_written": row.entities_written, "links_written": row.links_written, "checkpoints_updated": row.checkpoints_updated, "details": _load_json(row.details_json, {}), "error_message": row.error_message, } ) return output def store_stats(self) -> dict[str, Any]: with self._session() as session: entities_total = session.execute(select(func.count(CanonicalEntityRow.id))).scalar_one() links_total = session.execute(select(func.count(CanonicalLinkRow.id))).scalar_one() checkpoints_total = session.execute(select(func.count(RefreshCheckpointRow.entity_set))).scalar_one() latest_run = ( session.execute(select(RefreshRunRow).order_by(RefreshRunRow.started_at.desc()).limit(1)) .scalars() .first() ) latest_run_payload: dict[str, Any] | None = None if latest_run is not None: latest_run_payload = { "run_id": latest_run.id, "mode": latest_run.mode, "status": latest_run.status, "started_at": _dt_to_iso(latest_run.started_at), "finished_at": _dt_to_iso(latest_run.finished_at), } return { "db_url": self.db_url, "entities_total": int(entities_total), "links_total": int(links_total), "checkpoints_total": int(checkpoints_total), "latest_run": latest_run_payload, } def start_feature_run(self, *, baseline_window_hours: int) -> str: run_id = uuid4().hex with self._session() as session, session.begin(): session.add( FeatureRunRow( id=run_id, status="running", started_at=_utc_now(), baseline_window_hours=baseline_window_hours, ) ) return run_id def replace_feature_results( self, *, run_id: str, metrics: list[dict[str, Any]], anomalies: list[dict[str, Any]], ) -> tuple[int, int]: now = _utc_now() with self._session() as session, session.begin(): session.execute(delete(FeatureMetricRow).where(FeatureMetricRow.feature_run_id == run_id)) session.execute(delete(AnomalySignalRow).where(AnomalySignalRow.feature_run_id == run_id)) # Deactivate previously active anomalies before writing a new active snapshot. previous_anomalies = session.execute( select(AnomalySignalRow).where(AnomalySignalRow.is_active == 1) ).scalars().all() for item in previous_anomalies: item.is_active = 0 for metric in metrics: session.add( FeatureMetricRow( feature_run_id=run_id, metric_key=str(metric.get("metric_key", "")), scope=str(metric.get("scope", "global")), scope_id=str(metric.get("scope_id", "")), metric_type=str(metric.get("metric_type", "gauge")), metric_value=float(metric.get("metric_value", 0.0)), attributes_json=_dump_json(metric.get("attributes", {})), computed_at=now, ) ) for anomaly in anomalies: session.add( AnomalySignalRow( feature_run_id=run_id, signal_type=str(anomaly.get("signal_type", "unknown_signal")), severity=str(anomaly.get("severity", "medium")), scope=str(anomaly.get("scope", "global")), scope_id=str(anomaly.get("scope_id", "")), score=float(anomaly.get("score", 0.0)), details_json=_dump_json(anomaly.get("details", {})), detected_at=now, is_active=1, ) ) return len(metrics), len(anomalies) def finish_feature_run( self, *, run_id: str, status: str, entities_total: int, metrics_written: int, anomalies_written: int, details: dict[str, Any] | None = None, error_message: str | None = None, ) -> None: with self._session() as session, session.begin(): row = session.get(FeatureRunRow, run_id) if row is None: return row.status = status row.entities_total = entities_total row.metrics_written = metrics_written row.anomalies_written = anomalies_written row.details_json = _dump_json(details or {}) row.error_message = error_message row.finished_at = _utc_now() def list_recent_feature_runs(self, limit: int = 20) -> list[dict[str, Any]]: safe_limit = max(1, min(limit, 200)) with self._session() as session: rows = ( session.execute( select(FeatureRunRow) .order_by(FeatureRunRow.started_at.desc()) .limit(safe_limit) ) .scalars() .all() ) output: list[dict[str, Any]] = [] for row in rows: output.append( { "run_id": row.id, "status": row.status, "started_at": _dt_to_iso(row.started_at), "finished_at": _dt_to_iso(row.finished_at), "baseline_window_hours": row.baseline_window_hours, "entities_total": row.entities_total, "metrics_written": row.metrics_written, "anomalies_written": row.anomalies_written, "details": _load_json(row.details_json, {}), "error_message": row.error_message, } ) return output def list_feature_metrics( self, *, limit: int = 200, metric_key: str | None = None, scope: str | None = None, run_id: str | None = None, ) -> list[dict[str, Any]]: safe_limit = max(1, min(limit, 2000)) with self._session() as session: stmt = select(FeatureMetricRow).order_by(FeatureMetricRow.computed_at.desc(), FeatureMetricRow.id.desc()) if metric_key: stmt = stmt.where(FeatureMetricRow.metric_key == metric_key) if scope: stmt = stmt.where(FeatureMetricRow.scope == scope) if run_id: stmt = stmt.where(FeatureMetricRow.feature_run_id == run_id) rows = session.execute(stmt.limit(safe_limit)).scalars().all() output: list[dict[str, Any]] = [] for row in rows: output.append( { "id": row.id, "feature_run_id": row.feature_run_id, "metric_key": row.metric_key, "scope": row.scope, "scope_id": row.scope_id, "metric_type": row.metric_type, "metric_value": row.metric_value, "attributes": _load_json(row.attributes_json, {}), "computed_at": _dt_to_iso(row.computed_at), } ) return output def list_anomaly_signals( self, *, limit: int = 200, severity: str | None = None, active_only: bool = True, run_id: str | None = None, ) -> list[dict[str, Any]]: safe_limit = max(1, min(limit, 2000)) with self._session() as session: stmt = select(AnomalySignalRow).order_by(AnomalySignalRow.detected_at.desc(), AnomalySignalRow.id.desc()) if active_only: stmt = stmt.where(AnomalySignalRow.is_active == 1) if severity: stmt = stmt.where(AnomalySignalRow.severity == severity) if run_id: stmt = stmt.where(AnomalySignalRow.feature_run_id == run_id) rows = session.execute(stmt.limit(safe_limit)).scalars().all() output: list[dict[str, Any]] = [] for row in rows: output.append( { "id": row.id, "feature_run_id": row.feature_run_id, "signal_type": row.signal_type, "severity": row.severity, "scope": row.scope, "scope_id": row.scope_id, "score": row.score, "details": _load_json(row.details_json, {}), "detected_at": _dt_to_iso(row.detected_at), "is_active": bool(row.is_active), } ) return output def feature_store_stats(self) -> dict[str, Any]: with self._session() as session: metrics_total = session.execute(select(func.count(FeatureMetricRow.id))).scalar_one() anomalies_total = session.execute(select(func.count(AnomalySignalRow.id))).scalar_one() active_anomalies_total = session.execute( select(func.count(AnomalySignalRow.id)).where(AnomalySignalRow.is_active == 1) ).scalar_one() latest_feature_run = ( session.execute(select(FeatureRunRow).order_by(FeatureRunRow.started_at.desc()).limit(1)) .scalars() .first() ) latest_payload: dict[str, Any] | None = None if latest_feature_run is not None: latest_payload = { "run_id": latest_feature_run.id, "status": latest_feature_run.status, "started_at": _dt_to_iso(latest_feature_run.started_at), "finished_at": _dt_to_iso(latest_feature_run.finished_at), "metrics_written": latest_feature_run.metrics_written, "anomalies_written": latest_feature_run.anomalies_written, } return { "metrics_total": int(metrics_total), "anomalies_total": int(anomalies_total), "active_anomalies_total": int(active_anomalies_total), "latest_feature_run": latest_payload, } def latest_successful_feature_run(self) -> dict[str, Any] | None: with self._session() as session: row = ( session.execute( select(FeatureRunRow) .where(FeatureRunRow.status == "success") .order_by(FeatureRunRow.finished_at.desc(), FeatureRunRow.started_at.desc()) .limit(1) ) .scalars() .first() ) if row is None: return None return { "run_id": row.id, "status": row.status, "started_at": _dt_to_iso(row.started_at), "finished_at": _dt_to_iso(row.finished_at), "metrics_written": row.metrics_written, "anomalies_written": row.anomalies_written, } def latest_refresh_finished_at(self) -> datetime | None: with self._session() as session: row = ( session.execute( select(RefreshRunRow) .where(RefreshRunRow.status.in_(("success", "partial_success"))) .order_by(RefreshRunRow.finished_at.desc()) .limit(1) ) .scalars() .first() ) if row is None: return None return row.finished_at def iter_entities_for_features(self, *, limit: int = 200000) -> list[dict[str, Any]]: safe_limit = max(1, min(limit, 200000)) with self._session() as session: rows = ( session.execute( select(CanonicalEntityRow) .order_by(CanonicalEntityRow.updated_at.desc()) .limit(safe_limit) ) .scalars() .all() ) output: list[dict[str, Any]] = [] for row in rows: output.append( { "source_entity": row.source_entity, "source_id": row.source_id, "display_name": row.display_name, "attributes": _load_json(row.attributes_json, {}), "updated_at": row.updated_at, } ) return output def link_counts_by_source(self) -> dict[tuple[str, str], int]: with self._session() as session: rows = session.execute( select( CanonicalLinkRow.source_entity, CanonicalLinkRow.source_id, func.count(CanonicalLinkRow.id), ) .group_by(CanonicalLinkRow.source_entity, CanonicalLinkRow.source_id) ).all() output: dict[tuple[str, str], int] = {} for source_entity, source_id, count in rows: output[(str(source_entity), str(source_id))] = int(count) return output def start_risk_run(self, *, source_feature_run_id: str | None) -> str: run_id = uuid4().hex with self._session() as session, session.begin(): session.add( RiskRunRow( id=run_id, status="running", started_at=_utc_now(), source_feature_run_id=source_feature_run_id, ) ) return run_id def replace_risk_patterns(self, *, run_id: str, patterns: list[dict[str, Any]]) -> int: now = _utc_now() with self._session() as session, session.begin(): session.execute(delete(RiskPatternRow).where(RiskPatternRow.risk_run_id == run_id)) previous_active = session.execute( select(RiskPatternRow).where(RiskPatternRow.is_active == 1) ).scalars().all() for item in previous_active: item.is_active = 0 for pattern in patterns: session.add( RiskPatternRow( risk_run_id=run_id, pattern_key=str(pattern.get("pattern_key", "unknown_pattern")), severity=str(pattern.get("severity", "low")), scope=str(pattern.get("scope", "global")), scope_id=str(pattern.get("scope_id", "")), score=float(pattern.get("score", 0.0)), confidence=float(pattern.get("confidence", 0.0)), details_json=_dump_json(pattern.get("details", {})), detected_at=now, is_active=1, ) ) return len(patterns) def finish_risk_run( self, *, run_id: str, status: str, patterns_written: int, global_score: float, details: dict[str, Any] | None = None, error_message: str | None = None, ) -> None: with self._session() as session, session.begin(): row = session.get(RiskRunRow, run_id) if row is None: return row.status = status row.patterns_written = patterns_written row.global_score = global_score row.details_json = _dump_json(details or {}) row.error_message = error_message row.finished_at = _utc_now() def list_recent_risk_runs(self, limit: int = 20) -> list[dict[str, Any]]: safe_limit = max(1, min(limit, 200)) with self._session() as session: rows = ( session.execute( select(RiskRunRow) .order_by(RiskRunRow.started_at.desc()) .limit(safe_limit) ) .scalars() .all() ) output: list[dict[str, Any]] = [] for row in rows: output.append( { "run_id": row.id, "status": row.status, "started_at": _dt_to_iso(row.started_at), "finished_at": _dt_to_iso(row.finished_at), "source_feature_run_id": row.source_feature_run_id, "patterns_written": row.patterns_written, "global_score": row.global_score, "details": _load_json(row.details_json, {}), "error_message": row.error_message, } ) return output def list_risk_patterns( self, *, limit: int = 200, severity: str | None = None, active_only: bool = True, run_id: str | None = None, pattern_key: str | None = None, scope: str | None = None, ) -> list[dict[str, Any]]: safe_limit = max(1, min(limit, 2000)) with self._session() as session: stmt = select(RiskPatternRow).order_by(RiskPatternRow.detected_at.desc(), RiskPatternRow.id.desc()) if active_only: stmt = stmt.where(RiskPatternRow.is_active == 1) if severity: stmt = stmt.where(RiskPatternRow.severity == severity) if run_id: stmt = stmt.where(RiskPatternRow.risk_run_id == run_id) if pattern_key: stmt = stmt.where(RiskPatternRow.pattern_key == pattern_key) if scope: stmt = stmt.where(RiskPatternRow.scope == scope) rows = session.execute(stmt.limit(safe_limit)).scalars().all() output: list[dict[str, Any]] = [] for row in rows: output.append( { "id": row.id, "risk_run_id": row.risk_run_id, "pattern_key": row.pattern_key, "severity": row.severity, "scope": row.scope, "scope_id": row.scope_id, "score": row.score, "confidence": row.confidence, "details": _load_json(row.details_json, {}), "detected_at": _dt_to_iso(row.detected_at), "is_active": bool(row.is_active), } ) return output def risk_store_stats(self) -> dict[str, Any]: with self._session() as session: patterns_total = session.execute(select(func.count(RiskPatternRow.id))).scalar_one() active_patterns_total = session.execute( select(func.count(RiskPatternRow.id)).where(RiskPatternRow.is_active == 1) ).scalar_one() latest_run = ( session.execute(select(RiskRunRow).order_by(RiskRunRow.started_at.desc()).limit(1)) .scalars() .first() ) latest_payload: dict[str, Any] | None = None if latest_run is not None: latest_payload = { "run_id": latest_run.id, "status": latest_run.status, "started_at": _dt_to_iso(latest_run.started_at), "finished_at": _dt_to_iso(latest_run.finished_at), "patterns_written": latest_run.patterns_written, "global_score": latest_run.global_score, } return { "patterns_total": int(patterns_total), "active_patterns_total": int(active_patterns_total), "latest_risk_run": latest_payload, }