#!/usr/bin/python3 import os.path from dbconfig import Session, scoped_session from typing import List, Optional from sqlalchemy import ( Boolean, Column, DateTime, ForeignKey, func, Integer, select, String, ) from sqlalchemy.ext.associationproxy import association_proxy from sqlalchemy.orm import ( declarative_base, relationship, ) from sqlalchemy.orm.exc import ( NoResultFound ) from config import Config from log import log Base = declarative_base() # Database classes class Accounts(Base): __tablename__ = 'accounts' id = Column(Integer, primary_key=True, autoincrement=True) account_id = Column(String(32), index=True, nullable=False) username = Column(String(256), index=True, default=None) acct = Column(String(256), index=False, default=None) display_name = Column(String(256), index=False, default=None) bot = Column(Boolean, index=False, nullable=False, default=False) url = Column(String(256), index=False) followed = Column(Boolean, index=False, nullable=False, default=False) def __repr__(self) -> str: return ( f"" ) def __init__(self, session: Session, account_id: str) -> None: self.account_id = account_id session.add(self) session.flush() @classmethod def get_followed(cls, session: Session) -> List["Accounts"]: """ Return a list of account objects that we follow """ records = ( session.execute( select(cls) .where(cls.followed.is_(True)) ).scalars().all() ) return records @classmethod def get_or_create(cls, session: Session, account_id: str) -> "Accounts": """ Return any existing account with this id or create a new one """ try: rec = ( session.execute( select(cls) .where(cls.account_id == account_id) ).scalar_one() ) except NoResultFound: rec = Accounts(session, account_id) return rec class Hashtags(Base): __tablename__ = 'hashtags' id = Column(Integer, primary_key=True, autoincrement=True) name = Column(String(256), index=True, nullable=False) url = Column(String(256), index=False) tags_to_posts = relationship("PostTags", back_populates="hashtag") posts = association_proxy("tags_to_posts", "post") followed = Column(Boolean, index=False, nullable=False, default=False) def __repr__(self) -> str: return ( f"" ) def __init__(self, session: Session, name: str, url: str) -> None: self.name = name self.url = url session.add(self) session.flush() @classmethod def get_all(cls, session: Session) -> List["Hashtags"]: """ Return a list of all hashtags """ records = ( session.execute( select(cls) ).scalars().all() ) return records @classmethod def get_followed(cls, session: Session) -> List["Hashtags"]: """ Return a list of hashtags objects that we follow """ records = ( session.execute( select(cls) .where(cls.followed.is_(True)) ).scalars().all() ) return records @classmethod def get_or_create(cls, session: Session, name: str, url: str) -> "Hashtags": """ Return any existing hashtag with this name or create a new one """ try: rec = ( session.execute( select(cls) .where(cls.name == name) ).scalar_one() ) except NoResultFound: rec = Hashtags(session, name, url) return rec class Posts(Base): __tablename__ = 'posts' id = Column(Integer, primary_key=True, autoincrement=True) post_id = Column(String(32), index=True, nullable=False) account_id = Column(Integer, ForeignKey('accounts.id'), nullable=True) account = relationship("Accounts", foreign_keys=[account_id]) created_at = Column(DateTime, index=True, default=None) uri = Column(String(256), index=False) posts_to_tags = relationship("PostTags", back_populates="post") hashtags = association_proxy("posts_to_tags", "hashtag") favourited = Column(Boolean, index=True, nullable=False, default=False) bookmarked = Column(Boolean, index=True, nullable=False, default=False) def __repr__(self) -> str: return f"" def __init__(self, session: Session, post_id) -> None: self.post_id = post_id session.add(self) session.flush() @classmethod def get_by_post_id(cls, session: Session, post_id: str) -> "Posts": """ Return post identified by post_id or None """ return ( session.scalars( select(cls) .where(cls.post_id == post_id) .limit(1) ).first() ) @classmethod def get_or_create(cls, session: Session, post_id: str) -> "Posts": """ Return any existing post with this id or create a new one """ try: rec = ( session.execute( select(cls) .where(cls.post_id == post_id) ).scalar_one() ) except NoResultFound: rec = Posts(session, post_id) return rec @staticmethod def max_post_id(session): """ Return the maximum post_id """ return session.scalars(select(func.max(Posts.post_id))).first() class PostTags(Base): __tablename__ = 'post_tags' id = Column(Integer, primary_key=True, autoincrement=True) post_id = Column(Integer, ForeignKey('posts.id'), nullable=False, index=True) hashtag_id = Column(Integer, ForeignKey('hashtags.id'), nullable=False, index=True) post = relationship(Posts, back_populates="posts_to_tags") hashtag = relationship("Hashtags") def __init__(self, hashtag=None, post=None): self.post = post self.hashtag = hashtag