WIP Remove repository.py dependency on helpers.py

This commit is contained in:
Keith Edmunds 2025-04-05 16:32:31 +01:00
parent e71f1d072f
commit f3b1e05e83
4 changed files with 62 additions and 18 deletions

View File

@ -36,6 +36,7 @@ from classes import (
from config import Config from config import Config
from helpers import ( from helpers import (
file_is_unreadable, file_is_unreadable,
get_all_track_metadata,
get_audio_metadata, get_audio_metadata,
get_tags, get_tags,
normalise_track, normalise_track,
@ -389,7 +390,9 @@ class FileImporter:
Lookup in existing track in the local cache and return it Lookup in existing track in the local cache and return it
""" """
existing_track_records = [a for a in self.existing_tracks if a.track_id == track_id] existing_track_records = [
a for a in self.existing_tracks if a.track_id == track_id
]
if len(existing_track_records) != 1: if len(existing_track_records) != 1:
raise ApplicationError( raise ApplicationError(
f"Internal error in _get_existing_track: {existing_track_records=}" f"Internal error in _get_existing_track: {existing_track_records=}"
@ -633,10 +636,13 @@ class DoTrackImport(QThread):
normalise_track(self.destination_track_path) normalise_track(self.destination_track_path)
# Update databse # Update databse
metadata = get_all_track_metadata(self.destination_track_path)
if self.track_id == 0: if self.track_id == 0:
track_dto = repository.create_track(self.destination_track_path) track_dto = repository.create_track(self.destination_track_path, metadata)
else: else:
track_dto = repository.update_track(self.destination_track_path, self.track_id) track_dto = repository.update_track(
self.destination_track_path, self.track_id, metadata
)
self.signals.status_message_signal.emit( self.signals.status_message_signal.emit(
f"{os.path.basename(self.import_file_path)} imported", 10000 f"{os.path.basename(self.import_file_path)} imported", 10000

View File

@ -41,6 +41,7 @@ from config import Config
from helpers import ( from helpers import (
ask_yes_no, ask_yes_no,
file_is_unreadable, file_is_unreadable,
get_all_track_metadata,
get_embedded_time, get_embedded_time,
get_relative_date, get_relative_date,
ms_to_mmss, ms_to_mmss,
@ -1066,7 +1067,8 @@ class PlaylistModel(QAbstractTableModel):
""" """
track = self.playlist_rows[row_number] track = self.playlist_rows[row_number]
_ = repository.update_track(track.path, track.track_id) metadata = get_all_track_metadata(track.path)
_ = repository.update_track(track.path, track.track_id, metadata)
roles = [ roles = [
Qt.ItemDataRole.BackgroundRole, Qt.ItemDataRole.BackgroundRole,

View File

@ -18,7 +18,6 @@ from classes import ApplicationError, PlaylistRowDTO
# App imports # App imports
from classes import PlaylistDTO, TrackDTO from classes import PlaylistDTO, TrackDTO
from config import Config from config import Config
import helpers
from log import log, log_call from log import log, log_call
from models import ( from models import (
db, db,
@ -31,6 +30,34 @@ from models import (
) )
# Helper functions
def _remove_substring_case_insensitive(parent_string: str, substring: str) -> str:
"""
Remove all instances of substring from parent string, case insensitively
"""
# Convert both strings to lowercase for case-insensitive comparison
lower_parent = parent_string.lower()
lower_substring = substring.lower()
# Initialize the result string
result = parent_string
# Continue removing the substring until it's no longer found
while lower_substring in lower_parent:
# Find the index of the substring
index = lower_parent.find(lower_substring)
# Remove the substring
result = result[:index] + result[index + len(substring) :]
# Update the lowercase versions
lower_parent = result.lower()
return result
# Notecolour functions # Notecolour functions
def _get_colour_record(text: str) -> tuple[NoteColours | None, str]: def _get_colour_record(text: str) -> tuple[NoteColours | None, str]:
""" """
@ -57,7 +84,7 @@ def _get_colour_record(text: str) -> tuple[NoteColours | None, str]:
return (rec, return_text) return (rec, return_text)
else: else:
if rec.substring.lower() in text.lower(): if rec.substring.lower() in text.lower():
return_text = helpers.remove_substring_case_insensitive( return_text = _remove_substring_case_insensitive(
text, rec.substring text, rec.substring
) )
return (rec, return_text) return (rec, return_text)
@ -105,12 +132,11 @@ def add_track_to_header(playlistrow_id: int, track_id: int) -> None:
session.commit() session.commit()
def create_track(path: str) -> TrackDTO: def create_track(path: str, metadata: dict[str, str | int | float]) -> TrackDTO:
""" """
Create a track db entry from a track path and return the DTO Create a track db entry from a track path and return the DTO
""" """
metadata = helpers.get_all_track_metadata(path)
with db.Session() as session: with db.Session() as session:
try: try:
track = Tracks( track = Tracks(
@ -137,12 +163,13 @@ def create_track(path: str) -> TrackDTO:
return new_track return new_track
def update_track(path: str, track_id: int) -> TrackDTO: def update_track(
path: str, track_id: int, metadata: dict[str, str | int | float]
) -> TrackDTO:
""" """
Update an existing track db entry return the DTO Update an existing track db entry return the DTO
""" """
metadata = helpers.get_all_track_metadata(path)
with db.Session() as session: with db.Session() as session:
track = session.get(Tracks, track_id) track = session.get(Tracks, track_id)
if not track: if not track:

View File

@ -10,6 +10,7 @@ from app import playlistmodel
from app import repository from app import repository
from app.models import db from app.models import db
from classes import PlaylistDTO from classes import PlaylistDTO
from helpers import get_all_track_metadata
from playlistmodel import PlaylistModel from playlistmodel import PlaylistModel
@ -50,9 +51,11 @@ class MyTestCase(unittest.TestCase):
def create_playlist_model_tracks(self, playlist_name: str): def create_playlist_model_tracks(self, playlist_name: str):
(playlist, model) = self.create_playlist_and_model("my playlist") (playlist, model) = self.create_playlist_and_model("my playlist")
# Create tracks # Create tracks
self.track1 = repository.create_track(self.isa_path) metadata1 = get_all_track_metadata(self.isa_path)
self.track1 = repository.create_track(self.isa_path, metadata1)
self.track2 = repository.create_track(self.mom_path) metadata2 = get_all_track_metadata(self.mom_path)
self.track2 = repository.create_track(self.mom_path, metadata2)
# Add tracks and header to playlist # Add tracks and header to playlist
self.row0 = repository.insert_row( self.row0 = repository.insert_row(
@ -99,19 +102,23 @@ class MyTestCase(unittest.TestCase):
assert result.track_id == self.track2.track_id assert result.track_id == self.track2.track_id
def test_create_track(self): def test_create_track(self):
repository.create_track(self.isa_path) metadata = get_all_track_metadata(self.isa_path)
repository.create_track(self.isa_path, metadata)
results = repository.get_all_tracks() results = repository.get_all_tracks()
assert len(results) == 1 assert len(results) == 1
assert results[0].path == self.isa_path assert results[0].path == self.isa_path
def test_get_track_by_id(self): def test_get_track_by_id(self):
dto = repository.create_track(self.isa_path) metadata = get_all_track_metadata(self.isa_path)
dto = repository.create_track(self.isa_path, metadata)
result = repository.track_by_id(dto.track_id) result = repository.track_by_id(dto.track_id)
assert result.path == self.isa_path assert result.path == self.isa_path
def test_get_track_by_artist(self): def test_get_track_by_artist(self):
_ = repository.create_track(self.isa_path) metadata = get_all_track_metadata(self.isa_path)
_ = repository.create_track(self.mom_path) _ = repository.create_track(self.isa_path, metadata)
metadata = get_all_track_metadata(self.mom_path)
_ = repository.create_track(self.mom_path, metadata)
result_isa = repository.tracks_like_artist(self.isa_artist) result_isa = repository.tracks_like_artist(self.isa_artist)
assert len(result_isa) == 1 assert len(result_isa) == 1
assert result_isa[0].artist == self.isa_artist assert result_isa[0].artist == self.isa_artist
@ -120,8 +127,10 @@ class MyTestCase(unittest.TestCase):
assert result_mom[0].artist == self.mom_artist assert result_mom[0].artist == self.mom_artist
def test_get_track_by_title(self): def test_get_track_by_title(self):
_ = repository.create_track(self.isa_path) metadata = get_all_track_metadata(self.isa_path)
_ = repository.create_track(self.mom_path) _ = repository.create_track(self.isa_path, metadata)
metadata = get_all_track_metadata(self.mom_path)
_ = repository.create_track(self.mom_path, metadata)
result_isa = repository.tracks_like_title(self.isa_title) result_isa = repository.tracks_like_title(self.isa_title)
assert len(result_isa) == 1 assert len(result_isa) == 1
assert result_isa[0].title == self.isa_title assert result_isa[0].title == self.isa_title