Guard against erroneous SQL statements in queries

This commit is contained in:
Keith Edmunds 2025-02-15 04:28:52 +00:00
parent e6404d075e
commit 678515403c
3 changed files with 65 additions and 50 deletions

View File

@ -15,14 +15,17 @@ from sqlalchemy import (
delete, delete,
func, func,
select, select,
text,
update, update,
) )
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError, ProgrammingError
from sqlalchemy.orm.exc import NoResultFound from sqlalchemy.orm.exc import NoResultFound
from sqlalchemy.orm import joinedload from sqlalchemy.orm import joinedload
from sqlalchemy.orm.session import Session from sqlalchemy.orm.session import Session
from sqlalchemy.engine.row import RowMapping
# App imports # App imports
from classes import ApplicationError
from config import Config from config import Config
from dbmanager import DatabaseManager from dbmanager import DatabaseManager
import dbtables import dbtables
@ -38,6 +41,17 @@ if "unittest" in sys.modules and "sqlite" not in DATABASE_URL:
db = DatabaseManager.get_instance(DATABASE_URL, engine_options=Config.ENGINE_OPTIONS).db db = DatabaseManager.get_instance(DATABASE_URL, engine_options=Config.ENGINE_OPTIONS).db
def run_sql(session: Session, sql: str) -> Sequence[RowMapping]:
"""
Run a sql string and return results
"""
try:
return session.execute(text(sql)).mappings().all()
except ProgrammingError as e:
raise ApplicationError(e)
# Database classes # Database classes
class NoteColours(dbtables.NoteColoursTable): class NoteColours(dbtables.NoteColoursTable):
def __init__( def __init__(

View File

@ -61,7 +61,7 @@ from classes import (
from config import Config from config import Config
from dialogs import TrackSelectDialog from dialogs import TrackSelectDialog
from file_importer import FileImporter from file_importer import FileImporter
from helpers import file_is_unreadable from helpers import ask_yes_no, file_is_unreadable, ms_to_mmss, show_OK
from log import log from log import log
from models import ( from models import (
db, db,
@ -83,7 +83,6 @@ from ui.main_window_playlist_ui import Ui_PlaylistSection # type: ignore
from ui.main_window_footer_ui import Ui_FooterSection # type: ignore from ui.main_window_footer_ui import Ui_FooterSection # type: ignore
from utilities import check_db, update_bitrates from utilities import check_db, update_bitrates
import helpers
class DownloadCSV(QDialog): class DownloadCSV(QDialog):
@ -869,8 +868,8 @@ class Window(QMainWindow):
# Don't allow window to close when a track is playing # Don't allow window to close when a track is playing
if track_sequence.current and track_sequence.current.is_playing(): if track_sequence.current and track_sequence.current.is_playing():
event.ignore() event.ignore()
helpers.show_warning( self.show_warning(
self, "Track playing", "Can't close application while track is playing" "Track playing", "Can't close application while track is playing"
) )
else: else:
with db.Session() as session: with db.Session() as session:
@ -927,7 +926,7 @@ class Window(QMainWindow):
current_track_playlist_id = track_sequence.current.playlist_id current_track_playlist_id = track_sequence.current.playlist_id
if current_track_playlist_id: if current_track_playlist_id:
if closing_tab_playlist_id == current_track_playlist_id: if closing_tab_playlist_id == current_track_playlist_id:
helpers.show_OK( show_OK(
"Current track", "Can't close current track playlist", self "Current track", "Can't close current track playlist", self
) )
return False return False
@ -937,7 +936,7 @@ class Window(QMainWindow):
next_track_playlist_id = track_sequence.next.playlist_id next_track_playlist_id = track_sequence.next.playlist_id
if next_track_playlist_id: if next_track_playlist_id:
if closing_tab_playlist_id == next_track_playlist_id: if closing_tab_playlist_id == next_track_playlist_id:
helpers.show_OK( show_OK(
"Next track", "Can't close next track playlist", self "Next track", "Can't close next track playlist", self
) )
return False return False
@ -1053,7 +1052,7 @@ class Window(QMainWindow):
playlist_id = self.current.playlist_id playlist_id = self.current.playlist_id
playlist = session.get(Playlists, playlist_id) playlist = session.get(Playlists, playlist_id)
if playlist: if playlist:
if helpers.ask_yes_no( if ask_yes_no(
"Delete playlist", "Delete playlist",
f"Delete playlist '{playlist.name}': " "Are you sure?", f"Delete playlist '{playlist.name}': " "Are you sure?",
): ):
@ -1323,7 +1322,7 @@ class Window(QMainWindow):
self.playlist_section.tabPlaylist.setCurrentIndex(idx) self.playlist_section.tabPlaylist.setCurrentIndex(idx)
elif action == "Delete": elif action == "Delete":
if helpers.ask_yes_no( if ask_yes_no(
"Delete template", "Delete template",
f"Delete template '{playlist.name}': " "Are you sure?", f"Delete template '{playlist.name}': " "Are you sure?",
): ):
@ -1478,14 +1477,17 @@ class Window(QMainWindow):
def open_querylist(self) -> None: def open_querylist(self) -> None:
"""Open existing querylist""" """Open existing querylist"""
with db.Session() as session: try:
dlg = QueryDialog(session) with db.Session() as session:
if dlg.exec(): dlg = QueryDialog(session)
new_row_number = self.current_row_or_end() if dlg.exec():
for track_id in dlg.selected_tracks: new_row_number = self.current_row_or_end()
self.current.base_model.insert_row(new_row_number, track_id) for track_id in dlg.selected_tracks:
else: self.current.base_model.insert_row(new_row_number, track_id)
return # User cancelled else:
return # User cancelled
except ApplicationError as e:
self.show_warning("Query error", f"Your query gave an error:\n\n{e}")
def open_songfacts_browser(self, title: str) -> None: def open_songfacts_browser(self, title: str) -> None:
"""Search Songfacts for title""" """Search Songfacts for title"""
@ -1759,7 +1761,7 @@ class Window(QMainWindow):
msg = "Hit return to play next track now" msg = "Hit return to play next track now"
else: else:
msg = "Press tab to select Yes and hit return to play next track" msg = "Press tab to select Yes and hit return to play next track"
if not helpers.ask_yes_no( if not ask_yes_no(
"Play next track", "Play next track",
msg, msg,
default_yes=default_yes, default_yes=default_yes,
@ -1829,12 +1831,12 @@ class Window(QMainWindow):
template_name = dlg.textValue() template_name = dlg.textValue()
if template_name not in template_names: if template_name not in template_names:
break break
helpers.show_warning( self.show_warning(
self, "Duplicate template", "Template name already in use" "Duplicate template", "Template name already in use"
) )
Playlists.save_as_template(session, self.current.playlist_id, template_name) Playlists.save_as_template(session, self.current.playlist_id, template_name)
session.commit() session.commit()
helpers.show_OK("Template", "Template saved", self) show_OK("Template", "Template saved", self)
def search_playlist(self) -> None: def search_playlist(self) -> None:
"""Show text box to search playlist""" """Show text box to search playlist"""
@ -1980,8 +1982,7 @@ class Window(QMainWindow):
if Playlists.name_is_available(session, proposed_name): if Playlists.name_is_available(session, proposed_name):
return proposed_name return proposed_name
else: else:
helpers.show_warning( self.show_warning(
self,
"Name in use", "Name in use",
f"There's already a playlist called '{proposed_name}'", f"There's already a playlist called '{proposed_name}'",
) )
@ -2084,16 +2085,16 @@ class Window(QMainWindow):
if track_sequence.current and track_sequence.current.is_playing(): if track_sequence.current and track_sequence.current.is_playing():
# Elapsed time # Elapsed time
self.header_section.label_elapsed_timer.setText( self.header_section.label_elapsed_timer.setText(
helpers.ms_to_mmss(track_sequence.current.time_playing()) ms_to_mmss(track_sequence.current.time_playing())
+ " / " + " / "
+ helpers.ms_to_mmss(track_sequence.current.duration) + ms_to_mmss(track_sequence.current.duration)
) )
# Time to fade # Time to fade
time_to_fade = track_sequence.current.time_to_fade() time_to_fade = track_sequence.current.time_to_fade()
time_to_silence = track_sequence.current.time_to_silence() time_to_silence = track_sequence.current.time_to_silence()
self.footer_section.label_fade_timer.setText( self.footer_section.label_fade_timer.setText(
helpers.ms_to_mmss(time_to_fade) ms_to_mmss(time_to_fade)
) )
# If silent in the next 5 seconds, put warning colour on # If silent in the next 5 seconds, put warning colour on
@ -2132,7 +2133,7 @@ class Window(QMainWindow):
self.footer_section.frame_fade.setStyleSheet("") self.footer_section.frame_fade.setStyleSheet("")
self.footer_section.label_silent_timer.setText( self.footer_section.label_silent_timer.setText(
helpers.ms_to_mmss(time_to_silence) ms_to_mmss(time_to_silence)
) )
def update_headers(self) -> None: def update_headers(self) -> None:

View File

@ -26,6 +26,7 @@ from sqlalchemy.orm.session import Session
# App imports # App imports
from classes import ( from classes import (
ApplicationError,
QueryCol, QueryCol,
) )
from config import Config from config import Config
@ -33,9 +34,10 @@ from helpers import (
file_is_unreadable, file_is_unreadable,
get_relative_date, get_relative_date,
ms_to_mmss, ms_to_mmss,
show_warning,
) )
from log import log from log import log
from models import db, Playdates from models import db, Playdates, run_sql
from music_manager import RowAndTrack from music_manager import RowAndTrack
@ -220,33 +222,31 @@ class QuerylistModel(QAbstractTableModel):
Populate self.querylist_rows Populate self.querylist_rows
""" """
# TODO: Move the SQLAlchemy parts to models later, but for now as proof
# of concept we'll keep it here.
from sqlalchemy import text
# Clear any exsiting rows # Clear any exsiting rows
self.querylist_rows = {} self.querylist_rows = {}
row = 0 row = 0
results = self.session.execute(text(self.sql)).mappings().all() try:
for result in results: results = run_sql(self.session, self.sql)
if hasattr(result, "lastplayed"): for result in results:
lastplayed = result["lastplayed"] if hasattr(result, "lastplayed"):
else: lastplayed = result["lastplayed"]
lastplayed = None else:
queryrow = QueryRow( lastplayed = None
artist=result["artist"], queryrow = QueryRow(
bitrate=result["bitrate"], artist=result["artist"],
duration=result["duration"], bitrate=result["bitrate"],
lastplayed=lastplayed, duration=result["duration"],
path=result["path"], lastplayed=lastplayed,
title=result["title"], path=result["path"],
track_id=result["id"], title=result["title"],
) track_id=result["id"],
)
self.querylist_rows[row] = queryrow self.querylist_rows[row] = queryrow
row += 1 row += 1
except ApplicationError as e:
show_warning(None, "Query error", f"Error loading query data ({e})")
def rowCount(self, index: QModelIndex = QModelIndex()) -> int: def rowCount(self, index: QModelIndex = QModelIndex()) -> int:
"""Standard function for view""" """Standard function for view"""