Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 22 additions & 16 deletions auth_backend/auth_plugins/email.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import hashlib
import logging

import pydantic
from fastapi import Depends, Header, HTTPException, Request
from fastapi.background import BackgroundTasks
from fastapi_sqlalchemy import db
from pydantic import constr, validator
from pydantic import constr, field_validator
from sqlalchemy import func

from auth_backend.base import Base, StatusResponseModel
Expand Down Expand Up @@ -67,34 +68,34 @@ class EmailLogin(Base):
password: constr(min_length=1)
scopes: list[Scope] | None = None
session_name: str | None = None
email_validator = validator("email", allow_reuse=True)(check_email)
email_validator = field_validator("email")(check_email)


class EmailRegister(Base):
email: constr(min_length=1)
password: constr(min_length=1)
email_validator = validator("email", allow_reuse=True)(check_email)
email_validator = field_validator("email")(check_email)


class EmailChange(Base):
email: constr(min_length=1)

email_validator = validator("email", allow_reuse=True)(check_email)
email_validator = field_validator("email")(check_email)


class RequestResetPassword(Base):
email: constr(min_length=1)
password: str | None = None
new_password: str | None = None

email_validator = validator("email", allow_reuse=True)(check_email)
email_validator = field_validator("email")(check_email)


class ResetPassword(Base):
email: constr(min_length=1)
new_password: constr(min_length=1)

email_validator = validator("email", allow_reuse=True)(check_email)
email_validator = field_validator("email")(check_email)


class EmailParams(MethodMeta):
Expand Down Expand Up @@ -270,7 +271,7 @@ async def _approve_email(token: str) -> StatusResponseModel:
)
if not auth_method:
raise HTTPException(
status_code=403, detail=StatusResponseModel(status="Error", message="Incorrect link").dict()
status_code=403, detail=StatusResponseModel(status="Error", message="Incorrect link").model_dump()
)
auth_method.user.auth_methods.email.confirmed.value = "true"
db.session.commit()
Expand All @@ -293,7 +294,7 @@ async def _request_reset_email(
)
if user_session.user.auth_methods.email.email.value == scheme.email:
raise HTTPException(
status_code=401, detail=StatusResponseModel(status="Error", message="Email incorrect").dict()
status_code=401, detail=StatusResponseModel(status="Error", message="Email incorrect").model_dump()
)
token = random_string()
await user_session.user.auth_methods.email.bulk_create(
Expand Down Expand Up @@ -324,7 +325,7 @@ async def _reset_email(token: str) -> StatusResponseModel:
if not auth:
raise HTTPException(
status_code=403,
detail=StatusResponseModel(status="Error", message="Incorrect confirmation token").dict(),
detail=StatusResponseModel(status="Error", message="Incorrect confirmation token").model_dump(),
)
user: User = auth.user
if user.auth_methods.email.confirmed.value == "false":
Expand All @@ -351,7 +352,9 @@ async def _request_reset_password(
if not user_session.user.auth_methods.email:
raise HTTPException(
status_code=401,
detail=StatusResponseModel(status="Error", message="Auth method restricted for this user").dict(),
detail=StatusResponseModel(
status="Error", message="Auth method restricted for this user"
).model_dump(),
)
if not Email._validate_password(
schema.password,
Expand All @@ -370,7 +373,8 @@ async def _request_reset_password(
)
if auth_method_email.user_id != user_session.user_id:
raise HTTPException(
status_code=403, detail=StatusResponseModel(status="Error", message="Incorrect user session").dict()
status_code=403,
detail=StatusResponseModel(status="Error", message="Incorrect user session").model_dump(),
)
user_session.user.auth_methods.email.hashed_password.value = Email._hash_password(schema.new_password, salt)
user_session.user.auth_methods.email.salt.value = salt
Expand All @@ -396,12 +400,14 @@ async def _request_reset_password(
)
if not auth_method_email:
raise HTTPException(
status_code=404, detail=StatusResponseModel(status="Error", message="Email not found").dict()
status_code=404, detail=StatusResponseModel(status="Error", message="Email not found").model_dump()
)
if not auth_method_email.user.auth_methods.email:
raise HTTPException(
status_code=401,
detail=StatusResponseModel(status="Error", message="Auth method restricted for this user").dict(),
detail=StatusResponseModel(
status="Error", message="Auth method restricted for this user"
).model_dump(),
)
if auth_method_email.user.auth_methods.email.confirmed.value.lower() == "false":
raise AuthFailed(
Expand All @@ -420,10 +426,10 @@ async def _request_reset_password(
return StatusResponseModel(status="Success", message="Reset link has been successfully mailed")
elif not user_session and schema.password and schema.new_password:
raise HTTPException(
status_code=403, detail=StatusResponseModel(status="Error", message="Missing session").dict()
status_code=403, detail=StatusResponseModel(status="Error", message="Missing session").model_dump()
)
raise HTTPException(
status_code=422, detail=StatusResponseModel(status="Error", message="Unprocessable entity").dict()
status_code=422, detail=StatusResponseModel(status="Error", message="Unprocessable entity").model_dump()
)

@staticmethod
Expand All @@ -445,7 +451,7 @@ async def _reset_password(schema: ResetPassword, reset_token: str = Header(min_l
):
raise HTTPException(
status_code=403,
detail=StatusResponseModel(status="Error", message="Incorrect reset token").dict(),
detail=StatusResponseModel(status="Error", message="Incorrect reset token").model_dump(),
)
salt = random_string()
auth_method.user.auth_methods.email.hashed_password.value = Email._hash_password(schema.new_password, salt)
Expand Down
4 changes: 2 additions & 2 deletions auth_backend/auth_plugins/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ async def _register(
if not user_inp.id_token:
flow = await cls._default_flow()
try:
credentials = flow.fetch_token(**user_inp.dict(exclude_unset=True))
credentials = flow.fetch_token(**user_inp.model_dump(exclude_unset=True))
except oauthlib.oauth2.rfc6749.errors.InvalidGrantError as exc:
raise OauthCredentialsIncorrect(f'Google account response invalid: {exc}')
id_token = credentials.get("id_token")
Expand Down Expand Up @@ -122,7 +122,7 @@ async def _login(cls, user_inp: OauthResponseSchema):
"""
flow = await cls._default_flow()
try:
credentials = flow.fetch_token(**user_inp.dict(exclude_unset=True))
credentials = flow.fetch_token(**user_inp.model_dump(exclude_unset=True))
except oauthlib.oauth2.rfc6749.errors.OAuth2Error as exc:
raise OauthCredentialsIncorrect(f'Google account response invalid: {exc}')
try:
Expand Down
4 changes: 3 additions & 1 deletion auth_backend/auth_plugins/telegram.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,9 @@ async def _register(
user = user_session.user
await cls._register_auth_method('user_id', telegram_user_id, user, db_session=db.session)

return await cls._create_session(user, user_inp.scopes, db_session=db.session, session_name=session_name)
return await cls._create_session(
user, user_inp.scopes, db_session=db.session, session_name=user_inp.session_name
)

@classmethod
async def _login(cls, user_inp: OauthResponseSchema) -> Session:
Expand Down
2 changes: 1 addition & 1 deletion auth_backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
class Base(BaseModel):
def __repr__(self) -> str:
attrs = []
for k, v in self.__class__.schema().items():
for k, v in self.__class__.model_json_schema().items():
attrs.append(f"{k}={v}")
return "{}({})".format(self.__class__.__name__, ', '.join(attrs))

Expand Down
22 changes: 12 additions & 10 deletions auth_backend/routes/exc_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,22 @@

@app.exception_handler(ObjectNotFound)
async def not_found_handler(req: starlette.requests.Request, exc: ObjectNotFound):
return JSONResponse(content=StatusResponseModel(status="Error", message=f"{exc}").dict(), status_code=404)
return JSONResponse(content=StatusResponseModel(status="Error", message=f"{exc}").model_dump(), status_code=404)


@app.exception_handler(IncorrectUserAuthType)
async def incorrect_auth_type_handler(req: starlette.requests.Request, exc: IncorrectUserAuthType):
return JSONResponse(content=StatusResponseModel(status="Error", message=f"{exc}").dict(), status_code=403)
return JSONResponse(content=StatusResponseModel(status="Error", message=f"{exc}").model_dump(), status_code=403)


@app.exception_handler(AlreadyExists)
async def already_exists_handler(req: starlette.requests.Request, exc: AlreadyExists):
return JSONResponse(content=StatusResponseModel(status="Error", message=f"{exc}").dict(), status_code=409)
return JSONResponse(content=StatusResponseModel(status="Error", message=f"{exc}").model_dump(), status_code=409)


@app.exception_handler(AuthFailed)
async def auth_failed_handler(req: starlette.requests.Request, exc: AuthFailed):
return JSONResponse(content=StatusResponseModel(status="Error", message=f"{exc}").dict(), status_code=401)
return JSONResponse(content=StatusResponseModel(status="Error", message=f"{exc}").model_dump(), status_code=401)


class OauthAuthFailedStatusResponseModel(StatusResponseModel):
Expand All @@ -48,25 +48,25 @@ async def oauth_failed_handler(req: starlette.requests.Request, exc: OauthAuthFa
status="Error",
message=f"{exc}",
id_token=exc.id_token,
).dict(exclude_none=True),
).model_dump(exclude_none=True),
status_code=exc.status_code,
)


@app.exception_handler(OauthCredentialsIncorrect)
async def oauth_creds_failed_handler(req: starlette.requests.Request, exc: OauthCredentialsIncorrect):
return JSONResponse(content=StatusResponseModel(status="Error", message=f"{exc}").dict(), status_code=406)
return JSONResponse(content=StatusResponseModel(status="Error", message=f"{exc}").model_dump(), status_code=406)


@app.exception_handler(SessionExpired)
async def session_expired_handler(req: starlette.requests.Request, exc: SessionExpired):
return JSONResponse(content=StatusResponseModel(status="Error", message=f"{exc}").dict(), status_code=403)
return JSONResponse(content=StatusResponseModel(status="Error", message=f"{exc}").model_dump(), status_code=403)


@app.exception_handler(Exception)
async def http_error_handler(req: starlette.requests.Request, exc: Exception):
return JSONResponse(
content=StatusResponseModel(status="Error", message="Internal server error").dict(), status_code=500
content=StatusResponseModel(status="Error", message="Internal server error").model_dump(), status_code=500
)


Expand All @@ -75,14 +75,16 @@ async def too_many_requests_handler(req: starlette.requests.Request, exc: TooMan
return JSONResponse(
content=StatusResponseModel(
status="Error", message=f"Too many requests. Delay time: {int(exc.delay_time.total_seconds())} seconds."
).dict(),
).model_dump(),
status_code=429,
)


@app.exception_handler(LastAuthMethodDelete)
async def last_auth_method_delete_handler(req: starlette.requests.Request, exc: LastAuthMethodDelete):
return JSONResponse(
content=StatusResponseModel(status="Error", message=f"Unable to remove last authentication method").dict(),
content=StatusResponseModel(
status="Error", message=f"Unable to remove last authentication method"
).model_dump(),
status_code=403,
)
20 changes: 10 additions & 10 deletions auth_backend/routes/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ async def get_group(
"""
group = DbGroup.get(id, session=db.session)
result = {}
result = result | Group.from_orm(group).dict()
result = result | Group.model_validate(group).model_dump()
if "child" in info:
result["child"] = group.child
if "scopes" in info:
Expand All @@ -34,7 +34,7 @@ async def get_group(
result["indirect_scopes"] = group.indirect_scopes
if "users" in info:
result["users"] = [user.id for user in group.users]
return GroupGet(**result).dict(exclude_unset=True)
return GroupGet(**result).model_dump(exclude_unset=True)


@groups.post("", response_model=Group)
Expand All @@ -49,7 +49,7 @@ async def create_group(
raise ObjectNotFound(Group, group_inp.parent_id)
if DbGroup.query(session=db.session).filter(DbGroup.name == group_inp.name).one_or_none():
raise HTTPException(
status_code=409, detail=StatusResponseModel(status="Error", message="Name already exists").dict()
status_code=409, detail=StatusResponseModel(status="Error", message="Name already exists").model_dump()
)
scopes = set()
if group_inp.scopes:
Expand All @@ -62,7 +62,7 @@ async def create_group(
for scope in scopes:
GroupScope.create(session=db.session, group_id=group.id, scope_id=scope.id)
db.session.commit()
return Group(**result).dict(exclude_unset=True)
return Group(**result).model_dump(exclude_unset=True)


@groups.patch("/{id}", response_model=Group)
Expand All @@ -83,19 +83,19 @@ async def patch_group(
group = DbGroup.get(id, session=db.session)
if group_inp.parent_id in (row.id for row in group.child):
raise HTTPException(
status_code=400, detail=StatusResponseModel(status="Error", message="Cycle detected").dict()
status_code=400, detail=StatusResponseModel(status="Error", message="Cycle detected").model_dump()
)
result = Group.from_orm(
DbGroup.update(id, session=db.session, **group_inp.dict(exclude_unset=True, exclude={"scopes"}))
).dict(exclude_unset=True)
result = Group.model_validate(
DbGroup.update(id, session=db.session, **group_inp.model_dump(exclude_unset=True, exclude={"scopes"}))
).model_dump(exclude_unset=True)
scopes = set()
if group_inp.scopes:
for _scope_id in group_inp.scopes:
scopes.add(Scope.get(session=db.session, id=_scope_id))
if scopes:
group.scopes = scopes
db.session.commit()
return Group.from_orm(group)
return Group.model_validate(group)


@groups.delete("/{id}", response_model=None)
Expand Down Expand Up @@ -137,4 +137,4 @@ async def get_groups(
if "users" in info:
add["users"] = [user.id for user in group.users]
result["items"].append(add)
return GroupsGet(**result).dict(exclude_unset=True)
return GroupsGet(**result).model_dump(exclude_unset=True)
17 changes: 11 additions & 6 deletions auth_backend/routes/scopes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from fastapi import APIRouter, Depends, HTTPException
from fastapi_sqlalchemy import db
from pydantic import parse_obj_as
from pydantic import TypeAdapter
from sqlalchemy import func

from auth_backend.base import StatusResponseModel
Expand All @@ -22,10 +22,12 @@ async def create_scope(
"""
if Scope.query(session=db.session).filter(func.lower(Scope.name) == scope.name.lower()).all():
raise HTTPException(
status_code=409, detail=StatusResponseModel(status="Error", message="Already exists").dict()
status_code=409, detail=StatusResponseModel(status="Error", message="Already exists").model_dump()
)
scope.name = scope.name.lower()
return ScopeGet.from_orm(Scope.create(**scope.dict(), creator_id=user_session.user_id, session=db.session))
return ScopeGet.model_validate(
Scope.create(**scope.model_dump(), creator_id=user_session.user_id, session=db.session)
)


@scopes.get("/{id}", response_model=ScopeGet)
Expand All @@ -35,7 +37,7 @@ async def get_scope(
"""
Scopes: `["auth.scope.read"]`
"""
return ScopeGet.from_orm(Scope.get(id, session=db.session))
return ScopeGet.model_validate(Scope.get(id, session=db.session))


@scopes.get("", response_model=list[ScopeGet])
Expand All @@ -45,7 +47,8 @@ async def get_scopes(
"""
Scopes: `["auth.scope.read"]`
"""
return parse_obj_as(list[ScopeGet], Scope.query(session=db.session).all())
adapter = TypeAdapter(list[ScopeGet])
return adapter.validate_python(Scope.query(session=db.session).all())


@scopes.patch("/{id}", response_model=ScopeGet)
Expand All @@ -58,7 +61,9 @@ async def update_scope(
Scopes: `["auth.scope.update"]`
"""
scope = Scope.get(id, session=db.session)
return ScopeGet.from_orm(Scope.update(scope.id, **scope_inp.dict(exclude_unset=True), session=db.session))
return ScopeGet.model_validate(
Scope.update(scope.id, **scope_inp.model_dump(exclude_unset=True), session=db.session)
)


@scopes.delete("/{id}", response_model=StatusResponseModel)
Expand Down
Loading