| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272 |
- """
- Authentication service - business logic for auth operations.
- """
- from datetime import datetime, timedelta, timezone
- from fastapi import HTTPException, status
- from sqlalchemy import select
- from sqlalchemy.ext.asyncio import AsyncSession
- from app.config import settings
- from app.core.security import (
- create_access_token,
- create_refresh_token,
- hash_password,
- verify_password,
- verify_token,
- )
- from app.models.organization import Organization
- from app.models.refresh_token import RefreshToken
- from app.models.user import User
- from app.schemas.auth import AuthResponse, OrganizationInfo, UserInfo
- async def register_user(
- db: AsyncSession,
- email: str,
- password: str,
- full_name: str | None,
- phone: str | None,
- organization_name: str,
- ) -> dict:
- """
- Register new user with organization.
- Creates:
- 1. Organization (status=pending, all products disabled)
- 2. User (role=owner, status=pending, email_verified=False)
- Args:
- db: Database session
- email: User email
- password: Plain password
- full_name: User full name
- phone: User phone
- organization_name: Organization name
- Returns:
- Dict with message
- Raises:
- HTTPException: If email already exists
- """
- # Check if email already exists
- result = await db.execute(select(User).where(User.email == email))
- if result.scalar_one_or_none():
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="Email already registered",
- )
- # Create organization
- org = Organization(
- name=organization_name,
- contact_email=email,
- contact_phone=phone,
- wifi_enabled=False,
- ble_enabled=False,
- status="pending",
- )
- db.add(org)
- await db.flush()
- # Create user (owner of organization)
- user = User(
- email=email,
- hashed_password=hash_password(password),
- full_name=full_name,
- phone=phone,
- role="owner",
- status="pending",
- organization_id=org.id,
- email_verified=False,
- )
- db.add(user)
- await db.commit()
- return {"message": "Registration successful. Awaiting admin approval."}
- async def login_user(
- db: AsyncSession,
- email: str,
- password: str,
- ) -> AuthResponse:
- """
- Authenticate user and return tokens.
- Args:
- db: Database session
- email: User email
- password: Plain password
- Returns:
- AuthResponse with tokens and user info
- Raises:
- HTTPException: If credentials are invalid
- """
- # Get user
- result = await db.execute(select(User).where(User.email == email))
- user = result.scalar_one_or_none()
- if user is None or not verify_password(password, user.hashed_password):
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="Incorrect email or password",
- )
- # For MVP, allow login even if status is pending or email not verified
- # In production, you might want to enforce verification
- # Update last login
- user.last_login_at = datetime.now(timezone.utc)
- await db.commit()
- # Create tokens
- token_data = {"sub": user.id}
- access_token = create_access_token(token_data)
- refresh_token_str = create_refresh_token(token_data)
- # Save refresh token to database
- refresh_token = RefreshToken(
- user_id=user.id,
- token=refresh_token_str,
- expires_at=datetime.now(timezone.utc)
- + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS),
- )
- db.add(refresh_token)
- await db.commit()
- # Get organization info (if user has one)
- org_info = None
- if user.organization_id:
- result = await db.execute(
- select(Organization).where(Organization.id == user.organization_id)
- )
- org = result.scalar_one_or_none()
- if org:
- org_info = OrganizationInfo.model_validate(org)
- return AuthResponse(
- access_token=access_token,
- refresh_token=refresh_token_str,
- token_type="bearer",
- expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
- user=UserInfo.model_validate(user),
- organization=org_info,
- )
- async def refresh_access_token(
- db: AsyncSession,
- refresh_token_str: str,
- ) -> AuthResponse:
- """
- Refresh access token using refresh token.
- Args:
- db: Database session
- refresh_token_str: Refresh token string
- Returns:
- AuthResponse with new tokens
- Raises:
- HTTPException: If refresh token is invalid
- """
- # Verify refresh token
- payload = verify_token(refresh_token_str, expected_type="refresh")
- if payload is None:
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="Invalid or expired refresh token",
- )
- # Check if refresh token exists in database and is not revoked
- result = await db.execute(
- select(RefreshToken).where(RefreshToken.token == refresh_token_str)
- )
- db_token = result.scalar_one_or_none()
- if db_token is None or not db_token.is_valid:
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="Invalid or expired refresh token",
- )
- # Get user
- user_id = payload.get("sub")
- result = await db.execute(select(User).where(User.id == user_id))
- user = result.scalar_one_or_none()
- if user is None:
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="User not found",
- )
- # Revoke old refresh token
- db_token.revoked_at = datetime.now(timezone.utc)
- # Create new tokens
- token_data = {"sub": user.id}
- access_token = create_access_token(token_data)
- new_refresh_token = create_refresh_token(token_data)
- # Save new refresh token
- new_token = RefreshToken(
- user_id=user.id,
- token=new_refresh_token,
- expires_at=datetime.now(timezone.utc)
- + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS),
- )
- db.add(new_token)
- await db.commit()
- # Get organization info
- org_info = None
- if user.organization_id:
- result = await db.execute(
- select(Organization).where(Organization.id == user.organization_id)
- )
- org = result.scalar_one_or_none()
- if org:
- org_info = OrganizationInfo.model_validate(org)
- return AuthResponse(
- access_token=access_token,
- refresh_token=new_refresh_token,
- token_type="bearer",
- expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
- user=UserInfo.model_validate(user),
- organization=org_info,
- )
- async def logout_user(
- db: AsyncSession,
- refresh_token_str: str,
- ) -> dict:
- """
- Logout user by revoking refresh token.
- Args:
- db: Database session
- refresh_token_str: Refresh token to revoke
- Returns:
- Dict with message
- """
- # Find and revoke refresh token
- result = await db.execute(
- select(RefreshToken).where(RefreshToken.token == refresh_token_str)
- )
- db_token = result.scalar_one_or_none()
- if db_token:
- db_token.revoked_at = datetime.now(timezone.utc)
- await db.commit()
- return {"message": "Logged out successfully"}
|