FastAPI 自动生成 Schema 和 Router:打造高效的 CRUD API
引言
在使用 FastAPI 开发 Web API 时,我们经常需要为每个数据模型重复编写类似的 CRUD(创建、读取、更新、删除)操作代码。这不仅耗时,而且容易出错。本文将介绍如何通过装饰器模式自动从 SQLAlchemy 模型生成 FastAPI 的 schema 和 router,大大提高开发效率。
项目特点
-
自动代码生成
- 自动从 SQLAlchemy 模型生成 Pydantic BaseModel
- 自动生成标准的 CRUD 路由
- 支持自定义路由配置
- 支持字段类型自动映射
-
数据库优化
- 智能连接池管理
- 自动会话管理
- 事务安全保证
- 异常自动回滚
-
高级特性
- 完整的 CRUD 操作支持
- 类型安全的数据验证
- 自动 API 文档生成
- 优雅的错误处理
项目结构
fastapi_auto/
├── database.py # 数据库配置和会话管理
├── decorators.py # 核心装饰器和路由生成逻辑
├── models.py # SQLAlchemy 模型定义
├── main.py # FastAPI 应用入口
├── requirements.txt # 项目依赖
└── .env # 环境变量配置
完整代码实现
1. 环境配置(requirements.txt)
fastapi>=0.68.0
sqlalchemy>=1.4.23
pydantic>=1.8.2
uvicorn>=0.15.0
python-dotenv>=0.19.0
pymysql>=1.0.2
2. 数据库配置(.env)
DB_HOST=localhost
DB_PORT=3306
DB_USER=root
DB_PASSWORD=your_password
DB_NAME=your_database
3. 数据库连接管理(database.py)
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
import os
from dotenv import load_dotenv
from contextlib import contextmanager
from typing import Generator
from fastapi import Depends
# 加载环境变量
load_dotenv()
# 数据库配置
DB_HOST = os.getenv("DB_HOST", "localhost")
DB_PORT = os.getenv("DB_PORT", "3306")
DB_USER = os.getenv("DB_USER", "root")
DB_PASSWORD = os.getenv("DB_PASSWORD", "")
DB_NAME = os.getenv("DB_NAME", "testdb")
# 构建数据库URL
DATABASE_URL = f"mysql+pymysql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}"
# 创建数据库引擎
engine = create_engine(
DATABASE_URL,
pool_pre_ping=True, # 心跳检测
pool_recycle=3600, # 连接回收时间
pool_size=5, # 连接池大小
max_overflow=10 # 最大溢出连接数
)
# 创建会话工厂
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# 创建基类
Base = declarative_base()
@contextmanager
def get_db_session():
"""创建数据库会话的上下文管理器"""
db = SessionLocal()
try:
yield db
finally:
db.close()
def get_db() -> Generator:
"""FastAPI依赖注入的数据库会话获取器"""
with get_db_session() as db:
yield db
# 创建一个可重用的数据库依赖
db_dependency = Depends(get_db)
4. 装饰器实现(decorators.py)
from typing import Type, Any, Optional, List, Dict
from pydantic import BaseModel, Field, create_model
from sqlalchemy import Column, Integer, String, Boolean, DateTime, Float, Text
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from database import get_db, db_dependency
from datetime import datetime
from contextlib import contextmanager
class RouterBaseModel(BaseModel):
"""基础路由模型,提供标准CRUD操作"""
@staticmethod
def list(model: Type[Any], db: Session) -> List[Any]:
"""获取列表"""
try:
return db.query(model).all()
except Exception as e:
db.rollback()
raise HTTPException(status_code=500, detail=str(e))
@staticmethod
def get(model: Type[Any], db: Session, item_id: int) -> Optional[Any]:
"""获取单个项目"""
try:
item = db.query(model).filter(model.id == item_id).first()
if not item:
raise HTTPException(status_code=404, detail="Item not found")
return item
except HTTPException:
raise
except Exception as e:
db.rollback()
raise HTTPException(status_code=500, detail=str(e))
@staticmethod
def create(model: Type[Any], db: Session, item: Dict[str, Any]) -> Any:
"""创建项目"""
try:
db_item = model(**item)
db.add(db_item)
db.commit()
db.refresh(db_item)
return db_item
except Exception as e:
db.rollback()
raise HTTPException(status_code=500, detail=str(e))
@staticmethod
def update(model: Type[Any], db: Session, item_id: int, item: Dict[str, Any]) -> Any:
"""更新项目"""
try:
db_item = db.query(model).filter(model.id == item_id).first()
if not db_item:
raise HTTPException(status_code=404, detail="Item not found")
for key, value in item.items():
if hasattr(db_item, key):
setattr(db_item, key, value)
db.commit()
db.refresh(db_item)
return db_item
except HTTPException:
raise
except Exception as e:
db.rollback()
raise HTTPException(status_code=500, detail=str(e))
@staticmethod
def delete(model: Type[Any], db: Session, item_id: int) -> Dict[str, str]:
"""删除项目"""
try:
db_item = db.query(model).filter(model.id == item_id).first()
if not db_item:
raise HTTPException(status_code=404, detail="Item not found")
db.delete(db_item)
db.commit()
return {"message": "Item deleted successfully"}
except HTTPException:
raise
except Exception as e:
db.rollback()
raise HTTPException(status_code=500, detail=str(e))
def write2route(self, prefix: str, app: APIRouter, model: Type[Any]):
"""注册路由到FastAPI应用"""
@app.get(prefix + "/", response_model=List[self.__class__])
def get_items(db: Session = db_dependency):
return self.list(model, db)
@app.get(prefix + "/{item_id}", response_model=self.__class__)
def get_item(item_id: int, db: Session = db_dependency):
return self.get(model, db, item_id)
@app.post(prefix + "/", response_model=self.__class__)
def create_item(item: self.__class__, db: Session = db_dependency):
return self.create(model, db, item.dict(exclude_unset=True))
@app.put(prefix + "/{item_id}", response_model=self.__class__)
def update_item(item_id: int, item: self.__class__, db: Session = db_dependency):
return self.update(model, db, item_id, item.dict(exclude_unset=True))
@app.delete(prefix + "/{item_id}")
def delete_item(item_id: int, db: Session = db_dependency):
return self.delete(model, db, item_id)
class Config:
orm_mode = True
def get_basemodel(cls: Type[Any]) -> Type[Any]:
"""将SQLAlchemy模型转换为Pydantic模型的装饰器"""
# 类型映射
type_mapping = {
Integer: int,
String: str,
Boolean: bool,
DateTime: datetime,
Float: float,
Text: str
}
# 创建字段映射
fields = {}
for column in cls.__table__.columns:
field_type = type_mapping.get(type(column.type), str)
default = column.default.arg if column.default else ...
description = column.comment if column.comment else None
# 创建字段配置
field = Field(
default=default,
description=description,
title=column.name
)
# 添加到字段映射
fields[column.name] = (field_type, field)
# 创建Pydantic模型
model_name = f"{cls.__name__}Schema"
schema_model = create_model(
model_name,
__base__=RouterBaseModel,
**fields
)
# 将schema绑定到原始模型
cls.__model__ = schema_model()
return cls
5. 模型定义(models.py)
from sqlalchemy import Column, Integer, String, Boolean
from database import Base
from decorators import get_basemodel
@get_basemodel
class User(Base):
__tablename__ = "users"
id = Column(Integer, primary_key=True, index=True)
email = Column(String(64), unique=True, index=True)
hashed_password = Column(String(64))
is_active = Column(Boolean, default=True)
6. 应用入口(main.py)
from fastapi import FastAPI
from database import Base, engine
from models import User
# 创建数据库表
Base.metadata.create_all(bind=engine)
# 创建FastAPI应用
app = FastAPI(
title="FastAPI Auto Schema",
description="自动生成Schema和Router的FastAPI应用",
version="1.0.0"
)
# 注册用户模型的路由
User.__model__.write2route("/users", app, User)
使用说明
- 安装依赖
pip install -r requirements.txt
fastapi>=0.68.0
sqlalchemy>=1.4.23
pydantic>=1.8.2
uvicorn>=0.15.0
python-dotenv>=0.19.0
pymysql>=1.0.2
-
配置数据库
编辑.env
文件,设置数据库连接信息。 -
运行应用
uvicorn main:app --reload
- 访问API文档
打开浏览器访问http://localhost:8000/docs
性能优化
-
连接池管理
- 复用数据库连接
- 自动处理断开的连接
- 限制最大连接数
-
查询优化
- 只更新修改过的字段
- 自动处理关系映射
- 惰性加载
注意事项
- 确保正确配置数据库连接信息
- 在生产环境中适当调整连接池参数
- 根据需要自定义错误处理逻辑
- 注意保护敏感的数据库字段
总结
通过这个项目,我们实现了一个高效的 FastAPI 自动化工具,它可以:
- 自动生成 API 端点
- 确保类型安全
- 优化数据库操作
- 提供完整的错误处理
这大大减少了重复代码的编写,提高了开发效率,同时保证了代码质量和性能。