284 lines
10 KiB
Python
284 lines
10 KiB
Python
|
|
"""APNs push notification service.
|
||
|
|
|
||
|
|
Uses HTTP/2 APNs provider API with .p8 auth key (token-based auth).
|
||
|
|
Falls back to logging if APNS_KEY_ID / APNS_TEAM_ID / APNS_P8_PATH are not configured.
|
||
|
|
|
||
|
|
Required .env vars:
|
||
|
|
APNS_KEY_ID — 10-char key ID from Apple Developer portal
|
||
|
|
APNS_TEAM_ID — 10-char team ID from Apple Developer portal
|
||
|
|
APNS_P8_PATH — absolute path to the AuthKey_XXXXXXXXXX.p8 file
|
||
|
|
APNS_SANDBOX — True for development/TestFlight, False (default) for production
|
||
|
|
"""
|
||
|
|
|
||
|
|
import base64
|
||
|
|
import json
|
||
|
|
import logging
|
||
|
|
import time
|
||
|
|
|
||
|
|
import httpx
|
||
|
|
from cryptography.hazmat.primitives.asymmetric.ec import ECDSA
|
||
|
|
from cryptography.hazmat.primitives import hashes
|
||
|
|
from cryptography.hazmat.primitives.serialization import load_pem_private_key
|
||
|
|
|
||
|
|
from app.config import settings
|
||
|
|
from app.services.db import get_pool
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
# Cache the provider JWT — valid for 60 min, refresh 5 min early
|
||
|
|
_apns_token: str | None = None
|
||
|
|
_apns_token_exp: float = 0.0
|
||
|
|
_private_key = None
|
||
|
|
|
||
|
|
|
||
|
|
def _b64(data: bytes) -> str:
|
||
|
|
return base64.urlsafe_b64encode(data).rstrip(b"=").decode()
|
||
|
|
|
||
|
|
|
||
|
|
def _apns_configured() -> bool:
|
||
|
|
return bool(settings.APNS_KEY_ID and settings.APNS_TEAM_ID and settings.APNS_P8_PATH)
|
||
|
|
|
||
|
|
|
||
|
|
def _make_apns_jwt() -> str:
|
||
|
|
global _apns_token, _apns_token_exp, _private_key
|
||
|
|
|
||
|
|
now = time.time()
|
||
|
|
if _apns_token and now < _apns_token_exp:
|
||
|
|
return _apns_token
|
||
|
|
|
||
|
|
if _private_key is None:
|
||
|
|
with open(settings.APNS_P8_PATH, "rb") as f:
|
||
|
|
_private_key = load_pem_private_key(f.read(), password=None)
|
||
|
|
|
||
|
|
header = _b64(json.dumps({"alg": "ES256", "kid": settings.APNS_KEY_ID}).encode())
|
||
|
|
payload = _b64(json.dumps({"iss": settings.APNS_TEAM_ID, "iat": int(now)}).encode())
|
||
|
|
msg = f"{header}.{payload}".encode()
|
||
|
|
sig = _b64(_private_key.sign(msg, ECDSA(hashes.SHA256())))
|
||
|
|
token = f"{header}.{payload}.{sig}"
|
||
|
|
|
||
|
|
_apns_token = token
|
||
|
|
_apns_token_exp = now + 3300 # 55-minute lifetime (APNs tokens last 60 min)
|
||
|
|
return token
|
||
|
|
|
||
|
|
|
||
|
|
async def _send_apns(device_token: str, aps_payload: dict, push_type: str = "alert") -> bool:
|
||
|
|
host = "api.sandbox.push.apple.com" if settings.APNS_SANDBOX else "api.push.apple.com"
|
||
|
|
url = f"https://{host}/3/device/{device_token}"
|
||
|
|
|
||
|
|
topic = settings.APPLE_BUNDLE_ID
|
||
|
|
if push_type == "liveactivity":
|
||
|
|
topic += ".push-type.liveactivity"
|
||
|
|
|
||
|
|
headers = {
|
||
|
|
"authorization": f"bearer {_make_apns_jwt()}",
|
||
|
|
"apns-topic": topic,
|
||
|
|
"apns-push-type": push_type,
|
||
|
|
"apns-priority": "10",
|
||
|
|
}
|
||
|
|
|
||
|
|
try:
|
||
|
|
async with httpx.AsyncClient(http2=True) as client:
|
||
|
|
resp = await client.post(url, json=aps_payload, headers=headers, timeout=10.0)
|
||
|
|
print(f"APNs response: {resp.status_code} http_version={resp.http_version} token=…{device_token[-8:]} body={resp.text}")
|
||
|
|
print(f"APNs request: url={url} payload={json.dumps(aps_payload)}")
|
||
|
|
if resp.status_code == 200:
|
||
|
|
return True
|
||
|
|
if resp.status_code == 410:
|
||
|
|
# Token is dead — device uninstalled or revoked push. Remove from DB.
|
||
|
|
logger.warning(f"APNs 410 Unregistered for token …{device_token[-8:]}, removing from DB")
|
||
|
|
await _remove_device_token(device_token)
|
||
|
|
return False
|
||
|
|
logger.error(f"APNs {resp.status_code} for token …{device_token[-8:]}: {resp.text}")
|
||
|
|
return False
|
||
|
|
except Exception as exc:
|
||
|
|
logger.error(f"APNs request failed: {exc}")
|
||
|
|
return False
|
||
|
|
|
||
|
|
|
||
|
|
async def _remove_device_token(device_token: str):
|
||
|
|
"""Remove a dead APNs token from all users."""
|
||
|
|
pool = await get_pool()
|
||
|
|
await pool.execute(
|
||
|
|
"""UPDATE users SET device_tokens = (
|
||
|
|
SELECT COALESCE(jsonb_agg(t), '[]'::jsonb)
|
||
|
|
FROM jsonb_array_elements(device_tokens) t
|
||
|
|
WHERE t->>'token' != $1
|
||
|
|
) WHERE device_tokens @> $2::jsonb""",
|
||
|
|
device_token,
|
||
|
|
json.dumps([{"token": device_token}]),
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
# ── Public API ────────────────────────────────────────────────────────────────
|
||
|
|
|
||
|
|
|
||
|
|
async def get_device_tokens(user_id: str, platform: str | None = None) -> list[dict]:
|
||
|
|
pool = await get_pool()
|
||
|
|
row = await pool.fetchrow(
|
||
|
|
"SELECT device_tokens FROM users WHERE id = $1::uuid", user_id
|
||
|
|
)
|
||
|
|
if not row or not row["device_tokens"]:
|
||
|
|
return []
|
||
|
|
tokens = (
|
||
|
|
json.loads(row["device_tokens"])
|
||
|
|
if isinstance(row["device_tokens"], str)
|
||
|
|
else row["device_tokens"]
|
||
|
|
)
|
||
|
|
if platform:
|
||
|
|
tokens = [t for t in tokens if t.get("platform", "").startswith(platform)]
|
||
|
|
return tokens
|
||
|
|
|
||
|
|
|
||
|
|
async def register_device_token(user_id: str, platform: str, token: str):
|
||
|
|
pool = await get_pool()
|
||
|
|
await pool.execute(
|
||
|
|
"""UPDATE users SET device_tokens = (
|
||
|
|
SELECT COALESCE(jsonb_agg(t), '[]'::jsonb)
|
||
|
|
FROM jsonb_array_elements(device_tokens) t
|
||
|
|
WHERE t->>'platform' != $2
|
||
|
|
) || $3::jsonb
|
||
|
|
WHERE id = $1::uuid""",
|
||
|
|
user_id,
|
||
|
|
platform,
|
||
|
|
json.dumps([{"platform": platform, "token": token}]),
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
async def send_push(user_id: str, platform: str, aps_payload: dict):
|
||
|
|
"""Send an APNs push to all registered tokens for a user/platform."""
|
||
|
|
tokens = await get_device_tokens(user_id, platform)
|
||
|
|
print(f"send_push → user={user_id} platform={platform} tokens={tokens} configured={_apns_configured()}")
|
||
|
|
if not tokens:
|
||
|
|
return
|
||
|
|
|
||
|
|
if not _apns_configured():
|
||
|
|
for t in tokens:
|
||
|
|
logger.info(
|
||
|
|
f"[APNs STUB] platform={t['platform']} token=…{t['token'][-8:]} payload={aps_payload}"
|
||
|
|
)
|
||
|
|
return
|
||
|
|
|
||
|
|
for t in tokens:
|
||
|
|
await _send_apns(t["token"], aps_payload)
|
||
|
|
|
||
|
|
|
||
|
|
async def send_task_added(user_id: str, task_title: str, step_count: int = 0):
|
||
|
|
"""Notify the user that a new task was added."""
|
||
|
|
subtitle = f"{step_count} subtask{'s' if step_count != 1 else ''}"
|
||
|
|
payload = {
|
||
|
|
"aps": {
|
||
|
|
"alert": {"title": task_title, "subtitle": subtitle},
|
||
|
|
"sound": "default",
|
||
|
|
}
|
||
|
|
}
|
||
|
|
for platform in ["iphone", "ipad"]:
|
||
|
|
await send_push(user_id, platform, payload)
|
||
|
|
|
||
|
|
|
||
|
|
async def send_activity_update(user_id: str, task_title: str, task_id=None, started_at: int | None = None):
|
||
|
|
"""Send ActivityKit push to update Live Activity on all devices with current step progress."""
|
||
|
|
tokens = await get_device_tokens(user_id, "liveactivity_update_")
|
||
|
|
if not tokens:
|
||
|
|
return
|
||
|
|
|
||
|
|
step_progress = await _get_step_progress(task_id)
|
||
|
|
now_ts = started_at or int(time.time())
|
||
|
|
content_state = _build_content_state(task_title, now_ts, step_progress)
|
||
|
|
|
||
|
|
if not _apns_configured():
|
||
|
|
for t in tokens:
|
||
|
|
logger.info(f"[ActivityKit STUB] token=…{t['token'][-8:]} state={content_state}")
|
||
|
|
return
|
||
|
|
|
||
|
|
payload = {"aps": {"timestamp": int(time.time()), "content-state": content_state, "event": "update"}}
|
||
|
|
for t in tokens:
|
||
|
|
await _send_apns(t["token"], payload, push_type="liveactivity")
|
||
|
|
|
||
|
|
|
||
|
|
async def send_activity_end(user_id: str, task_title: str = "Session ended", task_id=None):
|
||
|
|
"""Send ActivityKit push-to-end using per-activity update tokens."""
|
||
|
|
tokens = await get_device_tokens(user_id, "liveactivity_update_")
|
||
|
|
if not tokens:
|
||
|
|
return
|
||
|
|
|
||
|
|
now_ts = int(time.time())
|
||
|
|
step_progress = await _get_step_progress(task_id)
|
||
|
|
payload = {
|
||
|
|
"aps": {
|
||
|
|
"timestamp": now_ts,
|
||
|
|
"event": "end",
|
||
|
|
"content-state": _build_content_state(task_title, now_ts, step_progress),
|
||
|
|
"dismissal-date": now_ts,
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
if not _apns_configured():
|
||
|
|
for t in tokens:
|
||
|
|
logger.info(f"[ActivityKit END STUB] token=...{t['token'][-8:]}")
|
||
|
|
return
|
||
|
|
|
||
|
|
for t in tokens:
|
||
|
|
await _send_apns(t["token"], payload, push_type="liveactivity")
|
||
|
|
|
||
|
|
|
||
|
|
async def _get_step_progress(task_id) -> dict:
|
||
|
|
"""Fetch step progress for a task: completed count, total count, current step title."""
|
||
|
|
if not task_id:
|
||
|
|
return {"stepsCompleted": 0, "stepsTotal": 0, "currentStepTitle": None, "lastCompletedStepTitle": None}
|
||
|
|
pool = await get_pool()
|
||
|
|
rows = await pool.fetch(
|
||
|
|
"SELECT title, status FROM steps WHERE task_id = $1 ORDER BY sort_order", task_id
|
||
|
|
)
|
||
|
|
total = len(rows)
|
||
|
|
completed = sum(1 for r in rows if r["status"] == "done")
|
||
|
|
current = next((r["title"] for r in rows if r["status"] in ("in_progress", "pending")), None)
|
||
|
|
last_completed = next((r["title"] for r in reversed(rows) if r["status"] == "done"), None)
|
||
|
|
return {"stepsCompleted": completed, "stepsTotal": total, "currentStepTitle": current, "lastCompletedStepTitle": last_completed}
|
||
|
|
|
||
|
|
|
||
|
|
def _build_content_state(task_title: str, started_at: int, step_progress: dict) -> dict:
|
||
|
|
state = {
|
||
|
|
"taskTitle": task_title,
|
||
|
|
"startedAt": started_at,
|
||
|
|
"stepsCompleted": step_progress["stepsCompleted"],
|
||
|
|
"stepsTotal": step_progress["stepsTotal"],
|
||
|
|
}
|
||
|
|
if step_progress["currentStepTitle"]:
|
||
|
|
state["currentStepTitle"] = step_progress["currentStepTitle"]
|
||
|
|
if step_progress.get("lastCompletedStepTitle"):
|
||
|
|
state["lastCompletedStepTitle"] = step_progress["lastCompletedStepTitle"]
|
||
|
|
return state
|
||
|
|
|
||
|
|
|
||
|
|
async def send_activity_start(user_id: str, task_title: str, task_id=None):
|
||
|
|
"""Send ActivityKit push-to-start to all liveactivity tokens."""
|
||
|
|
tokens = await get_device_tokens(user_id, "liveactivity")
|
||
|
|
if not tokens:
|
||
|
|
return
|
||
|
|
|
||
|
|
now_ts = int(time.time())
|
||
|
|
step_progress = await _get_step_progress(task_id)
|
||
|
|
payload = {
|
||
|
|
"aps": {
|
||
|
|
"timestamp": now_ts,
|
||
|
|
"event": "start",
|
||
|
|
"content-state": _build_content_state(task_title, now_ts, step_progress),
|
||
|
|
"attributes-type": "FocusSessionAttributes",
|
||
|
|
"attributes": {
|
||
|
|
"sessionType": "Focus"
|
||
|
|
},
|
||
|
|
"alert": {
|
||
|
|
"title": "Focus Session Started",
|
||
|
|
"body": task_title
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
if not _apns_configured():
|
||
|
|
for t in tokens:
|
||
|
|
logger.info(f"[ActivityKit START STUB] token=...{t['token'][-8:]} start payload={payload}")
|
||
|
|
return
|
||
|
|
|
||
|
|
for t in tokens:
|
||
|
|
await _send_apns(t["token"], payload, push_type="liveactivity")
|