WIP: sessions only in repository and mocels

This commit is contained in:
Keith Edmunds 2025-04-07 18:33:16 +01:00
parent f3b1e05e83
commit 5e492f4569
11 changed files with 976 additions and 672 deletions

View File

@ -46,6 +46,54 @@ def singleton(cls):
return wrapper_singleton
# DTOs
@dataclass
class PlaylistDTO:
playlist_id: int
name: str
open: bool = False
favourite: bool = False
is_template: bool = False
@dataclass
class QueryDTO:
query_id: int
name: str
favourite: bool
filter: Filter
@dataclass
class TrackDTO:
track_id: int
artist: str
bitrate: int
duration: int
fade_at: int
intro: int | None
path: str
silence_at: int
start_gap: int
title: str
lastplayed: dt.datetime | None
@dataclass
class PlaylistRowDTO(TrackDTO):
note: str
played: bool
playlist_id: int
playlistrow_id: int
row_number: int
@dataclass
class PlaydatesDTO(TrackDTO):
playdate_id: int
lastplayed: dt.datetime
class ApplicationError(Exception):
"""
Custom exception
@ -124,39 +172,6 @@ class Tags(NamedTuple):
duration: int = 0
@dataclass
class PlaylistDTO:
name: str
playlist_id: int
favourite: bool = False
is_template: bool = False
open: bool = False
@dataclass
class TrackDTO:
track_id: int
artist: str
bitrate: int
duration: int
fade_at: int
intro: int | None
path: str
silence_at: int
start_gap: int
title: str
lastplayed: dt.datetime | None
@dataclass
class PlaylistRowDTO(TrackDTO):
note: str
played: bool
playlist_id: int
playlistrow_id: int
row_number: int
class TrackInfo(NamedTuple):
track_id: int
row_number: int
@ -177,6 +192,12 @@ class InsertTrack:
note: str
@dataclass
class PlayTrack:
playlist_id: int
track_id: int
@singleton
@dataclass
class MusicMusterSignals(QObject):
@ -204,6 +225,7 @@ class MusicMusterSignals(QObject):
# specify that here as it requires us to import PlaylistRow from
# playlistrow.py, which itself imports MusicMusterSignals
signal_set_next_track = pyqtSignal(object)
signal_track_started = pyqtSignal(PlayTrack)
span_cells_signal = pyqtSignal(int, int, int, int, int)
status_message_signal = pyqtSignal(str, int)
track_ended_signal = pyqtSignal()

View File

@ -34,6 +34,7 @@ class Config(object):
COLOUR_QUERYLIST_SELECTED = "#d3ffd3"
COLOUR_UNREADABLE = "#dc3545"
COLOUR_WARNING_TIMER = "#ffc107"
DB_NOT_FOUND = "Database not found"
DBFS_SILENCE = -50
DEFAULT_COLUMN_WIDTH = 200
DISPLAY_SQL = False

View File

@ -2,13 +2,7 @@
from typing import Optional
# PyQt imports
from PyQt6.QtCore import QEvent, Qt
from PyQt6.QtGui import QKeyEvent
from PyQt6.QtWidgets import (
QDialog,
QListWidgetItem,
QMainWindow,
)
from PyQt6.QtCore import Qt
from PyQt6.QtWidgets import (
QDialog,
QHBoxLayout,
@ -16,6 +10,7 @@ from PyQt6.QtWidgets import (
QLineEdit,
QListWidget,
QListWidgetItem,
QMainWindow,
QPushButton,
QVBoxLayout,
)
@ -98,12 +93,10 @@ class TrackInsertDialog(QDialog):
self.setLayout(layout)
self.resize(800, 600)
# TODO
# record = Settings.get_setting(self.session, "dbdialog_width")
# width = record.f_int or 800
# record = Settings.get_setting(self.session, "dbdialog_height")
# height = record.f_int or 600
# self.resize(width, height)
width = repository.get_setting("dbdialog_width") or 800
height = repository.get_setting("dbdialog_height") or 800
self.resize(width, height)
self.signals = MusicMusterSignals()
@ -114,9 +107,9 @@ class TrackInsertDialog(QDialog):
return
if text.startswith("a/") and len(text) > 2:
self.tracks = repository.tracks_like_artist(text[2:])
self.tracks = repository.tracks_by_artist(text[2:])
else:
self.tracks = repository.tracks_like_title(text)
self.tracks = repository.tracks_by_title(text)
for track in self.tracks:
duration_str = ms_to_mmss(track.duration)

View File

@ -465,7 +465,7 @@ class FileImporter:
# file). Check that because the path field in the database is
# unique and so adding a duplicate will give a db integrity
# error.
if repository.track_with_path(tfd.destination_path):
if repository.track_by_path(tfd.destination_path):
tfd.error = (
"Importing a new track but destination path already exists "
f"in database ({tfd.destination_path})"

View File

@ -331,32 +331,6 @@ def normalise_track(path: str) -> None:
os.remove(temp_path)
def remove_substring_case_insensitive(parent_string: str, substring: str) -> str:
"""
Remove all instances of substring from parent string, case insensitively
"""
# Convert both strings to lowercase for case-insensitive comparison
lower_parent = parent_string.lower()
lower_substring = substring.lower()
# Initialize the result string
result = parent_string
# Continue removing the substring until it's no longer found
while lower_substring in lower_parent:
# Find the index of the substring
index = lower_parent.find(lower_substring)
# Remove the substring
result = result[:index] + result[index + len(substring) :]
# Update the lowercase versions
lower_parent = result.lower()
return result
def send_mail(to_addr: str, from_addr: str, subj: str, body: str) -> None:
# From https://docs.python.org/3/library/email.examples.html

File diff suppressed because it is too large Load Diff

View File

@ -21,7 +21,6 @@ from PyQt6.QtGui import (
)
# Third party imports
from sqlalchemy.orm.session import Session
# import snoop # type: ignore
@ -39,8 +38,9 @@ from helpers import (
show_warning,
)
from log import log
from models import db, Playdates, Tracks
from models import db, Playdates
from playlistrow import PlaylistRow
import repository
@dataclass
@ -64,7 +64,7 @@ class QuerylistModel(QAbstractTableModel):
"""
def __init__(self, session: Session, filter: Filter) -> None:
def __init__(self, filter: Filter) -> None:
"""
Load query
"""
@ -72,7 +72,6 @@ class QuerylistModel(QAbstractTableModel):
log.debug(f"QuerylistModel.__init__({filter=})")
super().__init__()
self.session = session
self.filter = filter
self.querylist_rows: dict[int, QueryRow] = {}
@ -230,7 +229,7 @@ class QuerylistModel(QAbstractTableModel):
row = 0
try:
results = Tracks.get_filtered_tracks(self.session, self.filter)
results = repository.get_filtered_tracks(self.filter)
for result in results:
lastplayed = None
if hasattr(result, "playdates"):
@ -244,7 +243,7 @@ class QuerylistModel(QAbstractTableModel):
lastplayed=lastplayed,
path=result.path,
title=result.title,
track_id=result.id,
track_id=result.track_id,
)
self.querylist_rows[row] = queryrow
@ -275,16 +274,7 @@ class QuerylistModel(QAbstractTableModel):
if column != QueryCol.LAST_PLAYED.value:
return QVariant()
with db.Session() as session:
track_id = self.querylist_rows[row].track_id
if not track_id:
return QVariant()
playdates = Playdates.last_playdates(session, track_id)
return (
"<br>".join(
[
a.lastplayed.strftime(Config.LAST_PLAYED_TOOLTIP_DATE_FORMAT)
for a in reversed(playdates)
]
)
)
track_id = self.querylist_rows[row].track_id
if not track_id:
return QVariant()
return repository.get_last_played_dates(track_id)

View File

@ -1,4 +1,5 @@
# Standard library imports
import datetime as dt
import re
# PyQt imports
@ -13,10 +14,17 @@ from sqlalchemy import (
from sqlalchemy.orm import aliased
from sqlalchemy.orm.session import Session
from sqlalchemy.sql.elements import BinaryExpression, ColumnElement
from classes import ApplicationError, PlaylistRowDTO
# App imports
from classes import PlaylistDTO, TrackDTO
from classes import (
ApplicationError,
Filter,
PlaydatesDTO,
PlaylistDTO,
PlaylistRowDTO,
QueryDTO,
TrackDTO,
)
from config import Config
from log import log, log_call
from models import (
@ -25,13 +33,14 @@ from models import (
Playdates,
PlaylistRows,
Playlists,
Queries,
Settings,
Tracks,
)
# Helper functions
@log_call
def _remove_substring_case_insensitive(parent_string: str, substring: str) -> str:
"""
Remove all instances of substring from parent string, case insensitively
@ -107,6 +116,7 @@ def get_colour(text: str, foreground: bool = False) -> str:
return rec.colour
@log_call
def remove_colour_substring(text: str) -> str:
"""
Remove text that identifies the colour to be used if strip_substring is True
@ -118,6 +128,131 @@ def remove_colour_substring(text: str) -> str:
# Track functions
@log_call
def _tracks_where(
query: BinaryExpression | ColumnElement[bool],
filter_by_last_played: bool = False,
last_played_before: dt.datetime | None = None,
) -> list[TrackDTO]:
"""
Return tracks selected by query
"""
# Alibas PlaydatesTable for subquery
LatestPlaydate = aliased(Playdates)
# Create a 'latest playdate' subquery
latest_playdate_subq = (
select(
LatestPlaydate.track_id,
func.max(LatestPlaydate.lastplayed).label("lastplayed"),
)
.group_by(LatestPlaydate.track_id)
.subquery()
)
if not filter_by_last_played:
query = query.outerjoin(
latest_playdate_subq, Tracks.id == latest_playdate_subq.c.track_id
)
else:
# We are filtering by last played. If last_played_before is None,
# we want tracks that have never been played
if last_played_before is None:
query = query.outerjoin(Playdates, Tracks.id == Playdates.track_id).where(
Playdates.id.is_(None)
)
else:
query = query.join(
latest_playdate_subq, Tracks.id == latest_playdate_subq.c.track_id
).where(latest_playdate_subq.c.max_last_played < last_played_before)
pass
stmt = select(
Tracks.id.label("track_id"),
Tracks.artist,
Tracks.bitrate,
Tracks.duration,
Tracks.fade_at,
Tracks.intro,
Tracks.path,
Tracks.silence_at,
Tracks.start_gap,
Tracks.title,
latest_playdate_subq.c.lastplayed,
).where(query)
results: list[TrackDTO] = []
with db.Session() as session:
records = session.execute(stmt).all()
for record in records:
dto = TrackDTO(
artist=record.artist,
bitrate=record.bitrate,
duration=record.duration,
fade_at=record.fade_at,
intro=record.intro,
lastplayed=record.lastplayed,
path=record.path,
silence_at=record.silence_at,
start_gap=record.start_gap,
title=record.title,
track_id=record.track_id,
)
results.append(dto)
return results
def track_by_path(path: str) -> TrackDTO | None:
"""
Return track with passed path or None
"""
track_list = _tracks_where(Tracks.path.ilike(path))
if not track_list:
return None
if len(track_list) > 1:
raise ApplicationError(f"Duplicate {path=}")
return track_list[0]
def track_by_id(track_id: int) -> TrackDTO | None:
"""
Return track with specified id
"""
track_list = _tracks_where(Tracks.id == track_id)
if not track_list:
return None
if len(track_list) > 1:
raise ApplicationError(f"Duplicate {track_id=}")
return track_list[0]
def tracks_by_artist(filter_str: str) -> list[TrackDTO]:
"""
Return tracks where artist is like filter
"""
return _tracks_where(Tracks.artist.ilike(f"%{filter_str}%"))
def tracks_by_title(filter_str: str) -> list[TrackDTO]:
"""
Return tracks where title is like filter
"""
return _tracks_where(Tracks.title.ilike(f"%{filter_str}%"))
def get_all_tracks() -> list[TrackDTO]:
"""Return a list of all tracks"""
return _tracks_where(Tracks.id > 0)
@log_call
def add_track_to_header(playlistrow_id: int, track_id: int) -> None:
"""
Add a track to this (header) row
@ -132,6 +267,7 @@ def add_track_to_header(playlistrow_id: int, track_id: int) -> None:
session.commit()
@log_call
def create_track(path: str, metadata: dict[str, str | int | float]) -> TrackDTO:
"""
Create a track db entry from a track path and return the DTO
@ -163,6 +299,7 @@ def create_track(path: str, metadata: dict[str, str | int | float]) -> TrackDTO:
return new_track
@log_call
def update_track(
path: str, track_id: int, metadata: dict[str, str | int | float]
) -> TrackDTO:
@ -192,183 +329,208 @@ def update_track(
return updated_track
def get_all_tracks() -> list[TrackDTO]:
"""Return a list of all tracks"""
return _tracks_where(Tracks.id > 0)
def track_by_id(track_id: int) -> TrackDTO | None:
@log_call
def get_filtered_tracks(filter: Filter) -> list[TrackDTO]:
"""
Return track with specified id
Return tracks matching filter
"""
# Alias PlaydatesTable for subquery
LatestPlaydate = aliased(Playdates)
# Create a base query
query = Tracks.id > 0
# Subquery: latest playdate for each track
latest_playdate_subq = (
select(
LatestPlaydate.track_id,
func.max(LatestPlaydate.lastplayed).label("lastplayed"),
)
.group_by(LatestPlaydate.track_id)
.subquery()
)
# Path specification
if filter.path:
if filter.path_type == "contains":
query = query.where(Tracks.path.ilike(f"%{filter.path}%"))
elif filter.path_type == "excluding":
query = query.where(Tracks.path.notilike(f"%{filter.path}%"))
else:
raise ApplicationError(f"Can't process filter path ({filter=})")
stmt = (
select(
Tracks.id.label("track_id"),
Tracks.artist,
Tracks.bitrate,
Tracks.duration,
Tracks.fade_at,
Tracks.intro,
Tracks.path,
Tracks.silence_at,
Tracks.start_gap,
Tracks.title,
latest_playdate_subq.c.lastplayed,
# Duration specification
seconds_duration = filter.duration_number
if filter.duration_unit == Config.FILTER_DURATION_MINUTES:
seconds_duration *= 60
elif filter.duration_unit != Config.FILTER_DURATION_SECONDS:
raise ApplicationError(f"Can't process filter duration ({filter=})")
if filter.duration_type == Config.FILTER_DURATION_LONGER:
query = query.where(Tracks.duration >= seconds_duration)
elif filter.duration_unit == Config.FILTER_DURATION_SHORTER:
query = query.where(Tracks.duration <= seconds_duration)
else:
raise ApplicationError(f"Can't process filter duration type ({filter=})")
# Process comparator
if filter.last_played_comparator == Config.FILTER_PLAYED_COMPARATOR_ANYTIME:
return _tracks_where(query, filter_by_last_played=False)
elif filter.last_played_comparator == Config.FILTER_PLAYED_COMPARATOR_NEVER:
return _tracks_where(query, filter_by_last_played=True, last_played_before=None)
else:
# Last played specification
now = dt.datetime.now()
# Set sensible default, and correct for Config.FILTER_PLAYED_COMPARATOR_ANYTIME
before = now
# If not ANYTIME, set 'before' appropriates
if filter.last_played_unit == Config.FILTER_PLAYED_DAYS:
before = now - dt.timedelta(days=filter.last_played_number)
elif filter.last_played_unit == Config.FILTER_PLAYED_WEEKS:
before = now - dt.timedelta(days=7 * filter.last_played_number)
elif filter.last_played_unit == Config.FILTER_PLAYED_MONTHS:
before = now - dt.timedelta(days=30 * filter.last_played_number)
elif filter.last_played_unit == Config.FILTER_PLAYED_YEARS:
before = now - dt.timedelta(days=365 * filter.last_played_number)
else:
raise ApplicationError("Can't determine last played criteria")
return _tracks_where(
query, filter_by_last_played=True, last_played_before=before
)
.outerjoin(latest_playdate_subq, Tracks.id == latest_playdate_subq.c.track_id)
.where(Tracks.id == track_id)
)
def set_track_intro(track_id: int, intro: int) -> None:
"""
Set track intro time
"""
with db.Session() as session:
record = session.execute(stmt).one_or_none()
if not record:
return None
dto = TrackDTO(
artist=record.artist,
bitrate=record.bitrate,
duration=record.duration,
fade_at=record.fade_at,
intro=record.intro,
lastplayed=record.lastplayed,
path=record.path,
silence_at=record.silence_at,
start_gap=record.start_gap,
title=record.title,
track_id=record.track_id,
session.execute(
update(Tracks)
.where(Tracks.id == track_id)
.values(intro=intro)
)
return dto
session.commit()
def _tracks_where(where: BinaryExpression | ColumnElement[bool]) -> list[TrackDTO]:
# Playlist functions
@log_call
def _playlists_where(
query: BinaryExpression | ColumnElement[bool],
) -> list[PlaylistDTO]:
"""
Return tracks selected by where
Return playlists selected by query
"""
# Alias PlaydatesTable for subquery
LatestPlaydate = aliased(Playdates)
stmt = select(
Playlists.favourite,
Playlists.is_template,
Playlists.id.label("playlist_id"),
Playlists.name,
Playlists.open,
).where(query)
# Subquery: latest playdate for each track
latest_playdate_subq = (
select(
LatestPlaydate.track_id,
func.max(LatestPlaydate.lastplayed).label("lastplayed"),
)
.group_by(LatestPlaydate.track_id)
.subquery()
)
stmt = (
select(
Tracks.id.label("track_id"),
Tracks.artist,
Tracks.bitrate,
Tracks.duration,
Tracks.fade_at,
Tracks.intro,
Tracks.path,
Tracks.silence_at,
Tracks.start_gap,
Tracks.title,
latest_playdate_subq.c.lastplayed,
)
.outerjoin(latest_playdate_subq, Tracks.id == latest_playdate_subq.c.track_id)
.where(where)
)
results: list[TrackDTO] = []
results: list[PlaylistDTO] = []
with db.Session() as session:
records = session.execute(stmt).all()
for record in records:
dto = TrackDTO(
artist=record.artist,
bitrate=record.bitrate,
duration=record.duration,
fade_at=record.fade_at,
intro=record.intro,
lastplayed=record.lastplayed,
path=record.path,
silence_at=record.silence_at,
start_gap=record.start_gap,
title=record.title,
track_id=record.track_id,
dto = PlaylistDTO(
favourite=record.favourite,
is_template=record.is_template,
playlist_id=record.playlist_id,
name=record.name,
open=record.open,
)
results.append(dto)
return results
def track_with_path(path: str) -> bool:
@log_call
def playlist_by_id(playlist_id: int) -> PlaylistDTO | None:
"""
Return True if a track with passed path exists, else False
Return playlist with specified id
"""
playlist_list = _playlists_where(Playlists.id == playlist_id)
if not playlist_list:
return None
if len(playlist_list) > 1:
raise ApplicationError(f"Duplicate {playlist_id=}")
return playlist_list[0]
def playlists_closed() -> list[PlaylistDTO]:
"""
Return a list of closed playlists
"""
return _playlists_where(Playlists.open.is_(False))
def playlists_open() -> list[PlaylistDTO]:
"""
Return a list of open playlists
"""
return _playlists_where(Playlists.open.is_(True))
def playlists_template_by_id(playlist_id: int) -> PlaylistDTO | None:
"""
Return a list of closed playlists
"""
playlist_list = _playlists_where(
Playlists.playlist_id == playlist_id, Playlists.is_template.is_(True)
)
if not playlist_list:
return None
if len(playlist_list) > 1:
raise ApplicationError(f"Duplicate {playlist_id=}")
return playlist_list[0]
def playlists_templates() -> list[PlaylistDTO]:
"""
Return a list of playlist templates
"""
return _playlists_where(Playlists.is_template.is_(True))
def get_all_playlists():
"""Return all playlists"""
return _playlists_where(Playlists.id > 0)
def delete_playlist(playlist_id: int) -> None:
"""Delete playlist"""
with db.Session() as session:
query = session.get(Playlists, playlist_id)
session.delete(query)
session.commit()
def save_as_template(playlist_id: int, template_name: str) -> None:
"""
Save playlist as templated
"""
new_template = create_playlist(template_name, 0, as_template=True)
copy_playlist(playlist_id, new_template.id)
def playlist_rename(playlist_id: int, new_name: str) -> None:
"""
Rename playlist
"""
with db.Session() as session:
track = (
session.execute(select(Tracks).where(Tracks.path == path))
.scalars()
.one_or_none()
session.execute(
update(Playlists)
.where(Playlists.id == playlist_id)
.values(name=new_name)
)
return track is not None
session.commit()
def tracks_like_artist(filter_str: str) -> list[TrackDTO]:
"""
Return tracks where artist is like filter
"""
return _tracks_where(Tracks.artist.ilike(f"%{filter_str}%"))
def tracks_like_title(filter_str: str) -> list[TrackDTO]:
"""
Return tracks where title is like filter
"""
return _tracks_where(Tracks.title.ilike(f"%{filter_str}%"))
def get_last_played_dates(track_id: int, limit: int = 5) -> str:
"""
Return the most recent 'limit' dates that this track has been played
as a text list
"""
with db.Session() as session:
playdates = session.scalars(
Playdates.select()
.where(Playdates.track_id == track_id)
.order_by(Playdates.lastplayed.desc())
.limit(limit)
).all()
return "<br>".join(
[
a.lastplayed.strftime(Config.LAST_PLAYED_TOOLTIP_DATE_FORMAT)
for a in playdates
]
)
# Playlist functions
def _check_playlist_integrity(
session: Session, playlist_id: int, fix: bool = False
) -> None:
@ -401,6 +563,20 @@ def _check_playlist_integrity(
raise ApplicationError(msg)
@log_call
def playlist_mark_status(playlist_id: int, open: bool) -> None:
"""Mark playlist as open or closed"""
with db.Session() as session:
session.execute(
update(Playlists)
.where(Playlists.id == playlist_id)
.values(open=open)
)
session.commit()
@log_call
def _shift_rows(
session: Session, playlist_id: int, starting_row: int, shift_by: int
@ -512,15 +688,7 @@ def move_rows(
_check_playlist_integrity(session, to_playlist_id, fix=False)
def update_playdates(track_id: int) -> None:
"""
Update playdates for passed track
"""
with db.Session() as session:
_ = Playdates(session, track_id)
@log_call
def update_row_numbers(
playlist_id: int, id_to_row_number: list[dict[int, int]]
) -> None:
@ -537,7 +705,8 @@ def update_row_numbers(
_check_playlist_integrity(session, playlist_id, fix=False)
def create_playlist(name: str, template_id: int) -> PlaylistDTO:
@log_call
def create_playlist(name: str, template_id: int, as_template: bool = False) -> PlaylistDTO:
"""
Create playlist and return DTO.
"""
@ -545,11 +714,15 @@ def create_playlist(name: str, template_id: int) -> PlaylistDTO:
with db.Session() as session:
try:
playlist = Playlists(session, name, template_id)
playlist.is_template = as_template
playlist_id = playlist.id
session.commit()
except Exception:
raise ApplicationError("Can't create Playlist")
if template_id != 0:
copy_playlist(template_id, playlist_id)
new_playlist = playlist_by_id(playlist_id)
if not new_playlist:
raise ApplicationError("Can't retrieve new Playlist")
@ -557,6 +730,7 @@ def create_playlist(name: str, template_id: int) -> PlaylistDTO:
return new_playlist
@log_call
def get_playlist_row(playlistrow_id: int) -> PlaylistRowDTO | None:
"""
Return specific row DTO
@ -746,6 +920,40 @@ def get_playlist_rows(
return dto_list
def copy_playlist(src_id: int, dst_id: int) -> None:
"""Copy playlist entries"""
with db.Session() as session:
src_rows = session.scalars(
select(PlaylistRows).filter(PlaylistRows.playlist_id == src_id)
).all()
for plr in src_rows:
PlaylistRows(
session=session,
playlist_id=dst_id,
row_number=plr.row_number,
note=plr.note,
track_id=plr.track_id,
)
def playlist_row_count(playlist_id: int) -> int:
"""
Return number of rows in playlist
"""
with db.Session() as session:
count = session.scalar(
select(func.count())
.select_from(PlaylistRows)
.where(PlaylistRows.playlist_id == playlist_id)
)
return count
@log_call
def insert_row(
playlist_id: int, row_number: int, track_id: int | None, note: str
) -> PlaylistRowDTO:
@ -825,33 +1033,220 @@ def remove_rows(playlist_id: int, row_numbers: list[int]) -> None:
session.commit()
def playlist_by_id(playlist_id: int) -> PlaylistDTO | None:
@log_call
def update_template_favourite(template_id: int, favourite: bool) -> None:
"""Update template favourite"""
with db.Session() as session:
session.execute(
update(Playlists)
.where(Playlists.id == template_id)
.values(favourite=favourite)
)
session.commit()
@log_call
def playlist_save_tabs(playlist_id_to_tab: dict[int, int]) -> None:
"""
Return playlist with specified id
Save the tab numbers of the open playlists.
"""
with db.Session() as session:
# Clear all existing tab numbers
session.execute(
update(Playlists)
.where(Playlists.id.in_(playlist_id_to_tab.keys()))
.values(tab=None)
)
for (playlist_id, tab) in playlist_id_to_tab.items():
session.execute(
update(Playlists)
.where(Playlists.id == playlist_id)
.values(tab=tab)
)
session.commit()
# Playdates
@log_call
def get_last_played_dates(track_id: int, limit: int = 5) -> str:
"""
Return the most recent 'limit' dates that this track has been played
as a text list
"""
with db.Session() as session:
playdates = session.scalars(
Playdates.select()
.where(Playdates.track_id == track_id)
.order_by(Playdates.lastplayed.desc())
.limit(limit)
).all()
return "<br>".join(
[
a.lastplayed.strftime(Config.LAST_PLAYED_TOOLTIP_DATE_FORMAT)
for a in playdates
]
)
def update_playdates(track_id: int) -> None:
"""
Update playdates for passed track
"""
with db.Session() as session:
_ = Playdates(session, track_id)
def playdates_between_dates(
start: dt.datetime, end: dt.datetime | None = None
) -> list[PlaydatesDTO]:
"""
Return a list of PlaydateDTO objects from between times (until now if end is None)
"""
if end is None:
end = dt.datetime.now()
stmt = select(
Playdates.id.label("playdate_id"),
Playdates.lastplayed,
Playdates.track_id,
Playdates.track,
).where(
Playdates.lastplayed >= start,
Playdates.lastplayed <= end
)
results: list[PlaydatesDTO] = []
with db.Session() as session:
records = session.execute(stmt).all()
for record in records:
dto = PlaydatesDTO(
playdate_id=record.playdate_id,
lastplayed=record.lastplayed,
track_id=record.track_id,
artist=record.track.artist,
bitrate=record.track.bitrate,
duration=record.track.duration,
fade_at=record.track.fade_at,
intro=record.track.intro,
path=record.track.path,
silence_at=record.track.silence_at,
start_gap=record.track.start_gap,
title=record.track.title,
)
results.append(dto)
return results
# Queries
@log_call
def _queries_where(
query: BinaryExpression | ColumnElement[bool],
) -> list[QueryDTO]:
"""
Return queries selected by query
"""
stmt = select(
Playlists.id.label("playlist_id"),
Playlists.name,
Playlists.favourite,
Playlists.is_template,
Playlists.open,
).where(Playlists.id == playlist_id)
Queries.id.label("query_id"), Queries.name, Queries.favourite, Queries.filter
).where(query)
results: list[QueryDTO] = []
with db.Session() as session:
record = session.execute(stmt).one_or_none()
if not record:
return None
records = session.execute(stmt).one_or_none()
for record in records:
dto = QueryDTO(
favourite=record.favourite,
filter=record.filter,
name=record.name,
query_id=record.query_id,
)
results.append(dto)
dto = PlaylistDTO(
name=record.name,
playlist_id=record.playlist_id,
favourite=record.favourite,
is_template=record.is_template,
open=record.open,
return results
def get_all_queries(favourites_only: bool = False) -> list[QueryDTO]:
"""Return a list of all queries"""
query = Queries.id > 0
return _queries_where(query)
def query_by_id(query_id: int) -> QueryDTO | None:
"""Return query"""
query_list = _queries_where(Queries.id == query_id)
if not query_list:
return None
if len(query_list) > 1:
raise ApplicationError(f"Duplicate {query_id=}")
return query_list[0]
def update_query_filter(query_id: int, filter: Filter) -> None:
"""Update query filter"""
with db.Session() as session:
session.execute(
update(Queries).where(Queries.id == query_id).values(filter=filter)
)
session.commit()
return dto
def delete_query(query_id: int) -> None:
"""Delete query"""
with db.Session() as session:
query = session.get(Queries, query_id)
session.delete(query)
session.commit()
def update_query_name(query_id: int, name: str) -> None:
"""Update query name"""
with db.Session() as session:
session.execute(update(Queries).where(Queries.id == query_id).values(name=name))
session.commit()
def update_query_favourite(query_id: int, favourite: bool) -> None:
"""Update query favourite"""
with db.Session() as session:
session.execute(
update(Queries).where(Queries.id == query_id).values(favourite=favourite)
)
session.commit()
def create_query(name: str, filter: Filter) -> QueryDTO:
"""
Create a query and return the DTO
"""
with db.Session() as session:
try:
query = Queries(session=session, name=name, filter=filter)
query_id = query.id
session.commit()
except Exception:
raise ApplicationError("Can't create Query")
new_query = query_by_id(query_id)
if not new_query:
raise ApplicationError("Unable to create new query")
return new_query
# Misc
@ -889,3 +1284,14 @@ def set_setting(name: str, value: int) -> None:
raise ApplicationError("Can't create Settings record")
record.f_int = value
session.commit()
def get_db_name() -> str:
"""Return database name"""
with db.Session() as session:
if session.bind:
dbname = session.bind.engine.url.database
return dbname
return Config.DB_NOT_FOUND

View File

@ -5,7 +5,6 @@ import os
# PyQt imports
# Third party imports
from sqlalchemy.orm.session import Session
# App imports
from config import Config
@ -13,10 +12,10 @@ from helpers import (
get_tags,
)
from log import log
from models import Tracks
import repository
def check_db(session: Session) -> None:
def check_db() -> None:
"""
Database consistency check.
@ -27,7 +26,7 @@ def check_db(session: Session) -> None:
Check all paths in database exist
"""
db_paths = set([a.path for a in Tracks.get_all(session)])
db_paths = set([a.path for a in repository.get_all_tracks()])
os_paths_list = []
for root, _dirs, files in os.walk(Config.ROOT):
@ -52,7 +51,7 @@ def check_db(session: Session) -> None:
missing_file_count += 1
track = Tracks.get_by_path(session, path)
track = repository.track_by_path(path)
if not track:
# This shouldn't happen as we're looking for paths in
# database that aren't in filesystem, but just in case...
@ -74,7 +73,7 @@ def check_db(session: Session) -> None:
for t in paths_not_found:
print(
f"""
Track ID: {t.id}
Track ID: {t.track_id}
Path: {t.path}
Title: {t.title}
Artist: {t.artist}
@ -84,14 +83,15 @@ def check_db(session: Session) -> None:
print("There were more paths than listed that were not found")
def update_bitrates(session: Session) -> None:
def update_bitrates() -> None:
"""
Update bitrates on all tracks in database
"""
for track in Tracks.get_all(session):
for track in repository.get_all_tracks():
try:
t = get_tags(track.path)
# TODO this won't persist as we're updating DTO
track.bitrate = t.bitrate
except FileNotFoundError:
continue

View File

@ -131,10 +131,10 @@ class MyTestCase(unittest.TestCase):
_ = repository.create_track(self.isa_path, metadata)
metadata = get_all_track_metadata(self.mom_path)
_ = repository.create_track(self.mom_path, metadata)
result_isa = repository.tracks_like_title(self.isa_title)
result_isa = repository.tracks_by_title(self.isa_title)
assert len(result_isa) == 1
assert result_isa[0].title == self.isa_title
result_mom = repository.tracks_like_title(self.mom_title)
result_mom = repository.tracks_by_title(self.mom_title)
assert len(result_mom) == 1
assert result_mom[0].title == self.mom_title

View File

@ -132,9 +132,8 @@ class MyTestCase(unittest.TestCase):
Config.ROOT = os.path.join(os.path.dirname(__file__), "testdata")
with db.Session() as session:
utilities.check_db(session)
utilities.update_bitrates(session)
utilities.check_db()
utilities.update_bitrates()
# def test_meta_all_clear(qtbot, session):