WIP: sessions only in repository and mocels

This commit is contained in:
Keith Edmunds 2025-04-07 18:33:16 +01:00
parent f3b1e05e83
commit 5e492f4569
11 changed files with 976 additions and 672 deletions

View File

@ -46,6 +46,54 @@ def singleton(cls):
return wrapper_singleton return wrapper_singleton
# DTOs
@dataclass
class PlaylistDTO:
playlist_id: int
name: str
open: bool = False
favourite: bool = False
is_template: bool = False
@dataclass
class QueryDTO:
query_id: int
name: str
favourite: bool
filter: Filter
@dataclass
class TrackDTO:
track_id: int
artist: str
bitrate: int
duration: int
fade_at: int
intro: int | None
path: str
silence_at: int
start_gap: int
title: str
lastplayed: dt.datetime | None
@dataclass
class PlaylistRowDTO(TrackDTO):
note: str
played: bool
playlist_id: int
playlistrow_id: int
row_number: int
@dataclass
class PlaydatesDTO(TrackDTO):
playdate_id: int
lastplayed: dt.datetime
class ApplicationError(Exception): class ApplicationError(Exception):
""" """
Custom exception Custom exception
@ -124,39 +172,6 @@ class Tags(NamedTuple):
duration: int = 0 duration: int = 0
@dataclass
class PlaylistDTO:
name: str
playlist_id: int
favourite: bool = False
is_template: bool = False
open: bool = False
@dataclass
class TrackDTO:
track_id: int
artist: str
bitrate: int
duration: int
fade_at: int
intro: int | None
path: str
silence_at: int
start_gap: int
title: str
lastplayed: dt.datetime | None
@dataclass
class PlaylistRowDTO(TrackDTO):
note: str
played: bool
playlist_id: int
playlistrow_id: int
row_number: int
class TrackInfo(NamedTuple): class TrackInfo(NamedTuple):
track_id: int track_id: int
row_number: int row_number: int
@ -177,6 +192,12 @@ class InsertTrack:
note: str note: str
@dataclass
class PlayTrack:
playlist_id: int
track_id: int
@singleton @singleton
@dataclass @dataclass
class MusicMusterSignals(QObject): class MusicMusterSignals(QObject):
@ -204,6 +225,7 @@ class MusicMusterSignals(QObject):
# specify that here as it requires us to import PlaylistRow from # specify that here as it requires us to import PlaylistRow from
# playlistrow.py, which itself imports MusicMusterSignals # playlistrow.py, which itself imports MusicMusterSignals
signal_set_next_track = pyqtSignal(object) signal_set_next_track = pyqtSignal(object)
signal_track_started = pyqtSignal(PlayTrack)
span_cells_signal = pyqtSignal(int, int, int, int, int) span_cells_signal = pyqtSignal(int, int, int, int, int)
status_message_signal = pyqtSignal(str, int) status_message_signal = pyqtSignal(str, int)
track_ended_signal = pyqtSignal() track_ended_signal = pyqtSignal()

View File

@ -34,6 +34,7 @@ class Config(object):
COLOUR_QUERYLIST_SELECTED = "#d3ffd3" COLOUR_QUERYLIST_SELECTED = "#d3ffd3"
COLOUR_UNREADABLE = "#dc3545" COLOUR_UNREADABLE = "#dc3545"
COLOUR_WARNING_TIMER = "#ffc107" COLOUR_WARNING_TIMER = "#ffc107"
DB_NOT_FOUND = "Database not found"
DBFS_SILENCE = -50 DBFS_SILENCE = -50
DEFAULT_COLUMN_WIDTH = 200 DEFAULT_COLUMN_WIDTH = 200
DISPLAY_SQL = False DISPLAY_SQL = False

View File

@ -2,13 +2,7 @@
from typing import Optional from typing import Optional
# PyQt imports # PyQt imports
from PyQt6.QtCore import QEvent, Qt from PyQt6.QtCore import Qt
from PyQt6.QtGui import QKeyEvent
from PyQt6.QtWidgets import (
QDialog,
QListWidgetItem,
QMainWindow,
)
from PyQt6.QtWidgets import ( from PyQt6.QtWidgets import (
QDialog, QDialog,
QHBoxLayout, QHBoxLayout,
@ -16,6 +10,7 @@ from PyQt6.QtWidgets import (
QLineEdit, QLineEdit,
QListWidget, QListWidget,
QListWidgetItem, QListWidgetItem,
QMainWindow,
QPushButton, QPushButton,
QVBoxLayout, QVBoxLayout,
) )
@ -98,12 +93,10 @@ class TrackInsertDialog(QDialog):
self.setLayout(layout) self.setLayout(layout)
self.resize(800, 600) self.resize(800, 600)
# TODO
# record = Settings.get_setting(self.session, "dbdialog_width") width = repository.get_setting("dbdialog_width") or 800
# width = record.f_int or 800 height = repository.get_setting("dbdialog_height") or 800
# record = Settings.get_setting(self.session, "dbdialog_height") self.resize(width, height)
# height = record.f_int or 600
# self.resize(width, height)
self.signals = MusicMusterSignals() self.signals = MusicMusterSignals()
@ -114,9 +107,9 @@ class TrackInsertDialog(QDialog):
return return
if text.startswith("a/") and len(text) > 2: if text.startswith("a/") and len(text) > 2:
self.tracks = repository.tracks_like_artist(text[2:]) self.tracks = repository.tracks_by_artist(text[2:])
else: else:
self.tracks = repository.tracks_like_title(text) self.tracks = repository.tracks_by_title(text)
for track in self.tracks: for track in self.tracks:
duration_str = ms_to_mmss(track.duration) duration_str = ms_to_mmss(track.duration)

View File

@ -465,7 +465,7 @@ class FileImporter:
# file). Check that because the path field in the database is # file). Check that because the path field in the database is
# unique and so adding a duplicate will give a db integrity # unique and so adding a duplicate will give a db integrity
# error. # error.
if repository.track_with_path(tfd.destination_path): if repository.track_by_path(tfd.destination_path):
tfd.error = ( tfd.error = (
"Importing a new track but destination path already exists " "Importing a new track but destination path already exists "
f"in database ({tfd.destination_path})" f"in database ({tfd.destination_path})"

View File

@ -331,32 +331,6 @@ def normalise_track(path: str) -> None:
os.remove(temp_path) os.remove(temp_path)
def remove_substring_case_insensitive(parent_string: str, substring: str) -> str:
"""
Remove all instances of substring from parent string, case insensitively
"""
# Convert both strings to lowercase for case-insensitive comparison
lower_parent = parent_string.lower()
lower_substring = substring.lower()
# Initialize the result string
result = parent_string
# Continue removing the substring until it's no longer found
while lower_substring in lower_parent:
# Find the index of the substring
index = lower_parent.find(lower_substring)
# Remove the substring
result = result[:index] + result[index + len(substring) :]
# Update the lowercase versions
lower_parent = result.lower()
return result
def send_mail(to_addr: str, from_addr: str, subj: str, body: str) -> None: def send_mail(to_addr: str, from_addr: str, subj: str, body: str) -> None:
# From https://docs.python.org/3/library/email.examples.html # From https://docs.python.org/3/library/email.examples.html

File diff suppressed because it is too large Load Diff

View File

@ -21,7 +21,6 @@ from PyQt6.QtGui import (
) )
# Third party imports # Third party imports
from sqlalchemy.orm.session import Session
# import snoop # type: ignore # import snoop # type: ignore
@ -39,8 +38,9 @@ from helpers import (
show_warning, show_warning,
) )
from log import log from log import log
from models import db, Playdates, Tracks from models import db, Playdates
from playlistrow import PlaylistRow from playlistrow import PlaylistRow
import repository
@dataclass @dataclass
@ -64,7 +64,7 @@ class QuerylistModel(QAbstractTableModel):
""" """
def __init__(self, session: Session, filter: Filter) -> None: def __init__(self, filter: Filter) -> None:
""" """
Load query Load query
""" """
@ -72,7 +72,6 @@ class QuerylistModel(QAbstractTableModel):
log.debug(f"QuerylistModel.__init__({filter=})") log.debug(f"QuerylistModel.__init__({filter=})")
super().__init__() super().__init__()
self.session = session
self.filter = filter self.filter = filter
self.querylist_rows: dict[int, QueryRow] = {} self.querylist_rows: dict[int, QueryRow] = {}
@ -230,7 +229,7 @@ class QuerylistModel(QAbstractTableModel):
row = 0 row = 0
try: try:
results = Tracks.get_filtered_tracks(self.session, self.filter) results = repository.get_filtered_tracks(self.filter)
for result in results: for result in results:
lastplayed = None lastplayed = None
if hasattr(result, "playdates"): if hasattr(result, "playdates"):
@ -244,7 +243,7 @@ class QuerylistModel(QAbstractTableModel):
lastplayed=lastplayed, lastplayed=lastplayed,
path=result.path, path=result.path,
title=result.title, title=result.title,
track_id=result.id, track_id=result.track_id,
) )
self.querylist_rows[row] = queryrow self.querylist_rows[row] = queryrow
@ -275,16 +274,7 @@ class QuerylistModel(QAbstractTableModel):
if column != QueryCol.LAST_PLAYED.value: if column != QueryCol.LAST_PLAYED.value:
return QVariant() return QVariant()
with db.Session() as session:
track_id = self.querylist_rows[row].track_id track_id = self.querylist_rows[row].track_id
if not track_id: if not track_id:
return QVariant() return QVariant()
playdates = Playdates.last_playdates(session, track_id) return repository.get_last_played_dates(track_id)
return (
"<br>".join(
[
a.lastplayed.strftime(Config.LAST_PLAYED_TOOLTIP_DATE_FORMAT)
for a in reversed(playdates)
]
)
)

View File

@ -1,4 +1,5 @@
# Standard library imports # Standard library imports
import datetime as dt
import re import re
# PyQt imports # PyQt imports
@ -13,10 +14,17 @@ from sqlalchemy import (
from sqlalchemy.orm import aliased from sqlalchemy.orm import aliased
from sqlalchemy.orm.session import Session from sqlalchemy.orm.session import Session
from sqlalchemy.sql.elements import BinaryExpression, ColumnElement from sqlalchemy.sql.elements import BinaryExpression, ColumnElement
from classes import ApplicationError, PlaylistRowDTO
# App imports # App imports
from classes import PlaylistDTO, TrackDTO from classes import (
ApplicationError,
Filter,
PlaydatesDTO,
PlaylistDTO,
PlaylistRowDTO,
QueryDTO,
TrackDTO,
)
from config import Config from config import Config
from log import log, log_call from log import log, log_call
from models import ( from models import (
@ -25,13 +33,14 @@ from models import (
Playdates, Playdates,
PlaylistRows, PlaylistRows,
Playlists, Playlists,
Queries,
Settings, Settings,
Tracks, Tracks,
) )
# Helper functions # Helper functions
@log_call
def _remove_substring_case_insensitive(parent_string: str, substring: str) -> str: def _remove_substring_case_insensitive(parent_string: str, substring: str) -> str:
""" """
Remove all instances of substring from parent string, case insensitively Remove all instances of substring from parent string, case insensitively
@ -107,6 +116,7 @@ def get_colour(text: str, foreground: bool = False) -> str:
return rec.colour return rec.colour
@log_call
def remove_colour_substring(text: str) -> str: def remove_colour_substring(text: str) -> str:
""" """
Remove text that identifies the colour to be used if strip_substring is True Remove text that identifies the colour to be used if strip_substring is True
@ -118,6 +128,131 @@ def remove_colour_substring(text: str) -> str:
# Track functions # Track functions
@log_call
def _tracks_where(
query: BinaryExpression | ColumnElement[bool],
filter_by_last_played: bool = False,
last_played_before: dt.datetime | None = None,
) -> list[TrackDTO]:
"""
Return tracks selected by query
"""
# Alibas PlaydatesTable for subquery
LatestPlaydate = aliased(Playdates)
# Create a 'latest playdate' subquery
latest_playdate_subq = (
select(
LatestPlaydate.track_id,
func.max(LatestPlaydate.lastplayed).label("lastplayed"),
)
.group_by(LatestPlaydate.track_id)
.subquery()
)
if not filter_by_last_played:
query = query.outerjoin(
latest_playdate_subq, Tracks.id == latest_playdate_subq.c.track_id
)
else:
# We are filtering by last played. If last_played_before is None,
# we want tracks that have never been played
if last_played_before is None:
query = query.outerjoin(Playdates, Tracks.id == Playdates.track_id).where(
Playdates.id.is_(None)
)
else:
query = query.join(
latest_playdate_subq, Tracks.id == latest_playdate_subq.c.track_id
).where(latest_playdate_subq.c.max_last_played < last_played_before)
pass
stmt = select(
Tracks.id.label("track_id"),
Tracks.artist,
Tracks.bitrate,
Tracks.duration,
Tracks.fade_at,
Tracks.intro,
Tracks.path,
Tracks.silence_at,
Tracks.start_gap,
Tracks.title,
latest_playdate_subq.c.lastplayed,
).where(query)
results: list[TrackDTO] = []
with db.Session() as session:
records = session.execute(stmt).all()
for record in records:
dto = TrackDTO(
artist=record.artist,
bitrate=record.bitrate,
duration=record.duration,
fade_at=record.fade_at,
intro=record.intro,
lastplayed=record.lastplayed,
path=record.path,
silence_at=record.silence_at,
start_gap=record.start_gap,
title=record.title,
track_id=record.track_id,
)
results.append(dto)
return results
def track_by_path(path: str) -> TrackDTO | None:
"""
Return track with passed path or None
"""
track_list = _tracks_where(Tracks.path.ilike(path))
if not track_list:
return None
if len(track_list) > 1:
raise ApplicationError(f"Duplicate {path=}")
return track_list[0]
def track_by_id(track_id: int) -> TrackDTO | None:
"""
Return track with specified id
"""
track_list = _tracks_where(Tracks.id == track_id)
if not track_list:
return None
if len(track_list) > 1:
raise ApplicationError(f"Duplicate {track_id=}")
return track_list[0]
def tracks_by_artist(filter_str: str) -> list[TrackDTO]:
"""
Return tracks where artist is like filter
"""
return _tracks_where(Tracks.artist.ilike(f"%{filter_str}%"))
def tracks_by_title(filter_str: str) -> list[TrackDTO]:
"""
Return tracks where title is like filter
"""
return _tracks_where(Tracks.title.ilike(f"%{filter_str}%"))
def get_all_tracks() -> list[TrackDTO]:
"""Return a list of all tracks"""
return _tracks_where(Tracks.id > 0)
@log_call
def add_track_to_header(playlistrow_id: int, track_id: int) -> None: def add_track_to_header(playlistrow_id: int, track_id: int) -> None:
""" """
Add a track to this (header) row Add a track to this (header) row
@ -132,6 +267,7 @@ def add_track_to_header(playlistrow_id: int, track_id: int) -> None:
session.commit() session.commit()
@log_call
def create_track(path: str, metadata: dict[str, str | int | float]) -> TrackDTO: def create_track(path: str, metadata: dict[str, str | int | float]) -> TrackDTO:
""" """
Create a track db entry from a track path and return the DTO Create a track db entry from a track path and return the DTO
@ -163,6 +299,7 @@ def create_track(path: str, metadata: dict[str, str | int | float]) -> TrackDTO:
return new_track return new_track
@log_call
def update_track( def update_track(
path: str, track_id: int, metadata: dict[str, str | int | float] path: str, track_id: int, metadata: dict[str, str | int | float]
) -> TrackDTO: ) -> TrackDTO:
@ -192,183 +329,208 @@ def update_track(
return updated_track return updated_track
def get_all_tracks() -> list[TrackDTO]: @log_call
"""Return a list of all tracks""" def get_filtered_tracks(filter: Filter) -> list[TrackDTO]:
return _tracks_where(Tracks.id > 0)
def track_by_id(track_id: int) -> TrackDTO | None:
""" """
Return track with specified id Return tracks matching filter
""" """
# Alias PlaydatesTable for subquery # Create a base query
LatestPlaydate = aliased(Playdates) query = Tracks.id > 0
# Subquery: latest playdate for each track # Path specification
latest_playdate_subq = ( if filter.path:
select( if filter.path_type == "contains":
LatestPlaydate.track_id, query = query.where(Tracks.path.ilike(f"%{filter.path}%"))
func.max(LatestPlaydate.lastplayed).label("lastplayed"), elif filter.path_type == "excluding":
) query = query.where(Tracks.path.notilike(f"%{filter.path}%"))
.group_by(LatestPlaydate.track_id) else:
.subquery() raise ApplicationError(f"Can't process filter path ({filter=})")
# Duration specification
seconds_duration = filter.duration_number
if filter.duration_unit == Config.FILTER_DURATION_MINUTES:
seconds_duration *= 60
elif filter.duration_unit != Config.FILTER_DURATION_SECONDS:
raise ApplicationError(f"Can't process filter duration ({filter=})")
if filter.duration_type == Config.FILTER_DURATION_LONGER:
query = query.where(Tracks.duration >= seconds_duration)
elif filter.duration_unit == Config.FILTER_DURATION_SHORTER:
query = query.where(Tracks.duration <= seconds_duration)
else:
raise ApplicationError(f"Can't process filter duration type ({filter=})")
# Process comparator
if filter.last_played_comparator == Config.FILTER_PLAYED_COMPARATOR_ANYTIME:
return _tracks_where(query, filter_by_last_played=False)
elif filter.last_played_comparator == Config.FILTER_PLAYED_COMPARATOR_NEVER:
return _tracks_where(query, filter_by_last_played=True, last_played_before=None)
else:
# Last played specification
now = dt.datetime.now()
# Set sensible default, and correct for Config.FILTER_PLAYED_COMPARATOR_ANYTIME
before = now
# If not ANYTIME, set 'before' appropriates
if filter.last_played_unit == Config.FILTER_PLAYED_DAYS:
before = now - dt.timedelta(days=filter.last_played_number)
elif filter.last_played_unit == Config.FILTER_PLAYED_WEEKS:
before = now - dt.timedelta(days=7 * filter.last_played_number)
elif filter.last_played_unit == Config.FILTER_PLAYED_MONTHS:
before = now - dt.timedelta(days=30 * filter.last_played_number)
elif filter.last_played_unit == Config.FILTER_PLAYED_YEARS:
before = now - dt.timedelta(days=365 * filter.last_played_number)
else:
raise ApplicationError("Can't determine last played criteria")
return _tracks_where(
query, filter_by_last_played=True, last_played_before=before
) )
stmt = (
select( def set_track_intro(track_id: int, intro: int) -> None:
Tracks.id.label("track_id"), """
Tracks.artist, Set track intro time
Tracks.bitrate, """
Tracks.duration,
Tracks.fade_at,
Tracks.intro,
Tracks.path,
Tracks.silence_at,
Tracks.start_gap,
Tracks.title,
latest_playdate_subq.c.lastplayed,
)
.outerjoin(latest_playdate_subq, Tracks.id == latest_playdate_subq.c.track_id)
.where(Tracks.id == track_id)
)
with db.Session() as session: with db.Session() as session:
record = session.execute(stmt).one_or_none() session.execute(
if not record: update(Tracks)
return None .where(Tracks.id == track_id)
.values(intro=intro)
dto = TrackDTO(
artist=record.artist,
bitrate=record.bitrate,
duration=record.duration,
fade_at=record.fade_at,
intro=record.intro,
lastplayed=record.lastplayed,
path=record.path,
silence_at=record.silence_at,
start_gap=record.start_gap,
title=record.title,
track_id=record.track_id,
) )
return dto session.commit()
def _tracks_where(where: BinaryExpression | ColumnElement[bool]) -> list[TrackDTO]: # Playlist functions
@log_call
def _playlists_where(
query: BinaryExpression | ColumnElement[bool],
) -> list[PlaylistDTO]:
""" """
Return tracks selected by where Return playlists selected by query
""" """
# Alias PlaydatesTable for subquery stmt = select(
LatestPlaydate = aliased(Playdates) Playlists.favourite,
Playlists.is_template,
Playlists.id.label("playlist_id"),
Playlists.name,
Playlists.open,
).where(query)
# Subquery: latest playdate for each track results: list[PlaylistDTO] = []
latest_playdate_subq = (
select(
LatestPlaydate.track_id,
func.max(LatestPlaydate.lastplayed).label("lastplayed"),
)
.group_by(LatestPlaydate.track_id)
.subquery()
)
stmt = (
select(
Tracks.id.label("track_id"),
Tracks.artist,
Tracks.bitrate,
Tracks.duration,
Tracks.fade_at,
Tracks.intro,
Tracks.path,
Tracks.silence_at,
Tracks.start_gap,
Tracks.title,
latest_playdate_subq.c.lastplayed,
)
.outerjoin(latest_playdate_subq, Tracks.id == latest_playdate_subq.c.track_id)
.where(where)
)
results: list[TrackDTO] = []
with db.Session() as session: with db.Session() as session:
records = session.execute(stmt).all() records = session.execute(stmt).all()
for record in records: for record in records:
dto = TrackDTO( dto = PlaylistDTO(
artist=record.artist, favourite=record.favourite,
bitrate=record.bitrate, is_template=record.is_template,
duration=record.duration, playlist_id=record.playlist_id,
fade_at=record.fade_at, name=record.name,
intro=record.intro, open=record.open,
lastplayed=record.lastplayed,
path=record.path,
silence_at=record.silence_at,
start_gap=record.start_gap,
title=record.title,
track_id=record.track_id,
) )
results.append(dto) results.append(dto)
return results return results
def track_with_path(path: str) -> bool: @log_call
def playlist_by_id(playlist_id: int) -> PlaylistDTO | None:
""" """
Return True if a track with passed path exists, else False Return playlist with specified id
"""
playlist_list = _playlists_where(Playlists.id == playlist_id)
if not playlist_list:
return None
if len(playlist_list) > 1:
raise ApplicationError(f"Duplicate {playlist_id=}")
return playlist_list[0]
def playlists_closed() -> list[PlaylistDTO]:
"""
Return a list of closed playlists
"""
return _playlists_where(Playlists.open.is_(False))
def playlists_open() -> list[PlaylistDTO]:
"""
Return a list of open playlists
"""
return _playlists_where(Playlists.open.is_(True))
def playlists_template_by_id(playlist_id: int) -> PlaylistDTO | None:
"""
Return a list of closed playlists
"""
playlist_list = _playlists_where(
Playlists.playlist_id == playlist_id, Playlists.is_template.is_(True)
)
if not playlist_list:
return None
if len(playlist_list) > 1:
raise ApplicationError(f"Duplicate {playlist_id=}")
return playlist_list[0]
def playlists_templates() -> list[PlaylistDTO]:
"""
Return a list of playlist templates
"""
return _playlists_where(Playlists.is_template.is_(True))
def get_all_playlists():
"""Return all playlists"""
return _playlists_where(Playlists.id > 0)
def delete_playlist(playlist_id: int) -> None:
"""Delete playlist"""
with db.Session() as session:
query = session.get(Playlists, playlist_id)
session.delete(query)
session.commit()
def save_as_template(playlist_id: int, template_name: str) -> None:
"""
Save playlist as templated
"""
new_template = create_playlist(template_name, 0, as_template=True)
copy_playlist(playlist_id, new_template.id)
def playlist_rename(playlist_id: int, new_name: str) -> None:
"""
Rename playlist
""" """
with db.Session() as session: with db.Session() as session:
track = ( session.execute(
session.execute(select(Tracks).where(Tracks.path == path)) update(Playlists)
.scalars() .where(Playlists.id == playlist_id)
.one_or_none() .values(name=new_name)
) )
return track is not None session.commit()
def tracks_like_artist(filter_str: str) -> list[TrackDTO]:
"""
Return tracks where artist is like filter
"""
return _tracks_where(Tracks.artist.ilike(f"%{filter_str}%"))
def tracks_like_title(filter_str: str) -> list[TrackDTO]:
"""
Return tracks where title is like filter
"""
return _tracks_where(Tracks.title.ilike(f"%{filter_str}%"))
def get_last_played_dates(track_id: int, limit: int = 5) -> str:
"""
Return the most recent 'limit' dates that this track has been played
as a text list
"""
with db.Session() as session:
playdates = session.scalars(
Playdates.select()
.where(Playdates.track_id == track_id)
.order_by(Playdates.lastplayed.desc())
.limit(limit)
).all()
return "<br>".join(
[
a.lastplayed.strftime(Config.LAST_PLAYED_TOOLTIP_DATE_FORMAT)
for a in playdates
]
)
# Playlist functions
def _check_playlist_integrity( def _check_playlist_integrity(
session: Session, playlist_id: int, fix: bool = False session: Session, playlist_id: int, fix: bool = False
) -> None: ) -> None:
@ -401,6 +563,20 @@ def _check_playlist_integrity(
raise ApplicationError(msg) raise ApplicationError(msg)
@log_call
def playlist_mark_status(playlist_id: int, open: bool) -> None:
"""Mark playlist as open or closed"""
with db.Session() as session:
session.execute(
update(Playlists)
.where(Playlists.id == playlist_id)
.values(open=open)
)
session.commit()
@log_call @log_call
def _shift_rows( def _shift_rows(
session: Session, playlist_id: int, starting_row: int, shift_by: int session: Session, playlist_id: int, starting_row: int, shift_by: int
@ -512,15 +688,7 @@ def move_rows(
_check_playlist_integrity(session, to_playlist_id, fix=False) _check_playlist_integrity(session, to_playlist_id, fix=False)
def update_playdates(track_id: int) -> None: @log_call
"""
Update playdates for passed track
"""
with db.Session() as session:
_ = Playdates(session, track_id)
def update_row_numbers( def update_row_numbers(
playlist_id: int, id_to_row_number: list[dict[int, int]] playlist_id: int, id_to_row_number: list[dict[int, int]]
) -> None: ) -> None:
@ -537,7 +705,8 @@ def update_row_numbers(
_check_playlist_integrity(session, playlist_id, fix=False) _check_playlist_integrity(session, playlist_id, fix=False)
def create_playlist(name: str, template_id: int) -> PlaylistDTO: @log_call
def create_playlist(name: str, template_id: int, as_template: bool = False) -> PlaylistDTO:
""" """
Create playlist and return DTO. Create playlist and return DTO.
""" """
@ -545,11 +714,15 @@ def create_playlist(name: str, template_id: int) -> PlaylistDTO:
with db.Session() as session: with db.Session() as session:
try: try:
playlist = Playlists(session, name, template_id) playlist = Playlists(session, name, template_id)
playlist.is_template = as_template
playlist_id = playlist.id playlist_id = playlist.id
session.commit() session.commit()
except Exception: except Exception:
raise ApplicationError("Can't create Playlist") raise ApplicationError("Can't create Playlist")
if template_id != 0:
copy_playlist(template_id, playlist_id)
new_playlist = playlist_by_id(playlist_id) new_playlist = playlist_by_id(playlist_id)
if not new_playlist: if not new_playlist:
raise ApplicationError("Can't retrieve new Playlist") raise ApplicationError("Can't retrieve new Playlist")
@ -557,6 +730,7 @@ def create_playlist(name: str, template_id: int) -> PlaylistDTO:
return new_playlist return new_playlist
@log_call
def get_playlist_row(playlistrow_id: int) -> PlaylistRowDTO | None: def get_playlist_row(playlistrow_id: int) -> PlaylistRowDTO | None:
""" """
Return specific row DTO Return specific row DTO
@ -746,6 +920,40 @@ def get_playlist_rows(
return dto_list return dto_list
def copy_playlist(src_id: int, dst_id: int) -> None:
"""Copy playlist entries"""
with db.Session() as session:
src_rows = session.scalars(
select(PlaylistRows).filter(PlaylistRows.playlist_id == src_id)
).all()
for plr in src_rows:
PlaylistRows(
session=session,
playlist_id=dst_id,
row_number=plr.row_number,
note=plr.note,
track_id=plr.track_id,
)
def playlist_row_count(playlist_id: int) -> int:
"""
Return number of rows in playlist
"""
with db.Session() as session:
count = session.scalar(
select(func.count())
.select_from(PlaylistRows)
.where(PlaylistRows.playlist_id == playlist_id)
)
return count
@log_call
def insert_row( def insert_row(
playlist_id: int, row_number: int, track_id: int | None, note: str playlist_id: int, row_number: int, track_id: int | None, note: str
) -> PlaylistRowDTO: ) -> PlaylistRowDTO:
@ -825,33 +1033,220 @@ def remove_rows(playlist_id: int, row_numbers: list[int]) -> None:
session.commit() session.commit()
def playlist_by_id(playlist_id: int) -> PlaylistDTO | None: @log_call
def update_template_favourite(template_id: int, favourite: bool) -> None:
"""Update template favourite"""
with db.Session() as session:
session.execute(
update(Playlists)
.where(Playlists.id == template_id)
.values(favourite=favourite)
)
session.commit()
@log_call
def playlist_save_tabs(playlist_id_to_tab: dict[int, int]) -> None:
""" """
Return playlist with specified id Save the tab numbers of the open playlists.
"""
with db.Session() as session:
# Clear all existing tab numbers
session.execute(
update(Playlists)
.where(Playlists.id.in_(playlist_id_to_tab.keys()))
.values(tab=None)
)
for (playlist_id, tab) in playlist_id_to_tab.items():
session.execute(
update(Playlists)
.where(Playlists.id == playlist_id)
.values(tab=tab)
)
session.commit()
# Playdates
@log_call
def get_last_played_dates(track_id: int, limit: int = 5) -> str:
"""
Return the most recent 'limit' dates that this track has been played
as a text list
"""
with db.Session() as session:
playdates = session.scalars(
Playdates.select()
.where(Playdates.track_id == track_id)
.order_by(Playdates.lastplayed.desc())
.limit(limit)
).all()
return "<br>".join(
[
a.lastplayed.strftime(Config.LAST_PLAYED_TOOLTIP_DATE_FORMAT)
for a in playdates
]
)
def update_playdates(track_id: int) -> None:
"""
Update playdates for passed track
"""
with db.Session() as session:
_ = Playdates(session, track_id)
def playdates_between_dates(
start: dt.datetime, end: dt.datetime | None = None
) -> list[PlaydatesDTO]:
"""
Return a list of PlaydateDTO objects from between times (until now if end is None)
"""
if end is None:
end = dt.datetime.now()
stmt = select(
Playdates.id.label("playdate_id"),
Playdates.lastplayed,
Playdates.track_id,
Playdates.track,
).where(
Playdates.lastplayed >= start,
Playdates.lastplayed <= end
)
results: list[PlaydatesDTO] = []
with db.Session() as session:
records = session.execute(stmt).all()
for record in records:
dto = PlaydatesDTO(
playdate_id=record.playdate_id,
lastplayed=record.lastplayed,
track_id=record.track_id,
artist=record.track.artist,
bitrate=record.track.bitrate,
duration=record.track.duration,
fade_at=record.track.fade_at,
intro=record.track.intro,
path=record.track.path,
silence_at=record.track.silence_at,
start_gap=record.track.start_gap,
title=record.track.title,
)
results.append(dto)
return results
# Queries
@log_call
def _queries_where(
query: BinaryExpression | ColumnElement[bool],
) -> list[QueryDTO]:
"""
Return queries selected by query
""" """
stmt = select( stmt = select(
Playlists.id.label("playlist_id"), Queries.id.label("query_id"), Queries.name, Queries.favourite, Queries.filter
Playlists.name, ).where(query)
Playlists.favourite,
Playlists.is_template, results: list[QueryDTO] = []
Playlists.open,
).where(Playlists.id == playlist_id)
with db.Session() as session: with db.Session() as session:
record = session.execute(stmt).one_or_none() records = session.execute(stmt).one_or_none()
if not record: for record in records:
return None dto = QueryDTO(
dto = PlaylistDTO(
name=record.name,
playlist_id=record.playlist_id,
favourite=record.favourite, favourite=record.favourite,
is_template=record.is_template, filter=record.filter,
open=record.open, name=record.name,
query_id=record.query_id,
) )
results.append(dto)
return dto return results
def get_all_queries(favourites_only: bool = False) -> list[QueryDTO]:
"""Return a list of all queries"""
query = Queries.id > 0
return _queries_where(query)
def query_by_id(query_id: int) -> QueryDTO | None:
"""Return query"""
query_list = _queries_where(Queries.id == query_id)
if not query_list:
return None
if len(query_list) > 1:
raise ApplicationError(f"Duplicate {query_id=}")
return query_list[0]
def update_query_filter(query_id: int, filter: Filter) -> None:
"""Update query filter"""
with db.Session() as session:
session.execute(
update(Queries).where(Queries.id == query_id).values(filter=filter)
)
session.commit()
def delete_query(query_id: int) -> None:
"""Delete query"""
with db.Session() as session:
query = session.get(Queries, query_id)
session.delete(query)
session.commit()
def update_query_name(query_id: int, name: str) -> None:
"""Update query name"""
with db.Session() as session:
session.execute(update(Queries).where(Queries.id == query_id).values(name=name))
session.commit()
def update_query_favourite(query_id: int, favourite: bool) -> None:
"""Update query favourite"""
with db.Session() as session:
session.execute(
update(Queries).where(Queries.id == query_id).values(favourite=favourite)
)
session.commit()
def create_query(name: str, filter: Filter) -> QueryDTO:
"""
Create a query and return the DTO
"""
with db.Session() as session:
try:
query = Queries(session=session, name=name, filter=filter)
query_id = query.id
session.commit()
except Exception:
raise ApplicationError("Can't create Query")
new_query = query_by_id(query_id)
if not new_query:
raise ApplicationError("Unable to create new query")
return new_query
# Misc # Misc
@ -889,3 +1284,14 @@ def set_setting(name: str, value: int) -> None:
raise ApplicationError("Can't create Settings record") raise ApplicationError("Can't create Settings record")
record.f_int = value record.f_int = value
session.commit() session.commit()
def get_db_name() -> str:
"""Return database name"""
with db.Session() as session:
if session.bind:
dbname = session.bind.engine.url.database
return dbname
return Config.DB_NOT_FOUND

View File

@ -5,7 +5,6 @@ import os
# PyQt imports # PyQt imports
# Third party imports # Third party imports
from sqlalchemy.orm.session import Session
# App imports # App imports
from config import Config from config import Config
@ -13,10 +12,10 @@ from helpers import (
get_tags, get_tags,
) )
from log import log from log import log
from models import Tracks import repository
def check_db(session: Session) -> None: def check_db() -> None:
""" """
Database consistency check. Database consistency check.
@ -27,7 +26,7 @@ def check_db(session: Session) -> None:
Check all paths in database exist Check all paths in database exist
""" """
db_paths = set([a.path for a in Tracks.get_all(session)]) db_paths = set([a.path for a in repository.get_all_tracks()])
os_paths_list = [] os_paths_list = []
for root, _dirs, files in os.walk(Config.ROOT): for root, _dirs, files in os.walk(Config.ROOT):
@ -52,7 +51,7 @@ def check_db(session: Session) -> None:
missing_file_count += 1 missing_file_count += 1
track = Tracks.get_by_path(session, path) track = repository.track_by_path(path)
if not track: if not track:
# This shouldn't happen as we're looking for paths in # This shouldn't happen as we're looking for paths in
# database that aren't in filesystem, but just in case... # database that aren't in filesystem, but just in case...
@ -74,7 +73,7 @@ def check_db(session: Session) -> None:
for t in paths_not_found: for t in paths_not_found:
print( print(
f""" f"""
Track ID: {t.id} Track ID: {t.track_id}
Path: {t.path} Path: {t.path}
Title: {t.title} Title: {t.title}
Artist: {t.artist} Artist: {t.artist}
@ -84,14 +83,15 @@ def check_db(session: Session) -> None:
print("There were more paths than listed that were not found") print("There were more paths than listed that were not found")
def update_bitrates(session: Session) -> None: def update_bitrates() -> None:
""" """
Update bitrates on all tracks in database Update bitrates on all tracks in database
""" """
for track in Tracks.get_all(session): for track in repository.get_all_tracks():
try: try:
t = get_tags(track.path) t = get_tags(track.path)
# TODO this won't persist as we're updating DTO
track.bitrate = t.bitrate track.bitrate = t.bitrate
except FileNotFoundError: except FileNotFoundError:
continue continue

View File

@ -131,10 +131,10 @@ class MyTestCase(unittest.TestCase):
_ = repository.create_track(self.isa_path, metadata) _ = repository.create_track(self.isa_path, metadata)
metadata = get_all_track_metadata(self.mom_path) metadata = get_all_track_metadata(self.mom_path)
_ = repository.create_track(self.mom_path, metadata) _ = repository.create_track(self.mom_path, metadata)
result_isa = repository.tracks_like_title(self.isa_title) result_isa = repository.tracks_by_title(self.isa_title)
assert len(result_isa) == 1 assert len(result_isa) == 1
assert result_isa[0].title == self.isa_title assert result_isa[0].title == self.isa_title
result_mom = repository.tracks_like_title(self.mom_title) result_mom = repository.tracks_by_title(self.mom_title)
assert len(result_mom) == 1 assert len(result_mom) == 1
assert result_mom[0].title == self.mom_title assert result_mom[0].title == self.mom_title

View File

@ -132,9 +132,8 @@ class MyTestCase(unittest.TestCase):
Config.ROOT = os.path.join(os.path.dirname(__file__), "testdata") Config.ROOT = os.path.join(os.path.dirname(__file__), "testdata")
with db.Session() as session: utilities.check_db()
utilities.check_db(session) utilities.update_bitrates()
utilities.update_bitrates(session)
# def test_meta_all_clear(qtbot, session): # def test_meta_all_clear(qtbot, session):