import json from datetime import datetime, timezone from uuid import UUID from fastapi import APIRouter, Depends, HTTPException from app.middleware.auth import get_current_user_id from app.models import ( OpenSessionOut, ResumeCard, SessionCheckpointRequest, SessionEndRequest, SessionJoinRequest, SessionJoinResponse, SessionOut, SessionResumeResponse, SessionStartRequest, StepOut, ) from app.services import llm, push from app.services.db import get_pool router = APIRouter(prefix="/sessions", tags=["sessions"]) SESSION_COLUMNS = "id, user_id, task_id, platform, started_at, ended_at, status, checkpoint, created_at" def _parse_session_row(row) -> SessionOut: result = dict(row) result["checkpoint"] = json.loads(result["checkpoint"]) if isinstance(result["checkpoint"], str) else result["checkpoint"] return SessionOut(**result) @router.post("/start", response_model=SessionOut, status_code=201) async def start_session(req: SessionStartRequest, user_id: str = Depends(get_current_user_id)): pool = await get_pool() # Check if an active session already exists for this account active = await pool.fetchrow( f"SELECT {SESSION_COLUMNS} FROM sessions WHERE user_id = $1::uuid AND status = 'active'", user_id, ) if active: # Idempotently return the existing active session and don't create a new one return _parse_session_row(active) checkpoint = {} if req.task_id: task = await pool.fetchrow( "SELECT id, title, description FROM tasks WHERE id = $1 AND user_id = $2::uuid", req.task_id, user_id, ) if not task: raise HTTPException(status_code=404, detail="Task not found") await pool.execute( "UPDATE tasks SET status = 'in_progress', updated_at = now() WHERE id = $1", req.task_id, ) checkpoint["goal"] = task["title"] if req.work_app_bundle_ids: checkpoint["work_app_bundle_ids"] = req.work_app_bundle_ids checkpoint["devices"] = [req.platform] row = await pool.fetchrow( f"""INSERT INTO sessions (user_id, task_id, platform, checkpoint) VALUES ($1::uuid, $2, $3, $4) RETURNING {SESSION_COLUMNS}""", user_id, req.task_id, req.platform, json.dumps(checkpoint), ) # Notify other devices about new session if req.task_id: task_row = await pool.fetchrow("SELECT title FROM tasks WHERE id = $1", req.task_id) task_title = task_row["title"] if task_row else "Focus Session" await push.send_push(user_id, "ipad" if req.platform == "mac" else "mac", { "type": "session_started", "session_id": str(row["id"]), "task_title": task_title, "platform": req.platform, }) # Start Live Activity on all registered devices await push.send_activity_start(user_id, task_title, task_id=req.task_id) return _parse_session_row(row) @router.get("/active", response_model=SessionOut) async def get_active_session(user_id: str = Depends(get_current_user_id)): pool = await get_pool() row = await pool.fetchrow( f"SELECT {SESSION_COLUMNS} FROM sessions WHERE user_id = $1::uuid AND status = 'active'", user_id, ) if not row: raise HTTPException(status_code=404, detail="No active session") return _parse_session_row(row) @router.get("/open", response_model=list[OpenSessionOut]) async def get_open_sessions(user_id: str = Depends(get_current_user_id)): """All active + interrupted sessions. Used by VLM on startup for session-aware analysis.""" pool = await get_pool() rows = await pool.fetch( f"SELECT {SESSION_COLUMNS} FROM sessions WHERE user_id = $1::uuid AND status IN ('active', 'interrupted') ORDER BY started_at DESC", user_id, ) results = [] for row in rows: checkpoint = json.loads(row["checkpoint"]) if isinstance(row["checkpoint"], str) else (row["checkpoint"] or {}) task_info = None if row["task_id"]: task_row = await pool.fetchrow( "SELECT title, description FROM tasks WHERE id = $1", row["task_id"] ) if task_row: task_info = {"title": task_row["title"], "goal": task_row["description"]} results.append(OpenSessionOut( id=row["id"], task_id=row["task_id"], task=task_info, status=row["status"], platform=row["platform"], started_at=row["started_at"], ended_at=row["ended_at"], checkpoint=checkpoint, )) return results @router.post("/{session_id}/join", response_model=SessionJoinResponse) async def join_session( session_id: UUID, req: SessionJoinRequest, user_id: str = Depends(get_current_user_id), ): pool = await get_pool() session = await pool.fetchrow( f"SELECT {SESSION_COLUMNS} FROM sessions WHERE id = $1 AND user_id = $2::uuid", session_id, user_id, ) if not session: raise HTTPException(status_code=404, detail="Session not found") if session["status"] != "active": raise HTTPException(status_code=400, detail="Session is not active") # Update checkpoint with joining device checkpoint = json.loads(session["checkpoint"]) if isinstance(session["checkpoint"], str) else (session["checkpoint"] or {}) devices = checkpoint.get("devices", [session["platform"]]) if req.platform not in devices: devices.append(req.platform) checkpoint["devices"] = devices if req.work_app_bundle_ids: checkpoint["work_app_bundle_ids"] = req.work_app_bundle_ids await pool.execute( "UPDATE sessions SET checkpoint = $1 WHERE id = $2", json.dumps(checkpoint), session_id, ) # Build response with full task + step context task_info = None current_step = None all_steps = [] suggested_app_scheme = None suggested_app_name = None if session["task_id"]: task_row = await pool.fetchrow( "SELECT id, title, description FROM tasks WHERE id = $1", session["task_id"], ) if task_row: task_info = { "id": str(task_row["id"]), "title": task_row["title"], "goal": task_row["description"], } # Suggest a work app based on task try: suggestion = await llm.suggest_work_apps(task_row["title"], task_row["description"]) suggested_app_scheme = suggestion.get("suggested_app_scheme") suggested_app_name = suggestion.get("suggested_app_name") except Exception: pass step_rows = await pool.fetch( """SELECT id, task_id, sort_order, title, description, estimated_minutes, status, checkpoint_note, last_checked_at, completed_at, created_at FROM steps WHERE task_id = $1 ORDER BY sort_order""", session["task_id"], ) all_steps = [StepOut(**dict(r)) for r in step_rows] # Find current in-progress step for s in step_rows: if s["status"] == "in_progress": current_step = { "id": str(s["id"]), "title": s["title"], "status": s["status"], "checkpoint_note": s["checkpoint_note"], } break return SessionJoinResponse( session_id=session["id"], joined=True, task=task_info, current_step=current_step, all_steps=all_steps, suggested_app_scheme=suggested_app_scheme, suggested_app_name=suggested_app_name, ) @router.post("/{session_id}/checkpoint", response_model=SessionOut) async def save_checkpoint( session_id: UUID, req: SessionCheckpointRequest, user_id: str = Depends(get_current_user_id), ): pool = await get_pool() session = await pool.fetchrow( "SELECT id, status FROM sessions WHERE id = $1 AND user_id = $2::uuid", session_id, user_id, ) if not session: raise HTTPException(status_code=404, detail="Session not found") if session["status"] != "active": raise HTTPException(status_code=400, detail="Session is not active") checkpoint = req.model_dump(exclude_unset=True) if "current_step_id" in checkpoint and checkpoint["current_step_id"]: checkpoint["current_step_id"] = str(checkpoint["current_step_id"]) row = await pool.fetchrow( f"""UPDATE sessions SET checkpoint = checkpoint || $1::jsonb WHERE id = $2 RETURNING {SESSION_COLUMNS}""", json.dumps(checkpoint), session_id, ) return _parse_session_row(row) @router.post("/{session_id}/end", response_model=SessionOut) async def end_session( session_id: UUID, req: SessionEndRequest, user_id: str = Depends(get_current_user_id), ): pool = await get_pool() row = await pool.fetchrow( f"""UPDATE sessions SET status = $1, ended_at = now() WHERE id = $2 AND user_id = $3::uuid AND status = 'active' RETURNING {SESSION_COLUMNS}""", req.status, session_id, user_id, ) if not row: raise HTTPException(status_code=404, detail="Active session not found") # Notify other joined devices that session ended checkpoint = json.loads(row["checkpoint"]) if isinstance(row["checkpoint"], str) else (row["checkpoint"] or {}) devices = checkpoint.get("devices", []) for device in devices: if device != row["platform"]: await push.send_push(user_id, device, { "type": "session_ended", "session_id": str(row["id"]), "ended_by": row["platform"], }) # End Live Activity on all devices task_title = checkpoint.get("goal", "Session ended") await push.send_activity_end(user_id, task_title=task_title, task_id=row["task_id"]) return _parse_session_row(row) @router.get("/{session_id}/resume", response_model=SessionResumeResponse) async def resume_session(session_id: UUID, user_id: str = Depends(get_current_user_id)): pool = await get_pool() session = await pool.fetchrow( f"SELECT {SESSION_COLUMNS} FROM sessions WHERE id = $1 AND user_id = $2::uuid", session_id, user_id, ) if not session: raise HTTPException(status_code=404, detail="Session not found") checkpoint = json.loads(session["checkpoint"]) if isinstance(session["checkpoint"], str) else (session["checkpoint"] or {}) task_info = None current_step = None completed_count = 0 total_count = 0 next_step_title = None if session["task_id"]: task_row = await pool.fetchrow( "SELECT id, title, description FROM tasks WHERE id = $1", session["task_id"], ) if task_row: task_info = {"title": task_row["title"], "overall_goal": task_row["description"]} step_rows = await pool.fetch( "SELECT id, sort_order, title, status, checkpoint_note, last_checked_at FROM steps WHERE task_id = $1 ORDER BY sort_order", session["task_id"], ) total_count = len(step_rows) found_current = False for s in step_rows: if s["status"] == "done": completed_count += 1 elif s["status"] == "in_progress" and not found_current: current_step = { "id": str(s["id"]), "title": s["title"], "checkpoint_note": s["checkpoint_note"], "last_checked_at": s["last_checked_at"].isoformat() if s["last_checked_at"] else None, } found_current = True elif found_current and next_step_title is None: next_step_title = s["title"] now = datetime.now(timezone.utc) last_activity = session["ended_at"] or session["started_at"] minutes_away = int((now - last_activity).total_seconds() / 60) resume_card_data = await llm.generate_resume_card( task_title=task_info["title"] if task_info else "Unknown task", goal=task_info.get("overall_goal") if task_info else None, current_step_title=current_step["title"] if current_step else None, checkpoint_note=current_step["checkpoint_note"] if current_step else None, completed_count=completed_count, total_count=total_count, next_step_title=next_step_title, minutes_away=minutes_away, attention_score=checkpoint.get("attention_score"), ) return SessionResumeResponse( session_id=session["id"], task=task_info, current_step=current_step, progress={ "completed": completed_count, "total": total_count, "attention_score": checkpoint.get("attention_score"), "distraction_count": checkpoint.get("distraction_count", 0), }, resume_card=ResumeCard(**resume_card_data), )