diff --git a/db/user.py b/db/user.py index 3c6c0efed562fa8997294e77c0a6675b410e929f..8a267ca3f1e6f2d8ae188388e28300cb43ea4b98 100644 --- a/db/user.py +++ b/db/user.py @@ -1,10 +1,22 @@ -from abc import ABC - +from bson.objectid import ObjectId from .mongo import Model +from pymongo.errors import DuplicateKeyError +import hashlib class User(Model): + @staticmethod + def dict_to_object(dictionary): + u = User( + first_name=dictionary["first_name"], + last_name=dictionary["last_name"], + email=dictionary["email"], + hashed_password=dictionary["hashed_password"] + ) + u.__id = str(dictionary["_id"]) + return u + def to_dict(self): res = {} if self.__id is not None: @@ -13,29 +25,48 @@ class User(Model): res["first_name"] = self.__first_name if self.__last_name is not None: res["last_name"] = self.__last_name + if self.__email is not None: + res["email"] = self.__email + if self.__hashed_password is not None: + res["hashed_password"] = self.__hashed_password + + return res def __init__(self, first_name, last_name, email, hashed_password): - self.__id = "" + self.__id = None self.__first_name = first_name self.__last_name = last_name self.__email = email self.__hashed_password = hashed_password - def get_collection(self): - return self.get_db()["users"] + @staticmethod + def get_collection(): + return Model.get_db()["users"] def get_id(self): return self.__id def store(self): - add_document = self.get_collection().insert_one({ - "first_name": self.__first_name, - "last_name": self.__last_name, - "email": self.__email, - "hashed_password": self.__hashed_password, - }) - self.__id = str(add_document.inserted_id) + try: + add_document = User.get_collection().insert_one(self.to_dict()) + self.__id = str(add_document.inserted_id) + except DuplicateKeyError: + print("Error occurred") + + @staticmethod + def find_by_username(email): + u = User.get_collection().find_one({"email": email}) + if u is not None: + return User.dict_to_object(u) + return u @staticmethod - def find_by_username(self, email): - return self.get_collection().find_one({"email": email}) + def find_by_id(object_id): + u = User.get_collection().find_one({"_id": ObjectId(object_id)}) + if u is not None: + return User.dict_to_object(u) + return u + + def compare_password(self, new_password): + new_hashed_password = hashlib.md5(new_password.encode()).hexdigest() + return self.__hashed_password == new_hashed_password