diff --git a/app/models.py b/app/models.py index 7709aa3..0f00f31 100644 --- a/app/models.py +++ b/app/models.py @@ -9,7 +9,7 @@ from datetime import datetime from mutagen.flac import FLAC from mutagen.mp3 import MP3 from sqlalchemy.ext.associationproxy import association_proxy -from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.ext.declarative import declarative_base, DeclarativeMeta from sqlalchemy import ( Boolean, Column, @@ -21,8 +21,9 @@ from sqlalchemy import ( func ) from sqlalchemy.exc import IntegrityError -from sqlalchemy.orm.exc import MultipleResultsFound, NoResultFound from sqlalchemy.orm import backref, relationship, sessionmaker, scoped_session +from sqlalchemy.orm.collections import attribute_mapped_collection +from sqlalchemy.orm.exc import MultipleResultsFound, NoResultFound from app.config import Config from app.helpers import ( @@ -237,6 +238,7 @@ class Playlists(Base): back_populates="playlist") tracks = association_proxy('playlist_tracks', 'tracks') + row = association_proxy('playlist_tracks', 'row') def __init__(self, session, name): self.name = name @@ -341,7 +343,13 @@ class PlaylistTracks(Base): track_id = Column(Integer, ForeignKey('tracks.id'), primary_key=True) row = Column(Integer, nullable=False) tracks = relationship("Tracks") - playlist = relationship(Playlists, backref=backref("playlist_tracks")) + playlist = relationship( + Playlists, + backref=backref( + "playlist_tracks", + collection_class=attribute_mapped_collection("row") + ) + ) def __init__(self, session, playlist_id, track_id, row): DEBUG(f"PlaylistTracks.__init__({playlist_id=}, {track_id=}, {row=})") diff --git a/test_models.py b/test_models.py index 6d52ea3..e33b32e 100644 --- a/test_models.py +++ b/test_models.py @@ -172,13 +172,36 @@ def test_playlist_add_track(session): track_path = "/a/b/c" track = Tracks(session, track_path) - playlist.add_track(session, track) + row = 17 + + playlist.add_track(session, track, row) assert len(playlist.tracks) == 1 - playlist_track = playlist.tracks[0] + playlist_track = playlist.tracks[row] assert playlist_track.path == track_path +def test_playlist_tracks(session): + # We need a playlist + playlist = Playlists(session, "my playlist") + + # We need two tracks + track1_path = "/a/b/c" + track1_row = 17 + track1 = Tracks(session, track1_path) + + track2_path = "/x/y/z" + track2_row = 29 + track2 = Tracks(session, track2_path) + + playlist.add_track(session, track1, track1_row) + playlist.add_track(session, track2, track2_row) + + tracks = playlist.tracks + assert tracks[track1_row] == track1 + assert tracks[track2_row] == track2 + + def test_playlist_open_and_close(session): # We need a playlist