294 lines
11 KiB
Python
294 lines
11 KiB
Python
"""Session manager — fetches and caches session state from the backend.
|
|
|
|
Provides session context to the VLM prompt and handles session lifecycle
|
|
actions (start, resume, switch, complete).
|
|
|
|
Swift portability notes:
|
|
- SessionManager becomes an ObservableObject with @Published properties
|
|
- Backend calls use URLSession instead of httpx
|
|
- match_session() logic is identical
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import time
|
|
from dataclasses import dataclass, field
|
|
|
|
import httpx
|
|
|
|
from argus.config import BACKEND_BASE_URL, BACKEND_JWT
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class SessionInfo:
|
|
session_id: str
|
|
task_id: str | None
|
|
task_title: str
|
|
task_goal: str
|
|
status: str # active | interrupted
|
|
last_app: str
|
|
last_file: str
|
|
checkpoint_note: str
|
|
started_at: str
|
|
ended_at: str | None
|
|
minutes_ago: int | None = None
|
|
|
|
|
|
class SessionManager:
|
|
"""Caches open sessions and provides matching + lifecycle operations."""
|
|
|
|
def __init__(self, jwt: str | None = None, base_url: str | None = None):
|
|
self._jwt = jwt or BACKEND_JWT
|
|
self._base_url = base_url or BACKEND_BASE_URL
|
|
self._sessions: list[SessionInfo] = []
|
|
self._active_session: SessionInfo | None = None
|
|
self._last_fetch: float = 0
|
|
self._fetch_interval = 30.0 # re-fetch every 30s
|
|
self._mock: bool = False
|
|
|
|
# Track inferred_task stability for new session suggestion
|
|
self._inferred_task_history: list[str] = []
|
|
self._stable_threshold = 3 # consecutive matching inferred_tasks
|
|
|
|
@property
|
|
def active(self) -> SessionInfo | None:
|
|
return self._active_session
|
|
|
|
@property
|
|
def sessions(self) -> list[SessionInfo]:
|
|
return self._sessions
|
|
|
|
def has_sessions(self) -> bool:
|
|
return len(self._sessions) > 0
|
|
|
|
# ── Backend communication ────────────────────────────────────────
|
|
|
|
async def fetch_open_sessions(self) -> list[SessionInfo]:
|
|
"""Fetch the active session from backend (GET /sessions/active → single object or 404).
|
|
Called on startup and periodically.
|
|
"""
|
|
url = f"{self._base_url}/sessions/active"
|
|
headers = {}
|
|
if self._jwt:
|
|
headers["Authorization"] = f"Bearer {self._jwt}"
|
|
|
|
try:
|
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
|
resp = await client.get(url, headers=headers)
|
|
|
|
if resp.status_code == 404:
|
|
# No active session — clear cache
|
|
self._sessions = []
|
|
self._active_session = None
|
|
self._last_fetch = time.monotonic()
|
|
log.info("No active session on backend")
|
|
return []
|
|
|
|
resp.raise_for_status()
|
|
data = resp.json()
|
|
session = self._parse_session(data)
|
|
self._sessions = [session]
|
|
self._active_session = session
|
|
self._last_fetch = time.monotonic()
|
|
log.info("Fetched active session: %s (%s)", session.session_id, session.task_title)
|
|
return self._sessions
|
|
|
|
except httpx.HTTPStatusError as e:
|
|
log.warning("Failed to fetch sessions: %s", e.response.status_code)
|
|
return self._sessions
|
|
except httpx.RequestError as e:
|
|
log.debug("Backend unreachable for sessions: %s", e)
|
|
return self._sessions
|
|
|
|
async def maybe_refresh(self) -> None:
|
|
"""Re-fetch if stale. Skips if using mock data."""
|
|
if self._mock:
|
|
return
|
|
if time.monotonic() - self._last_fetch > self._fetch_interval:
|
|
await self.fetch_open_sessions()
|
|
|
|
async def start_session(self, task_title: str, task_goal: str = "") -> dict | None:
|
|
"""Start a new focus session. Returns session data or None."""
|
|
url = f"{self._base_url}/sessions/start"
|
|
headers = {"Authorization": f"Bearer {self._jwt}"} if self._jwt else {}
|
|
|
|
# If we have a task_title but no task_id, the backend should
|
|
# create an ad-hoc session (or we create the task first).
|
|
payload = {"platform": "mac"}
|
|
|
|
try:
|
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
|
resp = await client.post(url, json=payload, headers=headers)
|
|
resp.raise_for_status()
|
|
result = resp.json()
|
|
await self.fetch_open_sessions() # refresh
|
|
log.info("Started session: %s", result.get("id"))
|
|
return result
|
|
except Exception:
|
|
log.exception("Failed to start session")
|
|
return None
|
|
|
|
async def end_session(self, session_id: str, status: str = "completed") -> bool:
|
|
"""End a session. Returns True on success."""
|
|
url = f"{self._base_url}/sessions/{session_id}/end"
|
|
headers = {"Authorization": f"Bearer {self._jwt}"} if self._jwt else {}
|
|
|
|
try:
|
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
|
resp = await client.post(url, json={"status": status}, headers=headers)
|
|
resp.raise_for_status()
|
|
await self.fetch_open_sessions() # refresh
|
|
log.info("Ended session %s (%s)", session_id, status)
|
|
return True
|
|
except Exception:
|
|
log.exception("Failed to end session %s", session_id)
|
|
return False
|
|
|
|
async def get_resume_card(self, session_id: str) -> dict | None:
|
|
"""Fetch AI-generated resume card for a session."""
|
|
url = f"{self._base_url}/sessions/{session_id}/resume"
|
|
headers = {"Authorization": f"Bearer {self._jwt}"} if self._jwt else {}
|
|
|
|
try:
|
|
async with httpx.AsyncClient(timeout=15.0) as client:
|
|
resp = await client.get(url, headers=headers)
|
|
resp.raise_for_status()
|
|
return resp.json()
|
|
except Exception:
|
|
log.exception("Failed to get resume card for %s", session_id)
|
|
return None
|
|
|
|
# ── Session matching ─────────────────────────────────────────────
|
|
|
|
def match_session(self, inferred_task: str, app_name: str) -> SessionInfo | None:
|
|
"""Find an open session that matches the current screen state.
|
|
Returns the best match or None.
|
|
"""
|
|
if not inferred_task or not self._sessions:
|
|
return None
|
|
|
|
inferred_lower = inferred_task.lower()
|
|
|
|
best_match: SessionInfo | None = None
|
|
best_score = 0
|
|
|
|
for session in self._sessions:
|
|
score = 0
|
|
title_lower = session.task_title.lower()
|
|
|
|
# Check for keyword overlap between inferred_task and session task_title
|
|
inferred_words = set(inferred_lower.split())
|
|
title_words = set(title_lower.split())
|
|
overlap = inferred_words & title_words
|
|
# Filter out common words
|
|
overlap -= {"the", "a", "an", "in", "on", "to", "and", "or", "is", "for", "of", "with"}
|
|
score += len(overlap) * 2
|
|
|
|
# App match
|
|
if app_name and session.last_app and app_name.lower() in session.last_app.lower():
|
|
score += 3
|
|
|
|
# File match — check if session's last_file appears in inferred_task
|
|
if session.last_file and session.last_file.lower() in inferred_lower:
|
|
score += 5
|
|
|
|
if score > best_score:
|
|
best_score = score
|
|
best_match = session
|
|
|
|
# Require minimum score to avoid false matches
|
|
if best_score >= 4:
|
|
return best_match
|
|
return None
|
|
|
|
def should_suggest_new_session(self, inferred_task: str) -> bool:
|
|
"""Check if we should suggest starting a new session.
|
|
Returns True if inferred_task has been stable for N iterations
|
|
and doesn't match any existing session.
|
|
"""
|
|
if not inferred_task:
|
|
self._inferred_task_history.clear()
|
|
return False
|
|
|
|
self._inferred_task_history.append(inferred_task)
|
|
# Keep only recent history
|
|
if len(self._inferred_task_history) > self._stable_threshold + 2:
|
|
self._inferred_task_history = self._inferred_task_history[-self._stable_threshold - 2:]
|
|
|
|
if len(self._inferred_task_history) < self._stable_threshold:
|
|
return False
|
|
|
|
# Check if the last N inferred tasks are "similar" (share key words)
|
|
recent = self._inferred_task_history[-self._stable_threshold:]
|
|
first_words = set(recent[0].lower().split()) - {"the", "a", "an", "in", "on", "to", "and", "or", "is", "for", "of", "with"}
|
|
all_similar = all(
|
|
len(first_words & set(t.lower().split())) >= len(first_words) * 0.5
|
|
for t in recent[1:]
|
|
)
|
|
|
|
if not all_similar:
|
|
return False
|
|
|
|
# Make sure it doesn't match an existing session
|
|
matched = self.match_session(inferred_task, "")
|
|
return matched is None
|
|
|
|
# ── Prompt formatting ────────────────────────────────────────────
|
|
|
|
def format_for_prompt(self) -> str:
|
|
"""Format open sessions for injection into the VLM prompt.
|
|
Includes session_id so VLM can reference the exact ID in session_action.
|
|
"""
|
|
if not self._sessions:
|
|
return "(no open sessions)"
|
|
|
|
lines: list[str] = []
|
|
for s in self._sessions:
|
|
status_tag = f"[{s.status}]"
|
|
ago = f" (paused {s.minutes_ago}m ago)" if s.minutes_ago else ""
|
|
line = f" session_id=\"{s.session_id}\" {status_tag} \"{s.task_title}\" — last in {s.last_app}"
|
|
if s.last_file:
|
|
line += f"/{s.last_file}"
|
|
if s.checkpoint_note:
|
|
line += f", \"{s.checkpoint_note}\""
|
|
line += ago
|
|
lines.append(line)
|
|
return "\n".join(lines)
|
|
|
|
# ── Internal ─────────────────────────────────────────────────────
|
|
|
|
def _parse_session(self, data: dict) -> SessionInfo:
|
|
checkpoint = data.get("checkpoint", {}) or {}
|
|
ended = data.get("ended_at")
|
|
minutes_ago = None
|
|
if ended:
|
|
import datetime
|
|
try:
|
|
ended_dt = datetime.datetime.fromisoformat(ended.replace("Z", "+00:00"))
|
|
now = datetime.datetime.now(datetime.timezone.utc)
|
|
minutes_ago = int((now - ended_dt).total_seconds() / 60)
|
|
except (ValueError, TypeError):
|
|
pass
|
|
|
|
# SessionOut has no nested "task" object — task title is stored in checkpoint["goal"]
|
|
# by POST /sessions/start when a task_id is provided.
|
|
task_title = checkpoint.get("goal", "")
|
|
checkpoint_note = checkpoint.get("last_action_summary", checkpoint.get("last_vlm_summary", ""))
|
|
|
|
return SessionInfo(
|
|
session_id=str(data.get("id", "")),
|
|
task_id=data.get("task_id"),
|
|
task_title=task_title,
|
|
task_goal=task_title,
|
|
status=data.get("status", ""),
|
|
last_app=checkpoint.get("active_app", ""),
|
|
last_file=checkpoint.get("active_file", ""),
|
|
checkpoint_note=checkpoint_note,
|
|
started_at=data.get("started_at", ""),
|
|
ended_at=ended,
|
|
minutes_ago=minutes_ago,
|
|
)
|