diff --git a/app/crud.py b/app/crud.py index d22dfaf..71151b7 100644 --- a/app/crud.py +++ b/app/crud.py @@ -1,6 +1,8 @@ import app.models as models -from sqlalchemy import func +from sqlalchemy import func, and_ from sqlalchemy.orm.attributes import flag_modified +from starlette_context import context +from starlette_context.header_keys import HeaderKeys def get_songs_and_vote_for_session(db, session_name) -> list[models.Song]: @@ -80,6 +82,7 @@ def create_or_update_vote(db, song_id, session_name, vote): db.add(vote_entry) db.commit() + def create_or_update_comment(db, song_id, session_name, comment): session_entry = activate_session(db, session_name) @@ -96,14 +99,20 @@ def create_or_update_comment(db, song_id, session_name, comment): db.add(vote_entry) db.commit() + def activate_session(db, session_name): - session_entry = db.query(models.Session).filter( - (models.Session.session_name == session_name)).first() + ip = context.data[HeaderKeys.forwarded_for] + + session_entry = db.query(models.Session).filter(and_( + models.Session.session_name == session_name, models.Session.ip == ip)).first() if session_entry: + print(session_entry.__dict__) session_entry.active = True else: - session_entry = models.Session(session_name=session_name, active=True) + session_entry = models.Session( + session_name=session_name, active=True, ip=ip) db.add(session_entry) + flag_modified(session_entry, "active") db.commit() @@ -111,28 +120,35 @@ def activate_session(db, session_name): def deactivate_session(db, session_name): - session_entry = db.query(models.Session).filter( - (models.Session.session_name == session_name)).first() + ip = context.data[HeaderKeys.forwarded_for] + + session_entry = db.query(models.Session).filter(and_( + models.Session.session_name == session_name, models.Session.ip == ip)).first() if session_entry: session_entry.active = False + flag_modified(session_entry, "active") + db.commit() else: - session_entry = models.Session(session_name=session_name, active=False) - db.add(session_entry) - db.commit() + pass + # session_entry = models.Session(session_name=session_name, ip=ip, active=False) + # db.add(session_entry) def get_setting(db, key): - entry = db.query(models.Config.value).filter(models.Config.key == key).first() + entry = db.query(models.Config.value).filter( + models.Config.key == key).first() if entry: return entry[0] else: return None + def set_setting(db, key, value): - setting_entry = db.query(models.Config).filter(models.Config.key == key).first() + setting_entry = db.query(models.Config).filter( + models.Config.key == key).first() if setting_entry: setting_entry.value = value else: setting_entry = models.Config(key=key, value=value) db.add(setting_entry) - db.commit() \ No newline at end of file + db.commit() diff --git a/app/database.py b/app/database.py index f3a63c6..67dfb4c 100644 --- a/app/database.py +++ b/app/database.py @@ -22,5 +22,6 @@ async def get_db(): class Base(DeclarativeBase): type_annotation_map = { dict[str, bool]: PickleType, - object: PickleType + object: PickleType, + set: PickleType } \ No newline at end of file diff --git a/app/main.py b/app/main.py index 6ff1bfc..bdc4891 100644 --- a/app/main.py +++ b/app/main.py @@ -10,9 +10,23 @@ from typing import Annotated from app.schemas import Song import json +from starlette.middleware import Middleware + +from starlette_context import context, plugins +from starlette_context.middleware import RawContextMiddleware + Base.metadata.create_all(engine) -app = FastAPI() +middleware = [ + Middleware( + RawContextMiddleware, + plugins=( + plugins.ForwardedForPlugin(), + ) + ) +] + +app = FastAPI(middleware=middleware) app.include_router(admin.router) app.include_router(user.router) diff --git a/app/models.py b/app/models.py index f0534fa..9c1a76d 100644 --- a/app/models.py +++ b/app/models.py @@ -30,10 +30,11 @@ class Song(Base): class Session(Base): __tablename__ = 'sessions' id: Mapped[int] = mapped_column(primary_key=True) - session_name: Mapped[int] + session_name: Mapped[str] active: Mapped[bool] - time_created: Mapped[datetime] = mapped_column(server_default=func.now()) - time_updated: Mapped[Optional[datetime] + ip: Mapped[str] + first_seen: Mapped[datetime] = mapped_column(server_default=func.now()) + last_seen: Mapped[Optional[datetime] ] = mapped_column(onupdate=func.now()) diff --git a/app/routers/session.py b/app/routers/session.py index 0703d73..020bf2b 100644 --- a/app/routers/session.py +++ b/app/routers/session.py @@ -1,5 +1,5 @@ from typing import Annotated -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, Request from sqlalchemy.orm import Session import app.models as models diff --git a/app/routers/songs.py b/app/routers/songs.py index 524f8d7..6991114 100644 --- a/app/routers/songs.py +++ b/app/routers/songs.py @@ -1,5 +1,5 @@ from typing import Annotated -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, Request from sqlalchemy.orm import Session import app.models as models diff --git a/requirements.txt b/requirements.txt index 9c7501d..abe054c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,4 +8,5 @@ passlib[bcrypt] python-multipart jinja2 openpyxl -requests \ No newline at end of file +requests +starlette-context \ No newline at end of file