From 3832d9300c030a3c6be138b85269e139dd24978d Mon Sep 17 00:00:00 2001 From: Keith Edmunds Date: Sat, 28 Oct 2023 11:30:37 +0100 Subject: [PATCH] move_rows implemented; all tests pass --- app/playlistmodel.py | 86 +++++++++++++++++++++------------------ archive/db_experiments.py | 53 ++++++++++++++---------- test_playlistmodel.py | 12 ++++-- 3 files changed, 88 insertions(+), 63 deletions(-) diff --git a/app/playlistmodel.py b/app/playlistmodel.py index ed409c8..1494076 100644 --- a/app/playlistmodel.py +++ b/app/playlistmodel.py @@ -1,6 +1,6 @@ from datetime import datetime from enum import auto, Enum -from sqlalchemy import update +from sqlalchemy import bindparam, update from typing import List, Optional, TYPE_CHECKING from dbconfig import scoped_session, Session @@ -364,56 +364,64 @@ class PlaylistModel(QAbstractTableModel): for modified_row in modified_rows: self.invalidate_row(modified_row) - def move_rows(self, from_rows: List[int], to_row: int) -> None: + def move_rows(self, from_rows: List[int], to_row_number: int) -> None: """ Move the playlist rows given to to_row and below. """ - new_playlist_rows: dict[int, PlaylistRowData] = {} + # Build a {current_row_number: new_row_number} dictionary + row_map: dict[int, int] = {} - # Move the from_row records from the playlist_rows dict to the - # new_playlist_rows dict. The total number of elements in the - # playlist doesn't change, so check that adding the moved rows - # starting at to_row won't overshoot the end of the playlist. - if to_row + len(from_rows) > len(self.playlist_rows): + # Put the from_row row numbers into the row_map. Ultimately the + # total number of elements in the playlist doesn't change, so + # check that adding the moved rows starting at to_row won't + # overshoot the end of the playlist. + if to_row_number + len(from_rows) > len(self.playlist_rows): next_to_row = len(self.playlist_rows) - len(from_rows) else: - next_to_row = to_row + next_to_row = to_row_number - for from_row in from_rows: - new_playlist_rows[next_to_row] = self.playlist_rows[from_row] - del self.playlist_rows[from_row] - next_to_row += 1 + for from_row, to_row in zip( + from_rows, range(next_to_row, next_to_row + len(from_rows)) + ): + row_map[from_row] = to_row + # Move the remaining rows to the row_map. We want to fill it + # before (if there are gaps) and after (likewise) the rows that + # are moving. + # This iterates old_row and new_row simultaneously. + for old_row, new_row in zip( + [x for x in self.playlist_rows.keys() if x not in from_rows], + [y for y in range(len(self.playlist_rows)) if y not in row_map.values()], + ): + # Optimise: only add to map if there is a change + row_map[old_row] = new_row - # Move the remaining rows to the gaps in new_playlist_rows - new_row = 0 - for old_row in self.playlist_rows.keys(): - # Find next gap - while new_row in new_playlist_rows: - new_row += 1 - new_playlist_rows[new_row] = self.playlist_rows[old_row] - new_row += 1 + # For SQLAlchemy, build a list of dictionaries that map plrid to + # new row number: + sqla_map: List[dict[str, int]] = [] + for oldrow, newrow in row_map.items(): + plrid = self.playlist_rows[oldrow].plrid + sqla_map.append({"plrid": plrid, "plr_rownum": newrow}) - # Make copy of rows live - self.playlist_rows = new_playlist_rows + # Update database. Ref: + # https://docs.sqlalchemy.org/en/20/core/sqlelement.html#sqlalchemy.sql.expression.case + stmt = ( + update(PlaylistRows) + .where( + PlaylistRows.playlist_id == self.playlist_id, + PlaylistRows.id == bindparam("plrid"), + ) + .values(plr_rownum=bindparam("plr_rownum")) + ) - # Update PlaylistRows table and notify display of rows that - # moved with Session() as session: - for idx in range(len(self.playlist_rows)): - if self.playlist_rows[idx].plr_rownum == idx: - continue - # Row number in this row is incorred. Fix it in - # database: - plr = session.get(PlaylistRows, self.playlist_rows[idx].plrid) - if not plr: - print(f"\nCan't find plr in playlistmodel:move_rows {idx=}") - continue - plr.plr_rownum = idx - # Fix in self.playlist_rows - self.playlist_rows[idx].plr_rownum = idx - # Update display - self.invalidate_row(idx) + session.connection().execute(stmt, sqla_map) + + # Update playlist_rows + self.refresh_data(session) + + # Update display + self.invalidate_rows(list(row_map.keys())) def refresh_data(self, session: scoped_session): """Populate dicts for data calls""" diff --git a/archive/db_experiments.py b/archive/db_experiments.py index ef03e05..352b143 100755 --- a/archive/db_experiments.py +++ b/archive/db_experiments.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 -from sqlalchemy import create_engine, String, update +from sqlalchemy import create_engine, String, update, bindparam, case from sqlalchemy.orm import ( DeclarativeBase, Mapped, @@ -45,29 +45,40 @@ Base.metadata.create_all(engine) inital_number_of_records = 10 + +def move_rows(session): + new_row = 6 + + with Session() as session: + # new_record = Rhys(session, new_row, f"new {new_row=}") + # Move rows + + stmt = ( + update(Rhys) + .where(Rhys.ref_number > new_row) + # .where(Rhys.id.in_(session.query(Rhys.id).order_by(Rhys.id.desc()))) + .values({Rhys.ref_number: Rhys.ref_number + 1}) + ) + + session.execute(stmt) + + +sqla_map = [] +for k, v in zip(range(11), [0, 1, 2, 3, 4, 7, 8, 10, 5, 6, 9]): + sqla_map.append({"oldrow": k, "newrow": v}) + +# for a, b in sqla_map.items(): +# print(f"{a} > {b}") + with Session() as session: for a in range(inital_number_of_records): _ = Rhys(session, a, f"record: {a}") -new_row = 6 - -with Session() as session: - # new_record = Rhys(session, new_row, f"new {new_row=}") - # Move rows - - stmt = ( - update(Rhys) - .where(Rhys.ref_number > new_row) - # .where(Rhys.id.in_(session.query(Rhys.id).order_by(Rhys.id.desc()))) - .values({Rhys.ref_number: Rhys.ref_number + 1}) + stmt = update(Rhys).values( + ref_number=case( + {item['oldrow']: item['newrow'] for item in sqla_map}, + value=Rhys.ref_number + ) ) - session.execute(stmt) - -# for rec in range(new_row, inital_number_of_records): -# session.connection().execute( -# update(Rhys) -# .execution_options(synchronize_session=None) -# .where(Rhys.ref_number > new_row), -# [{"ref_number": Rhys.ref_number + 1}], -# ) + session.connection().execute(stmt, sqla_map) diff --git a/test_playlistmodel.py b/test_playlistmodel.py index ac8a993..7981e8d 100644 --- a/test_playlistmodel.py +++ b/test_playlistmodel.py @@ -14,8 +14,8 @@ def create_model_with_playlist_rows( for row in range(rows): plr = model._insert_row(session, row) newrow = plr.plr_rownum + plr.note = str(newrow) model.playlist_rows[newrow] = playlistmodel.PlaylistRowData(plr) - model.playlist_rows[newrow].note = str(newrow) session.commit() return model @@ -174,7 +174,10 @@ def test_insert_header_row_end(monkeypatch, session): prd = model.playlist_rows[model.rowCount() - 1] # Test against edit_role because display_role for headers is # handled differently (sets up row span) - assert model.edit_role(model.rowCount(), playlistmodel.Col.NOTE.value, prd) == note_text + assert ( + model.edit_role(model.rowCount(), playlistmodel.Col.NOTE.value, prd) + == note_text + ) def test_insert_header_row_middle(monkeypatch, session): @@ -191,4 +194,7 @@ def test_insert_header_row_middle(monkeypatch, session): prd = model.playlist_rows[insert_row] # Test against edit_role because display_role for headers is # handled differently (sets up row span) - assert model.edit_role(model.rowCount(), playlistmodel.Col.NOTE.value, prd) == note_text + assert ( + model.edit_role(model.rowCount(), playlistmodel.Col.NOTE.value, prd) + == note_text + )