Copy # app/controllers/gateway_controller.py
from core.controller import Controller
from core.redis_manager import redis_manager
from app.services.service_discovery import ServiceDiscovery
from app.services.load_balancer import LoadBalancer
from app.services.circuit_breaker import CircuitBreaker
import json
import time
import uuid
import asyncio
import aiohttp
from typing import Dict, Any, Optional
class GatewayController(Controller):
@staticmethod
async def route_to_user_service(payload: Dict[str, Any], client, **kwargs):
"""Route requests to user microservice"""
return await GatewayController._route_to_service(
service_name="user-service",
payload=payload,
client=client,
**kwargs
)
@staticmethod
async def route_to_order_service(payload: Dict[str, Any], client, **kwargs):
"""Route requests to order microservice"""
return await GatewayController._route_to_service(
service_name="order-service",
payload=payload,
client=client,
**kwargs
)
@staticmethod
async def route_to_payment_service(payload: Dict[str, Any], client, **kwargs):
"""Route requests to payment microservice"""
return await GatewayController._route_to_service(
service_name="payment-service",
payload=payload,
client=client,
**kwargs
)
@staticmethod
async def route_to_notification_service(payload: Dict[str, Any], client, **kwargs):
"""Route requests to notification microservice"""
return await GatewayController._route_to_service(
service_name="notification-service",
payload=payload,
client=client,
**kwargs
)
@staticmethod
async def _route_to_service(service_name: str, payload: Dict[str, Any], client, **kwargs):
"""Generic service routing with load balancing and circuit breaker"""
try:
request_id = str(uuid.uuid4())
start_time = time.time()
# Extract request context
context = kwargs.get('context', {})
path_params = context.get('path_params', {})
headers = context.get('headers', {})
method = context.get('method', 'POST')
# Get service instance from service discovery
service_instance = await ServiceDiscovery.get_service_instance(service_name)
if not service_instance:
raise Exception(f"No available instances for service {service_name}")
# Check circuit breaker
circuit_breaker = CircuitBreaker(service_name)
if circuit_breaker.is_open():
raise Exception(f"Circuit breaker is open for service {service_name}")
# Prepare request
service_url = f"{service_instance['url']}{context.get('path', '')}"
request_headers = {
"Content-Type": "application/json",
"X-Request-ID": request_id,
"X-Gateway-Service": service_name,
**headers
}
# Add authentication headers if present
if context.get('user_id'):
request_headers["X-User-ID"] = context['user_id']
if context.get('auth_token'):
request_headers["Authorization"] = f"Bearer {context['auth_token']}"
# Log request
await GatewayController._log_request(request_id, service_name, service_url, payload)
# Make HTTP request to microservice
async with aiohttp.ClientSession() as session:
async with session.request(
method=method,
url=service_url,
json=payload,
headers=request_headers,
timeout=aiohttp.ClientTimeout(total=30)
) as response:
response_data = await response.json() if response.content_type == 'application/json' else await response.text()
# Record success in circuit breaker
circuit_breaker.record_success()
# Log response
processing_time = time.time() - start_time
await GatewayController._log_response(
request_id,
service_name,
response.status,
processing_time
)
# Update service metrics
await GatewayController._update_service_metrics(
service_name,
response.status,
processing_time
)
if response.status >= 400:
circuit_breaker.record_failure()
raise Exception(f"Service error: {response.status} - {response_data}")
return {
"status": "success",
"data": response_data,
"request_id": request_id,
"service": service_name,
"processing_time": processing_time
}
except Exception as e:
# Record failure in circuit breaker
if 'circuit_breaker' in locals():
circuit_breaker.record_failure()
# Log error
await GatewayController._log_error(request_id, service_name, str(e))
# Try fallback service if available
fallback_response = await GatewayController._try_fallback(service_name, payload)
if fallback_response:
return fallback_response
raise
@staticmethod
async def route_to_user_service_v2(action: str, payload: Dict[str, Any], client, **kwargs):
"""Route to user service v2 with enhanced features"""
try:
# API v2 specific transformations
payload = await GatewayController._transform_payload_v2(payload, action)
# Route to appropriate v2 endpoint
context = kwargs.get('context', {})
context['path'] = f"/v2/users/{action}"
kwargs['context'] = context
return await GatewayController._route_to_service(
service_name="user-service-v2",
payload=payload,
client=client,
**kwargs
)
except Exception as e:
print(f"Error routing to user service v2: {e}")
raise
@staticmethod
async def health_check(payload: Dict[str, Any], client):
"""API Gateway health check"""
try:
gateway_status = {
"status": "healthy",
"timestamp": time.time(),
"version": "1.0.0"
}
# Check service health
services_health = await ServiceDiscovery.check_all_services_health()
gateway_status["services"] = services_health
# Check Redis connectivity
try:
await redis_manager.ping()
gateway_status["redis"] = "healthy"
except Exception:
gateway_status["redis"] = "unhealthy"
gateway_status["status"] = "degraded"
return gateway_status
except Exception as e:
return {
"status": "unhealthy",
"error": str(e),
"timestamp": time.time()
}
@staticmethod
async def list_services(payload: Dict[str, Any], client):
"""List all registered services"""
try:
services = await ServiceDiscovery.list_all_services()
return {
"status": "success",
"services": services,
"count": len(services),
"timestamp": time.time()
}
except Exception as e:
print(f"Error listing services: {e}")
raise
@staticmethod
async def get_metrics(payload: Dict[str, Any], client):
"""Get API Gateway metrics"""
try:
# Get request metrics
total_requests = await redis_manager.get("metrics:gateway:total_requests") or 0
successful_requests = await redis_manager.get("metrics:gateway:successful_requests") or 0
failed_requests = await redis_manager.get("metrics:gateway:failed_requests") or 0
# Get service metrics
service_metrics = {}
services = await ServiceDiscovery.list_all_services()
for service in services:
service_name = service["name"]
service_metrics[service_name] = {
"total_requests": await redis_manager.get(f"metrics:service:{service_name}:requests") or 0,
"avg_response_time": await redis_manager.get(f"metrics:service:{service_name}:avg_time") or 0,
"error_rate": await redis_manager.get(f"metrics:service:{service_name}:error_rate") or 0
}
# Get rate limiting metrics
rate_limit_hits = await redis_manager.get("metrics:gateway:rate_limit_hits") or 0
return {
"gateway": {
"total_requests": int(total_requests),
"successful_requests": int(successful_requests),
"failed_requests": int(failed_requests),
"success_rate": (int(successful_requests) / max(int(total_requests), 1)) * 100,
"rate_limit_hits": int(rate_limit_hits)
},
"services": service_metrics,
"timestamp": time.time()
}
except Exception as e:
print(f"Error getting metrics: {e}")
raise
# Helper methods
@staticmethod
async def _log_request(request_id: str, service_name: str, url: str, payload: Dict[str, Any]):
"""Log incoming request"""
log_entry = {
"request_id": request_id,
"service": service_name,
"url": url,
"payload_size": len(json.dumps(payload)),
"timestamp": time.time()
}
await redis_manager.lpush("gateway:request_logs", json.dumps(log_entry))
await redis_manager.ltrim("gateway:request_logs", 0, 999) # Keep last 1000 logs
await redis_manager.incr("metrics:gateway:total_requests")
@staticmethod
async def _log_response(request_id: str, service_name: str, status_code: int, processing_time: float):
"""Log service response"""
log_entry = {
"request_id": request_id,
"service": service_name,
"status_code": status_code,
"processing_time": processing_time,
"timestamp": time.time()
}
await redis_manager.lpush("gateway:response_logs", json.dumps(log_entry))
await redis_manager.ltrim("gateway:response_logs", 0, 999)
if status_code < 400:
await redis_manager.incr("metrics:gateway:successful_requests")
else:
await redis_manager.incr("metrics:gateway:failed_requests")
@staticmethod
async def _log_error(request_id: str, service_name: str, error_message: str):
"""Log error"""
error_entry = {
"request_id": request_id,
"service": service_name,
"error": error_message,
"timestamp": time.time()
}
await redis_manager.lpush("gateway:error_logs", json.dumps(error_entry))
await redis_manager.ltrim("gateway:error_logs", 0, 999)
await redis_manager.incr("metrics:gateway:failed_requests")
@staticmethod
async def _update_service_metrics(service_name: str, status_code: int, processing_time: float):
"""Update service performance metrics"""
await redis_manager.incr(f"metrics:service:{service_name}:requests")
# Update average response time (simplified)
current_avg = float(await redis_manager.get(f"metrics:service:{service_name}:avg_time") or 0)
total_requests = int(await redis_manager.get(f"metrics:service:{service_name}:requests") or 1)
new_avg = ((current_avg * (total_requests - 1)) + processing_time) / total_requests
await redis_manager.set(f"metrics:service:{service_name}:avg_time", new_avg)
# Update error rate
if status_code >= 400:
await redis_manager.incr(f"metrics:service:{service_name}:errors")
error_count = int(await redis_manager.get(f"metrics:service:{service_name}:errors") or 0)
error_rate = (error_count / total_requests) * 100
await redis_manager.set(f"metrics:service:{service_name}:error_rate", error_rate)
@staticmethod
async def _transform_payload_v2(payload: Dict[str, Any], action: str) -> Dict[str, Any]:
"""Transform payload for API v2"""
# Add v2 specific fields
payload["api_version"] = "v2"
payload["action"] = action
# Transform legacy fields if present
if "user_data" in payload:
payload["user"] = payload.pop("user_data")
return payload
@staticmethod
async def _try_fallback(service_name: str, payload: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""Try fallback service or cached response"""
# Check for cached response
cache_key = f"fallback:{service_name}:{hash(json.dumps(payload, sort_keys=True))}"
cached_response = await redis_manager.get_json(cache_key)
if cached_response:
cached_response["source"] = "cache"
return cached_response
# Try fallback service
fallback_service = f"{service_name}-fallback"
fallback_instance = await ServiceDiscovery.get_service_instance(fallback_service)
if fallback_instance:
# Implement fallback logic
return {
"status": "fallback",
"message": "Primary service unavailable, fallback response",
"data": {}
}
return None