diff --git a/app/main.py b/app/main.py index e1b5cae..e5e3b06 100644 --- a/app/main.py +++ b/app/main.py @@ -1,5 +1,5 @@ -from fastapi import FastAPI, Request, Depends -from app.routers import admin, user, songs, session +from fastapi import FastAPI, Request, Depends, Cookie, Security +from app.routers import admin, songs, session from fastapi.staticfiles import StaticFiles from fastapi.responses import HTMLResponse from fastapi.templating import Jinja2Templates @@ -11,6 +11,8 @@ from app.schemas import Song import json import os import asyncio +from jose import JWTError, jwt +from app.security import get_current_user from starlette.middleware import Middleware @@ -23,7 +25,7 @@ if os.path.isfile("first_run") and (os.environ.get("RELOAD_ON_FIRST_RUN").lower( asyncio.run(admin.create_upload_file(include_non_singable=True, db=db)) os.remove("first_run") -#Base.metadata.create_all(engine) +# Base.metadata.create_all(engine) middleware = [ Middleware( @@ -38,7 +40,6 @@ middleware = [ app = FastAPI(middleware=middleware) app.include_router(admin.router) -app.include_router(user.router) app.include_router(songs.router) app.include_router(session.router) @@ -55,14 +56,18 @@ async def root(request: Request) -> HTMLResponse: @app.get("/vote") -async def vote(request: Request, session_id: str, unordered: bool = False, db: Session = Depends(get_db)) -> HTMLResponse: +async def vote(request: Request, session_id: str, unordered: bool = False, user = Security(get_current_user, scopes=[]), + db: Session = Depends(get_db)) -> HTMLResponse: + + print(user) + veto_mode = get_setting(db, "veto_mode") songs = [Song(**s.__dict__, vote=v, vote_comment=c) for s, v, c in get_songs_and_vote_for_session(db, session_id)] if unordered: - songs_by_category = {"Alle Lieder" : songs} + songs_by_category = {"Alle Lieder": songs} all_categories = {"Alle Lieder"} for song in songs: all_categories.update(song.categories.keys()) diff --git a/app/routers/admin.py b/app/routers/admin.py index 362c8df..f85c77b 100644 --- a/app/routers/admin.py +++ b/app/routers/admin.py @@ -7,7 +7,7 @@ from fastapi import APIRouter, Security, Depends from sqlalchemy.orm import Session from app.database import get_db, engine, Base -from app.routers.user import get_current_user +from app.security import get_current_user from app.crud import create_song, get_setting, set_setting router = APIRouter( diff --git a/app/routers/user.py b/app/routers/user.py deleted file mode 100644 index 5824a6d..0000000 --- a/app/routers/user.py +++ /dev/null @@ -1,169 +0,0 @@ -from datetime import datetime, timedelta, timezone -from typing import Annotated - -from fastapi import Depends, APIRouter, HTTPException, Security, status -from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm, SecurityScopes -from jose import JWTError, jwt -from passlib.context import CryptContext -from pydantic import BaseModel, ValidationError -import os - -#from app.secrets import SECRET_KEY, fake_users_db -# to get a string like this run: -# openssl rand -hex 32 - -ALGORITHM = "HS256" -ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 31 -SECRET_KEY = os.environ['SECRET_KEY'] - -fake_users_db = { - "admin": { - "username": "admin", - "email": "admin@example.com", - "hashed_password": os.environ["ADMIN_PWD"], - "disabled": False, - "scopes" : ["admin", "public"] - } -} - - - -class Token(BaseModel): - access_token: str - token_type: str - - -class TokenData(BaseModel): - username: str | None = None - scopes: list[str] = [] - - -class User(BaseModel): - username: str - email: str | None = None - #full_name: str | None = None - disabled: bool | None = None - - -class UserInDB(User): - hashed_password: str - scopes: list[str] = [] - - -pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") - -oauth2_scheme = OAuth2PasswordBearer( - tokenUrl="user/token", - scopes={ - "admin": "Perform admin actions.", - "public": "Perform public actions." - } -) - -router = APIRouter( - prefix="/user" -) - - -def verify_password(plain_password, hashed_password): - print(get_password_hash(plain_password)) - return pwd_context.verify(plain_password, hashed_password) - - -def get_password_hash(password): - return pwd_context.hash(password) - - -def get_user(db, username: str): - if username in db: - user_dict = db[username] - return UserInDB(**user_dict) - - - -def authenticate_user(fake_db, username: str, password: str): - user = get_user(fake_db, username) - if not user: - return False - if not verify_password(password, user.hashed_password): - return False - return user - - -def create_access_token(data: dict, expires_delta: timedelta | None = None): - to_encode = data.copy() - if expires_delta: - expire = datetime.now(timezone.utc) + expires_delta - else: - expire = datetime.now(timezone.utc) + timedelta(minutes=15) - to_encode.update({"exp": expire}) - encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) - return encoded_jwt - -async def get_current_user( - security_scopes: SecurityScopes, token: Annotated[str, Depends(oauth2_scheme)] -): - if security_scopes.scopes: - authenticate_value = f'Bearer scope="{security_scopes.scope_str}"' - else: - authenticate_value = "Bearer" - credentials_exception = HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Could not validate credentials", - headers={"WWW-Authenticate": authenticate_value}, - ) - try: - payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) - username: str = payload.get("sub") # type: ignore - if username is None: - raise credentials_exception - token_scopes = payload.get("scopes", []) - token_data = TokenData(scopes=token_scopes, username=username) - except (JWTError, ValidationError): - raise credentials_exception - user = get_user(fake_users_db, username=token_data.username or "") - if user is None: - raise credentials_exception - for scope in security_scopes.scopes: - if scope not in token_data.scopes: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Not enough permissions", - headers={"WWW-Authenticate": authenticate_value}, - ) - return user - - -async def get_current_active_user( - current_user: Annotated[User, Security(get_current_user, scopes=["me"])], -): - if current_user.disabled: - raise HTTPException(status_code=400, detail="Inactive user") - return current_user - - -@router.post("/token") -async def login_for_access_token( - form_data: Annotated[OAuth2PasswordRequestForm, Depends()], -) -> Token: - user = authenticate_user( - fake_users_db, form_data.username, form_data.password) - if not user: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Incorrect username or password", - headers={"WWW-Authenticate": "Bearer"}, - ) - access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) - access_token = create_access_token( - data={"sub": user.username, "scopes": user.scopes}, expires_delta=access_token_expires - ) - return Token(access_token=access_token, token_type="bearer") - -# @router.get("/public_token") -# async def get_public_access_token(secret_identity : str) -> Token: -# access_token_expires = timedelta(minutes=60*24*365) -# access_token = create_access_token( -# data={"sub": "public", "secret_identity" : secret_identity, "scopes": ["public"]}, expires_delta=access_token_expires -# ) -# return Token(access_token=access_token, token_type="bearer") \ No newline at end of file diff --git a/app/security.py b/app/security.py new file mode 100644 index 0000000..9586d99 --- /dev/null +++ b/app/security.py @@ -0,0 +1,44 @@ +from typing import Annotated + +from fastapi import HTTPException, Cookie, status +from fastapi.security import SecurityScopes +from jose import JWTError, jwt +from pydantic import ValidationError +import os + +#from app.secrets import SECRET_KEY, fake_users_db +# to get a string like this run: +# openssl rand -hex 32 + +ALGORITHM = "HS512" +SECRET_KEY = os.environ['SECRET_KEY'] + +fake_user_db = { + os.environ['ADMIN_EMAIL'] : { + "scopes" : ["admin"] + } +} + +credentials_exception = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Not enough permissions" + ) + +async def get_current_user( + security_scopes: SecurityScopes, access_token: Annotated[str, Cookie()] = "" +): + try: + payload = jwt.decode(access_token, SECRET_KEY, algorithms=[ALGORITHM]) + username: str = payload.get("sub") # type: ignore + if username is None: + raise credentials_exception + email: str = payload.get("email") # type: ignore + except (JWTError, ValidationError): + raise credentials_exception + user = fake_user_db.get(email) + if user is None: + raise credentials_exception + for scope in security_scopes.scopes: + if scope not in user["scopes"]: + raise credentials_exception + return user | {"token_payload" : payload} \ No newline at end of file