urma/app/models.py
2023-01-18 21:23:23 +00:00

253 lines
6.5 KiB
Python

#!/usr/bin/python3
import os.path
from dbconfig import Session, scoped_session
from typing import List, Optional
from sqlalchemy import (
Boolean,
Column,
DateTime,
ForeignKey,
func,
Integer,
select,
String,
)
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.orm import (
declarative_base,
relationship,
)
from sqlalchemy.orm.exc import (
NoResultFound
)
from config import Config
from log import log
Base = declarative_base()
# Database classes
class Accounts(Base):
__tablename__ = 'accounts'
id = Column(Integer, primary_key=True, autoincrement=True)
account_id = Column(String(32), index=True, nullable=False)
username = Column(String(256), index=True, default=None)
acct = Column(String(256), index=False, default=None)
display_name = Column(String(256), index=False, default=None)
bot = Column(Boolean, index=False, nullable=False, default=False)
url = Column(String(256), index=False)
followed = Column(Boolean, index=False, nullable=False, default=False)
def __repr__(self) -> str:
return (
f"<Accounts(id={self.id}, username={self.username}, "
f"acct={self.acct}, followed={self.followed}>"
)
def __init__(self, session: Session, account_id: str) -> None:
self.account_id = account_id
session.add(self)
session.flush()
@classmethod
def get_followed(cls, session: Session) -> List["Accounts"]:
"""
Return a list of account objects that we follow
"""
records = (
session.execute(
select(cls)
.where(cls.followed.is_(True))
).scalars().all()
)
return records
@classmethod
def get_or_create(cls, session: Session, account_id: str) -> "Accounts":
"""
Return any existing account with this id or create a new one
"""
try:
rec = (
session.execute(
select(cls)
.where(cls.account_id == account_id)
).scalar_one()
)
except NoResultFound:
rec = Accounts(session, account_id)
return rec
class Hashtags(Base):
__tablename__ = 'hashtags'
id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(String(256), index=True, nullable=False)
url = Column(String(256), index=False)
tags_to_posts = relationship("PostTags", back_populates="hashtag")
posts = association_proxy("tags_to_posts", "post")
followed = Column(Boolean, index=False, nullable=False, default=False)
def __repr__(self) -> str:
return (
f"<Hashtags(id={self.id}, name={self.name}, "
f"url={self.url}, followed={self.followed}>"
)
def __init__(self, session: Session, name: str, url: str) -> None:
self.name = name
self.url = url
session.add(self)
session.flush()
@classmethod
def get_all(cls, session: Session) -> List["Hashtags"]:
"""
Return a list of all hashtags
"""
records = (
session.execute(
select(cls)
).scalars().all()
)
return records
@classmethod
def get_followed(cls, session: Session) -> List["Hashtags"]:
"""
Return a list of hashtags objects that we follow
"""
records = (
session.execute(
select(cls)
.where(cls.followed.is_(True))
).scalars().all()
)
return records
@classmethod
def get_or_create(cls, session: Session,
name: str, url: str) -> "Hashtags":
"""
Return any existing hashtag with this name or create a new one
"""
try:
rec = (
session.execute(
select(cls)
.where(cls.name == name)
).scalar_one()
)
except NoResultFound:
rec = Hashtags(session, name, url)
return rec
class Posts(Base):
__tablename__ = 'posts'
id = Column(Integer, primary_key=True, autoincrement=True)
post_id = Column(String(32), index=True, nullable=False)
account_id = Column(Integer, ForeignKey('accounts.id'), nullable=True)
account = relationship("Accounts", foreign_keys=[account_id])
created_at = Column(DateTime, index=True, default=None)
uri = Column(String(256), index=False)
posts_to_tags = relationship("PostTags", back_populates="post")
hashtags = association_proxy("posts_to_tags", "hashtag")
favourited = Column(Boolean, index=True, nullable=False, default=False)
bookmarked = Column(Boolean, index=True, nullable=False, default=False)
def __repr__(self) -> str:
return f"<Posts(id={self.id}>"
def __init__(self, session: Session, post_id) -> None:
self.post_id = post_id
session.add(self)
session.flush()
@classmethod
def get_by_post_id(cls, session: Session, post_id: str) -> "Posts":
"""
Return post identified by post_id or None
"""
return (
session.scalars(
select(cls)
.where(cls.post_id == post_id)
.limit(1)
).first()
)
@classmethod
def get_or_create(cls, session: Session, post_id: str) -> "Posts":
"""
Return any existing post with this id or create a new one
"""
try:
rec = (
session.execute(
select(cls)
.where(cls.post_id == post_id)
).scalar_one()
)
except NoResultFound:
rec = Posts(session, post_id)
return rec
@staticmethod
def max_post_id(session):
"""
Return the maximum post_id
"""
return session.scalars(select(func.max(Posts.post_id))).first()
class PostTags(Base):
__tablename__ = 'post_tags'
id = Column(Integer, primary_key=True, autoincrement=True)
post_id = Column(Integer, ForeignKey('posts.id'), nullable=False,
index=True)
hashtag_id = Column(Integer, ForeignKey('hashtags.id'), nullable=False,
index=True)
post = relationship(Posts, back_populates="posts_to_tags")
hashtag = relationship("Hashtags")
def __init__(self, hashtag=None, post=None):
self.post = post
self.hashtag = hashtag