diff --git a/app/models.py b/app/models.py index 07f5699..6ba4708 100644 --- a/app/models.py +++ b/app/models.py @@ -291,59 +291,56 @@ class Playlists(Base): session.add(self) session.commit() - @staticmethod - def get_all_closed_playlists(session): + @classmethod + def get_all_closed_playlists(cls, session): "Returns a list of all playlists not currently open" return ( - session.query(Playlists) + session.query(cls) .filter( - (Playlists.loaded == False) | # noqa E712 - (Playlists.loaded == None) + (cls.loaded == False) | # noqa E712 + (cls.loaded == None) ) - .order_by(Playlists.last_used.desc()) + .order_by(cls.last_used.desc()) ).all() - @staticmethod - def get_all_playlists(session): + @classmethod + def get_all_playlists(cls, session): "Returns a list of all playlists" - return session.query(Playlists).all() + return session.query(cls).all() - @staticmethod - def get_last_used(session): + @classmethod + def get_last_used(cls, session): """ Return a list of playlists marked "loaded", ordered by loaded date. """ return ( - session.query(Playlists) - .filter(Playlists.loaded == True) # noqa E712 - .order_by(Playlists.last_used.desc()) + session.query(cls) + .filter(cls.loaded == True) # noqa E712 + .order_by(cls.last_used.desc()) ).all() def get_notes(self): return [a.note for a in self.notes] - @staticmethod - def get_playlist(session, playlist_id): - return ( - session.query(Playlists) - .filter( - Playlists.id == playlist_id # noqa E712 - ) - ).one() + @classmethod + def get_playlist_by_id(cls, session, playlist_id): + + return (session.query(cls).filter(cls.id == playlist_id)).one() def get_tracks(self): return [a.tracks for a in self.tracks] - @staticmethod - def new(session, name): + @classmethod + def new(cls, session, name): DEBUG(f"Playlists.new(name={name})") - playlist = Playlists() + playlist = cls() playlist.name = name session.add(playlist) session.commit() + return playlist def open(self, session): @@ -418,10 +415,17 @@ class PlaylistTracks(Base): @staticmethod def new_row(session, playlist_id): - "Return row number > largest existing row number" + """ + Return row number > largest existing row number + + If there are no existing rows, return 0 (ie, first row number) + """ last_row = session.query(func.max(PlaylistTracks.row)).one()[0] - return last_row + 1 + if last_row: + return last_row + 1 + else: + return 0 @staticmethod def remove_all_tracks(session, playlist_id): diff --git a/app/musicmuster.py b/app/musicmuster.py index 4d7d765..3fec216 100755 --- a/app/musicmuster.py +++ b/app/musicmuster.py @@ -356,7 +356,7 @@ class Window(QMainWindow, Ui_MainWindow): # Get playlist db object with Session() as session: - playlist_db = Playlists.get_playlist( + playlist_db = Playlists.get_playlist_by_id( session, self.visible_playlist_tab().id) with open(path, "w") as f: # Required directive on first line @@ -552,7 +552,7 @@ class Window(QMainWindow, Ui_MainWindow): dlg = SelectPlaylistDialog(self, playlist_dbs=playlist_dbs) dlg.exec() if dlg.plid: - playlist_db = Playlists.get_playlist(session, dlg.plid) + playlist_db = Playlists.get_playlist_by_id(session, dlg.plid) self.load_playlist(session, playlist_db) def select_next_row(self): diff --git a/test_models.py b/test_models.py index a39d2be..a9c1657 100644 --- a/test_models.py +++ b/test_models.py @@ -54,7 +54,7 @@ def test_notes_add_note(session): Notes.add_note(session, pl.id, 1, note_text) # We retrieve notes via playlist - playlist = Playlists.get_playlist(session, pl.id) + playlist = Playlists.get_playlist_by_id(session, pl.id) notes = playlist.get_notes() assert len(notes) == 1 assert notes[0] == note_text @@ -75,7 +75,7 @@ def test_notes_delete_note(session): Notes.delete_note(session, rec.id) # We retrieve notes via playlist - playlist = Playlists.get_playlist(session, pl.id) + playlist = Playlists.get_playlist_by_id(session, pl.id) notes = playlist.get_notes() assert len(notes) == 0 @@ -96,7 +96,7 @@ def test_notes_update_note(session): Notes.update_note(session, rec.id, 1, text=replacement_text) # We retrieve notes via playlist - playlist = Playlists.get_playlist(session, pl.id) + playlist = Playlists.get_playlist_by_id(session, pl.id) notes = playlist.get_notes() assert len(notes) == 1 assert notes[0] == replacement_text @@ -117,3 +117,40 @@ def test_playdates_add_playdate(session): assert playdate.lastplayed == last_played + +def test_playdates_remove_track(session): + """Test removing a track from a playdate""" + + # We need a track + track_path = "/a/b/c" + track = Tracks.get_or_create(session, track_path) + # Need to commit because track record is updated in Playdates.add_playdate() + session.commit() + + playdate = Playdates.add_playdate(session, track) + Playdates.remove_track(session, track.id) + + last_played = Playdates.last_played(session, track.id) + + assert last_played is None + + +def test_playlist_add_track(session): + """Test adding track to playlist""" + + # We need a track + track_path = "/a/b/c" + track = Tracks.get_or_create(session, track_path) + # Need to commit because track record is updated in Playdates.add_playdate() + session.commit() + + playlist = Playlists() + playlist.name = "Test playlist" + session.add(playlist) + session.commit() + + playlist.add_track(session, track) + + tracks = playlist.get_tracks() + assert len(tracks) == 1 + assert tracks[0].path == track_path