# Standard library imports from __future__ import annotations from typing import List, Optional, Sequence import datetime as dt import os import re import sys # PyQt imports # Third party imports import line_profiler from sqlalchemy import ( bindparam, delete, func, select, update, ) from sqlalchemy.exc import IntegrityError from sqlalchemy.orm.exc import NoResultFound from sqlalchemy.orm import joinedload from sqlalchemy.orm.session import Session # App imports from config import Config from dbmanager import DatabaseManager import dbtables from log import log # Establish database connection DATABASE_URL = os.environ.get("DATABASE_URL") if DATABASE_URL is None: raise ValueError("DATABASE_URL is undefined") if "unittest" in sys.modules and "sqlite" not in DATABASE_URL: raise ValueError("Unit tests running on non-Sqlite database") db = DatabaseManager.get_instance(DATABASE_URL, engine_options=Config.ENGINE_OPTIONS).db db.create_all() # Database classes class NoteColours(dbtables.NoteColoursTable): def __init__( self, session: Session, substring: str, colour: str, enabled: bool = True, is_regex: bool = False, is_casesensitive: bool = False, order: Optional[int] = 0, ) -> None: self.substring = substring self.colour = colour self.enabled = enabled self.is_regex = is_regex self.is_casesensitive = is_casesensitive self.order = order session.add(self) session.commit() @classmethod def get_all(cls, session: Session) -> Sequence["NoteColours"]: """ Return all records """ result = session.scalars(select(cls)).all() return result @staticmethod def get_colour( session: Session, text: str, foreground: bool = False ) -> Optional[str]: """ Parse text and return background (foreground if foreground==True) colour string if matched, else None """ if not text: return None match = False for rec in session.scalars( select(NoteColours) .where( NoteColours.enabled.is_(True), ) .order_by(NoteColours.order) ).all(): if rec.is_regex: flags = re.UNICODE if not rec.is_casesensitive: flags |= re.IGNORECASE p = re.compile(rec.substring, flags) if p.match(text): match = True else: if rec.is_casesensitive: if rec.substring in text: match = True else: if rec.substring.lower() in text.lower(): match = True if match: if foreground: return rec.foreground else: return rec.colour return None class Playdates(dbtables.PlaydatesTable): def __init__(self, session: Session, track_id: int) -> None: """Record that track was played""" self.lastplayed = dt.datetime.now() self.track_id = track_id session.add(self) session.commit() @staticmethod def last_playdates( session: Session, track_id: int, limit: int = 5 ) -> Sequence["Playdates"]: """ Return a list of the last limit playdates for this track, sorted earliest to latest. """ return session.scalars( Playdates.select() .where(Playdates.track_id == track_id) .order_by(Playdates.lastplayed.asc()) .limit(limit) ).all() @staticmethod def last_played(session: Session, track_id: int) -> dt.datetime: """Return datetime track last played or None""" last_played = session.execute( select(Playdates.lastplayed) .where(Playdates.track_id == track_id) .order_by(Playdates.lastplayed.desc()) .limit(1) ).first() if last_played: return last_played[0] else: # Should never be reached as we create record with a # last_played value return Config.EPOCH # pragma: no cover @staticmethod def last_played_tracks(session: Session, limit: int = 5) -> Sequence["Playdates"]: """ Return a list of the last limit tracks played, sorted earliest to latest. """ return session.scalars( Playdates.select().order_by(Playdates.lastplayed.desc()).limit(limit) ).all() @staticmethod def played_after(session: Session, since: dt.datetime) -> Sequence["Playdates"]: """Return a list of Playdates objects since passed time""" return session.scalars( select(Playdates) .where(Playdates.lastplayed >= since) .order_by(Playdates.lastplayed) ).all() class Playlists(dbtables.PlaylistsTable): def __init__(self, session: Session, name: str): self.name = name self.last_used = dt.datetime.now() session.add(self) session.commit() @staticmethod def clear_tabs(session: Session, playlist_ids: List[int]) -> None: """ Make all tab records NULL """ session.execute( update(Playlists).where((Playlists.id.in_(playlist_ids))).values(tab=None) ) def close(self, session: Session) -> None: """Mark playlist as unloaded""" self.open = False session.commit() @classmethod def create_playlist_from_template( cls, session: Session, template: "Playlists", playlist_name: str ) -> Optional["Playlists"]: """Create a new playlist from template""" # Sanity check if not template.id: return None playlist = cls(session, playlist_name) # Sanity / mypy checks if not playlist or not playlist.id: return None PlaylistRows.copy_playlist(session, template.id, playlist.id) return playlist def delete(self, session: Session) -> None: """ Delete playlist """ session.execute(delete(Playlists).where(Playlists.id == self.id)) session.commit() @classmethod def get_all(cls, session: Session) -> Sequence["Playlists"]: """Returns a list of all playlists ordered by last use""" return session.scalars( select(cls) .filter(cls.is_template.is_(False)) .order_by(cls.last_used.desc()) ).all() @classmethod def get_all_templates(cls, session: Session) -> Sequence["Playlists"]: """Returns a list of all templates ordered by name""" return session.scalars( select(cls) .where(cls.is_template.is_(True)) .order_by(cls.name) ).all() @classmethod def get_closed(cls, session: Session) -> Sequence["Playlists"]: """Returns a list of all closed playlists ordered by last use""" return session.scalars( select(cls) .filter( cls.open.is_(False), cls.is_template.is_(False), ) .order_by(cls.last_used.desc()) ).all() @classmethod def get_open(cls, session: Session) -> Sequence[Optional["Playlists"]]: """ Return a list of loaded playlists ordered by tab. """ return session.scalars( select(cls).where(cls.open.is_(True)).order_by(cls.tab) ).all() def mark_open(self) -> None: """Mark playlist as loaded and used now""" self.open = True self.last_used = dt.datetime.now() @staticmethod def name_is_available(session: Session, name: str) -> bool: """ Return True if no playlist of this name exists else false. """ return ( session.execute(select(Playlists).where(Playlists.name == name)).first() is None ) def rename(self, session: Session, new_name: str) -> None: """ Rename playlist """ self.name = new_name session.commit() @staticmethod def save_as_template( session: Session, playlist_id: int, template_name: str ) -> None: """Save passed playlist as new template""" template = Playlists(session, template_name) if not template or not template.id: return template.is_template = True session.commit() PlaylistRows.copy_playlist(session, playlist_id, template.id) class PlaylistRows(dbtables.PlaylistRowsTable): def __init__( self, session: Session, playlist_id: int, row_number: int, note: str = "", track_id: Optional[int] = None, ) -> None: """Create PlaylistRows object""" self.playlist_id = playlist_id self.track_id = track_id self.row_number = row_number self.note = note session.add(self) session.commit() def append_note(self, extra_note: str) -> None: """Append passed note to any existing note""" current_note = self.note if current_note: self.note = current_note + "\n" + extra_note else: self.note = extra_note @staticmethod def copy_playlist(session: Session, src_id: int, dst_id: int) -> None: """Copy playlist entries""" src_rows = session.scalars( select(PlaylistRows).filter(PlaylistRows.playlist_id == src_id) ).all() for plr in src_rows: PlaylistRows( session=session, playlist_id=dst_id, row_number=plr.row_number, note=plr.note, track_id=plr.track_id, ) @classmethod def deep_row( cls, session: Session, playlist_id: int, row_number: int ) -> "PlaylistRows": """ Return a playlist row that includes full track and lastplayed data for given playlist_id and row """ stmt = ( select(PlaylistRows) .options(joinedload(cls.track)) .where( PlaylistRows.playlist_id == playlist_id, PlaylistRows.row_number == row_number, ) # .options(joinedload(Tracks.playdates)) ) return session.execute(stmt).unique().scalar_one() @staticmethod def delete_higher_rows(session: Session, playlist_id: int, maxrow: int) -> None: """ Delete rows in given playlist that have a higher row number than 'maxrow' """ session.execute( delete(PlaylistRows).where( PlaylistRows.playlist_id == playlist_id, PlaylistRows.row_number > maxrow, ) ) session.commit() @staticmethod def delete_row(session: Session, playlist_id: int, row_number: int) -> None: """ Delete passed row in given playlist. """ session.execute( delete(PlaylistRows).where( PlaylistRows.playlist_id == playlist_id, PlaylistRows.row_number == row_number, ) ) @staticmethod def fixup_rownumbers(session: Session, playlist_id: int) -> None: """ Ensure the row numbers for passed playlist have no gaps """ plrs = session.scalars( select(PlaylistRows) .where(PlaylistRows.playlist_id == playlist_id) .order_by(PlaylistRows.row_number) ).all() for i, plr in enumerate(plrs): plr.row_number = i # Ensure new row numbers are available to the caller session.commit() @classmethod def plrids_to_plrs( cls, session: Session, playlist_id: int, plr_ids: List[int] ) -> Sequence["PlaylistRows"]: """ Take a list of PlaylistRows ids and return a list of corresponding PlaylistRows objects """ plrs = session.scalars( select(cls) .where(cls.playlist_id == playlist_id, cls.id.in_(plr_ids)) .order_by(cls.row_number) ).all() return plrs @staticmethod def get_last_used_row(session: Session, playlist_id: int) -> Optional[int]: """Return the last used row for playlist, or None if no rows""" return session.execute( select(func.max(PlaylistRows.row_number)).where( PlaylistRows.playlist_id == playlist_id ) ).scalar_one() @staticmethod def get_track_plr( session: Session, track_id: int, playlist_id: int ) -> Optional["PlaylistRows"]: """Return first matching PlaylistRows object or None""" return session.scalars( select(PlaylistRows) .where( PlaylistRows.track_id == track_id, PlaylistRows.playlist_id == playlist_id, ) .limit(1) ).first() @classmethod def get_played_rows( cls, session: Session, playlist_id: int ) -> Sequence["PlaylistRows"]: """ For passed playlist, return a list of rows that have been played. """ plrs = session.scalars( select(cls) .where(cls.playlist_id == playlist_id, cls.played.is_(True)) .order_by(cls.row_number) ).all() return plrs @classmethod def get_playlist_rows( cls, session: Session, playlist_id: int ) -> Sequence["PlaylistRows"]: """ For passed playlist, return a list of rows. """ plrs = session.scalars( select(cls).where(cls.playlist_id == playlist_id).order_by(cls.row_number) ).all() return plrs @classmethod def get_rows_with_tracks( cls, session: Session, playlist_id: int, ) -> Sequence["PlaylistRows"]: """ For passed playlist, return a list of rows that contain tracks """ query = select(cls).where( cls.playlist_id == playlist_id, cls.track_id.is_not(None) ) plrs = session.scalars((query).order_by(cls.row_number)).all() return plrs @classmethod def get_unplayed_rows( cls, session: Session, playlist_id: int ) -> Sequence["PlaylistRows"]: """ For passed playlist, return a list of playlist rows that have not been played. """ plrs = session.scalars( select(cls) .where( cls.playlist_id == playlist_id, cls.track_id.is_not(None), cls.played.is_(False), ) .order_by(cls.row_number) ).all() return plrs @classmethod def insert_row( cls, session: Session, playlist_id: int, new_row_number: int, note: str = "", track_id: Optional[int] = None, ) -> "PlaylistRows": cls.move_rows_down(session, playlist_id, new_row_number, 1) return cls( session, playlist_id=playlist_id, row_number=new_row_number, note=note, track_id=track_id, ) @staticmethod def move_rows_down( session: Session, playlist_id: int, starting_row: int, move_by: int ) -> None: """ Create space to insert move_by additional rows by incremented row number from starting_row to end of playlist """ log.debug(f"(move_rows_down({playlist_id=}, {starting_row=}, {move_by=}") session.execute( update(PlaylistRows) .where( (PlaylistRows.playlist_id == playlist_id), (PlaylistRows.row_number >= starting_row), ) .values(row_number=PlaylistRows.row_number + move_by) ) @staticmethod @line_profiler.profile def update_plr_row_numbers( session: Session, playlist_id: int, sqla_map: List[dict[str, int]], dummy_for_profiling: Optional[int] = None, ) -> None: """ Take a {plrid: row_number} dictionary and update the row numbers accordingly """ # Update database. Ref: # https://docs.sqlalchemy.org/en/20/tutorial/data_update.html#the-update-sql-expression-construct stmt = ( update(PlaylistRows) .where( PlaylistRows.playlist_id == playlist_id, PlaylistRows.id == bindparam("playlistrow_id"), ) .values(row_number=bindparam("row_number")) ) session.connection().execute(stmt, sqla_map) class Settings(dbtables.SettingsTable): def __init__(self, session: Session, name: str): self.name = name session.add(self) session.commit() @classmethod def get_setting(cls, session: Session, name: str) -> "Settings": """Get existing setting or return new setting record""" try: return session.execute(select(cls).where(cls.name == name)).scalar_one() except NoResultFound: return Settings(session, name) class Tracks(dbtables.TracksTable): def __init__( self, session: Session, path: str, title: str, artist: str, duration: int, start_gap: int, fade_at: int, silence_at: int, bitrate: int, ): self.path = path self.title = title self.artist = artist self.bitrate = bitrate self.duration = duration self.start_gap = start_gap self.fade_at = fade_at self.silence_at = silence_at try: session.add(self) session.commit() except IntegrityError as error: session.rollback() log.error(f"Error ({error=}) importing track ({path=})") raise ValueError(error) @classmethod def get_all(cls, session: Session) -> Sequence["Tracks"]: """Return a list of all tracks""" return session.scalars(select(cls)).unique().all() @classmethod def all_tracks_indexed_by_id(cls, session: Session) -> dict[int, Tracks]: """ Return a dictionary of all tracks, keyed by title """ result: dict[int, Tracks] = {} for track in cls.get_all(session): result[track.id] = track return result @classmethod def exact_title_and_artist( cls, session: Session, title: str, artist: str ) -> Sequence["Tracks"]: """ Search for exact but case-insensitive match of title and artist """ return ( session.scalars( select(cls) .where(cls.title.ilike(title), cls.artist.ilike(artist)) .order_by(cls.title) ) .unique() .all() ) @classmethod def get_by_path(cls, session: Session, path: str) -> Optional["Tracks"]: """ Return track with passed path, or None. """ try: return ( session.execute(select(Tracks).where(Tracks.path == path)) .unique() .scalar_one() ) except NoResultFound: return None @classmethod def search_artists(cls, session: Session, text: str) -> Sequence["Tracks"]: """ Search case-insenstively for artists containing str The query performs an outer join with 'joinedload' to populate the results from the Playdates table at the same time. unique() needed; see https://docs.sqlalchemy.org/en/20/orm/queryguide/relationships.html#joined-eager-loading """ return ( session.scalars( select(cls) .options(joinedload(Tracks.playdates)) .where(cls.artist.ilike(f"%{text}%")) .order_by(cls.title) ) .unique() .all() ) @classmethod def search_titles(cls, session: Session, text: str) -> Sequence["Tracks"]: """ Search case-insenstively for titles containing str The query performs an outer join with 'joinedload' to populate the results from the Playdates table at the same time. unique() needed; see https://docs.sqlalchemy.org/en/20/orm/queryguide/relationships.html#joined-eager-loading """ return ( session.scalars( select(cls) .options(joinedload(Tracks.playdates)) .where(cls.title.like(f"{text}%")) .order_by(cls.title) ) .unique() .all() )