urma/app/models.py

306 lines
8.3 KiB
Python

#!/usr/bin/python3
import os.path
from dbconfig import Session, scoped_session
from typing import List, Optional
from sqlalchemy import (
Boolean,
Column,
DateTime,
ForeignKey,
Integer,
select,
String,
)
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.orm import (
declarative_base,
relationship,
)
from sqlalchemy.orm.exc import (
NoResultFound
)
from config import Config
from log import log
Base = declarative_base()
# Database classes
class Accounts(Base):
__tablename__ = 'accounts'
id = Column(Integer, primary_key=True, autoincrement=True)
account_id = Column(String(32), index=True, nullable=False)
username = Column(String(256), index=True, default=None)
acct = Column(String(256), index=False, default=None)
display_name = Column(String(256), index=False, default=None)
bot = Column(Boolean, index=False, nullable=False, default=False)
url = Column(String(256), index=False)
followed = Column(Boolean, index=False, nullable=False, default=False)
posts = relationship("Posts", back_populates="account")
def __repr__(self) -> str:
return (
f"<Accounts(id={self.id}, username={self.username}, "
f"acct={self.acct}, followed={self.followed}>"
)
def __init__(self, session: Session, account_id: str) -> None:
self.account_id = account_id
session.add(self)
session.commit()
@classmethod
def get_followed(cls, session: Session) -> List["Accounts"]:
"""
Return a list of account objects that we follow
"""
records = (
session.execute(
select(cls)
.where(cls.followed.is_(True))
).scalars().all()
)
return records
@classmethod
def get_or_create(cls, session: Session, account_id: str) -> "Accounts":
"""
Return any existing account with this id or create a new one
"""
try:
rec = (
session.execute(
select(cls)
.where(cls.account_id == account_id)
).scalar_one()
)
except NoResultFound:
rec = Accounts(session, account_id)
return rec
class Attachments(Base):
__tablename__ = 'attachments'
id = Column(Integer, primary_key=True, autoincrement=True)
media_id = Column(String(32), index=True, nullable=False)
url = Column(String(256), index=False)
preview_url = Column(String(256), index=False)
description = Column(String(2048), index=False)
post_id = Column(Integer, ForeignKey("posts.id"))
type = Column(String(256), index=False)
def __repr__(self) -> str:
return (
f"<Attachments(id={self.id}, url={self.url}, "
f"description={self.description}>"
)
def __init__(self, session: Session, media_id: str, post_id: int) -> None:
self.media_id = media_id
self.post_id = post_id
session.add(self)
session.commit()
@classmethod
def get_or_create(cls, session: Session, media_id: str,
post_id: int) -> "Attachments":
"""
Return any existing Attachment with this id or create a new one
"""
try:
rec = (
session.execute(
select(cls)
.where(
cls.media_id == media_id,
cls.post_id == post_id
)
).scalar_one()
)
except NoResultFound:
rec = Attachments(session, media_id, post_id)
return rec
class Hashtags(Base):
__tablename__ = 'hashtags'
id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(String(256), index=True, nullable=False)
url = Column(String(256), index=False)
tags_to_posts = relationship("PostTags", back_populates="hashtag")
posts = association_proxy("tags_to_posts", "post")
followed = Column(Boolean, index=False, nullable=False, default=False)
def __repr__(self) -> str:
return (
f"<Hashtags(id={self.id}, name={self.name}, "
f"url={self.url}, followed={self.followed}>"
)
def __init__(self, session: Session, name: str, url: str) -> None:
self.name = name
self.url = url
session.add(self)
session.commit()
@classmethod
def get_followed(cls, session: Session) -> List["Hashtags"]:
"""
Return a list of hashtags objects that we follow
"""
records = (
session.execute(
select(cls)
.where(cls.followed.is_(True))
).scalars().all()
)
return records
@classmethod
def get_or_create(cls, session: Session,
name: str, url: str) -> "Hashtags":
"""
Return any existing hashtag with this name or create a new one
"""
try:
rec = (
session.execute(
select(cls)
.where(cls.name == name)
).scalar_one()
)
except NoResultFound:
rec = Hashtags(session, name, url)
return rec
class Posts(Base):
__tablename__ = 'posts'
id = Column(Integer, primary_key=True, autoincrement=True)
post_id = Column(String(32), index=True, nullable=False)
created_at = Column(DateTime, index=True, default=None)
uri = Column(String(256), index=False)
url = Column(String(256), index=False)
content = Column(String(2048), index=False, default="")
account_id = Column(Integer, ForeignKey('accounts.id'), nullable=True)
account = relationship("Accounts", back_populates="posts")
reblogged_by_post = relationship("Posts")
boosted_post_id = Column(Integer, ForeignKey("posts.id"))
media_attachments = relationship("Attachments")
posts_to_tags = relationship("PostTags", back_populates="post")
hashtags = association_proxy("posts_to_tags", "hashtag")
rating = Column(Integer, index=True, default=None)
def __repr__(self) -> str:
return f"<Posts(id={self.id}, content={self.content[:60]}>"
def __init__(self, session: Session, post_id) -> None:
self.post_id = post_id
session.add(self)
session.commit()
@classmethod
def get_unrated_before(cls, session: Session,
post_id: int) -> Optional["Posts"]:
"""
Return latest unrated Posts object before past post_id, or None
if there isn't one.
"""
return (
session.scalars(
select(cls)
.where(
(cls.rating.is_(None)),
(cls.post_id < post_id)
)
.order_by(cls.post_id.desc())
.limit(1)
).first()
)
@classmethod
def get_unrated_newest(cls, session: Session) -> Optional["Posts"]:
"""
Return most recent Posts object that has not been rated and which
is not a boosted post, or None if there isn't one.
"""
print("get_unrated_newest")
return (
session.scalars(
select(cls)
.where(cls.rating.is_(None))
.order_by(cls.post_id.desc())
.limit(1)
).first()
)
@classmethod
def get_or_create(cls, session: Session, post_id: str) -> "Posts":
"""
Return any existing post with this id or create a new one
"""
try:
rec = (
session.execute(
select(cls)
.where(cls.post_id == post_id)
).scalar_one()
)
except NoResultFound:
rec = Posts(session, post_id)
return rec
class PostTags(Base):
__tablename__ = 'post_tags'
id = Column(Integer, primary_key=True, autoincrement=True)
post_id = Column(Integer, ForeignKey('posts.id'), nullable=False,
index=True)
hashtag_id = Column(Integer, ForeignKey('hashtags.id'), nullable=False,
index=True)
post = relationship(Posts, back_populates="posts_to_tags")
hashtag = relationship("Hashtags")
def __init__(self, hashtag=None, post=None):
self.post = post
self.hashtag = hashtag