move_rows implemented; all tests pass

This commit is contained in:
Keith Edmunds 2023-10-28 11:30:37 +01:00
parent afb8ddfaf5
commit 3832d9300c
3 changed files with 88 additions and 63 deletions

View File

@ -1,6 +1,6 @@
from datetime import datetime from datetime import datetime
from enum import auto, Enum from enum import auto, Enum
from sqlalchemy import update from sqlalchemy import bindparam, update
from typing import List, Optional, TYPE_CHECKING from typing import List, Optional, TYPE_CHECKING
from dbconfig import scoped_session, Session from dbconfig import scoped_session, Session
@ -364,56 +364,64 @@ class PlaylistModel(QAbstractTableModel):
for modified_row in modified_rows: for modified_row in modified_rows:
self.invalidate_row(modified_row) self.invalidate_row(modified_row)
def move_rows(self, from_rows: List[int], to_row: int) -> None: def move_rows(self, from_rows: List[int], to_row_number: int) -> None:
""" """
Move the playlist rows given to to_row and below. Move the playlist rows given to to_row and below.
""" """
new_playlist_rows: dict[int, PlaylistRowData] = {} # Build a {current_row_number: new_row_number} dictionary
row_map: dict[int, int] = {}
# Move the from_row records from the playlist_rows dict to the # Put the from_row row numbers into the row_map. Ultimately the
# new_playlist_rows dict. The total number of elements in the # total number of elements in the playlist doesn't change, so
# playlist doesn't change, so check that adding the moved rows # check that adding the moved rows starting at to_row won't
# starting at to_row won't overshoot the end of the playlist. # overshoot the end of the playlist.
if to_row + len(from_rows) > len(self.playlist_rows): if to_row_number + len(from_rows) > len(self.playlist_rows):
next_to_row = len(self.playlist_rows) - len(from_rows) next_to_row = len(self.playlist_rows) - len(from_rows)
else: else:
next_to_row = to_row next_to_row = to_row_number
for from_row in from_rows: for from_row, to_row in zip(
new_playlist_rows[next_to_row] = self.playlist_rows[from_row] from_rows, range(next_to_row, next_to_row + len(from_rows))
del self.playlist_rows[from_row] ):
next_to_row += 1 row_map[from_row] = to_row
# Move the remaining rows to the row_map. We want to fill it
# before (if there are gaps) and after (likewise) the rows that
# are moving.
# This iterates old_row and new_row simultaneously.
for old_row, new_row in zip(
[x for x in self.playlist_rows.keys() if x not in from_rows],
[y for y in range(len(self.playlist_rows)) if y not in row_map.values()],
):
# Optimise: only add to map if there is a change
row_map[old_row] = new_row
# Move the remaining rows to the gaps in new_playlist_rows # For SQLAlchemy, build a list of dictionaries that map plrid to
new_row = 0 # new row number:
for old_row in self.playlist_rows.keys(): sqla_map: List[dict[str, int]] = []
# Find next gap for oldrow, newrow in row_map.items():
while new_row in new_playlist_rows: plrid = self.playlist_rows[oldrow].plrid
new_row += 1 sqla_map.append({"plrid": plrid, "plr_rownum": newrow})
new_playlist_rows[new_row] = self.playlist_rows[old_row]
new_row += 1
# Make copy of rows live # Update database. Ref:
self.playlist_rows = new_playlist_rows # https://docs.sqlalchemy.org/en/20/core/sqlelement.html#sqlalchemy.sql.expression.case
stmt = (
update(PlaylistRows)
.where(
PlaylistRows.playlist_id == self.playlist_id,
PlaylistRows.id == bindparam("plrid"),
)
.values(plr_rownum=bindparam("plr_rownum"))
)
# Update PlaylistRows table and notify display of rows that
# moved
with Session() as session: with Session() as session:
for idx in range(len(self.playlist_rows)): session.connection().execute(stmt, sqla_map)
if self.playlist_rows[idx].plr_rownum == idx:
continue # Update playlist_rows
# Row number in this row is incorred. Fix it in self.refresh_data(session)
# database:
plr = session.get(PlaylistRows, self.playlist_rows[idx].plrid)
if not plr:
print(f"\nCan't find plr in playlistmodel:move_rows {idx=}")
continue
plr.plr_rownum = idx
# Fix in self.playlist_rows
self.playlist_rows[idx].plr_rownum = idx
# Update display # Update display
self.invalidate_row(idx) self.invalidate_rows(list(row_map.keys()))
def refresh_data(self, session: scoped_session): def refresh_data(self, session: scoped_session):
"""Populate dicts for data calls""" """Populate dicts for data calls"""

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
from sqlalchemy import create_engine, String, update from sqlalchemy import create_engine, String, update, bindparam, case
from sqlalchemy.orm import ( from sqlalchemy.orm import (
DeclarativeBase, DeclarativeBase,
Mapped, Mapped,
@ -45,10 +45,8 @@ Base.metadata.create_all(engine)
inital_number_of_records = 10 inital_number_of_records = 10
with Session() as session:
for a in range(inital_number_of_records):
_ = Rhys(session, a, f"record: {a}")
def move_rows(session):
new_row = 6 new_row = 6
with Session() as session: with Session() as session:
@ -64,10 +62,23 @@ with Session() as session:
session.execute(stmt) session.execute(stmt)
# for rec in range(new_row, inital_number_of_records):
# session.connection().execute( sqla_map = []
# update(Rhys) for k, v in zip(range(11), [0, 1, 2, 3, 4, 7, 8, 10, 5, 6, 9]):
# .execution_options(synchronize_session=None) sqla_map.append({"oldrow": k, "newrow": v})
# .where(Rhys.ref_number > new_row),
# [{"ref_number": Rhys.ref_number + 1}], # for a, b in sqla_map.items():
# ) # print(f"{a} > {b}")
with Session() as session:
for a in range(inital_number_of_records):
_ = Rhys(session, a, f"record: {a}")
stmt = update(Rhys).values(
ref_number=case(
{item['oldrow']: item['newrow'] for item in sqla_map},
value=Rhys.ref_number
)
)
session.connection().execute(stmt, sqla_map)

View File

@ -14,8 +14,8 @@ def create_model_with_playlist_rows(
for row in range(rows): for row in range(rows):
plr = model._insert_row(session, row) plr = model._insert_row(session, row)
newrow = plr.plr_rownum newrow = plr.plr_rownum
plr.note = str(newrow)
model.playlist_rows[newrow] = playlistmodel.PlaylistRowData(plr) model.playlist_rows[newrow] = playlistmodel.PlaylistRowData(plr)
model.playlist_rows[newrow].note = str(newrow)
session.commit() session.commit()
return model return model
@ -174,7 +174,10 @@ def test_insert_header_row_end(monkeypatch, session):
prd = model.playlist_rows[model.rowCount() - 1] prd = model.playlist_rows[model.rowCount() - 1]
# Test against edit_role because display_role for headers is # Test against edit_role because display_role for headers is
# handled differently (sets up row span) # handled differently (sets up row span)
assert model.edit_role(model.rowCount(), playlistmodel.Col.NOTE.value, prd) == note_text assert (
model.edit_role(model.rowCount(), playlistmodel.Col.NOTE.value, prd)
== note_text
)
def test_insert_header_row_middle(monkeypatch, session): def test_insert_header_row_middle(monkeypatch, session):
@ -191,4 +194,7 @@ def test_insert_header_row_middle(monkeypatch, session):
prd = model.playlist_rows[insert_row] prd = model.playlist_rows[insert_row]
# Test against edit_role because display_role for headers is # Test against edit_role because display_role for headers is
# handled differently (sets up row span) # handled differently (sets up row span)
assert model.edit_role(model.rowCount(), playlistmodel.Col.NOTE.value, prd) == note_text assert (
model.edit_role(model.rowCount(), playlistmodel.Col.NOTE.value, prd)
== note_text
)