This commit is contained in:
rishubm
2025-10-18 22:36:20 -05:00
parent 3c961efaff
commit a92ddf06bb
45 changed files with 5106 additions and 0 deletions

View File

@@ -0,0 +1,157 @@
"""
Gemini API client wrapper with retry logic and error handling.
"""
import google.generativeai as genai
import json
import logging
import time
from typing import Dict, Any, Optional
from config import get_settings
logger = logging.getLogger(__name__)
class GeminiClient:
"""Wrapper for Google Gemini API with retry logic and JSON parsing."""
def __init__(self):
"""Initialize Gemini client with API key from settings."""
settings = get_settings()
genai.configure(api_key=settings.gemini_api_key)
self.model = genai.GenerativeModel(settings.gemini_model)
self.max_retries = settings.gemini_max_retries
self.demo_mode = settings.demo_mode
# Cache for demo mode
self._demo_cache: Dict[str, Any] = {}
logger.info(f"Gemini client initialized with model: {settings.gemini_model}")
async def generate_json(
self,
prompt: str,
temperature: float = 0.7,
timeout: int = 30
) -> Dict[str, Any]:
"""
Generate JSON response from Gemini with retry logic.
Args:
prompt: The prompt to send to Gemini
temperature: Sampling temperature (0.0-1.0)
timeout: Request timeout in seconds
Returns:
Parsed JSON response
Raises:
Exception: If all retries fail or JSON parsing fails
"""
# Check demo cache
if self.demo_mode:
cache_key = self._get_cache_key(prompt, temperature)
if cache_key in self._demo_cache:
logger.info("Returning cached response (demo mode)")
return self._demo_cache[cache_key]
last_error = None
for attempt in range(1, self.max_retries + 1):
try:
logger.info(f"Gemini API call attempt {attempt}/{self.max_retries}")
# Configure generation parameters
generation_config = genai.GenerationConfig(
temperature=temperature,
response_mime_type="application/json"
)
# Generate response with longer timeout
# Use max of provided timeout or 60 seconds
actual_timeout = max(timeout, 60)
response = self.model.generate_content(
prompt,
generation_config=generation_config,
request_options={"timeout": actual_timeout}
)
# Extract text
response_text = response.text
logger.debug(f"Raw response length: {len(response_text)} chars")
# Parse JSON
result = self._parse_json(response_text)
# Cache in demo mode
if self.demo_mode:
cache_key = self._get_cache_key(prompt, temperature)
self._demo_cache[cache_key] = result
logger.info("Successfully generated and parsed JSON response")
return result
except json.JSONDecodeError as e:
last_error = f"JSON parsing error: {str(e)}"
logger.warning(f"Attempt {attempt} failed: {last_error}")
if attempt < self.max_retries:
# Retry with stricter prompt
prompt = self._add_json_emphasis(prompt)
time.sleep(1)
except Exception as e:
last_error = f"API error: {str(e)}"
logger.warning(f"Attempt {attempt} failed: {last_error}")
if attempt < self.max_retries:
# Exponential backoff, longer for timeout errors
if "timeout" in str(e).lower() or "504" in str(e):
wait_time = 5 * attempt
logger.info(f"Timeout detected, waiting {wait_time}s before retry")
else:
wait_time = 2 * attempt
time.sleep(wait_time)
# All retries failed
error_msg = f"Failed after {self.max_retries} attempts. Last error: {last_error}"
logger.error(error_msg)
raise Exception(error_msg)
def _parse_json(self, text: str) -> Dict[str, Any]:
"""
Parse JSON from response text, handling common issues.
Args:
text: Raw response text
Returns:
Parsed JSON object
Raises:
json.JSONDecodeError: If parsing fails
"""
# Remove markdown code blocks if present
text = text.strip()
if text.startswith("```json"):
text = text[7:]
if text.startswith("```"):
text = text[3:]
if text.endswith("```"):
text = text[:-3]
text = text.strip()
# Parse JSON
return json.loads(text)
def _add_json_emphasis(self, prompt: str) -> str:
"""Add stronger JSON formatting requirements to prompt."""
emphasis = "\n\nIMPORTANT: You MUST return ONLY valid JSON. No markdown, no code blocks, no explanations. Just the raw JSON object."
if emphasis not in prompt:
return prompt + emphasis
return prompt
def _get_cache_key(self, prompt: str, temperature: float) -> str:
"""Generate cache key for demo mode."""
# Use first 100 chars of prompt + temperature as key
return f"{prompt[:100]}_{temperature}"

View File

@@ -0,0 +1,132 @@
"""
Strategy analyzer service - Step 2: Analysis & Selection.
"""
import logging
from typing import List
from config import get_settings
from models.input_models import EnrichedTelemetryWebhook, RaceContext, Strategy
from models.output_models import (
AnalyzeResponse,
AnalyzedStrategy,
PredictedOutcome,
RiskAssessment,
TelemetryInsights,
EngineerBrief,
ECUCommands,
SituationalContext
)
from services.gemini_client import GeminiClient
from prompts.analyze_prompt import build_analyze_prompt
logger = logging.getLogger(__name__)
class StrategyAnalyzer:
"""Analyzes strategies and selects top 3 using Gemini AI."""
def __init__(self):
"""Initialize strategy analyzer."""
self.gemini_client = GeminiClient()
self.settings = get_settings()
logger.info("Strategy analyzer initialized")
async def analyze(
self,
enriched_telemetry: List[EnrichedTelemetryWebhook],
race_context: RaceContext,
strategies: List[Strategy]
) -> AnalyzeResponse:
"""
Analyze strategies and select top 3.
Args:
enriched_telemetry: Recent enriched telemetry data
race_context: Current race context
strategies: Strategies to analyze
Returns:
AnalyzeResponse with top 3 strategies
Raises:
Exception: If analysis fails
"""
logger.info(f"Starting strategy analysis for {len(strategies)} strategies...")
# Build prompt (use fast mode if enabled)
if self.settings.fast_mode:
from prompts.analyze_prompt import build_analyze_prompt_fast
prompt = build_analyze_prompt_fast(enriched_telemetry, race_context, strategies)
logger.info("Using FAST MODE prompt")
else:
prompt = build_analyze_prompt(enriched_telemetry, race_context, strategies)
logger.debug(f"Prompt length: {len(prompt)} chars")
# Generate with Gemini (lower temperature for analytical consistency)
response_data = await self.gemini_client.generate_json(
prompt=prompt,
temperature=0.3,
timeout=self.settings.analyze_timeout
)
# Log the response structure for debugging
logger.info(f"Gemini response keys: {list(response_data.keys())}")
# Parse top strategies
if "top_strategies" not in response_data:
# Log first 500 chars of response for debugging
response_preview = str(response_data)[:500]
logger.error(f"Response preview: {response_preview}...")
raise Exception(f"Response missing 'top_strategies' field. Got keys: {list(response_data.keys())}. Check logs for details.")
if "situational_context" not in response_data:
raise Exception("Response missing 'situational_context' field")
top_strategies_data = response_data["top_strategies"]
situational_context_data = response_data["situational_context"]
logger.info(f"Received {len(top_strategies_data)} top strategies from Gemini")
# Parse top strategies
top_strategies = []
for ts_data in top_strategies_data:
try:
# Parse nested structures
predicted_outcome = PredictedOutcome(**ts_data["predicted_outcome"])
risk_assessment = RiskAssessment(**ts_data["risk_assessment"])
telemetry_insights = TelemetryInsights(**ts_data["telemetry_insights"])
engineer_brief = EngineerBrief(**ts_data["engineer_brief"])
ecu_commands = ECUCommands(**ts_data["ecu_commands"])
# Create analyzed strategy
analyzed_strategy = AnalyzedStrategy(
rank=ts_data["rank"],
strategy_id=ts_data["strategy_id"],
strategy_name=ts_data["strategy_name"],
classification=ts_data["classification"],
predicted_outcome=predicted_outcome,
risk_assessment=risk_assessment,
telemetry_insights=telemetry_insights,
engineer_brief=engineer_brief,
driver_audio_script=ts_data["driver_audio_script"],
ecu_commands=ecu_commands
)
top_strategies.append(analyzed_strategy)
except Exception as e:
logger.warning(f"Failed to parse strategy rank {ts_data.get('rank', '?')}: {e}")
# Parse situational context
situational_context = SituationalContext(**situational_context_data)
# Validate we have 3 strategies
if len(top_strategies) != 3:
logger.warning(f"Expected 3 top strategies, got {len(top_strategies)}")
logger.info(f"Successfully analyzed and selected {len(top_strategies)} strategies")
# Return response
return AnalyzeResponse(
top_strategies=top_strategies,
situational_context=situational_context
)

View File

@@ -0,0 +1,87 @@
"""
Strategy generator service - Step 1: Brainstorming.
"""
import logging
from typing import List
from config import get_settings
from models.input_models import EnrichedTelemetryWebhook, RaceContext, Strategy
from models.output_models import BrainstormResponse
from services.gemini_client import GeminiClient
from prompts.brainstorm_prompt import build_brainstorm_prompt
from utils.validators import StrategyValidator
logger = logging.getLogger(__name__)
class StrategyGenerator:
"""Generates diverse race strategies using Gemini AI."""
def __init__(self):
"""Initialize strategy generator."""
self.gemini_client = GeminiClient()
self.settings = get_settings()
logger.info("Strategy generator initialized")
async def generate(
self,
enriched_telemetry: List[EnrichedTelemetryWebhook],
race_context: RaceContext
) -> BrainstormResponse:
"""
Generate 20 diverse race strategies.
Args:
enriched_telemetry: Recent enriched telemetry data
race_context: Current race context
Returns:
BrainstormResponse with 20 strategies
Raises:
Exception: If generation fails
"""
logger.info("Starting strategy brainstorming...")
logger.info(f"Using {len(enriched_telemetry)} telemetry records")
# Build prompt (use fast mode if enabled)
if self.settings.fast_mode:
from prompts.brainstorm_prompt import build_brainstorm_prompt_fast
prompt = build_brainstorm_prompt_fast(enriched_telemetry, race_context)
logger.info("Using FAST MODE prompt")
else:
prompt = build_brainstorm_prompt(enriched_telemetry, race_context)
logger.debug(f"Prompt length: {len(prompt)} chars")
# Generate with Gemini (high temperature for creativity)
response_data = await self.gemini_client.generate_json(
prompt=prompt,
temperature=0.9,
timeout=self.settings.brainstorm_timeout
)
# Parse strategies
if "strategies" not in response_data:
raise Exception("Response missing 'strategies' field")
strategies_data = response_data["strategies"]
logger.info(f"Received {len(strategies_data)} strategies from Gemini")
# Validate and parse strategies
strategies = []
for s_data in strategies_data:
try:
strategy = Strategy(**s_data)
strategies.append(strategy)
except Exception as e:
logger.warning(f"Failed to parse strategy {s_data.get('strategy_id', '?')}: {e}")
logger.info(f"Successfully parsed {len(strategies)} strategies")
# Validate strategies
valid_strategies = StrategyValidator.validate_strategies(strategies, race_context)
if len(valid_strategies) < 10:
logger.warning(f"Only {len(valid_strategies)} valid strategies (expected 20)")
# Return response
return BrainstormResponse(strategies=valid_strategies)

View File

@@ -0,0 +1,80 @@
"""
Telemetry client for fetching enriched data from HPC enrichment service.
"""
import httpx
import logging
from typing import List, Optional
from config import get_settings
from models.input_models import EnrichedTelemetryWebhook
logger = logging.getLogger(__name__)
class TelemetryClient:
"""Client for fetching enriched telemetry from enrichment service."""
def __init__(self):
"""Initialize telemetry client."""
settings = get_settings()
self.base_url = settings.enrichment_service_url
self.fetch_limit = settings.enrichment_fetch_limit
logger.info(f"Telemetry client initialized for {self.base_url}")
async def fetch_latest(self, limit: Optional[int] = None) -> List[EnrichedTelemetryWebhook]:
"""
Fetch latest enriched telemetry records from enrichment service.
Args:
limit: Number of records to fetch (defaults to config setting)
Returns:
List of enriched telemetry records
Raises:
Exception: If request fails
"""
if limit is None:
limit = self.fetch_limit
url = f"{self.base_url}/enriched"
params = {"limit": limit}
try:
logger.info(f"Fetching telemetry from {url} (limit={limit})")
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.get(url, params=params)
response.raise_for_status()
data = response.json()
logger.info(f"Fetched {len(data)} telemetry records")
# Parse into Pydantic models
records = [EnrichedTelemetryWebhook(**item) for item in data]
return records
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error fetching telemetry: {e.response.status_code}")
raise Exception(f"Enrichment service returned error: {e.response.status_code}")
except httpx.RequestError as e:
logger.error(f"Request error fetching telemetry: {e}")
raise Exception(f"Cannot connect to enrichment service at {self.base_url}")
except Exception as e:
logger.error(f"Unexpected error fetching telemetry: {e}")
raise
async def health_check(self) -> bool:
"""
Check if enrichment service is reachable.
Returns:
True if service is healthy, False otherwise
"""
try:
url = f"{self.base_url}/health"
async with httpx.AsyncClient(timeout=5.0) as client:
response = await client.get(url)
return response.status_code == 200
except Exception as e:
logger.warning(f"Health check failed: {e}")
return False