372 lines
13 KiB
Python
372 lines
13 KiB
Python
|
|
import json
|
||
|
|
from datetime import datetime, timezone
|
||
|
|
from uuid import UUID
|
||
|
|
|
||
|
|
from fastapi import APIRouter, Depends, HTTPException
|
||
|
|
|
||
|
|
from app.middleware.auth import get_current_user_id
|
||
|
|
from app.models import (
|
||
|
|
OpenSessionOut,
|
||
|
|
ResumeCard,
|
||
|
|
SessionCheckpointRequest,
|
||
|
|
SessionEndRequest,
|
||
|
|
SessionJoinRequest,
|
||
|
|
SessionJoinResponse,
|
||
|
|
SessionOut,
|
||
|
|
SessionResumeResponse,
|
||
|
|
SessionStartRequest,
|
||
|
|
StepOut,
|
||
|
|
)
|
||
|
|
from app.services import llm, push
|
||
|
|
from app.services.db import get_pool
|
||
|
|
|
||
|
|
router = APIRouter(prefix="/sessions", tags=["sessions"])
|
||
|
|
|
||
|
|
SESSION_COLUMNS = "id, user_id, task_id, platform, started_at, ended_at, status, checkpoint, created_at"
|
||
|
|
|
||
|
|
|
||
|
|
def _parse_session_row(row) -> SessionOut:
|
||
|
|
result = dict(row)
|
||
|
|
result["checkpoint"] = json.loads(result["checkpoint"]) if isinstance(result["checkpoint"], str) else result["checkpoint"]
|
||
|
|
return SessionOut(**result)
|
||
|
|
|
||
|
|
|
||
|
|
@router.post("/start", response_model=SessionOut, status_code=201)
|
||
|
|
async def start_session(req: SessionStartRequest, user_id: str = Depends(get_current_user_id)):
|
||
|
|
pool = await get_pool()
|
||
|
|
|
||
|
|
# Check if an active session already exists for this account
|
||
|
|
active = await pool.fetchrow(
|
||
|
|
f"SELECT {SESSION_COLUMNS} FROM sessions WHERE user_id = $1::uuid AND status = 'active'",
|
||
|
|
user_id,
|
||
|
|
)
|
||
|
|
if active:
|
||
|
|
# Idempotently return the existing active session and don't create a new one
|
||
|
|
return _parse_session_row(active)
|
||
|
|
|
||
|
|
checkpoint = {}
|
||
|
|
if req.task_id:
|
||
|
|
task = await pool.fetchrow(
|
||
|
|
"SELECT id, title, description FROM tasks WHERE id = $1 AND user_id = $2::uuid",
|
||
|
|
req.task_id,
|
||
|
|
user_id,
|
||
|
|
)
|
||
|
|
if not task:
|
||
|
|
raise HTTPException(status_code=404, detail="Task not found")
|
||
|
|
|
||
|
|
await pool.execute(
|
||
|
|
"UPDATE tasks SET status = 'in_progress', updated_at = now() WHERE id = $1",
|
||
|
|
req.task_id,
|
||
|
|
)
|
||
|
|
checkpoint["goal"] = task["title"]
|
||
|
|
|
||
|
|
if req.work_app_bundle_ids:
|
||
|
|
checkpoint["work_app_bundle_ids"] = req.work_app_bundle_ids
|
||
|
|
|
||
|
|
checkpoint["devices"] = [req.platform]
|
||
|
|
|
||
|
|
row = await pool.fetchrow(
|
||
|
|
f"""INSERT INTO sessions (user_id, task_id, platform, checkpoint)
|
||
|
|
VALUES ($1::uuid, $2, $3, $4)
|
||
|
|
RETURNING {SESSION_COLUMNS}""",
|
||
|
|
user_id,
|
||
|
|
req.task_id,
|
||
|
|
req.platform,
|
||
|
|
json.dumps(checkpoint),
|
||
|
|
)
|
||
|
|
|
||
|
|
# Notify other devices about new session
|
||
|
|
if req.task_id:
|
||
|
|
task_row = await pool.fetchrow("SELECT title FROM tasks WHERE id = $1", req.task_id)
|
||
|
|
task_title = task_row["title"] if task_row else "Focus Session"
|
||
|
|
await push.send_push(user_id, "ipad" if req.platform == "mac" else "mac", {
|
||
|
|
"type": "session_started",
|
||
|
|
"session_id": str(row["id"]),
|
||
|
|
"task_title": task_title,
|
||
|
|
"platform": req.platform,
|
||
|
|
})
|
||
|
|
# Start Live Activity on all registered devices
|
||
|
|
await push.send_activity_start(user_id, task_title, task_id=req.task_id)
|
||
|
|
|
||
|
|
return _parse_session_row(row)
|
||
|
|
|
||
|
|
|
||
|
|
@router.get("/active", response_model=SessionOut)
|
||
|
|
async def get_active_session(user_id: str = Depends(get_current_user_id)):
|
||
|
|
pool = await get_pool()
|
||
|
|
row = await pool.fetchrow(
|
||
|
|
f"SELECT {SESSION_COLUMNS} FROM sessions WHERE user_id = $1::uuid AND status = 'active'",
|
||
|
|
user_id,
|
||
|
|
)
|
||
|
|
if not row:
|
||
|
|
raise HTTPException(status_code=404, detail="No active session")
|
||
|
|
return _parse_session_row(row)
|
||
|
|
|
||
|
|
|
||
|
|
@router.get("/open", response_model=list[OpenSessionOut])
|
||
|
|
async def get_open_sessions(user_id: str = Depends(get_current_user_id)):
|
||
|
|
"""All active + interrupted sessions. Used by VLM on startup for session-aware analysis."""
|
||
|
|
pool = await get_pool()
|
||
|
|
rows = await pool.fetch(
|
||
|
|
f"SELECT {SESSION_COLUMNS} FROM sessions WHERE user_id = $1::uuid AND status IN ('active', 'interrupted') ORDER BY started_at DESC",
|
||
|
|
user_id,
|
||
|
|
)
|
||
|
|
results = []
|
||
|
|
for row in rows:
|
||
|
|
checkpoint = json.loads(row["checkpoint"]) if isinstance(row["checkpoint"], str) else (row["checkpoint"] or {})
|
||
|
|
task_info = None
|
||
|
|
if row["task_id"]:
|
||
|
|
task_row = await pool.fetchrow(
|
||
|
|
"SELECT title, description FROM tasks WHERE id = $1", row["task_id"]
|
||
|
|
)
|
||
|
|
if task_row:
|
||
|
|
task_info = {"title": task_row["title"], "goal": task_row["description"]}
|
||
|
|
results.append(OpenSessionOut(
|
||
|
|
id=row["id"],
|
||
|
|
task_id=row["task_id"],
|
||
|
|
task=task_info,
|
||
|
|
status=row["status"],
|
||
|
|
platform=row["platform"],
|
||
|
|
started_at=row["started_at"],
|
||
|
|
ended_at=row["ended_at"],
|
||
|
|
checkpoint=checkpoint,
|
||
|
|
))
|
||
|
|
return results
|
||
|
|
|
||
|
|
|
||
|
|
@router.post("/{session_id}/join", response_model=SessionJoinResponse)
|
||
|
|
async def join_session(
|
||
|
|
session_id: UUID,
|
||
|
|
req: SessionJoinRequest,
|
||
|
|
user_id: str = Depends(get_current_user_id),
|
||
|
|
):
|
||
|
|
pool = await get_pool()
|
||
|
|
|
||
|
|
session = await pool.fetchrow(
|
||
|
|
f"SELECT {SESSION_COLUMNS} FROM sessions WHERE id = $1 AND user_id = $2::uuid",
|
||
|
|
session_id,
|
||
|
|
user_id,
|
||
|
|
)
|
||
|
|
if not session:
|
||
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
||
|
|
if session["status"] != "active":
|
||
|
|
raise HTTPException(status_code=400, detail="Session is not active")
|
||
|
|
|
||
|
|
# Update checkpoint with joining device
|
||
|
|
checkpoint = json.loads(session["checkpoint"]) if isinstance(session["checkpoint"], str) else (session["checkpoint"] or {})
|
||
|
|
devices = checkpoint.get("devices", [session["platform"]])
|
||
|
|
if req.platform not in devices:
|
||
|
|
devices.append(req.platform)
|
||
|
|
checkpoint["devices"] = devices
|
||
|
|
if req.work_app_bundle_ids:
|
||
|
|
checkpoint["work_app_bundle_ids"] = req.work_app_bundle_ids
|
||
|
|
|
||
|
|
await pool.execute(
|
||
|
|
"UPDATE sessions SET checkpoint = $1 WHERE id = $2",
|
||
|
|
json.dumps(checkpoint),
|
||
|
|
session_id,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Build response with full task + step context
|
||
|
|
task_info = None
|
||
|
|
current_step = None
|
||
|
|
all_steps = []
|
||
|
|
suggested_app_scheme = None
|
||
|
|
suggested_app_name = None
|
||
|
|
|
||
|
|
if session["task_id"]:
|
||
|
|
task_row = await pool.fetchrow(
|
||
|
|
"SELECT id, title, description FROM tasks WHERE id = $1",
|
||
|
|
session["task_id"],
|
||
|
|
)
|
||
|
|
if task_row:
|
||
|
|
task_info = {
|
||
|
|
"id": str(task_row["id"]),
|
||
|
|
"title": task_row["title"],
|
||
|
|
"goal": task_row["description"],
|
||
|
|
}
|
||
|
|
# Suggest a work app based on task
|
||
|
|
try:
|
||
|
|
suggestion = await llm.suggest_work_apps(task_row["title"], task_row["description"])
|
||
|
|
suggested_app_scheme = suggestion.get("suggested_app_scheme")
|
||
|
|
suggested_app_name = suggestion.get("suggested_app_name")
|
||
|
|
except Exception:
|
||
|
|
pass
|
||
|
|
|
||
|
|
step_rows = await pool.fetch(
|
||
|
|
"""SELECT id, task_id, sort_order, title, description, estimated_minutes,
|
||
|
|
status, checkpoint_note, last_checked_at, completed_at, created_at
|
||
|
|
FROM steps WHERE task_id = $1 ORDER BY sort_order""",
|
||
|
|
session["task_id"],
|
||
|
|
)
|
||
|
|
all_steps = [StepOut(**dict(r)) for r in step_rows]
|
||
|
|
|
||
|
|
# Find current in-progress step
|
||
|
|
for s in step_rows:
|
||
|
|
if s["status"] == "in_progress":
|
||
|
|
current_step = {
|
||
|
|
"id": str(s["id"]),
|
||
|
|
"title": s["title"],
|
||
|
|
"status": s["status"],
|
||
|
|
"checkpoint_note": s["checkpoint_note"],
|
||
|
|
}
|
||
|
|
break
|
||
|
|
|
||
|
|
return SessionJoinResponse(
|
||
|
|
session_id=session["id"],
|
||
|
|
joined=True,
|
||
|
|
task=task_info,
|
||
|
|
current_step=current_step,
|
||
|
|
all_steps=all_steps,
|
||
|
|
suggested_app_scheme=suggested_app_scheme,
|
||
|
|
suggested_app_name=suggested_app_name,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
@router.post("/{session_id}/checkpoint", response_model=SessionOut)
|
||
|
|
async def save_checkpoint(
|
||
|
|
session_id: UUID,
|
||
|
|
req: SessionCheckpointRequest,
|
||
|
|
user_id: str = Depends(get_current_user_id),
|
||
|
|
):
|
||
|
|
pool = await get_pool()
|
||
|
|
|
||
|
|
session = await pool.fetchrow(
|
||
|
|
"SELECT id, status FROM sessions WHERE id = $1 AND user_id = $2::uuid",
|
||
|
|
session_id,
|
||
|
|
user_id,
|
||
|
|
)
|
||
|
|
if not session:
|
||
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
||
|
|
if session["status"] != "active":
|
||
|
|
raise HTTPException(status_code=400, detail="Session is not active")
|
||
|
|
|
||
|
|
checkpoint = req.model_dump(exclude_unset=True)
|
||
|
|
if "current_step_id" in checkpoint and checkpoint["current_step_id"]:
|
||
|
|
checkpoint["current_step_id"] = str(checkpoint["current_step_id"])
|
||
|
|
|
||
|
|
row = await pool.fetchrow(
|
||
|
|
f"""UPDATE sessions SET checkpoint = checkpoint || $1::jsonb
|
||
|
|
WHERE id = $2
|
||
|
|
RETURNING {SESSION_COLUMNS}""",
|
||
|
|
json.dumps(checkpoint),
|
||
|
|
session_id,
|
||
|
|
)
|
||
|
|
return _parse_session_row(row)
|
||
|
|
|
||
|
|
|
||
|
|
@router.post("/{session_id}/end", response_model=SessionOut)
|
||
|
|
async def end_session(
|
||
|
|
session_id: UUID,
|
||
|
|
req: SessionEndRequest,
|
||
|
|
user_id: str = Depends(get_current_user_id),
|
||
|
|
):
|
||
|
|
pool = await get_pool()
|
||
|
|
|
||
|
|
row = await pool.fetchrow(
|
||
|
|
f"""UPDATE sessions SET status = $1, ended_at = now()
|
||
|
|
WHERE id = $2 AND user_id = $3::uuid AND status = 'active'
|
||
|
|
RETURNING {SESSION_COLUMNS}""",
|
||
|
|
req.status,
|
||
|
|
session_id,
|
||
|
|
user_id,
|
||
|
|
)
|
||
|
|
if not row:
|
||
|
|
raise HTTPException(status_code=404, detail="Active session not found")
|
||
|
|
|
||
|
|
# Notify other joined devices that session ended
|
||
|
|
checkpoint = json.loads(row["checkpoint"]) if isinstance(row["checkpoint"], str) else (row["checkpoint"] or {})
|
||
|
|
devices = checkpoint.get("devices", [])
|
||
|
|
for device in devices:
|
||
|
|
if device != row["platform"]:
|
||
|
|
await push.send_push(user_id, device, {
|
||
|
|
"type": "session_ended",
|
||
|
|
"session_id": str(row["id"]),
|
||
|
|
"ended_by": row["platform"],
|
||
|
|
})
|
||
|
|
|
||
|
|
# End Live Activity on all devices
|
||
|
|
task_title = checkpoint.get("goal", "Session ended")
|
||
|
|
await push.send_activity_end(user_id, task_title=task_title, task_id=row["task_id"])
|
||
|
|
|
||
|
|
return _parse_session_row(row)
|
||
|
|
|
||
|
|
|
||
|
|
@router.get("/{session_id}/resume", response_model=SessionResumeResponse)
|
||
|
|
async def resume_session(session_id: UUID, user_id: str = Depends(get_current_user_id)):
|
||
|
|
pool = await get_pool()
|
||
|
|
|
||
|
|
session = await pool.fetchrow(
|
||
|
|
f"SELECT {SESSION_COLUMNS} FROM sessions WHERE id = $1 AND user_id = $2::uuid",
|
||
|
|
session_id,
|
||
|
|
user_id,
|
||
|
|
)
|
||
|
|
if not session:
|
||
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
||
|
|
|
||
|
|
checkpoint = json.loads(session["checkpoint"]) if isinstance(session["checkpoint"], str) else (session["checkpoint"] or {})
|
||
|
|
|
||
|
|
task_info = None
|
||
|
|
current_step = None
|
||
|
|
completed_count = 0
|
||
|
|
total_count = 0
|
||
|
|
next_step_title = None
|
||
|
|
|
||
|
|
if session["task_id"]:
|
||
|
|
task_row = await pool.fetchrow(
|
||
|
|
"SELECT id, title, description FROM tasks WHERE id = $1",
|
||
|
|
session["task_id"],
|
||
|
|
)
|
||
|
|
if task_row:
|
||
|
|
task_info = {"title": task_row["title"], "overall_goal": task_row["description"]}
|
||
|
|
|
||
|
|
step_rows = await pool.fetch(
|
||
|
|
"SELECT id, sort_order, title, status, checkpoint_note, last_checked_at FROM steps WHERE task_id = $1 ORDER BY sort_order",
|
||
|
|
session["task_id"],
|
||
|
|
)
|
||
|
|
total_count = len(step_rows)
|
||
|
|
|
||
|
|
found_current = False
|
||
|
|
for s in step_rows:
|
||
|
|
if s["status"] == "done":
|
||
|
|
completed_count += 1
|
||
|
|
elif s["status"] == "in_progress" and not found_current:
|
||
|
|
current_step = {
|
||
|
|
"id": str(s["id"]),
|
||
|
|
"title": s["title"],
|
||
|
|
"checkpoint_note": s["checkpoint_note"],
|
||
|
|
"last_checked_at": s["last_checked_at"].isoformat() if s["last_checked_at"] else None,
|
||
|
|
}
|
||
|
|
found_current = True
|
||
|
|
elif found_current and next_step_title is None:
|
||
|
|
next_step_title = s["title"]
|
||
|
|
|
||
|
|
now = datetime.now(timezone.utc)
|
||
|
|
last_activity = session["ended_at"] or session["started_at"]
|
||
|
|
minutes_away = int((now - last_activity).total_seconds() / 60)
|
||
|
|
|
||
|
|
resume_card_data = await llm.generate_resume_card(
|
||
|
|
task_title=task_info["title"] if task_info else "Unknown task",
|
||
|
|
goal=task_info.get("overall_goal") if task_info else None,
|
||
|
|
current_step_title=current_step["title"] if current_step else None,
|
||
|
|
checkpoint_note=current_step["checkpoint_note"] if current_step else None,
|
||
|
|
completed_count=completed_count,
|
||
|
|
total_count=total_count,
|
||
|
|
next_step_title=next_step_title,
|
||
|
|
minutes_away=minutes_away,
|
||
|
|
attention_score=checkpoint.get("attention_score"),
|
||
|
|
)
|
||
|
|
|
||
|
|
return SessionResumeResponse(
|
||
|
|
session_id=session["id"],
|
||
|
|
task=task_info,
|
||
|
|
current_step=current_step,
|
||
|
|
progress={
|
||
|
|
"completed": completed_count,
|
||
|
|
"total": total_count,
|
||
|
|
"attention_score": checkpoint.get("attention_score"),
|
||
|
|
"distraction_count": checkpoint.get("distraction_count", 0),
|
||
|
|
},
|
||
|
|
resume_card=ResumeCard(**resume_card_data),
|
||
|
|
)
|