p
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
157
ai_intelligence_layer/services/gemini_client.py
Normal file
157
ai_intelligence_layer/services/gemini_client.py
Normal file
@@ -0,0 +1,157 @@
|
||||
"""
|
||||
Gemini API client wrapper with retry logic and error handling.
|
||||
"""
|
||||
import google.generativeai as genai
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict, Any, Optional
|
||||
from config import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GeminiClient:
|
||||
"""Wrapper for Google Gemini API with retry logic and JSON parsing."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Gemini client with API key from settings."""
|
||||
settings = get_settings()
|
||||
genai.configure(api_key=settings.gemini_api_key)
|
||||
self.model = genai.GenerativeModel(settings.gemini_model)
|
||||
self.max_retries = settings.gemini_max_retries
|
||||
self.demo_mode = settings.demo_mode
|
||||
|
||||
# Cache for demo mode
|
||||
self._demo_cache: Dict[str, Any] = {}
|
||||
|
||||
logger.info(f"Gemini client initialized with model: {settings.gemini_model}")
|
||||
|
||||
async def generate_json(
|
||||
self,
|
||||
prompt: str,
|
||||
temperature: float = 0.7,
|
||||
timeout: int = 30
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate JSON response from Gemini with retry logic.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to send to Gemini
|
||||
temperature: Sampling temperature (0.0-1.0)
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
Parsed JSON response
|
||||
|
||||
Raises:
|
||||
Exception: If all retries fail or JSON parsing fails
|
||||
"""
|
||||
# Check demo cache
|
||||
if self.demo_mode:
|
||||
cache_key = self._get_cache_key(prompt, temperature)
|
||||
if cache_key in self._demo_cache:
|
||||
logger.info("Returning cached response (demo mode)")
|
||||
return self._demo_cache[cache_key]
|
||||
|
||||
last_error = None
|
||||
|
||||
for attempt in range(1, self.max_retries + 1):
|
||||
try:
|
||||
logger.info(f"Gemini API call attempt {attempt}/{self.max_retries}")
|
||||
|
||||
# Configure generation parameters
|
||||
generation_config = genai.GenerationConfig(
|
||||
temperature=temperature,
|
||||
response_mime_type="application/json"
|
||||
)
|
||||
|
||||
# Generate response with longer timeout
|
||||
# Use max of provided timeout or 60 seconds
|
||||
actual_timeout = max(timeout, 60)
|
||||
response = self.model.generate_content(
|
||||
prompt,
|
||||
generation_config=generation_config,
|
||||
request_options={"timeout": actual_timeout}
|
||||
)
|
||||
|
||||
# Extract text
|
||||
response_text = response.text
|
||||
logger.debug(f"Raw response length: {len(response_text)} chars")
|
||||
|
||||
# Parse JSON
|
||||
result = self._parse_json(response_text)
|
||||
|
||||
# Cache in demo mode
|
||||
if self.demo_mode:
|
||||
cache_key = self._get_cache_key(prompt, temperature)
|
||||
self._demo_cache[cache_key] = result
|
||||
|
||||
logger.info("Successfully generated and parsed JSON response")
|
||||
return result
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
last_error = f"JSON parsing error: {str(e)}"
|
||||
logger.warning(f"Attempt {attempt} failed: {last_error}")
|
||||
|
||||
if attempt < self.max_retries:
|
||||
# Retry with stricter prompt
|
||||
prompt = self._add_json_emphasis(prompt)
|
||||
time.sleep(1)
|
||||
|
||||
except Exception as e:
|
||||
last_error = f"API error: {str(e)}"
|
||||
logger.warning(f"Attempt {attempt} failed: {last_error}")
|
||||
|
||||
if attempt < self.max_retries:
|
||||
# Exponential backoff, longer for timeout errors
|
||||
if "timeout" in str(e).lower() or "504" in str(e):
|
||||
wait_time = 5 * attempt
|
||||
logger.info(f"Timeout detected, waiting {wait_time}s before retry")
|
||||
else:
|
||||
wait_time = 2 * attempt
|
||||
time.sleep(wait_time)
|
||||
|
||||
# All retries failed
|
||||
error_msg = f"Failed after {self.max_retries} attempts. Last error: {last_error}"
|
||||
logger.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
|
||||
def _parse_json(self, text: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Parse JSON from response text, handling common issues.
|
||||
|
||||
Args:
|
||||
text: Raw response text
|
||||
|
||||
Returns:
|
||||
Parsed JSON object
|
||||
|
||||
Raises:
|
||||
json.JSONDecodeError: If parsing fails
|
||||
"""
|
||||
# Remove markdown code blocks if present
|
||||
text = text.strip()
|
||||
if text.startswith("```json"):
|
||||
text = text[7:]
|
||||
if text.startswith("```"):
|
||||
text = text[3:]
|
||||
if text.endswith("```"):
|
||||
text = text[:-3]
|
||||
|
||||
text = text.strip()
|
||||
|
||||
# Parse JSON
|
||||
return json.loads(text)
|
||||
|
||||
def _add_json_emphasis(self, prompt: str) -> str:
|
||||
"""Add stronger JSON formatting requirements to prompt."""
|
||||
emphasis = "\n\nIMPORTANT: You MUST return ONLY valid JSON. No markdown, no code blocks, no explanations. Just the raw JSON object."
|
||||
if emphasis not in prompt:
|
||||
return prompt + emphasis
|
||||
return prompt
|
||||
|
||||
def _get_cache_key(self, prompt: str, temperature: float) -> str:
|
||||
"""Generate cache key for demo mode."""
|
||||
# Use first 100 chars of prompt + temperature as key
|
||||
return f"{prompt[:100]}_{temperature}"
|
||||
132
ai_intelligence_layer/services/strategy_analyzer.py
Normal file
132
ai_intelligence_layer/services/strategy_analyzer.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""
|
||||
Strategy analyzer service - Step 2: Analysis & Selection.
|
||||
"""
|
||||
import logging
|
||||
from typing import List
|
||||
from config import get_settings
|
||||
from models.input_models import EnrichedTelemetryWebhook, RaceContext, Strategy
|
||||
from models.output_models import (
|
||||
AnalyzeResponse,
|
||||
AnalyzedStrategy,
|
||||
PredictedOutcome,
|
||||
RiskAssessment,
|
||||
TelemetryInsights,
|
||||
EngineerBrief,
|
||||
ECUCommands,
|
||||
SituationalContext
|
||||
)
|
||||
from services.gemini_client import GeminiClient
|
||||
from prompts.analyze_prompt import build_analyze_prompt
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StrategyAnalyzer:
|
||||
"""Analyzes strategies and selects top 3 using Gemini AI."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize strategy analyzer."""
|
||||
self.gemini_client = GeminiClient()
|
||||
self.settings = get_settings()
|
||||
logger.info("Strategy analyzer initialized")
|
||||
|
||||
async def analyze(
|
||||
self,
|
||||
enriched_telemetry: List[EnrichedTelemetryWebhook],
|
||||
race_context: RaceContext,
|
||||
strategies: List[Strategy]
|
||||
) -> AnalyzeResponse:
|
||||
"""
|
||||
Analyze strategies and select top 3.
|
||||
|
||||
Args:
|
||||
enriched_telemetry: Recent enriched telemetry data
|
||||
race_context: Current race context
|
||||
strategies: Strategies to analyze
|
||||
|
||||
Returns:
|
||||
AnalyzeResponse with top 3 strategies
|
||||
|
||||
Raises:
|
||||
Exception: If analysis fails
|
||||
"""
|
||||
logger.info(f"Starting strategy analysis for {len(strategies)} strategies...")
|
||||
|
||||
# Build prompt (use fast mode if enabled)
|
||||
if self.settings.fast_mode:
|
||||
from prompts.analyze_prompt import build_analyze_prompt_fast
|
||||
prompt = build_analyze_prompt_fast(enriched_telemetry, race_context, strategies)
|
||||
logger.info("Using FAST MODE prompt")
|
||||
else:
|
||||
prompt = build_analyze_prompt(enriched_telemetry, race_context, strategies)
|
||||
logger.debug(f"Prompt length: {len(prompt)} chars")
|
||||
|
||||
# Generate with Gemini (lower temperature for analytical consistency)
|
||||
response_data = await self.gemini_client.generate_json(
|
||||
prompt=prompt,
|
||||
temperature=0.3,
|
||||
timeout=self.settings.analyze_timeout
|
||||
)
|
||||
|
||||
# Log the response structure for debugging
|
||||
logger.info(f"Gemini response keys: {list(response_data.keys())}")
|
||||
|
||||
# Parse top strategies
|
||||
if "top_strategies" not in response_data:
|
||||
# Log first 500 chars of response for debugging
|
||||
response_preview = str(response_data)[:500]
|
||||
logger.error(f"Response preview: {response_preview}...")
|
||||
raise Exception(f"Response missing 'top_strategies' field. Got keys: {list(response_data.keys())}. Check logs for details.")
|
||||
|
||||
if "situational_context" not in response_data:
|
||||
raise Exception("Response missing 'situational_context' field")
|
||||
|
||||
top_strategies_data = response_data["top_strategies"]
|
||||
situational_context_data = response_data["situational_context"]
|
||||
|
||||
logger.info(f"Received {len(top_strategies_data)} top strategies from Gemini")
|
||||
|
||||
# Parse top strategies
|
||||
top_strategies = []
|
||||
for ts_data in top_strategies_data:
|
||||
try:
|
||||
# Parse nested structures
|
||||
predicted_outcome = PredictedOutcome(**ts_data["predicted_outcome"])
|
||||
risk_assessment = RiskAssessment(**ts_data["risk_assessment"])
|
||||
telemetry_insights = TelemetryInsights(**ts_data["telemetry_insights"])
|
||||
engineer_brief = EngineerBrief(**ts_data["engineer_brief"])
|
||||
ecu_commands = ECUCommands(**ts_data["ecu_commands"])
|
||||
|
||||
# Create analyzed strategy
|
||||
analyzed_strategy = AnalyzedStrategy(
|
||||
rank=ts_data["rank"],
|
||||
strategy_id=ts_data["strategy_id"],
|
||||
strategy_name=ts_data["strategy_name"],
|
||||
classification=ts_data["classification"],
|
||||
predicted_outcome=predicted_outcome,
|
||||
risk_assessment=risk_assessment,
|
||||
telemetry_insights=telemetry_insights,
|
||||
engineer_brief=engineer_brief,
|
||||
driver_audio_script=ts_data["driver_audio_script"],
|
||||
ecu_commands=ecu_commands
|
||||
)
|
||||
|
||||
top_strategies.append(analyzed_strategy)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse strategy rank {ts_data.get('rank', '?')}: {e}")
|
||||
|
||||
# Parse situational context
|
||||
situational_context = SituationalContext(**situational_context_data)
|
||||
|
||||
# Validate we have 3 strategies
|
||||
if len(top_strategies) != 3:
|
||||
logger.warning(f"Expected 3 top strategies, got {len(top_strategies)}")
|
||||
|
||||
logger.info(f"Successfully analyzed and selected {len(top_strategies)} strategies")
|
||||
|
||||
# Return response
|
||||
return AnalyzeResponse(
|
||||
top_strategies=top_strategies,
|
||||
situational_context=situational_context
|
||||
)
|
||||
87
ai_intelligence_layer/services/strategy_generator.py
Normal file
87
ai_intelligence_layer/services/strategy_generator.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""
|
||||
Strategy generator service - Step 1: Brainstorming.
|
||||
"""
|
||||
import logging
|
||||
from typing import List
|
||||
from config import get_settings
|
||||
from models.input_models import EnrichedTelemetryWebhook, RaceContext, Strategy
|
||||
from models.output_models import BrainstormResponse
|
||||
from services.gemini_client import GeminiClient
|
||||
from prompts.brainstorm_prompt import build_brainstorm_prompt
|
||||
from utils.validators import StrategyValidator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StrategyGenerator:
|
||||
"""Generates diverse race strategies using Gemini AI."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize strategy generator."""
|
||||
self.gemini_client = GeminiClient()
|
||||
self.settings = get_settings()
|
||||
logger.info("Strategy generator initialized")
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
enriched_telemetry: List[EnrichedTelemetryWebhook],
|
||||
race_context: RaceContext
|
||||
) -> BrainstormResponse:
|
||||
"""
|
||||
Generate 20 diverse race strategies.
|
||||
|
||||
Args:
|
||||
enriched_telemetry: Recent enriched telemetry data
|
||||
race_context: Current race context
|
||||
|
||||
Returns:
|
||||
BrainstormResponse with 20 strategies
|
||||
|
||||
Raises:
|
||||
Exception: If generation fails
|
||||
"""
|
||||
logger.info("Starting strategy brainstorming...")
|
||||
logger.info(f"Using {len(enriched_telemetry)} telemetry records")
|
||||
|
||||
# Build prompt (use fast mode if enabled)
|
||||
if self.settings.fast_mode:
|
||||
from prompts.brainstorm_prompt import build_brainstorm_prompt_fast
|
||||
prompt = build_brainstorm_prompt_fast(enriched_telemetry, race_context)
|
||||
logger.info("Using FAST MODE prompt")
|
||||
else:
|
||||
prompt = build_brainstorm_prompt(enriched_telemetry, race_context)
|
||||
logger.debug(f"Prompt length: {len(prompt)} chars")
|
||||
|
||||
# Generate with Gemini (high temperature for creativity)
|
||||
response_data = await self.gemini_client.generate_json(
|
||||
prompt=prompt,
|
||||
temperature=0.9,
|
||||
timeout=self.settings.brainstorm_timeout
|
||||
)
|
||||
|
||||
# Parse strategies
|
||||
if "strategies" not in response_data:
|
||||
raise Exception("Response missing 'strategies' field")
|
||||
|
||||
strategies_data = response_data["strategies"]
|
||||
logger.info(f"Received {len(strategies_data)} strategies from Gemini")
|
||||
|
||||
# Validate and parse strategies
|
||||
strategies = []
|
||||
for s_data in strategies_data:
|
||||
try:
|
||||
strategy = Strategy(**s_data)
|
||||
strategies.append(strategy)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse strategy {s_data.get('strategy_id', '?')}: {e}")
|
||||
|
||||
logger.info(f"Successfully parsed {len(strategies)} strategies")
|
||||
|
||||
# Validate strategies
|
||||
valid_strategies = StrategyValidator.validate_strategies(strategies, race_context)
|
||||
|
||||
if len(valid_strategies) < 10:
|
||||
logger.warning(f"Only {len(valid_strategies)} valid strategies (expected 20)")
|
||||
|
||||
# Return response
|
||||
return BrainstormResponse(strategies=valid_strategies)
|
||||
80
ai_intelligence_layer/services/telemetry_client.py
Normal file
80
ai_intelligence_layer/services/telemetry_client.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""
|
||||
Telemetry client for fetching enriched data from HPC enrichment service.
|
||||
"""
|
||||
import httpx
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
from config import get_settings
|
||||
from models.input_models import EnrichedTelemetryWebhook
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TelemetryClient:
|
||||
"""Client for fetching enriched telemetry from enrichment service."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize telemetry client."""
|
||||
settings = get_settings()
|
||||
self.base_url = settings.enrichment_service_url
|
||||
self.fetch_limit = settings.enrichment_fetch_limit
|
||||
logger.info(f"Telemetry client initialized for {self.base_url}")
|
||||
|
||||
async def fetch_latest(self, limit: Optional[int] = None) -> List[EnrichedTelemetryWebhook]:
|
||||
"""
|
||||
Fetch latest enriched telemetry records from enrichment service.
|
||||
|
||||
Args:
|
||||
limit: Number of records to fetch (defaults to config setting)
|
||||
|
||||
Returns:
|
||||
List of enriched telemetry records
|
||||
|
||||
Raises:
|
||||
Exception: If request fails
|
||||
"""
|
||||
if limit is None:
|
||||
limit = self.fetch_limit
|
||||
|
||||
url = f"{self.base_url}/enriched"
|
||||
params = {"limit": limit}
|
||||
|
||||
try:
|
||||
logger.info(f"Fetching telemetry from {url} (limit={limit})")
|
||||
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
response = await client.get(url, params=params)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
logger.info(f"Fetched {len(data)} telemetry records")
|
||||
|
||||
# Parse into Pydantic models
|
||||
records = [EnrichedTelemetryWebhook(**item) for item in data]
|
||||
return records
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"HTTP error fetching telemetry: {e.response.status_code}")
|
||||
raise Exception(f"Enrichment service returned error: {e.response.status_code}")
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Request error fetching telemetry: {e}")
|
||||
raise Exception(f"Cannot connect to enrichment service at {self.base_url}")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error fetching telemetry: {e}")
|
||||
raise
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""
|
||||
Check if enrichment service is reachable.
|
||||
|
||||
Returns:
|
||||
True if service is healthy, False otherwise
|
||||
"""
|
||||
try:
|
||||
url = f"{self.base_url}/health"
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
response = await client.get(url)
|
||||
return response.status_code == 200
|
||||
except Exception as e:
|
||||
logger.warning(f"Health check failed: {e}")
|
||||
return False
|
||||
Reference in New Issue
Block a user