44 lines
1.2 KiB
Python
44 lines
1.2 KiB
Python
|
|
from sqlalchemy import create_engine
|
||
|
|
from sqlalchemy.orm import declarative_base, sessionmaker
|
||
|
|
from sqlalchemy.exc import OperationalError
|
||
|
|
from fastapi import HTTPException, status
|
||
|
|
|
||
|
|
from .config import settings
|
||
|
|
|
||
|
|
DATABASE_URL = settings.DATABASE_URL
|
||
|
|
|
||
|
|
connect_args = {}
|
||
|
|
if DATABASE_URL.startswith("mysql+pymysql://"):
|
||
|
|
connect_args["connect_timeout"] = 5
|
||
|
|
|
||
|
|
engine = create_engine(
|
||
|
|
DATABASE_URL,
|
||
|
|
pool_pre_ping=True,
|
||
|
|
pool_recycle=3600,
|
||
|
|
pool_size=settings.DB_POOL_SIZE if settings.DB_POOL_SIZE else 5,
|
||
|
|
max_overflow=settings.DB_POOL_MAX_OVERFLOW if settings.DB_POOL_MAX_OVERFLOW else 10,
|
||
|
|
pool_timeout=30,
|
||
|
|
future=True,
|
||
|
|
connect_args=connect_args,
|
||
|
|
)
|
||
|
|
SessionLocal = sessionmaker(bind=engine, autocommit=False, autoflush=False, future=True)
|
||
|
|
Base = declarative_base()
|
||
|
|
|
||
|
|
|
||
|
|
def get_db():
|
||
|
|
db = SessionLocal()
|
||
|
|
try:
|
||
|
|
yield db
|
||
|
|
db.commit()
|
||
|
|
except OperationalError as exc:
|
||
|
|
db.rollback()
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||
|
|
detail="Database unavailable, please check DATABASE_URL / MySQL service.",
|
||
|
|
) from exc
|
||
|
|
except Exception:
|
||
|
|
db.rollback()
|
||
|
|
raise
|
||
|
|
finally:
|
||
|
|
db.close()
|