"""APNs push notification service. Uses HTTP/2 APNs provider API with .p8 auth key (token-based auth). Falls back to logging if APNS_KEY_ID / APNS_TEAM_ID / APNS_P8_PATH are not configured. Required .env vars: APNS_KEY_ID — 10-char key ID from Apple Developer portal APNS_TEAM_ID — 10-char team ID from Apple Developer portal APNS_P8_PATH — absolute path to the AuthKey_XXXXXXXXXX.p8 file APNS_SANDBOX — True for development/TestFlight, False (default) for production """ import base64 import json import logging import time import httpx from cryptography.hazmat.primitives.asymmetric.ec import ECDSA from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.serialization import load_pem_private_key from app.config import settings from app.services.db import get_pool logger = logging.getLogger(__name__) # Cache the provider JWT — valid for 60 min, refresh 5 min early _apns_token: str | None = None _apns_token_exp: float = 0.0 _private_key = None def _b64(data: bytes) -> str: return base64.urlsafe_b64encode(data).rstrip(b"=").decode() def _apns_configured() -> bool: return bool(settings.APNS_KEY_ID and settings.APNS_TEAM_ID and settings.APNS_P8_PATH) def _make_apns_jwt() -> str: global _apns_token, _apns_token_exp, _private_key now = time.time() if _apns_token and now < _apns_token_exp: return _apns_token if _private_key is None: with open(settings.APNS_P8_PATH, "rb") as f: _private_key = load_pem_private_key(f.read(), password=None) header = _b64(json.dumps({"alg": "ES256", "kid": settings.APNS_KEY_ID}).encode()) payload = _b64(json.dumps({"iss": settings.APNS_TEAM_ID, "iat": int(now)}).encode()) msg = f"{header}.{payload}".encode() sig = _b64(_private_key.sign(msg, ECDSA(hashes.SHA256()))) token = f"{header}.{payload}.{sig}" _apns_token = token _apns_token_exp = now + 3300 # 55-minute lifetime (APNs tokens last 60 min) return token async def _send_apns(device_token: str, aps_payload: dict, push_type: str = "alert") -> bool: host = "api.sandbox.push.apple.com" if settings.APNS_SANDBOX else "api.push.apple.com" url = f"https://{host}/3/device/{device_token}" topic = settings.APPLE_BUNDLE_ID if push_type == "liveactivity": topic += ".push-type.liveactivity" headers = { "authorization": f"bearer {_make_apns_jwt()}", "apns-topic": topic, "apns-push-type": push_type, "apns-priority": "10", } try: async with httpx.AsyncClient(http2=True) as client: resp = await client.post(url, json=aps_payload, headers=headers, timeout=10.0) print(f"APNs response: {resp.status_code} http_version={resp.http_version} token=…{device_token[-8:]} body={resp.text}") print(f"APNs request: url={url} payload={json.dumps(aps_payload)}") if resp.status_code == 200: return True if resp.status_code == 410: # Token is dead — device uninstalled or revoked push. Remove from DB. logger.warning(f"APNs 410 Unregistered for token …{device_token[-8:]}, removing from DB") await _remove_device_token(device_token) return False logger.error(f"APNs {resp.status_code} for token …{device_token[-8:]}: {resp.text}") return False except Exception as exc: logger.error(f"APNs request failed: {exc}") return False async def _remove_device_token(device_token: str): """Remove a dead APNs token from all users.""" pool = await get_pool() await pool.execute( """UPDATE users SET device_tokens = ( SELECT COALESCE(jsonb_agg(t), '[]'::jsonb) FROM jsonb_array_elements(device_tokens) t WHERE t->>'token' != $1 ) WHERE device_tokens @> $2::jsonb""", device_token, json.dumps([{"token": device_token}]), ) # ── Public API ──────────────────────────────────────────────────────────────── async def get_device_tokens(user_id: str, platform: str | None = None) -> list[dict]: pool = await get_pool() row = await pool.fetchrow( "SELECT device_tokens FROM users WHERE id = $1::uuid", user_id ) if not row or not row["device_tokens"]: return [] tokens = ( json.loads(row["device_tokens"]) if isinstance(row["device_tokens"], str) else row["device_tokens"] ) if platform: tokens = [t for t in tokens if t.get("platform", "").startswith(platform)] return tokens async def register_device_token(user_id: str, platform: str, token: str): pool = await get_pool() await pool.execute( """UPDATE users SET device_tokens = ( SELECT COALESCE(jsonb_agg(t), '[]'::jsonb) FROM jsonb_array_elements(device_tokens) t WHERE t->>'platform' != $2 ) || $3::jsonb WHERE id = $1::uuid""", user_id, platform, json.dumps([{"platform": platform, "token": token}]), ) async def send_push(user_id: str, platform: str, aps_payload: dict): """Send an APNs push to all registered tokens for a user/platform.""" tokens = await get_device_tokens(user_id, platform) print(f"send_push → user={user_id} platform={platform} tokens={tokens} configured={_apns_configured()}") if not tokens: return if not _apns_configured(): for t in tokens: logger.info( f"[APNs STUB] platform={t['platform']} token=…{t['token'][-8:]} payload={aps_payload}" ) return for t in tokens: await _send_apns(t["token"], aps_payload) async def send_task_added(user_id: str, task_title: str, step_count: int = 0): """Notify the user that a new task was added.""" subtitle = f"{step_count} subtask{'s' if step_count != 1 else ''}" payload = { "aps": { "alert": {"title": task_title, "subtitle": subtitle}, "sound": "default", } } for platform in ["iphone", "ipad"]: await send_push(user_id, platform, payload) async def send_activity_update(user_id: str, task_title: str, task_id=None, started_at: int | None = None): """Send ActivityKit push to update Live Activity on all devices with current step progress.""" tokens = await get_device_tokens(user_id, "liveactivity_update_") if not tokens: return step_progress = await _get_step_progress(task_id) now_ts = started_at or int(time.time()) content_state = _build_content_state(task_title, now_ts, step_progress) if not _apns_configured(): for t in tokens: logger.info(f"[ActivityKit STUB] token=…{t['token'][-8:]} state={content_state}") return payload = {"aps": {"timestamp": int(time.time()), "content-state": content_state, "event": "update"}} for t in tokens: await _send_apns(t["token"], payload, push_type="liveactivity") async def send_activity_end(user_id: str, task_title: str = "Session ended", task_id=None): """Send ActivityKit push-to-end using per-activity update tokens.""" tokens = await get_device_tokens(user_id, "liveactivity_update_") if not tokens: return now_ts = int(time.time()) step_progress = await _get_step_progress(task_id) payload = { "aps": { "timestamp": now_ts, "event": "end", "content-state": _build_content_state(task_title, now_ts, step_progress), "dismissal-date": now_ts, } } if not _apns_configured(): for t in tokens: logger.info(f"[ActivityKit END STUB] token=...{t['token'][-8:]}") return for t in tokens: await _send_apns(t["token"], payload, push_type="liveactivity") async def _get_step_progress(task_id) -> dict: """Fetch step progress for a task: completed count, total count, current step title.""" if not task_id: return {"stepsCompleted": 0, "stepsTotal": 0, "currentStepTitle": None, "lastCompletedStepTitle": None} pool = await get_pool() rows = await pool.fetch( "SELECT title, status FROM steps WHERE task_id = $1 ORDER BY sort_order", task_id ) total = len(rows) completed = sum(1 for r in rows if r["status"] == "done") current = next((r["title"] for r in rows if r["status"] in ("in_progress", "pending")), None) last_completed = next((r["title"] for r in reversed(rows) if r["status"] == "done"), None) return {"stepsCompleted": completed, "stepsTotal": total, "currentStepTitle": current, "lastCompletedStepTitle": last_completed} def _build_content_state(task_title: str, started_at: int, step_progress: dict) -> dict: state = { "taskTitle": task_title, "startedAt": started_at, "stepsCompleted": step_progress["stepsCompleted"], "stepsTotal": step_progress["stepsTotal"], } if step_progress["currentStepTitle"]: state["currentStepTitle"] = step_progress["currentStepTitle"] if step_progress.get("lastCompletedStepTitle"): state["lastCompletedStepTitle"] = step_progress["lastCompletedStepTitle"] return state async def send_activity_start(user_id: str, task_title: str, task_id=None): """Send ActivityKit push-to-start to all liveactivity tokens.""" tokens = await get_device_tokens(user_id, "liveactivity") if not tokens: return now_ts = int(time.time()) step_progress = await _get_step_progress(task_id) payload = { "aps": { "timestamp": now_ts, "event": "start", "content-state": _build_content_state(task_title, now_ts, step_progress), "attributes-type": "FocusSessionAttributes", "attributes": { "sessionType": "Focus" }, "alert": { "title": "Focus Session Started", "body": task_title } } } if not _apns_configured(): for t in tokens: logger.info(f"[ActivityKit START STUB] token=...{t['token'][-8:]} start payload={payload}") return for t in tokens: await _send_apns(t["token"], payload, push_type="liveactivity")