
中间件(Middleware)是现代Web框架和服务架构中的核心概念。在Rust生态中,从Actix-web到Axum,从Tower到Tonic,几乎所有主流框架都采用了中间件模式。但Rust的中间件设计与其他语言有着本质区别:它不是简单的函数链,而是基于类型系统的零成本抽象。
理解Rust中间件的设计哲学,不仅能帮你构建高性能的服务,更能让你深刻理解Rust的类型系统、异步编程和零成本抽象的威力。本文将从基础概念到高级实现,全面剖析Rust中间件系统的设计精髓。

// 概念上,中间件是一个包装器
Request -> [Middleware1 -> [Middleware2 -> [Handler] -> Response] -> Response] -> Response
// 典型的处理流程
Request
-> 日志中间件(记录开始)
-> 认证中间件(验证token)
-> 限流中间件(检查速率)
-> 业务处理器
<- 限流中间件(更新计数)
<- 认证中间件(无操作)
<- 日志中间件(记录结束)
<- Response这种"进入-处理-退出"的模式,形成了经典的洋葱模型。
// 其他语言的典型实现(伪代码)
interface Middleware {
fn handle(request: Request, next: Next) -> Response
}
// 运行时动态链:有vtable开销,难以内联优化
let middlewares: Vec<Box<dyn Middleware>> = vec![
Box::new(LoggingMiddleware),
Box::new(AuthMiddleware),
];Rust的设计目标是编译时确定整个中间件链,实现零成本抽象。
use std::task::{Context, Poll};
use std::future::Future;
use std::pin::Pin;
// Tower的Service trait
pub trait Service<Request> {
type Response;
type Error;
type Future: Future<Output = Result<Self::Response, Self::Error>>;
// 检查服务是否准备好接受请求
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>;
// 处理请求
fn call(&mut self, req: Request) -> Self::Future;
}这个设计的精妙之处:
Request作为泛型参数,允许不同类型的请求Response、Error、Future都是关联类型,提供类型安全Future,天然支持异步use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
// 一个简单的Echo服务
struct EchoService;
impl Service<String> for EchoService {
type Response = String;
type Error = std::io::Error;
type Future = Pin<Box<dyn Future<Output = Result<String, std::io::Error>>>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
// 总是准备好
Poll::Ready(Ok(()))
}
fn call(&mut self, req: String) -> Self::Future {
Box::pin(async move {
Ok(format!("Echo: {}", req))
})
}
}
// 使用
async fn use_service() {
let mut service = EchoService;
// 等待服务准备好
futures::future::poll_fn(|cx| service.poll_ready(cx)).await.unwrap();
// 发送请求
let response = service.call("Hello".to_string()).await.unwrap();
println!("{}", response); // 输出:Echo: Hello
}pub trait Layer<S> {
type Service;
fn layer(&self, inner: S) -> Self::Service;
}Layer的职责是将一个Service包装成另一个Service。这是构建中间件链的关键。
use std::time::Instant;
// 日志中间件Layer
struct LoggingLayer;
impl<S> Layer<S> for LoggingLayer {
type Service = LoggingService<S>;
fn layer(&self, inner: S) -> Self::Service {
LoggingService { inner }
}
}
// 日志中间件Service
struct LoggingService<S> {
inner: S,
}
impl<S, Request> Service<Request> for LoggingService<S>
where
S: Service<Request>,
Request: std::fmt::Debug,
{
type Response = S::Response;
type Error = S::Error;
type Future = LoggingFuture<S::Future>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request) -> Self::Future {
println!("[Request] {:?}", req);
let start = Instant::now();
LoggingFuture {
inner: self.inner.call(req),
start,
}
}
}
// 包装Future以记录响应时间
struct LoggingFuture<F> {
inner: F,
start: Instant,
}
impl<F, T, E> Future for LoggingFuture<F>
where
F: Future<Output = Result<T, E>>,
T: std::fmt::Debug,
{
type Output = Result<T, E>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
// SAFETY: 我们不会移动inner
let this = unsafe { self.get_unchecked_mut() };
let inner = unsafe { Pin::new_unchecked(&mut this.inner) };
match inner.poll(cx) {
Poll::Ready(result) => {
let duration = this.start.elapsed();
match &result {
Ok(response) => {
println!("[Response] {:?} (took {:?})", response, duration);
}
Err(_) => {
println!("[Error] (took {:?})", duration);
}
}
Poll::Ready(result)
}
Poll::Pending => Poll::Pending,
}
}
}use tower::ServiceBuilder;
// 构建中间件栈
let service = ServiceBuilder::new()
.layer(LoggingLayer)
.layer(TimeoutLayer::new(Duration::from_secs(30)))
.layer(RateLimitLayer::new(100))
.service(my_service);
// 编译器会将这展开成类似:
// LoggingService<TimeoutService<RateLimitService<MyService>>>关键点:整个中间件链在编译时确定,没有动态分发开销。
use std::marker::PhantomData;
// 认证Layer
struct AuthLayer<T> {
secret: String,
_marker: PhantomData<T>,
}
impl<T> AuthLayer<T> {
fn new(secret: String) -> Self {
Self {
secret,
_marker: PhantomData,
}
}
}
impl<S, T> Layer<S> for AuthLayer<T> {
type Service = AuthService<S, T>;
fn layer(&self, inner: S) -> Self::Service {
AuthService {
inner,
secret: self.secret.clone(),
_marker: PhantomData,
}
}
}
// 认证Service
struct AuthService<S, T> {
inner: S,
secret: String,
_marker: PhantomData<T>,
}
// 请求类型必须有token
trait HasToken {
fn token(&self) -> Option<&str>;
}
// 响应类型可以附加用户信息
trait WithUser {
fn with_user(self, user: User) -> Self;
}
#[derive(Debug, Clone)]
struct User {
id: u64,
name: String,
}
impl<S, Request> Service<Request> for AuthService<S, Request>
where
S: Service<Request>,
Request: HasToken,
S::Response: WithUser,
{
type Response = S::Response;
type Error = AuthError<S::Error>;
type Future = AuthFuture<S::Future>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx).map_err(AuthError::Inner)
}
fn call(&mut self, req: Request) -> Self::Future {
// 验证token
let token = req.token();
let user = match token {
Some(t) if self.verify_token(t) => {
Some(self.decode_token(t))
}
_ => None,
};
match user {
Some(user) => {
// 认证成功,继续处理
AuthFuture::Authenticated {
inner: self.inner.call(req),
user,
}
}
None => {
// 认证失败
AuthFuture::Unauthorized
}
}
}
}
impl<S, Request> AuthService<S, Request> {
fn verify_token(&self, token: &str) -> bool {
// 简化的验证逻辑
token.starts_with(&self.secret)
}
fn decode_token(&self, token: &str) -> User {
// 简化的解码逻辑
User {
id: 1,
name: "User".to_string(),
}
}
}
// 认证Future
enum AuthFuture<F> {
Authenticated { inner: F, user: User },
Unauthorized,
}
#[derive(Debug)]
enum AuthError<E> {
Unauthorized,
Inner(E),
}
impl<F, T, E> Future for AuthFuture<F>
where
F: Future<Output = Result<T, E>>,
T: WithUser,
{
type Output = Result<T, AuthError<E>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.get_mut() {
AuthFuture::Authenticated { inner, user } => {
// SAFETY: 我们不会移动inner
let inner = unsafe { Pin::new_unchecked(inner) };
match inner.poll(cx) {
Poll::Ready(Ok(response)) => {
Poll::Ready(Ok(response.with_user(user.clone())))
}
Poll::Ready(Err(e)) => Poll::Ready(Err(AuthError::Inner(e))),
Poll::Pending => Poll::Pending,
}
}
AuthFuture::Unauthorized => {
Poll::Ready(Err(AuthError::Unauthorized))
}
}
}
}// 条件Layer
struct ConditionalLayer<L, F> {
layer: L,
predicate: F,
}
impl<L, F> ConditionalLayer<L, F> {
fn new(layer: L, predicate: F) -> Self {
Self { layer, predicate }
}
}
impl<S, L, F, Request> Layer<S> for ConditionalLayer<L, F>
where
L: Layer<S>,
F: Fn(&Request) -> bool + Clone,
{
type Service = ConditionalService<L::Service, S, F, Request>;
fn layer(&self, inner: S) -> Self::Service {
ConditionalService {
wrapped: self.layer.layer(inner.clone()),
bypass: inner,
predicate: self.predicate.clone(),
_marker: PhantomData,
}
}
}
struct ConditionalService<W, B, F, Request> {
wrapped: W, // 包装的服务
bypass: B, // 原始服务
predicate: F, // 判断函数
_marker: PhantomData<Request>,
}
impl<W, B, F, Request> Service<Request> for ConditionalService<W, B, F, Request>
where
W: Service<Request>,
B: Service<Request, Response = W::Response, Error = W::Error>,
F: Fn(&Request) -> bool,
{
type Response = W::Response;
type Error = W::Error;
type Future = ConditionalFuture<W::Future, B::Future>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
// 两个服务都必须准备好
match self.wrapped.poll_ready(cx)? {
Poll::Ready(()) => self.bypass.poll_ready(cx),
Poll::Pending => Poll::Pending,
}
}
fn call(&mut self, req: Request) -> Self::Future {
if (self.predicate)(&req) {
ConditionalFuture::Wrapped(self.wrapped.call(req))
} else {
ConditionalFuture::Bypass(self.bypass.call(req))
}
}
}
enum ConditionalFuture<W, B> {
Wrapped(W),
Bypass(B),
}
impl<W, B, T, E> Future for ConditionalFuture<W, B>
where
W: Future<Output = Result<T, E>>,
B: Future<Output = Result<T, E>>,
{
type Output = Result<T, E>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
unsafe {
match self.get_unchecked_mut() {
ConditionalFuture::Wrapped(f) => Pin::new_unchecked(f).poll(cx),
ConditionalFuture::Bypass(f) => Pin::new_unchecked(f).poll(cx),
}
}
}
}
// 使用示例
let service = ServiceBuilder::new()
.layer(ConditionalLayer::new(
AuthLayer::new("secret".to_string()),
|req: &HttpRequest| req.path().starts_with("/api/"),
))
.service(handler);use std::sync::Arc;
use tokio::sync::RwLock;
// 共享状态
#[derive(Clone)]
struct AppState {
db: Arc<Database>,
cache: Arc<RwLock<Cache>>,
config: Arc<Config>,
}
// 状态注入Layer
struct StateLayer<T> {
state: T,
}
impl<T: Clone> StateLayer<T> {
fn new(state: T) -> Self {
Self { state }
}
}
impl<S, T: Clone> Layer<S> for StateLayer<T> {
type Service = StateService<S, T>;
fn layer(&self, inner: S) -> Self::Service {
StateService {
inner,
state: self.state.clone(),
}
}
}
struct StateService<S, T> {
inner: S,
state: T,
}
// 请求可以获取状态
trait WithState<T> {
fn with_state(self, state: T) -> Self;
}
impl<S, Request, T> Service<Request> for StateService<S, T>
where
S: Service<Request>,
Request: WithState<T>,
T: Clone,
{
type Response = S::Response;
type Error = S::Error;
type Future = S::Future;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request) -> Self::Future {
// 注入状态
let req = req.with_state(self.state.clone());
self.inner.call(req)
}
}// 错误映射Layer
struct MapErrorLayer<F> {
mapper: F,
}
impl<F> MapErrorLayer<F> {
fn new(mapper: F) -> Self {
Self { mapper }
}
}
impl<S, F> Layer<S> for MapErrorLayer<F>
where
F: Clone,
{
type Service = MapErrorService<S, F>;
fn layer(&self, inner: S) -> Self::Service {
MapErrorService {
inner,
mapper: self.mapper.clone(),
}
}
}
struct MapErrorService<S, F> {
inner: S,
mapper: F,
}
impl<S, F, Request, E2> Service<Request> for MapErrorService<S, F>
where
S: Service<Request>,
F: Fn(S::Error) -> E2 + Clone,
{
type Response = S::Response;
type Error = E2;
type Future = MapErrorFuture<S::Future, F>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx).map_err(|e| (self.mapper)(e))
}
fn call(&mut self, req: Request) -> Self::Future {
MapErrorFuture {
inner: self.inner.call(req),
mapper: self.mapper.clone(),
}
}
}
struct MapErrorFuture<F, M> {
inner: F,
mapper: M,
}
impl<F, M, T, E1, E2> Future for MapErrorFuture<F, M>
where
F: Future<Output = Result<T, E1>>,
M: Fn(E1) -> E2,
{
type Output = Result<T, E2>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = unsafe { self.get_unchecked_mut() };
let inner = unsafe { Pin::new_unchecked(&mut this.inner) };
inner.poll(cx).map(|result| result.map_err(|e| (this.mapper)(e)))
}
}
// 使用示例
let service = ServiceBuilder::new()
.layer(MapErrorLayer::new(|e: DbError| {
HttpError::InternalServerError(e.to_string())
}))
.service(db_service);// 不好:每次都堆分配
type BadFuture = Pin<Box<dyn Future<Output = Result<Response, Error>>>>;
// 好:使用具体类型
struct GoodFuture<F> {
inner: F,
}
impl<F> Future for GoodFuture<F>
where
F: Future<Output = Result<Response, Error>>,
{
type Output = F::Output;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
unsafe { self.map_unchecked_mut(|s| &mut s.inner).poll(cx) }
}
}use tower::util::{BoxService, BoxCloneService};
// 当确实需要类型擦除时
let boxed: BoxService<Request, Response, Error> = BoxService::new(service);
// 需要Clone时
let clonable: BoxCloneService<Request, Response, Error> =
BoxCloneService::new(service);use axum::{Router, routing::get};
use tower::ServiceBuilder;
use tower_http::{
trace::TraceLayer,
cors::CorsLayer,
compression::CompressionLayer,
};
async fn build_app() -> Router {
Router::new()
.route("/", get(handler))
.layer(
ServiceBuilder::new()
// 压缩响应
.layer(CompressionLayer::new())
// CORS
.layer(CorsLayer::permissive())
// 追踪
.layer(TraceLayer::new_for_http())
// 自定义中间件
.layer(RateLimitLayer::new(100))
.layer(AuthLayer::new("secret".into()))
)
}// 问题:每次请求都Clone大对象
struct ExpensiveState {
large_data: Vec<u8>, // 1MB数据
}
// 解决:使用Arc
struct BetterState {
large_data: Arc<Vec<u8>>,
}// 错误:忘记检查poll_ready
async fn wrong_usage<S>(mut service: S, req: Request)
where
S: Service<Request>,
{
let response = service.call(req).await; // 可能panic!
}
// 正确:先检查ready
async fn correct_usage<S>(mut service: S, req: Request)
where
S: Service<Request>,
{
futures::future::poll_fn(|cx| service.poll_ready(cx)).await.unwrap();
let response = service.call(req).await;
}use tower::ServiceExt;
let response = service
.ready()
.await?
.call(request)
.await?;Rust的中间件系统展示了语言设计的精妙:
掌握这些设计模式,你就能:
这是从框架使用者到框架设计者的关键一步。理解中间件的设计原理,你就掌握了构建大型Rust应用的核心能力! 🚀