RUST web框架axum快速入门教程3之状态共享

文章目录

本文主要讨论axum的状态共享,通过状态共享我们可以减少不必要的对象反复创建以及状态共享,共享状态的一个常用场景是共享数据库连接,通过复用数据库的连接对象可以极大的提升数据库操作效率。

往期文章:

数据库

为了能提供一个完整可用的代码,本文打算先简单的讲解一下sqlx的用法,详细教程可参考我之前的文章:https://youerning.top/sqlx-tutorial。

use sqlx::{
    Postgres,
    postgres::PgPoolOptions, migrate::MigrateDatabase};

    const DB_URL: &str = "postgres://用户名:密码@服务器地址:服务器端口/数据库名";

#[tokio::main]
async fn main() {
    // 判断数据库是否存在,不存在则创建
    if !Postgres::database_exists(DB_URL).await.unwrap_or(false) {
        println!("创建数据库 {}", DB_URL);
        match Postgres::create_database(DB_URL).await {
            Ok(_) => println!("创建数据库成功"),
            Err(err) => panic!("创建数据库失败: {}", err)
        }
    } else {
        println!("创建库已存在, 无需创建");
    }

    // 创建连接池
    let db: sqlx::Pool<_> = PgPoolOptions::new()
        // 设置最大连接数
        .max_connections(20)
        .connect(DB_URL)
        .await.unwrap();

    // 执行创建表的sql
    let result = sqlx::query(r#"
        CREATE TABLE IF NOT EXISTS todos
        (
            id          BIGSERIAL PRIMARY KEY,
            description TEXT    NOT NULL,
            done        BOOLEAN NOT NULL DEFAULT FALSE
        );"#)
        .execute(&db)
        .await
        .unwrap();
    println!("建表结果: {result:?}");

    // 插入数据
    let result = sqlx::query!(
        r#"
    INSERT into todos (description)
        values($1)
        RETURNING id
        "#,
        "hello world")
        .fetch_one(&db)
        .await
        .expect("插入数据失败.");

    println!("插入数据成功, 对应的id是: {:?}", result.id);
    // 查询数据
    let result = sqlx::query!(r#"SELECT * from todos"#,)
        .fetch_all(&db)
        .await
        .expect("查询数据失败.");
    
    println!("查询结果: {result:?}");


    for row in result {
        println!("查询数据结果: [{}] {} {}", row.id, row.description, row.done);
    }
}

代码中的注释已经足够清楚了,所以就不做过多解释了,后面的代码中就直接使用这里的db对象了。

值得注意的是: 使用sqlx的query!宏需要存在DATABASE_URL环境变量,可以在本地设置一个.env的文件,然后填充内容DATABASE_URL=“postgres://用户名:密码@服务器地址:服务器端口/数据库名”

状态共享

axum提供了三种状态共享的方式,分别是StateExtension和闭包, 这三种方式各有优缺点,作者推荐第一种,因为第一种是类型安全的方式。

State

use serde::Deserialize;
use axum::{
    response::Html,
    routing::get, Router, extract::{Path, State, Query},
};
use sqlx::{
    Postgres,
    postgres::PgPoolOptions, migrate::MigrateDatabase
};

const DB_URL: &str = "postgres://用户名:密码@服务器地址:服务器端口/数据库名";


#[tokio::main]
async fn main() {
    // 判断数据库是否存在,不存在则创建
    if !Postgres::database_exists(DB_URL).await.unwrap_or(false) {
        println!("创建数据库 {}", DB_URL);
        match Postgres::create_database(DB_URL).await {
            Ok(_) => println!("创建数据库成功"),
            Err(err) => panic!("创建数据库失败: {}", err)
        }
    } else {
        println!("创建库已存在, 无需创建");
    }

    // 创建连接池
    let db: sqlx::Pool<_> = PgPoolOptions::new()
        // 设置最大连接数
        .max_connections(20)
        .connect(DB_URL)
        .await.unwrap();

    let app = Router::new()
        .route("/", get(handler))
        .route("/todos/:id", get(show_todo))
        .route("/todos", get(create_todo))
    	// 注意: 这要放到最后
        .with_state(db);


    let addr = "0.0.0.0:8080";
    axum::Server::bind(&addr.parse().unwrap())
      .serve(app.into_make_service())
      .await
      .unwrap();
}

async fn handler() -> Html<&'static str> {
    Html("<h1>Hello, World!</h1>")
}

#[derive(Debug, Deserialize)]
struct CreateTodo {
    description: String
}

async fn create_todo(
    State(state): State<sqlx::Pool<Postgres>>,
    Query(payload): Query<CreateTodo>
) -> Result<String, String> {
    let description = payload.description;
    match sqlx::query!(
        r#"
    INSERT into todos (description)
        values($1)
        RETURNING id
        "#,
        description)
        .fetch_one(&state)
        .await {
            Ok(record) => {
                Ok(format!("插入数据成功, 插入的数据id是: {}", record.id))
            },
            Err(err) => {
                Err(format!("插入数据失败: {err:?}"))
            }
        }
}

async fn show_todo(
    State(state): State<sqlx::Pool<Postgres>>,
    Path(id): Path<i64>
) -> Result<String, String> {
    match sqlx::query!(r#"SELECT * from todos where id = $1"#, id)
        .fetch_one(&state)
        .await {
            Ok(ret) => {
                Ok(format!("you todo is {ret:?}"))
            },
            Err(err) => {
                Err(format!("查询数据错误: {err}"))
            }
        }
}

上面的代码为了简单起见,创建的接口使用的是GET方法,一般来说创建数据会有POST方法,然后数据放在请求体中以JSON的格式发送。

可以使用以下命令测试上面的代码。

$ curl   http://127.0.0.1:8080/todos?description=youerning.top
插入数据成功, 插入的数据id是: 9
$ curl   http://127.0.0.1:8080/todos/9
you todo is Record { id: 9, description: "youerning.top", done: false }

Extension

Extension与State的一点很大不同在于,前者需要时刻保证处理函数的类型正确,如果路由端设置的类型和处理函数(handler)设置的类型不匹配,会导致运行时错误,也就是状态码为500的服务器内部错误(Internal Server Error).

use serde::Deserialize;
use axum::{
    response::Html,
    routing::get, Router, extract::{Path, Query, Extension},
};
use sqlx::{
    Postgres,
    postgres::PgPoolOptions, migrate::MigrateDatabase
};

const DB_URL: &str = "postgres://用户名:密码@服务器地址:服务器端口/数据库名";


#[tokio::main]
async fn main() {
    // 判断数据库是否存在,不存在则创建
    if !Postgres::database_exists(DB_URL).await.unwrap_or(false) {
        println!("创建数据库 {}", DB_URL);
        match Postgres::create_database(DB_URL).await {
            Ok(_) => println!("创建数据库成功"),
            Err(err) => panic!("创建数据库失败: {}", err)
        }
    } else {
        println!("创建库已存在, 无需创建");
    }

    // 创建连接池
    let db: sqlx::Pool<_> = PgPoolOptions::new()
        // 设置最大连接数
        .max_connections(20)
        .connect(DB_URL)
        .await.unwrap();

    let app = Router::new()
        .route("/", get(handler))
        .route("/todos/:id", get(show_todo))
        .route("/todos", get(create_todo))
        .layer(Extension(db));
        // .route("/request_handler", get(request_handler));


    let addr = "0.0.0.0:8080";
    axum::Server::bind(&addr.parse().unwrap())
      .serve(app.into_make_service())
      .await
      .unwrap();
}

async fn handler() -> Html<&'static str> {
    Html("<h1>Hello, World!</h1>")
}

#[derive(Debug, Deserialize)]
struct CreateTodo {
    description: String
}

async fn create_todo(
    Extension(state): Extension<sqlx::Pool<Postgres>>,
    Query(payload): Query<CreateTodo>
) -> String {
    let description = payload.description;
    match sqlx::query!(
        r#"
    INSERT into todos (description)
        values($1)
        RETURNING id
        "#,
        description)
        .fetch_one(&state)
        .await {
            Ok(record) => {
                format!("插入数据成功, 插入的数据id是: {}", record.id)
            },
            Err(err) => {
                format!("插入数据失败: {err:?}")
            }
        }
}

async fn show_todo(
    Extension(state): Extension<sqlx::Pool<Postgres>>,
    Path(id): Path<i64>
) -> String {
    match sqlx::query!(r#"SELECT * from todos where id = $1"#, id)
        .fetch_one(&state)
        .await {
            Ok(ret) => {
                format!("you todo is {ret:?}")
            },
            Err(err) => {
                format!("查询数据错误: {err}")
            }
        }
}

两者代码差别不大,这里就不展开了,唯一的区别是共享的方式不同以及处理函数端设置的类型不同。

闭包

闭包的方式要比前两者要复杂一些,但是共享的数据类型会清晰很多,到底选哪种需要根据自己的需要来决定,个人也建议选择第一种,比较简单也类型安全。

use serde::Deserialize;
use axum::{
    response::Html,
    routing::get, Router, extract::{Path, Query},
};
use sqlx::{
    Postgres,
    postgres::PgPoolOptions, migrate::MigrateDatabase
};

const DB_URL: &str = "postgres://用户名:密码@服务器地址:服务器端口/数据库名";


#[tokio::main]
async fn main() {
    // 判断数据库是否存在,不存在则创建
    if !Postgres::database_exists(DB_URL).await.unwrap_or(false) {
        println!("创建数据库 {}", DB_URL);
        match Postgres::create_database(DB_URL).await {
            Ok(_) => println!("创建数据库成功"),
            Err(err) => panic!("创建数据库失败: {}", err)
        }
    } else {
        println!("创建库已存在, 无需创建");
    }

    // 创建连接池
    let db: sqlx::Pool<_> = PgPoolOptions::new()
        // 设置最大连接数
        .max_connections(20)
        .connect(DB_URL)
        .await.unwrap();

    let app = Router::new()
        .route("/", get(handler))
        .route("/todos/:id", get({
            let state = db.clone();
            move |query| show_todo(state, query)
        }))
        .route("/todos", get({
            let state = db.clone();
            move |query| create_todo(state, query)
        }));
        // .layer(Extension(db));
        // .route("/request_handler", get(request_handler));


    let addr = "0.0.0.0:8080";
    axum::Server::bind(&addr.parse().unwrap())
      .serve(app.into_make_service())
      .await
      .unwrap();
}

async fn handler() -> Html<&'static str> {
    Html("<h1>Hello, World!</h1>")
}

#[derive(Debug, Deserialize)]
struct CreateTodo {
    description: String
}

async fn create_todo(
    state: sqlx::Pool<Postgres>,
    Query(payload): Query<CreateTodo>
) -> String {
    let description = payload.description;
    match sqlx::query!(
        r#"
    INSERT into todos (description)
        values($1)
        RETURNING id
        "#,
        description)
        .fetch_one(&state)
        .await {
            Ok(record) => {
                format!("插入数据成功, 插入的数据id是: {}", record.id)
            },
            Err(err) => {
                format!("插入数据失败: {err:?}")
            }
        }
}

async fn show_todo(
    state: sqlx::Pool<Postgres>,
    Path(id): Path<i64>
) -> String {
    match sqlx::query!(r#"SELECT * from todos where id = $1"#, id)
        .fetch_one(&state)
        .await {
            Ok(ret) => {
                format!("you todo is {ret:?}")
            },
            Err(err) => {
                format!("查询数据错误: {err}")
            }
        }
}

小结

一般来说,共享对象都会用ARC包装一下的,这是为了线程安全,但是sqlx的Pool对象本身就是一个包裹在ARC里面的对象,所以不需要额外的用ARC在包一层了,如果是自定义的数据,则需要用ARC包装一下。

参考链接