99 lines
2.7 KiB
Python
99 lines
2.7 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import asdict, dataclass
|
|
from datetime import datetime, timezone
|
|
from typing import Any, Callable
|
|
from uuid import uuid4
|
|
|
|
|
|
JsonLike = dict[str, Any]
|
|
|
|
|
|
@dataclass
|
|
class BatchRuntimeJob:
|
|
job_id: str
|
|
job_type: str
|
|
question_id: str
|
|
slice_window: str
|
|
requested_outputs: list[str]
|
|
reason: list[str]
|
|
created_at: str
|
|
|
|
def to_dict(self) -> JsonLike:
|
|
return asdict(self)
|
|
|
|
|
|
@dataclass
|
|
class BatchRuntimeResult:
|
|
job_id: str
|
|
status: str
|
|
execution_mode: str
|
|
run_ids: JsonLike
|
|
error_message: str | None = None
|
|
|
|
def to_dict(self) -> JsonLike:
|
|
return asdict(self)
|
|
|
|
|
|
def enqueue_refresh_and_answer_job(
|
|
*,
|
|
question_id: str,
|
|
slice_window: str,
|
|
requested_outputs: list[str],
|
|
reason: list[str],
|
|
) -> BatchRuntimeJob:
|
|
return BatchRuntimeJob(
|
|
job_id=uuid4().hex,
|
|
job_type="refresh_and_answer",
|
|
question_id=question_id,
|
|
slice_window=slice_window,
|
|
requested_outputs=list(requested_outputs),
|
|
reason=list(reason),
|
|
created_at=datetime.now(timezone.utc).isoformat(),
|
|
)
|
|
|
|
|
|
def run_refresh_and_answer_job(
|
|
job: BatchRuntimeJob,
|
|
*,
|
|
refresh_executor: Callable[[], JsonLike] | None = None,
|
|
feature_executor: Callable[[], JsonLike] | None = None,
|
|
risk_executor: Callable[[], JsonLike] | None = None,
|
|
should_refresh: bool = False,
|
|
) -> BatchRuntimeResult:
|
|
run_ids: JsonLike = {}
|
|
|
|
try:
|
|
if should_refresh and refresh_executor is not None:
|
|
refresh_result = refresh_executor()
|
|
run_ids["refresh_run_id"] = refresh_result.get("run_id") or refresh_result.get("refresh_run_id")
|
|
elif should_refresh:
|
|
run_ids["refresh_run_id"] = None
|
|
run_ids["refresh_note"] = "refresh_requested_but_executor_missing"
|
|
else:
|
|
run_ids["refresh_note"] = "refresh_skipped_by_policy"
|
|
|
|
if feature_executor is not None and "feature_store" in job.requested_outputs:
|
|
feature_result = feature_executor()
|
|
run_ids["feature_run_id"] = feature_result.get("run_id")
|
|
|
|
if risk_executor is not None and "risk_store" in job.requested_outputs:
|
|
risk_result = risk_executor()
|
|
run_ids["risk_run_id"] = risk_result.get("run_id")
|
|
|
|
return BatchRuntimeResult(
|
|
job_id=job.job_id,
|
|
status="success",
|
|
execution_mode="batch_runtime_executed",
|
|
run_ids=run_ids,
|
|
)
|
|
except Exception as exc:
|
|
return BatchRuntimeResult(
|
|
job_id=job.job_id,
|
|
status="failed",
|
|
execution_mode="batch_runtime_failed",
|
|
run_ids=run_ids,
|
|
error_message=str(exc),
|
|
)
|
|
|