From f9943dc1c412ca7d36cda0689e0232800bb0e423 Mon Sep 17 00:00:00 2001 From: Keith Edmunds Date: Tue, 21 Jan 2025 21:26:06 +0000 Subject: [PATCH] WIP file_importer rewrite, one test written and working --- app/file_importer.py | 892 +++++++++++++++++++++--------------- app/musicmuster.py | 6 +- app/playlists.py | 2 +- app/ui/main_window_ui.py | 2 +- tests/test_file_importer.py | 82 ++++ tests/test_ui.py | 4 - 6 files changed, 601 insertions(+), 387 deletions(-) create mode 100644 tests/test_file_importer.py diff --git a/app/file_importer.py b/app/file_importer.py index d3c8c8b..7ddd643 100644 --- a/app/file_importer.py +++ b/app/file_importer.py @@ -1,10 +1,9 @@ -# Standard library imports from __future__ import annotations from dataclasses import dataclass, field from fuzzywuzzy import fuzz # type: ignore import os.path -from typing import Optional +from typing import Optional, Sequence import os import shutil @@ -46,12 +45,522 @@ from playlistmodel import PlaylistModel import helpers +@dataclass +class ThreadData: + """ + Data structure to hold details of the import thread context + """ + + base_model: PlaylistModel + row_number: int + worker: Optional[DoTrackImport] = None + + +@dataclass +class TrackFileData: + """ + Data structure to hold details of file to be imported + """ + + tags: Tags = Tags() + destination_path: str = "" + import_this_file: bool = True + error: str = "" + file_path_to_remove: Optional[str] = None + track_id: int = 0 + track_match_data: list[TrackMatchData] = field(default_factory=list) + + +@dataclass +class TrackMatchData: + """ + Data structure to hold details of existing files that are similar to + the file being imported. + """ + + artist: str + artist_match: float + title: str + title_match: float + track_id: int + + +class FileImporter: + """ + Class to manage the import of new tracks. Sanity checks are carried + out before processing each track. + + They may replace existing tracks, be imported as new tracks, or the + import may be skipped altogether. The user decides which of these in + the UI managed by the PickMatch class. + + The actual import is handled by the DoTrackImport class. + """ + + def __init__( + self, base_model: PlaylistModel, row_number: Optional[int] = None + ) -> None: + """ + Set up class + """ + + # Create ModelData + if not row_number: + row_number = base_model.rowCount() + self.model_data = ThreadData(base_model=base_model, row_number=row_number) + + # Populate self.import_files_data + for infile in [ + os.path.join(Config.REPLACE_FILES_DEFAULT_SOURCE, f) + for f in os.listdir(Config.REPLACE_FILES_DEFAULT_SOURCE) + if f.endswith((".mp3", ".flac")) + ]: + self.import_files_data[infile] = TrackFileData() + + # Place to keep a reference to importer threads + self.threads: list[QThread] = [] + + # Data structure to track files to import + self.import_files_data: dict[str, TrackFileData] = {} + + # Dictionary of exsting tracks indexed by track.id + self.existing_tracks = self._get_existing_tracks() + + self.signals = MusicMusterSignals() + + def _get_existing_tracks(self) -> Sequence[Tracks]: + """ + Return a list of all existing Tracks + """ + + with db.Session() as session: + return Tracks.get_all(session) + + def do_import(self) -> None: + """ + Populate self.import_files_data, which is a TrackFileData object for each entry. + + - Validate files to be imported + - Find matches and similar files + - Get user choices for each import file + - Validate self.import_files_data integrity + - Tell the user which files won't be imported and why + - Import the files, one by one. + """ + + if not self.import_files_data: + show_OK( + "File import", + f"No files in {Config.REPLACE_FILES_DEFAULT_SOURCE} to import", + None, + ) + return + + for path in self.import_files_data.keys(): + self.validate_file(path) + if self.import_files_data[path].import_this_file: + self.find_similar(path) + if len(self.import_files_data[path].track_match_data) > 1: + self.sort_track_match_data(path) + selection = self.get_user_choices(path) + self.process_selection(path, selection) + if self.import_files_data[path].import_this_file: + self.validate_file_data(path) + + # Tell user which files won't be imported and why + self.inform_user() + # Start the import of all other files + self.import_next_file() + + def validate_file(self, path: str) -> None: + """ + - check all files are readable + - check all files have tags + - Mark failures not to be imported and populate error text. + + On return, the following TrackFileData fields should be set: + + tags: Yes + destination_path: No + import_this_file: Yes (set by default) + error: No (only set if an error is detected) + file_path_to_remove: No + track_id: No + track_match_data: No + """ + + for path in self.import_files_data.keys(): + if file_is_unreadable(path): + self.import_files_data[path].import_this_file = False + self.import_files_data[path].error = f"{path} is unreadable" + continue + + try: + self.import_files_data[path].tags = get_tags(path) + except ApplicationError as e: + self.import_files_data[path].import_this_file = False + self.import_files_data[path].error = f"Tag errors ({str(e)})" + continue + + def find_similar(self, path: str) -> None: + """ + - Search title in existing tracks + - if score >= Config.FUZZYMATCH_MINIMUM_LIST: + - get artist score + - add TrackMatchData to self.import_files_data[path].track_match_data + + On return, the following TrackFileData fields should be set: + + tags: Yes + destination_path: No + import_this_file: Yes (set by default) + error: No (only set if an error is detected) + file_path_to_remove: No + track_id: No + track_match_data: YES, IN THIS FUNCTION + """ + + title = self.import_files_data[path].tags.title + artist = self.import_files_data[path].tags.artist + + for existing_track in self.existing_tracks: + title_score = self._get_match_score(title, existing_track.title) + if title_score >= Config.FUZZYMATCH_MINIMUM_LIST: + artist_score = self._get_match_score(artist, existing_track.artist) + self.import_files_data[path].track_match_data.append( + TrackMatchData( + artist=existing_track.artist, + artist_match=artist_score, + title=existing_track.title, + title_match=title_score, + track_id=existing_track.id, + ) + ) + + def sort_track_match_data(self, path: str) -> None: + """ + Sort matched tracks in artist-similarity order + """ + + self.import_files_data[path].track_match_data.sort( + key=lambda x: x.artist_match, reverse=True + ) + + def _get_match_score(self, str1: str, str2: str) -> float: + """ + Return the score of how well str1 matches str2. + """ + + ratio = fuzz.ratio(str1, str2) + partial_ratio = fuzz.partial_ratio(str1, str2) + token_sort_ratio = fuzz.token_sort_ratio(str1, str2) + token_set_ratio = fuzz.token_set_ratio(str1, str2) + + # Combine scores + combined_score = ( + ratio * 0.25 + + partial_ratio * 0.25 + + token_sort_ratio * 0.25 + + token_set_ratio * 0.25 + ) + + return combined_score + + def get_user_choices(self, path: str) -> int: + """ + Find out whether user wants to import this as a new track, + overwrite an existing track or not import it at all. + + Return -1 (user cancelled) 0 (import as new) >0 (replace track id) + """ + + # Build a list of (track title and artist, track_id, track path) + choices: list[tuple[str, int, str]] = [] + + # First choices are always a) don't import 2) import as a new track + choices.append((Config.DO_NOT_IMPORT, -1, "")) + choices.append((Config.IMPORT_AS_NEW, 0, "")) + + # New track details + new_track_description = ( + f"{self.import_files_data[path].tags.title} " + f"({self.import_files_data[path].tags.artist})" + ) + + # Select 'import as new' as default unless the top match is good + # enough + default = 1 + track_match_data = self.import_files_data[path].track_match_data + if track_match_data: + if ( + track_match_data[0].artist_match + >= Config.FUZZYMATCH_MINIMUM_SELECT_ARTIST + and track_match_data[0].title_match + >= Config.FUZZYMATCH_MINIMUM_SELECT_TITLE + ): + default = 2 + + for xt in track_match_data: + xt_description = f"{xt.title} ({xt.artist})" + if Config.FUZZYMATCH_SHOW_SCORES: + xt_description += f" ({xt.title_match:.0f}%)" + existing_track_path = self._get_existing_track(xt.track_id).path + choices.append( + ( + xt_description, + xt.track_id, + existing_track_path, + ) + ) + + dialog = PickMatch( + new_track_description=new_track_description, + choices=choices, + default=default, + ) + if dialog.exec(): + return dialog.selected_track_id + else: + return -1 + + def process_selection(self, path: str, selection: int) -> None: + """ + Process selection from PickMatch + """ + + if selection < 0: + # User cancelled + self.import_files_data[path].import_this_file = False + self.import_files_data[path].error = "you asked not to import this file" + + elif selection > 0: + # Import and replace track + self.replace_file(path=path, track_id=selection) + + else: + # Import as new + self.import_as_new(path=path) + + def replace_file(self, path: str, track_id: int) -> None: + """ + Set up to replace an existing file. + + On return, the following TrackFileData fields should be set: + + tags: Yes + destination_path: YES, IN THIS FUNCTION + import_this_file: Yes (set by default) + error: No (only set if an error is detected) + file_path_to_remove: YES, IN THIS FUNCTION + track_id: YES, IN THIS FUNCTION + track_match_data: Yes + """ + + ifd = self.import_files_data[path] + + if track_id < 1: + raise ApplicationError(f"No track ID: replace_file({path=}, {track_id=})") + + ifd.track_id = track_id + + existing_track_path = self._get_existing_track(track_id).path + ifd.file_path_to_remove = existing_track_path + + # If the existing file in the Config.IMPORT_DESTINATION + # directory, replace it with the imported file name; otherwise, + # use the existing file name. This so that we don't change file + # names from CDs, etc. + + if os.path.dirname(existing_track_path) == Config.IMPORT_DESTINATION: + ifd.destination_path = os.path.join( + Config.IMPORT_DESTINATION, os.path.basename(path) + ) + else: + ifd.destination_path = existing_track_path + + def _get_existing_track(self, track_id: int) -> Tracks: + """ + Lookup in existing track in the local cache and return it + """ + + existing_track_records = [a for a in self.existing_tracks if a.id == track_id] + if len(existing_track_records) != 1: + raise ApplicationError( + f"Internal error in _get_existing_track: {existing_track_records=}" + ) + + return existing_track_records[0] + + def import_as_new(self, path: str) -> None: + """ + Set up to import as a new file. + + On return, the following TrackFileData fields should be set: + + tags: Yes + destination_path: YES, IN THIS FUNCTION + import_this_file: Yes (set by default) + error: No (only set if an error is detected) + file_path_to_remove: No (not needed now) + track_id: Yes + track_match_data: Yes + """ + + ifd = self.import_files_data[path] + ifd.destination_path = os.path.join( + Config.IMPORT_DESTINATION, os.path.basename(path) + ) + + def validate_file_data(self, path: str) -> None: + """ + Check the data structures for integrity + """ + + ifd = self.import_files_data[path] + + # Check import_this_file + if not ifd.import_this_file: + return + + # Check tags + if not (ifd.tags.artist and ifd.tags.title): + raise ApplicationError(f"validate_file_data: {ifd.tags=}, {path=}") + + # Check file_path_to_remove + if ifd.file_path_to_remove and not os.path.exists(ifd.file_path_to_remove): + # File to remove is missing, but this isn't a major error. We + # may be importing to replace a deleted file. + ifd.file_path_to_remove = "" + + # Check destination_path + if not ifd.destination_path: + raise ApplicationError( + f"validate_file_data: no destination path set ({path=})" + ) + + # If destination path is the same as file_path_to_remove, that's + # OK, otherwise if this is a new import then check check + # destination path doesn't already exists + if ifd.track_id == 0 and ifd.destination_path != ifd.file_path_to_remove: + while os.path.exists(ifd.destination_path): + msg = ( + "New import requested but default destination path ({ifd.destination_path}) " + "already exists. Click OK and choose where to save this track" + ) + show_OK(title="Desintation path exists", msg=msg, parent=None) + # Get output filename + pathspec = QFileDialog.getSaveFileName( + None, + "Save imported track", + directory=Config.IMPORT_DESTINATION, + ) + if pathspec: + ifd.destination_path = pathspec[0] + else: + ifd.import_this_file = False + ifd.error = "destination file already exists" + return + + # Check track_id + if ifd.track_id < 0: + raise ApplicationError(f"validate_file_data: track_id < 0, {path=}") + + def inform_user(self) -> None: + """ + Tell user about files that won't be imported + """ + + msgs: list[str] = [] + for path, entry in self.import_files_data.items(): + if entry.import_this_file is False: + msgs.append( + f"{os.path.basename(path)} will not be imported because {entry.error}" + ) + if msgs: + show_OK("File not imported", "\r\r".join(msgs)) + + def import_next_file(self) -> None: + """ + Import the next file sequentially. + """ + + while True: + if not self.import_files_data: + self.signals.status_message_signal.emit("All files imported", 10000) + return + + # Get details for next file to import + path, tfd = self.import_files_data.popitem() + if tfd.import_this_file: + break + + print(f"import_next_file {path=}") + + # Create and start a thread for processing + worker = DoTrackImport( + import_file_path=path, + tags=tfd.tags, + destination_path=tfd.destination_path, + track_id=tfd.track_id, + ) + thread = QThread() + self.threads.append(thread) + + # Move worker to thread + worker.moveToThread(thread) + + # Connect signals and slots + thread.started.connect(worker.run) + thread.started.connect(lambda: print(f"Thread starting for {path=}")) + + worker.import_finished.connect(self.post_import_processing) + worker.import_finished.connect(thread.quit) + worker.import_finished.connect(lambda: print(f"Worker ended for {path=}")) + + # Ensure cleanup only after thread is fully stopped + thread.finished.connect(lambda: self.cleanup_thread(thread, worker)) + thread.finished.connect(lambda: print(f"Thread ended for {path=}")) + + # Start the thread + print(f"Calling thread.start() for {path=}") + thread.start() + + def cleanup_thread(self, thread, worker): + """ + Remove references to finished threads/workers to prevent leaks. + """ + + worker.deleteLater() + thread.deleteLater() + if thread in self.threads: + self.threads.remove(thread) + + def post_import_processing(self, track_id: int) -> None: + """ + If track already in playlist, refresh it else insert it + """ + + log.debug(f"post_import_processing({track_id=})") + + if self.model_data: + if self.model_data.base_model: + self.model_data.base_model.update_or_insert( + track_id, self.model_data.row_number + ) + + # Process next file + self.import_next_file() + + class DoTrackImport(QObject): - import_finished = pyqtSignal(int, QThread) + """ + Class to manage the actual import of tracks in a thread. + """ + + import_finished = pyqtSignal(int) def __init__( self, - associated_thread: QThread, import_file_path: str, tags: Tags, destination_path: str, @@ -62,7 +571,6 @@ class DoTrackImport(QObject): """ super().__init__() - self.associated_thread = associated_thread self.import_file_path = import_file_path self.tags = tags self.destination_track_path = destination_path @@ -129,355 +637,7 @@ class DoTrackImport(QObject): self.signals.status_message_signal.emit( f"{os.path.basename(self.import_file_path)} imported", 10000 ) - self.import_finished.emit(track.id, self.associated_thread) - - -class FileImporter: - """ - Manage importing of files - """ - - def __init__( - self, base_model: PlaylistModel, row_number: Optional[int] = None - ) -> None: - """ - Set up class - """ - - # Create ModelData - if not row_number: - row_number = base_model.rowCount() - self.model_data = ThreadData(base_model=base_model, row_number=row_number) - - # Place to keep reference to importer threads and data - self.thread_data: dict[QThread, ThreadData] = {} - - # Data structure to track files to import - self.import_files_data: dict[str, TrackFileData] = {} - - # Dictionary of exsting tracks indexed by track.id - self.existing_tracks = self._get_existing_tracks() - - # Populate self.import_files_data - for infile in [ - os.path.join(Config.REPLACE_FILES_DEFAULT_SOURCE, f) - for f in os.listdir(Config.REPLACE_FILES_DEFAULT_SOURCE) - if f.endswith((".mp3", ".flac")) - ]: - self.import_files_data[infile] = TrackFileData() - - def do_import(self) -> None: - """ - Scan source directory and: - - check all file are readable - - load readable files and tags into self.import_files - - check all files are tagged - - check for exact match of existing file - - check for duplicates and replacements - - allow deselection of import for any one file - - import files and either replace existing or add to pool - """ - - # Check all file are readable and have tags. Mark failures not to - # be imported and populate error text. - for path in self.import_files_data.keys(): - if file_is_unreadable(path): - self.import_files_data[path].import_this_file = False - self.import_files_data[path].error = f"{path} is unreadable" - continue - - # Get tags - try: - self.import_files_data[path].tags = get_tags(path) - except ApplicationError as e: - self.import_files_data[path].import_this_file = False - self.import_files_data[path].error = f"Tag errors ({str(e)})" - continue - - # Get track match data - self.populate_track_match_data(path) - # Sort with best artist match first - self.import_files_data[path].track_match_data.sort( - key=lambda rec: rec.artist_match, reverse=True - ) - - # Process user choices - self.process_user_choices(path) - - # Import files and tell users about files that won't be imported - msgs: list[str] = [] - for (path, entry) in self.import_files_data.items(): - if entry.import_this_file: - self._start_thread(path) - else: - msgs.append( - f"{os.path.basename(path)} will not be imported because {entry.error}" - ) - if msgs: - show_OK("File not imported", "\r\r".join(msgs)) - - def _start_thread(self, path: str) -> None: - """ - Import the file specified by path - """ - - log.debug(f"_start_thread({path=})") - - # Create thread and worker - thread = QThread() - worker = DoTrackImport( - associated_thread=thread, - import_file_path=path, - tags=self.import_files_data[path].tags, - destination_path=self.import_files_data[path].destination_path, - track_id=self.import_files_data[path].track_id, - ) - - # Associate data with the thread - self.model_data.worker = worker - self.thread_data[thread] = self.model_data - - # Move worker to thread - worker.moveToThread(thread) - log.debug(f"_start_thread_worker started ({path=}, {id(thread)=}, {id(worker)=})") - - # Connect signals - thread.started.connect(lambda: log.debug(f"Thread {thread} started")) - thread.started.connect(worker.run) - - thread.finished.connect(lambda: log.debug(f"Thread {thread} finished")) - thread.finished.connect(thread.deleteLater) - - worker.import_finished.connect( - lambda: log.debug(f"Worker task finished for thread {thread}") - ) - worker.import_finished.connect(self._thread_finished) - worker.import_finished.connect(thread.quit) - worker.import_finished.connect(worker.deleteLater) - - # Start thread - thread.start() - - def _thread_finished(self, track_id: int, thread: QThread) -> None: - """ - If track already in playlist, refresh it else insert it - """ - - log.debug(f" Ending thread {thread}") - - model_data = self.thread_data.pop(thread, None) - if model_data: - if model_data.base_model: - model_data.base_model.update_or_insert(track_id, model_data.row_number) - - def _get_existing_track(self, track_id: int) -> Tracks: - """ - Lookup in existing track in the local cache and return it - """ - - existing_track_records = [a for a in self.existing_tracks if a.id == track_id] - if len(existing_track_records) != 1: - raise ApplicationError( - f"Internal error in _get_existing_track: {existing_track_records=}" - ) - - return existing_track_records[0] - - def _get_existing_tracks(self): - """ - Return a dictionary {title: Track} for all existing tracks - """ - - with db.Session() as session: - return Tracks.get_all(session) - - def _get_match_score(self, str1: str, str2: str) -> float: - """ - Return the score of how well str1 matches str2. - """ - - ratio = fuzz.ratio(str1, str2) - partial_ratio = fuzz.partial_ratio(str1, str2) - token_sort_ratio = fuzz.token_sort_ratio(str1, str2) - token_set_ratio = fuzz.token_set_ratio(str1, str2) - - # Combine scores - combined_score = ( - ratio * 0.25 - + partial_ratio * 0.25 - + token_sort_ratio * 0.25 - + token_set_ratio * 0.25 - ) - - return combined_score - - def populate_track_match_data(self, path: str) -> None: - """ - Populate self.import_files_data[path].track_match_data - - - Search title in existing tracks - - if score >= Config.FUZZYMATCH_MINIMUM_LIST: - - get artist score - - add TrackMatchData to self.import_files_data[path].track_match_data - """ - - title = self.import_files_data[path].tags.title - artist = self.import_files_data[path].tags.artist - - for track in self.existing_tracks: - title_score = self._get_match_score(title, track.title) - if title_score >= Config.FUZZYMATCH_MINIMUM_LIST: - artist_score = self._get_match_score(artist, track.artist) - self.import_files_data[path].track_match_data.append( - TrackMatchData( - artist=track.artist, - artist_match=artist_score, - title=track.title, - title_match=title_score, - track_id=track.id, - ) - ) - - def process_user_choices(self, path: str) -> None: - """ - Find out whether user wants to import this as a new track, - overwrite an existing track or not import it at all. - """ - - # Build a list of (track title and artist, track_id, track path) - choices: list[tuple[str, int, str]] = [] - - # First choice is always to import as a new track - choices.append((Config.DO_NOT_IMPORT, -1, "")) - choices.append((Config.IMPORT_AS_NEW, 0, "")) - - # New track details - importing_track_description = ( - f"{self.import_files_data[path].tags.title} " - f"({self.import_files_data[path].tags.artist})" - ) - - # Select 'import as new' as default unless the top match is good - # enough - default = 1 # default choice is import as new - track_match_data = self.import_files_data[path].track_match_data - try: - if track_match_data: - if ( - track_match_data[0].artist_match - >= Config.FUZZYMATCH_MINIMUM_SELECT_ARTIST - and track_match_data[0].title_match - >= Config.FUZZYMATCH_MINIMUM_SELECT_TITLE - ): - default = 2 - - for rec in track_match_data: - existing_track_description = f"{rec.title} ({rec.artist})" - if Config.FUZZYMATCH_SHOW_SCORES: - existing_track_description += f" ({rec.title_match:.0f}%)" - existing_track_path = self._get_existing_track(rec.track_id).path - choices.append( - (existing_track_description, rec.track_id, existing_track_path) - ) - except IndexError: - import pdb - - pdb.set_trace() - print(2) - dialog = PickMatch( - new_track_description=importing_track_description, - choices=choices, - default=default, - ) - if dialog.exec(): - if dialog.selected_track_id < 0: - self.import_files_data[path].import_this_file = False - self.import_files_data[path].error = "you asked not to import this file" - elif dialog.selected_track_id > 0: - self.replace_file(path=path, track_id=dialog.selected_track_id) - else: - # Import as new, but check destination path doesn't - # already exists - while os.path.exists(self.import_files_data[path].destination_path): - msg = ( - "New import requested but default destination path ({path}) " - "already exists. Click OK and choose where to save this track" - ) - import pdb - - pdb.set_trace() - show_OK(None, title="Desintation path exists", msg=msg) - # Get output filename - pathspec = QFileDialog.getSaveFileName( - None, - "Save imported track", - directory=Config.IMPORT_DESTINATION, - ) - if not pathspec: - self.import_files_data[path].import_this_file = False - self.import_files_data[ - path - ].error = "destination file already exists" - return - - self.import_files_data[path].destination_path = pathspec[0] - - self.import_as_new(path=path) - else: - # User cancelled dialog - self.import_files_data[path].import_this_file = False - self.import_files_data[path].error = "you cancelled the import of this file" - - def import_as_new(self, path: str) -> None: - """ - Import passed path as a new file - """ - - log.debug(f"Import as new, {path=}") - - tfd = self.import_files_data[path] - - destination_path = os.path.join( - Config.IMPORT_DESTINATION, os.path.basename(path) - ) - if os.path.exists(destination_path): - tfd.import_this_file = False - tfd.error = f"this is a new import but destination file already exists ({destination_path})" - return - - tfd.destination_path = destination_path - - def replace_file(self, path: str, track_id: int) -> None: - """ - Replace existing track {track_id=} with passed path - """ - - log.debug(f"Replace {track_id=} with {path=}") - - tfd = self.import_files_data[path] - - existing_track_path = self._get_existing_track(track_id).path - proposed_destination_path = os.path.join( - os.path.dirname(existing_track_path), os.path.basename(path) - ) - # if the destination path exists and it's not the path the - # track_id points to, abort - if existing_track_path != proposed_destination_path and os.path.exists( - proposed_destination_path - ): - tfd.import_this_file = False - tfd.error = f"New import would overwrite existing file ({proposed_destination_path})" - return - tfd.file_path_to_remove = existing_track_path - tfd.destination_path = proposed_destination_path - tfd.track_id = track_id - - -@dataclass -class ThreadData: - base_model: PlaylistModel - row_number: int - worker: Optional[DoTrackImport] = None + self.import_finished.emit(track.id) class PickMatch(QDialog): @@ -558,27 +718,3 @@ class PickMatch(QDialog): # Get the ID of the selected button self.selected_track_id = self.button_group.checkedId() self.accept() - - -@dataclass -class TrackFileData: - """ - Simple class to track details changes to a track file - """ - - tags: Tags = Tags() - destination_path: str = "" - import_this_file: bool = True - error: str = "" - file_path_to_remove: Optional[str] = None - track_id: int = 0 - track_match_data: list[TrackMatchData] = field(default_factory=list) - - -@dataclass -class TrackMatchData: - artist: str - artist_match: float - title: str - title_match: float - track_id: int diff --git a/app/musicmuster.py b/app/musicmuster.py index 6ecc41d..83c3594 100755 --- a/app/musicmuster.py +++ b/app/musicmuster.py @@ -533,7 +533,7 @@ class Window(QMainWindow, Ui_MainWindow): if current_track_playlist_id: if closing_tab_playlist_id == current_track_playlist_id: helpers.show_OK( - self, "Current track", "Can't close current track playlist" + "Current track", "Can't close current track playlist", self ) return False @@ -543,7 +543,7 @@ class Window(QMainWindow, Ui_MainWindow): if next_track_playlist_id: if closing_tab_playlist_id == next_track_playlist_id: helpers.show_OK( - self, "Next track", "Can't close next track playlist" + "Next track", "Can't close next track playlist", self ) return False @@ -1480,7 +1480,7 @@ class Window(QMainWindow, Ui_MainWindow): session, self.current.playlist_id, template_name ) session.commit() - helpers.show_OK(self, "Template", "Template saved") + helpers.show_OK("Template", "Template saved", self) def search_playlist(self) -> None: """Show text box to search playlist""" diff --git a/app/playlists.py b/app/playlists.py index 6497564..d994e6f 100644 --- a/app/playlists.py +++ b/app/playlists.py @@ -876,7 +876,7 @@ class PlaylistTab(QTableView): else: txt = f"Can't find info about row{row_number}" - show_OK(self.musicmuster, "Track info", txt) + show_OK("Track info", txt, self.musicmuster) def _mark_as_unplayed(self, row_numbers: list[int]) -> None: """Mark row as unplayed""" diff --git a/app/ui/main_window_ui.py b/app/ui/main_window_ui.py index a33a770..114c163 100644 --- a/app/ui/main_window_ui.py +++ b/app/ui/main_window_ui.py @@ -678,4 +678,4 @@ class Ui_MainWindow(object): self.actionSelect_duplicate_rows.setText(_translate("MainWindow", "Select duplicate rows...")) self.actionImport_files.setText(_translate("MainWindow", "Import files...")) from infotabs import InfoTabs -from pyqtgraph import PlotWidget +from pyqtgraph import PlotWidget # type: ignore diff --git a/tests/test_file_importer.py b/tests/test_file_importer.py new file mode 100644 index 0000000..9143790 --- /dev/null +++ b/tests/test_file_importer.py @@ -0,0 +1,82 @@ +# Standard library imports +import unittest +from unittest.mock import patch + +# PyQt imports + +# Third party imports +import pytest +from pytestqt.plugin import QtBot # type: ignore + +# App imports +from config import Config +from app.models import ( + db, + Playlists, +) +from app import musicmuster + + +# Custom fixture to adapt qtbot for use with unittest.TestCase +@pytest.fixture(scope="class") +def qtbot_adapter(qapp, request): + """Adapt qtbot fixture for usefixtures and unittest.TestCase""" + request.cls.qtbot = QtBot(request) + + +# Wrapper to handle setup/teardown operations +def with_updown(function): + def test_wrapper(self, *args, **kwargs): + if callable(getattr(self, "up", None)): + self.up() + try: + function(self, *args, **kwargs) + finally: + if callable(getattr(self, "down", None)): + self.down() + + test_wrapper.__doc__ = function.__doc__ + return test_wrapper + + +# Apply the custom fixture to the test class +@pytest.mark.usefixtures("qtbot_adapter") +class MyTestCase(unittest.TestCase): + + def up(self): + db.create_all() + self.widget = musicmuster.Window() + + playlist_name = "file importer playlist" + + with db.Session() as session: + playlist = Playlists(session, playlist_name) + self.widget.create_playlist_tab(playlist) + with self.qtbot.waitExposed(self.widget): + self.widget.show() + + def down(self): + db.drop_all() + + @with_updown + @patch("file_importer.show_OK") + def test_import_no_files(self, mock_show_ok): + """Try importing with no files to import""" + + self.widget.import_files_wrapper() + mock_show_ok.assert_called_once_with( + "File import", + f"No files in {Config.REPLACE_FILES_DEFAULT_SOURCE} to import", + None, + ) + # @with_updown + # def test_import_no_files(self): + # """Try importing with no files to import""" + + # with patch("file_importer.show_OK") as mock_show_ok: + # self.widget.import_files_wrapper() + # mock_show_ok.assert_called_once_with( + # "File import", + # f"No files in {Config.REPLACE_FILES_DEFAULT_SOURCE} to import", + # None, + # ) diff --git a/tests/test_ui.py b/tests/test_ui.py index e230def..2d6f34b 100644 --- a/tests/test_ui.py +++ b/tests/test_ui.py @@ -3,19 +3,15 @@ import os import unittest # PyQt imports -from PyQt6.QtCore import Qt -from PyQt6.QtGui import QColor # Third party imports import pytest from pytestqt.plugin import QtBot # type: ignore # App imports -from config import Config from app import playlistmodel, utilities from app.models import ( db, - NoteColours, Playlists, Tracks, )