diff --git a/app/dbconfig.py b/app/dbconfig.py new file mode 100644 index 0000000..ca71d80 --- /dev/null +++ b/app/dbconfig.py @@ -0,0 +1,37 @@ +import os +import sqlalchemy + +from config import Config +from sqlalchemy.orm import (sessionmaker, scoped_session) + +MM_ENV = os.environ.get('MM_ENV', 'PRODUCTION') + +if MM_ENV == 'PRODUCTION': + dbname = os.environ.get('MM_PRODUCTION_DBNAME', 'musicmuster') + dbuser = os.environ.get('MM_PRODUCTION_DBUSER', 'musicmuster') + dbpw = os.environ.get('MM_PRODUCTION_DBPW', 'musicmuster') + dbhost = os.environ.get('MM_PRODUCTION_DBHOST', 'localhost') +elif MM_ENV == 'TESTING': + dbname = os.environ.get('MM_TESTING_DBNAME', 'musicmuster') + dbuser = os.environ.get('MM_TESTING_DBUSER', 'musicmuster') + dbpw = os.environ.get('MM_TESTING_DBPW', 'musicmuster') + dbhost = os.environ.get('MM_TESTING_DBHOST', 'localhost') +elif MM_ENV == 'DEVELOPMENT': + dbname = os.environ.get('MM_DEVELOPMENT_DBNAME', 'musicmuster') + dbuser = os.environ.get('MM_DEVELOPMENT_DBUSER', 'musicmuster') + dbpw = os.environ.get('MM_DEVELOPMENT_DBPW', 'musicmuster') + dbhost = os.environ.get('MM_DEVELOPMENT_DBHOST', 'localhost') +else: + raise ValueError(f"Unknown MusicMuster environment: {MM_ENV=}") + +MYSQL_CONNECT = f"mysql+mysqldb://{dbuser}:{dbpw}@{dbhost}/{dbname}" + +engine = sqlalchemy.create_engine( + MYSQL_CONNECT, + encoding='utf-8', + echo=Config.DISPLAY_SQL, + pool_pre_ping=True +) + +Session = scoped_session(sessionmaker(bind=engine)) + diff --git a/app/models.py b/app/models.py index 07b2275..b46cdae 100644 --- a/app/models.py +++ b/app/models.py @@ -3,7 +3,6 @@ import os.path import re -import sqlalchemy from datetime import datetime from typing import List, Optional @@ -25,8 +24,7 @@ from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import ( backref, relationship, - sessionmaker, - scoped_session, RelationshipProperty + RelationshipProperty ) from sqlalchemy.orm.collections import attribute_mapped_collection from sqlalchemy.orm.exc import MultipleResultsFound, NoResultFound @@ -40,21 +38,6 @@ from helpers import ( ) from log import DEBUG, ERROR -# Create session at the global level as per -# https://docs.sqlalchemy.org/en/13/orm/session_basics.html -# and make objects persistent -# https://docs.sqlalchemy.org/en/14/orm/session_state_management.html - -engine = sqlalchemy.create_engine( - f"{Config.MYSQL_CONNECT}?charset=utf8", - encoding='utf-8', - echo=Config.DISPLAY_SQL, - pool_pre_ping=True) - -# Create a Session factory -Session = scoped_session(sessionmaker(bind=engine)) -# sm: sessionmaker = sessionmaker(bind=engine) # , expire_on_commit=False) -# Session = scoped_session(sm) Base: DeclarativeMeta = declarative_base() Base.metadata.create_all(engine)