diff --git a/database/crud.py b/database/crud.py index 20e39de..3796b65 100644 --- a/database/crud.py +++ b/database/crud.py @@ -1,6 +1,6 @@ from datetime import datetime + from fastapi import HTTPException -from hashlib import sha1 from passlib.context import CryptContext from app.schemas import * @@ -8,8 +8,7 @@ from constants import SHA1_SALT from database import SessionLocal from database.models import * - -pwd_context = CryptContext(schemes=["bcrypt"]) +pwd_context = CryptContext(schemes=["bcrypt", "hex_sha1"], deprecated=["hex_sha1"]) def get_db(): @@ -50,7 +49,7 @@ def fetch_user_by_email(data, db): def create_user(data, db): - data.password = pwd_context.hash(data.password) + data.password = pwd_context.hash(secret=data.password) user = insert_data(model="Users", data=data, db=db) return user @@ -62,43 +61,35 @@ def update_otp(data: OTPResend, db): db.commit() -def rehash_password(password): - return pwd_context.hash(secret=password) - - def update_password_hash(user, password, db): - new_hash = rehash_password(password=password) + new_hash = pwd_context.hash(secret=password) db.query(Users).filter(Users.email == user.email).update({Users.password: new_hash}) db.commit() db.refresh(user) -def check_sha1_hash(db_hash): - hash_length = len(db_hash) +def check_legacy_hash(db_hash): sha1_length = 40 - if hash_length == sha1_length: + if len(db_hash) == sha1_length: return True return False -def verify_legacy_password(user, password, db): - hash = SHA1_SALT + password - correct_password = user.password == sha1(hash.encode("utf-8")).hexdigest() - if correct_password: - update_password_hash(user=user, password=password, db=db) - return True - return False - - -def verify_updated_password(user, password): - return pwd_context.verify(secret=password, hash=user.password) +def construct_secret(db_hash, password): + legacy_hash = check_legacy_hash(db_hash=db_hash) + if legacy_hash: + return SHA1_SALT + password, legacy_hash + return password, legacy_hash def verify_password(user, password, db): - legacy_hash = check_sha1_hash(user.password) - if legacy_hash: - return verify_legacy_password(user=user, password=password, db=db) - return verify_updated_password(user=user, password=password) + secret, legacy_hash = construct_secret(db_hash=user.password, password=password) + correct_password = pwd_context.verify(secret=secret, hash=user.password) + if correct_password: + if legacy_hash: + update_password_hash(user=user, password=password, db=db) + return True + return False def authenticate_user(data: UserLogin, db):