146 lines
4.6 KiB
Python
146 lines
4.6 KiB
Python
from fastapi import APIRouter, Depends, HTTPException, status
|
|
|
|
from app.middleware.auth import (
|
|
create_access_token,
|
|
create_refresh_token,
|
|
decode_token,
|
|
get_current_user_id,
|
|
hash_password,
|
|
verify_password,
|
|
)
|
|
from app.models import (
|
|
AppleAuthRequest,
|
|
AuthResponse,
|
|
DeviceTokenRequest,
|
|
LoginRequest,
|
|
RefreshRequest,
|
|
RegisterRequest,
|
|
UserOut,
|
|
)
|
|
from app.services import push
|
|
from app.services.db import get_pool
|
|
|
|
router = APIRouter(prefix="/auth", tags=["auth"])
|
|
|
|
|
|
def _build_auth_response(user_row) -> AuthResponse:
|
|
user_id = str(user_row["id"])
|
|
return AuthResponse(
|
|
access_token=create_access_token(user_id),
|
|
refresh_token=create_refresh_token(user_id),
|
|
expires_in=3600,
|
|
user=UserOut(
|
|
id=user_row["id"],
|
|
email=user_row["email"],
|
|
display_name=user_row["display_name"],
|
|
timezone=user_row["timezone"],
|
|
created_at=user_row["created_at"],
|
|
),
|
|
)
|
|
|
|
|
|
@router.post("/register", response_model=AuthResponse, status_code=status.HTTP_201_CREATED)
|
|
async def register(req: RegisterRequest):
|
|
pool = await get_pool()
|
|
|
|
existing = await pool.fetchrow("SELECT id FROM users WHERE email = $1", req.email)
|
|
if existing:
|
|
raise HTTPException(status_code=409, detail="Email already registered")
|
|
|
|
hashed = hash_password(req.password)
|
|
row = await pool.fetchrow(
|
|
"""INSERT INTO users (email, password_hash, display_name, timezone)
|
|
VALUES ($1, $2, $3, $4)
|
|
RETURNING id, email, display_name, timezone, created_at""",
|
|
req.email,
|
|
hashed,
|
|
req.display_name,
|
|
req.timezone,
|
|
)
|
|
return _build_auth_response(row)
|
|
|
|
|
|
@router.post("/login", response_model=AuthResponse)
|
|
async def login(req: LoginRequest):
|
|
pool = await get_pool()
|
|
|
|
row = await pool.fetchrow(
|
|
"SELECT id, email, password_hash, display_name, timezone, created_at FROM users WHERE email = $1",
|
|
req.email,
|
|
)
|
|
if not row or not row["password_hash"]:
|
|
raise HTTPException(status_code=401, detail="Invalid credentials")
|
|
|
|
if not verify_password(req.password, row["password_hash"]):
|
|
raise HTTPException(status_code=401, detail="Invalid credentials")
|
|
|
|
return _build_auth_response(row)
|
|
|
|
|
|
@router.post("/apple", response_model=AuthResponse)
|
|
async def apple_auth(req: AppleAuthRequest):
|
|
# Decode the Apple identity token to extract the subject (user ID)
|
|
# In production, verify signature against Apple's public keys
|
|
from jose import jwt as jose_jwt
|
|
|
|
try:
|
|
# Decode without verification for hackathon — in prod, fetch Apple's JWKS
|
|
claims = jose_jwt.get_unverified_claims(req.identity_token)
|
|
apple_user_id = claims["sub"]
|
|
email = claims.get("email")
|
|
except Exception:
|
|
raise HTTPException(status_code=400, detail="Invalid Apple identity token")
|
|
|
|
pool = await get_pool()
|
|
|
|
# Try to find existing user
|
|
row = await pool.fetchrow(
|
|
"SELECT id, email, display_name, timezone, created_at FROM users WHERE apple_user_id = $1",
|
|
apple_user_id,
|
|
)
|
|
if row:
|
|
return _build_auth_response(row)
|
|
|
|
# Check if email already exists (link accounts)
|
|
if email:
|
|
row = await pool.fetchrow(
|
|
"SELECT id, email, display_name, timezone, created_at FROM users WHERE email = $1",
|
|
email,
|
|
)
|
|
if row:
|
|
await pool.execute("UPDATE users SET apple_user_id = $1 WHERE id = $2", apple_user_id, row["id"])
|
|
return _build_auth_response(row)
|
|
|
|
# Create new user
|
|
row = await pool.fetchrow(
|
|
"""INSERT INTO users (apple_user_id, email, display_name, timezone)
|
|
VALUES ($1, $2, $3, $4)
|
|
RETURNING id, email, display_name, timezone, created_at""",
|
|
apple_user_id,
|
|
email,
|
|
req.full_name,
|
|
"America/Chicago",
|
|
)
|
|
return _build_auth_response(row)
|
|
|
|
|
|
@router.post("/refresh", response_model=AuthResponse)
|
|
async def refresh(req: RefreshRequest):
|
|
payload = decode_token(req.refresh_token, expected_type="refresh")
|
|
user_id = payload["sub"]
|
|
|
|
pool = await get_pool()
|
|
row = await pool.fetchrow(
|
|
"SELECT id, email, display_name, timezone, created_at FROM users WHERE id = $1::uuid",
|
|
user_id,
|
|
)
|
|
if not row:
|
|
raise HTTPException(status_code=401, detail="User not found")
|
|
|
|
return _build_auth_response(row)
|
|
|
|
|
|
@router.post("/device-token", status_code=204)
|
|
async def register_device(req: DeviceTokenRequest, user_id: str = Depends(get_current_user_id)):
|
|
await push.register_device_token(user_id, req.platform, req.token)
|