auth_service.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. """
  2. Authentication service - business logic for auth operations.
  3. """
  4. from datetime import datetime, timedelta, timezone
  5. from fastapi import HTTPException, status
  6. from sqlalchemy import select
  7. from sqlalchemy.ext.asyncio import AsyncSession
  8. from app.config import settings
  9. from app.core.security import (
  10. create_access_token,
  11. create_refresh_token,
  12. hash_password,
  13. verify_password,
  14. verify_token,
  15. )
  16. from app.models.organization import Organization
  17. from app.models.refresh_token import RefreshToken
  18. from app.models.user import User
  19. from app.schemas.auth import AuthResponse, OrganizationInfo, UserInfo
  20. async def register_user(
  21. db: AsyncSession,
  22. email: str,
  23. password: str,
  24. full_name: str | None,
  25. phone: str | None,
  26. organization_name: str,
  27. ) -> dict:
  28. """
  29. Register new user with organization.
  30. Creates:
  31. 1. Organization (status=pending, all products disabled)
  32. 2. User (role=owner, status=pending, email_verified=False)
  33. Args:
  34. db: Database session
  35. email: User email
  36. password: Plain password
  37. full_name: User full name
  38. phone: User phone
  39. organization_name: Organization name
  40. Returns:
  41. Dict with message
  42. Raises:
  43. HTTPException: If email already exists
  44. """
  45. # Check if email already exists
  46. result = await db.execute(select(User).where(User.email == email))
  47. if result.scalar_one_or_none():
  48. raise HTTPException(
  49. status_code=status.HTTP_400_BAD_REQUEST,
  50. detail="Email already registered",
  51. )
  52. # Create organization
  53. org = Organization(
  54. name=organization_name,
  55. contact_email=email,
  56. contact_phone=phone,
  57. wifi_enabled=False,
  58. ble_enabled=False,
  59. status="pending",
  60. )
  61. db.add(org)
  62. await db.flush()
  63. # Create user (owner of organization)
  64. user = User(
  65. email=email,
  66. hashed_password=hash_password(password),
  67. full_name=full_name,
  68. phone=phone,
  69. role="owner",
  70. status="pending",
  71. organization_id=org.id,
  72. email_verified=False,
  73. )
  74. db.add(user)
  75. await db.commit()
  76. return {"message": "Registration successful. Awaiting admin approval."}
  77. async def login_user(
  78. db: AsyncSession,
  79. email: str,
  80. password: str,
  81. ) -> AuthResponse:
  82. """
  83. Authenticate user and return tokens.
  84. Args:
  85. db: Database session
  86. email: User email
  87. password: Plain password
  88. Returns:
  89. AuthResponse with tokens and user info
  90. Raises:
  91. HTTPException: If credentials are invalid
  92. """
  93. # Get user
  94. result = await db.execute(select(User).where(User.email == email))
  95. user = result.scalar_one_or_none()
  96. if user is None or not verify_password(password, user.hashed_password):
  97. raise HTTPException(
  98. status_code=status.HTTP_401_UNAUTHORIZED,
  99. detail="Incorrect email or password",
  100. )
  101. # For MVP, allow login even if status is pending or email not verified
  102. # In production, you might want to enforce verification
  103. # Update last login
  104. user.last_login_at = datetime.now(timezone.utc)
  105. await db.commit()
  106. # Create tokens
  107. token_data = {"sub": user.id}
  108. access_token = create_access_token(token_data)
  109. refresh_token_str = create_refresh_token(token_data)
  110. # Save refresh token to database
  111. refresh_token = RefreshToken(
  112. user_id=user.id,
  113. token=refresh_token_str,
  114. expires_at=datetime.now(timezone.utc)
  115. + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS),
  116. )
  117. db.add(refresh_token)
  118. await db.commit()
  119. # Get organization info (if user has one)
  120. org_info = None
  121. if user.organization_id:
  122. result = await db.execute(
  123. select(Organization).where(Organization.id == user.organization_id)
  124. )
  125. org = result.scalar_one_or_none()
  126. if org:
  127. org_info = OrganizationInfo.model_validate(org)
  128. return AuthResponse(
  129. access_token=access_token,
  130. refresh_token=refresh_token_str,
  131. token_type="bearer",
  132. expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
  133. user=UserInfo.model_validate(user),
  134. organization=org_info,
  135. )
  136. async def refresh_access_token(
  137. db: AsyncSession,
  138. refresh_token_str: str,
  139. ) -> AuthResponse:
  140. """
  141. Refresh access token using refresh token.
  142. Args:
  143. db: Database session
  144. refresh_token_str: Refresh token string
  145. Returns:
  146. AuthResponse with new tokens
  147. Raises:
  148. HTTPException: If refresh token is invalid
  149. """
  150. # Verify refresh token
  151. payload = verify_token(refresh_token_str, expected_type="refresh")
  152. if payload is None:
  153. raise HTTPException(
  154. status_code=status.HTTP_401_UNAUTHORIZED,
  155. detail="Invalid or expired refresh token",
  156. )
  157. # Check if refresh token exists in database and is not revoked
  158. result = await db.execute(
  159. select(RefreshToken).where(RefreshToken.token == refresh_token_str)
  160. )
  161. db_token = result.scalar_one_or_none()
  162. if db_token is None or not db_token.is_valid:
  163. raise HTTPException(
  164. status_code=status.HTTP_401_UNAUTHORIZED,
  165. detail="Invalid or expired refresh token",
  166. )
  167. # Get user
  168. user_id = payload.get("sub")
  169. result = await db.execute(select(User).where(User.id == user_id))
  170. user = result.scalar_one_or_none()
  171. if user is None:
  172. raise HTTPException(
  173. status_code=status.HTTP_401_UNAUTHORIZED,
  174. detail="User not found",
  175. )
  176. # Revoke old refresh token
  177. db_token.revoked_at = datetime.now(timezone.utc)
  178. # Create new tokens
  179. token_data = {"sub": user.id}
  180. access_token = create_access_token(token_data)
  181. new_refresh_token = create_refresh_token(token_data)
  182. # Save new refresh token
  183. new_token = RefreshToken(
  184. user_id=user.id,
  185. token=new_refresh_token,
  186. expires_at=datetime.now(timezone.utc)
  187. + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS),
  188. )
  189. db.add(new_token)
  190. await db.commit()
  191. # Get organization info
  192. org_info = None
  193. if user.organization_id:
  194. result = await db.execute(
  195. select(Organization).where(Organization.id == user.organization_id)
  196. )
  197. org = result.scalar_one_or_none()
  198. if org:
  199. org_info = OrganizationInfo.model_validate(org)
  200. return AuthResponse(
  201. access_token=access_token,
  202. refresh_token=new_refresh_token,
  203. token_type="bearer",
  204. expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
  205. user=UserInfo.model_validate(user),
  206. organization=org_info,
  207. )
  208. async def logout_user(
  209. db: AsyncSession,
  210. refresh_token_str: str,
  211. ) -> dict:
  212. """
  213. Logout user by revoking refresh token.
  214. Args:
  215. db: Database session
  216. refresh_token_str: Refresh token to revoke
  217. Returns:
  218. Dict with message
  219. """
  220. # Find and revoke refresh token
  221. result = await db.execute(
  222. select(RefreshToken).where(RefreshToken.token == refresh_token_str)
  223. )
  224. db_token = result.scalar_one_or_none()
  225. if db_token:
  226. db_token.revoked_at = datetime.now(timezone.utc)
  227. await db.commit()
  228. return {"message": "Logged out successfully"}