Middleware Chains
Learn how to combine multiple middleware components to create powerful processing pipelines for your RouteMQ applications.
Understanding Middleware Chains
Middleware in RouteMQ executes in a chain pattern where each middleware can:
Process the request before passing it to the next middleware
Modify the context that flows through the chain
Process the response after receiving it from subsequent middleware
Stop the chain by not calling the next handler
Chain Execution Flow
Request → Middleware 1 → Middleware 2 → Middleware 3 → Handler
↓ ↓ ↓ ↓
Pre-process Pre-process Pre-process Execute
↑ ↑ ↑ ↑
Post-process Post-process Post-process Return
Response ← Middleware 1 ← Middleware 2 ← Middleware 3 ← HandlerBasic Chain Example
from core.router import Router
from app.middleware.logging import LoggingMiddleware
from app.middleware.auth import AuthenticationMiddleware
from app.middleware.rate_limit import RateLimitMiddleware
router = Router()
# Middleware executes in order: Logging → Auth → RateLimit → Handler
router.on("api/{endpoint}",
ApiController.handle_request,
middleware=[
LoggingMiddleware(), # 1st: Log the request
AuthenticationMiddleware(), # 2nd: Authenticate user
RateLimitMiddleware(), # 3rd: Check rate limits
])Execution Order:
LoggingMiddleware logs incoming request
AuthenticationMiddleware validates credentials
RateLimitMiddleware checks request limits
Handler executes business logic
RateLimitMiddleware can log rate limit status
AuthenticationMiddleware can log auth success
LoggingMiddleware logs final response and timing
Chain Design Patterns
1. Security-First Chain
Place security middleware at the beginning to fail fast:
security_chain = [
SecurityMiddleware(), # 1st: Block threats immediately
RateLimitMiddleware(), # 2nd: Prevent abuse
AuthenticationMiddleware(), # 3rd: Verify identity
AuthorizationMiddleware(), # 4th: Check permissions
ValidationMiddleware(schema), # 5th: Validate input
BusinessLogicHandler() # Final: Process request
]
router.on("secure/api/{endpoint}",
SecureController.handle,
middleware=security_chain)2. Performance-Optimized Chain
Order middleware by execution cost (fastest first):
performance_chain = [
CacheMiddleware(), # 1st: Check cache (fast)
RateLimitMiddleware(), # 2nd: Simple counter check
AuthenticationMiddleware(), # 3rd: Token validation
ValidationMiddleware(), # 4th: Schema validation
DatabaseMiddleware(), # 5th: Database operations
LoggingMiddleware() # Last: I/O operations
]3. IoT Device Chain
Specialized chain for IoT devices:
iot_chain = [
DeviceAuthMiddleware(), # Device certificate auth
TelemetryValidationMiddleware(), # Validate sensor data
DataTransformMiddleware(), # Convert units, formats
MetricsMiddleware(), # Collect telemetry metrics
PersistenceMiddleware() # Store to time-series DB
]
router.on("iot/devices/{device_id}/telemetry",
IoTController.handle_telemetry,
middleware=iot_chain)Context Flow Through Chains
Context Modification Example
class RequestIdMiddleware(Middleware):
async def handle(self, context, next_handler):
# Add request ID to context
context['request_id'] = str(uuid.uuid4())
context['start_time'] = time.time()
result = await next_handler(context)
# Add timing information to result
if isinstance(result, dict):
result['processing_time'] = time.time() - context['start_time']
result['request_id'] = context['request_id']
return result
class UserContextMiddleware(Middleware):
async def handle(self, context, next_handler):
# Use request_id from previous middleware
request_id = context.get('request_id', 'unknown')
# Add user context
user_id = context.get('payload', {}).get('user_id')
if user_id:
user = await self.get_user(user_id)
context['user'] = user
context['user_permissions'] = user.permissions
return await next_handler(context)
# Chain shows context flowing through middleware
context_chain = [
RequestIdMiddleware(), # Adds: request_id, start_time
UserContextMiddleware(), # Adds: user, user_permissions (uses request_id)
BusinessMiddleware() # Uses: all previous context data
]Context Dependencies
class DatabaseSessionMiddleware(Middleware):
async def handle(self, context, next_handler):
# Provide database session to downstream middleware
session = await Model.get_session()
context['db_session'] = session
try:
result = await next_handler(context)
await session.commit()
return result
except Exception:
await session.rollback()
raise
finally:
await session.close()
class AuditMiddleware(Middleware):
async def handle(self, context, next_handler):
# Depends on user context and database session
user = context.get('user')
db_session = context.get('db_session')
if user and db_session:
# Create audit log entry
audit_entry = AuditLog(
user_id=user.id,
action=context['topic'],
timestamp=time.time()
)
db_session.add(audit_entry)
return await next_handler(context)
# Order matters: Database session must come before audit
audit_chain = [
AuthenticationMiddleware(), # Provides: user
DatabaseSessionMiddleware(), # Provides: db_session
AuditMiddleware(), # Requires: user, db_session
]Conditional Chain Execution
Skip Middleware Based on Conditions
class ConditionalAuthMiddleware(Middleware):
async def handle(self, context, next_handler):
# Skip auth for public endpoints
if context['topic'].startswith('public/'):
context['authenticated'] = False
return await next_handler(context)
# Apply authentication for other endpoints
return await self.authenticate_request(context, next_handler)
class SmartCacheMiddleware(Middleware):
async def handle(self, context, next_handler):
# Only cache GET-like operations (queries)
payload = context.get('payload', {})
if payload.get('action') == 'query':
return await self.handle_with_cache(context, next_handler)
else:
# Skip caching for mutations
return await next_handler(context)Dynamic Chain Modification
class AdaptiveMiddleware(Middleware):
def __init__(self):
self.performance_mode = False
self.request_count = 0
async def handle(self, context, next_handler):
self.request_count += 1
# Enable performance mode under high load
if self.request_count % 100 == 0:
current_load = await self.get_system_load()
self.performance_mode = current_load > 0.8
if self.performance_mode:
# Skip expensive operations under high load
context['skip_analytics'] = True
context['cache_aggressively'] = True
return await next_handler(context)Error Handling in Chains
Early Termination
class ValidationMiddleware(Middleware):
async def handle(self, context, next_handler):
if not self.is_valid_payload(context['payload']):
# Stop chain execution here
return {
"error": "validation_failed",
"message": "Invalid payload format"
}
# Continue chain only if validation passes
return await next_handler(context)
class AuthorizationMiddleware(Middleware):
async def handle(self, context, next_handler):
user = context.get('user')
if not user or not self.has_permission(user, context['topic']):
# Stop chain - don't call next_handler
return {
"error": "access_denied",
"message": "Insufficient permissions"
}
return await next_handler(context)Error Recovery
class ErrorRecoveryMiddleware(Middleware):
async def handle(self, context, next_handler):
try:
return await next_handler(context)
except DatabaseConnectionError:
# Try with read-only database
context['read_only_mode'] = True
return await self.handle_read_only(context)
except RateLimitExceeded:
# Queue request for later processing
await self.queue_for_later(context)
return {"status": "queued", "message": "Request queued due to rate limit"}
except Exception as e:
# Log error and return generic response
self.logger.error(f"Unhandled error: {e}")
return {"error": "internal_error", "message": "Please try again later"}
class CircuitBreakerMiddleware(Middleware):
def __init__(self, failure_threshold=5, recovery_timeout=60):
self.failure_count = 0
self.failure_threshold = failure_threshold
self.recovery_timeout = recovery_timeout
self.circuit_open_time = None
async def handle(self, context, next_handler):
# Check if circuit is open
if self.circuit_open_time:
if time.time() - self.circuit_open_time < self.recovery_timeout:
return {"error": "service_unavailable", "message": "Circuit breaker open"}
else:
# Try to close circuit
self.circuit_open_time = None
self.failure_count = 0
try:
result = await next_handler(context)
self.failure_count = 0 # Reset on success
return result
except Exception as e:
self.failure_count += 1
if self.failure_count >= self.failure_threshold:
self.circuit_open_time = time.time()
self.logger.error("Circuit breaker opened due to repeated failures")
raiseChain Composition Patterns
Reusable Chain Components
# Define reusable middleware stacks
class MiddlewareStacks:
@staticmethod
def security_stack():
"""Standard security middleware stack"""
return [
SecurityMiddleware(),
RateLimitMiddleware(max_requests=1000),
AuthenticationMiddleware(),
AuthorizationMiddleware()
]
@staticmethod
def performance_stack():
"""Performance optimization stack"""
return [
CacheMiddleware(ttl=300),
CompressionMiddleware(),
MetricsMiddleware()
]
@staticmethod
def data_stack():
"""Data processing stack"""
return [
ValidationMiddleware(),
TransformationMiddleware(),
PersistenceMiddleware()
]
# Compose stacks for different route types
api_middleware = (
MiddlewareStacks.security_stack() +
MiddlewareStacks.performance_stack() +
[LoggingMiddleware()]
)
data_middleware = (
[AuthenticationMiddleware()] +
MiddlewareStacks.data_stack() +
[MetricsMiddleware()]
)Conditional Stack Assembly
def build_middleware_stack(route_type: str, environment: str) -> List[Middleware]:
"""Build middleware stack based on route type and environment"""
stack = []
# Always include security for non-public routes
if route_type != 'public':
stack.extend([
SecurityMiddleware(),
AuthenticationMiddleware()
])
# Add rate limiting (different limits for different environments)
if environment == 'production':
stack.append(RateLimitMiddleware(max_requests=1000, window_seconds=3600))
else:
stack.append(RateLimitMiddleware(max_requests=100, window_seconds=60))
# Add caching for read operations
if route_type in ['api', 'public']:
stack.append(CacheMiddleware(ttl=300))
# Add validation for data routes
if route_type == 'data':
stack.append(ValidationMiddleware(DataSchema()))
# Always include logging and metrics
stack.extend([
MetricsMiddleware(),
LoggingMiddleware()
])
return stack
# Use dynamic stack building
environment = os.getenv('ENVIRONMENT', 'development')
router.on("api/{endpoint}",
ApiController.handle,
middleware=build_middleware_stack('api', environment))
router.on("data/{type}/{id}",
DataController.handle,
middleware=build_middleware_stack('data', environment))Advanced Chain Patterns
Parallel Middleware Execution
import asyncio
class ParallelMiddleware(Middleware):
def __init__(self, middleware_list: List[Middleware]):
self.middleware_list = middleware_list
async def handle(self, context, next_handler):
"""Execute multiple middleware in parallel (where safe)"""
# Create tasks for independent middleware
tasks = []
for middleware in self.middleware_list:
if self.is_safe_for_parallel(middleware):
task = asyncio.create_task(
middleware.handle(context.copy(), self.noop_handler)
)
tasks.append(task)
# Wait for all parallel middleware to complete
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
# Continue with main chain
return await next_handler(context)
async def noop_handler(self, context):
"""No-op handler for parallel middleware"""
return None
def is_safe_for_parallel(self, middleware):
"""Check if middleware can run in parallel"""
# Only metrics and logging middleware are safe for parallel execution
return isinstance(middleware, (MetricsMiddleware, LoggingMiddleware))
# Usage
parallel_stack = [
AuthenticationMiddleware(),
ParallelMiddleware([
MetricsMiddleware(),
LoggingMiddleware(),
AnalyticsMiddleware()
]),
ValidationMiddleware()
]Middleware with State Sharing
class SharedStateMiddleware(Middleware):
"""Middleware that shares state across the chain"""
def __init__(self):
self.shared_state = {}
async def handle(self, context, next_handler):
# Add shared state to context
context['shared_state'] = self.shared_state
result = await next_handler(context)
# Extract any updates to shared state
if 'shared_state_updates' in context:
self.shared_state.update(context['shared_state_updates'])
return result
class StatefulMiddleware(Middleware):
"""Middleware that uses shared state"""
async def handle(self, context, next_handler):
shared_state = context.get('shared_state', {})
# Use shared state
request_count = shared_state.get('request_count', 0) + 1
# Update shared state
context['shared_state_updates'] = {
'request_count': request_count,
'last_request_time': time.time()
}
return await next_handler(context)
# Chain with shared state
stateful_chain = [
SharedStateMiddleware(), # Provides shared state
StatefulMiddleware(), # Uses and updates shared state
BusinessMiddleware() # Can also use shared state
]Testing Middleware Chains
Unit Testing Chains
import pytest
from unittest.mock import AsyncMock
@pytest.mark.asyncio
async def test_middleware_chain_execution():
"""Test that middleware executes in correct order"""
execution_order = []
class OrderTrackingMiddleware(Middleware):
def __init__(self, name):
self.name = name
async def handle(self, context, next_handler):
execution_order.append(f"{self.name}_start")
result = await next_handler(context)
execution_order.append(f"{self.name}_end")
return result
# Create middleware chain
middleware_chain = [
OrderTrackingMiddleware("first"),
OrderTrackingMiddleware("second"),
OrderTrackingMiddleware("third")
]
# Mock final handler
final_handler = AsyncMock(return_value="success")
# Build chain
current_handler = final_handler
for middleware in reversed(middleware_chain):
async def create_handler(mw, next_h):
async def handler(ctx):
return await mw.handle(ctx, next_h)
return handler
current_handler = await create_handler(middleware, current_handler)
# Execute chain
context = {'topic': 'test', 'payload': {}}
result = await current_handler(context)
# Verify execution order
expected_order = [
"first_start", "second_start", "third_start",
"third_end", "second_end", "first_end"
]
assert execution_order == expected_order
assert result == "success"
@pytest.mark.asyncio
async def test_middleware_early_termination():
"""Test middleware chain stops when middleware doesn't call next"""
class BlockingMiddleware(Middleware):
async def handle(self, context, next_handler):
if context.get('block', False):
return {"blocked": True}
return await next_handler(context)
class TrackingMiddleware(Middleware):
def __init__(self):
self.called = False
async def handle(self, context, next_handler):
self.called = True
return await next_handler(context)
tracking = TrackingMiddleware()
final_handler = AsyncMock()
# Build chain: blocking -> tracking -> final
chain = [BlockingMiddleware(), tracking]
# Test blocked execution
context = {'block': True}
# ... (build and execute chain)
# Verify tracking middleware and final handler weren't called
assert not tracking.called
assert not final_handler.calledIntegration Testing
@pytest.mark.asyncio
async def test_complete_middleware_stack():
"""Test complete middleware stack integration"""
# Setup test middleware stack
test_stack = [
AuthenticationMiddleware(),
ValidationMiddleware(TestSchema()),
LoggingMiddleware()
]
# Mock dependencies
redis_manager.enable() # Enable Redis for testing
# Create test route with middleware
router = Router()
router.on("test/{id}", test_handler, middleware=test_stack)
# Test valid request
valid_context = {
'topic': 'test/123',
'payload': {'token': 'valid_token', 'data': 'test'},
'params': {'id': '123'}
}
result = await router.dispatch(**valid_context)
# Verify successful processing
assert result['status'] == 'success'
assert 'request_id' in result
# Test invalid request (should be blocked by middleware)
invalid_context = {
'topic': 'test/123',
'payload': {'invalid': 'data'}, # Missing token
'params': {'id': '123'}
}
result = await router.dispatch(**invalid_context)
# Verify request was blocked
assert 'error' in resultPerformance Optimization
Middleware Ordering for Performance
# Optimal ordering (fastest to slowest):
optimized_chain = [
CacheMiddleware(), # Fast: Memory/Redis lookup
RateLimitMiddleware(), # Fast: Simple counter check
SecurityMiddleware(), # Medium: Pattern matching
AuthenticationMiddleware(), # Slow: Database/token validation
ValidationMiddleware(), # Slow: Schema validation
LoggingMiddleware() # Slowest: File I/O
]Lazy Loading in Middleware
class LazyLoadingMiddleware(Middleware):
def __init__(self):
self._expensive_resource = None
async def get_expensive_resource(self):
"""Lazy load expensive resource only when needed"""
if self._expensive_resource is None:
self._expensive_resource = await self.initialize_resource()
return self._expensive_resource
async def handle(self, context, next_handler):
# Only load resource if actually needed
if self.needs_resource(context):
resource = await self.get_expensive_resource()
context['resource'] = resource
return await next_handler(context)Middleware Caching
class CachingWrapperMiddleware(Middleware):
"""Wrapper that adds caching to any middleware"""
def __init__(self, wrapped_middleware: Middleware, cache_ttl: int = 60):
self.wrapped_middleware = wrapped_middleware
self.cache_ttl = cache_ttl
self.cache = {}
async def handle(self, context, next_handler):
# Generate cache key for middleware result
cache_key = self._generate_cache_key(context)
# Check cache first
if cache_key in self.cache:
cached_data, timestamp = self.cache[cache_key]
if time.time() - timestamp < self.cache_ttl:
# Apply cached modifications to context
context.update(cached_data)
return await next_handler(context)
# Execute wrapped middleware
original_context = context.copy()
result = await self.wrapped_middleware.handle(context, next_handler)
# Cache the context modifications
context_changes = {
k: v for k, v in context.items()
if k not in original_context or original_context[k] != v
}
self.cache[cache_key] = (context_changes, time.time())
return resultBest Practices
1. Design for Composability
# Good: Focused, reusable middleware
class AuthMiddleware(Middleware): pass
class LoggingMiddleware(Middleware): pass
class MetricsMiddleware(Middleware): pass
# Bad: Monolithic middleware
class AuthLoggingMetricsMiddleware(Middleware): pass2. Handle Dependencies Explicitly
# Good: Clear dependency chain
security_chain = [
AuthenticationMiddleware(), # Provides: user
AuthorizationMiddleware(), # Requires: user
]
# Bad: Hidden dependencies
mixed_chain = [
AuthorizationMiddleware(), # Fails: no user context
AuthenticationMiddleware(), # Provides: user (too late)
]3. Document Chain Requirements
class ApiMiddlewareStack:
"""
Standard API middleware stack.
Chain order:
1. SecurityMiddleware - Blocks threats
2. RateLimitMiddleware - Prevents abuse
3. AuthenticationMiddleware - Validates identity
4. ValidationMiddleware - Validates payload
5. LoggingMiddleware - Logs request/response
Context additions:
- request_id: Unique request identifier
- user: Authenticated user object
- validated_payload: Schema-validated payload
"""
@staticmethod
def create(schema: Schema = None):
return [
SecurityMiddleware(),
RateLimitMiddleware(),
AuthenticationMiddleware(),
ValidationMiddleware(schema) if schema else None,
LoggingMiddleware()
]4. Implement Graceful Degradation
class ResilientMiddleware(Middleware):
async def handle(self, context, next_handler):
try:
# Try normal operation
return await self.normal_operation(context, next_handler)
except ServiceUnavailableError:
# Degrade gracefully
return await self.degraded_operation(context, next_handler)
except Exception:
# Log error but continue chain
self.logger.error("Middleware failed, continuing...")
return await next_handler(context)Next Steps
Authentication Middleware - Deep dive into authentication patterns
Caching Middleware - Advanced caching strategies
Creating Middleware - Build custom middleware components
Built-in Middleware - Explore available middleware
Last updated