Skip to content

Commit

Permalink
Verification Token Handling and Bug Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
anshulg954 committed Oct 17, 2024
1 parent 4b2c1c7 commit 7eb9963
Show file tree
Hide file tree
Showing 7 changed files with 104 additions and 23 deletions.
35 changes: 33 additions & 2 deletions tabpfn_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,36 @@ def register(
access_token = response.json()["token"] if is_created else None
return is_created, message, access_token

def verify_email(self, token: str) -> tuple[bool, str]:
"""
Verify the email with the provided token.
Parameters
----------
token : str
Returns
-------
is_verified : bool
True if the email is verified successfully.
message : str
The message returned from the server.
"""

response = self.httpx_client.get(
self.server_endpoints.verify_email.path,
params={"token": token},
)
self._validate_response(response, "verify_email", only_version_check=True)
if response.status_code == 200:
is_verified = True
message = response.json()["message"]
else:
is_verified = False
message = response.json()["detail"]

return is_verified, message

def login(self, email: str, password: str) -> tuple[str, str]:
"""
Login with the provided credentials and return the access token if successful.
Expand Down Expand Up @@ -385,8 +415,9 @@ def login(self, email: str, password: str) -> tuple[str, str]:
message = ""
else:
message = response.json()["detail"]

return access_token, message
# status code signifies the success of the login, issues with password, and email verification
# 200 : success, 401 : wrong password, 403 : email not verified yet
return access_token, message, response.status_code

def get_password_policy(self) -> {}:
"""
Expand Down
1 change: 1 addition & 0 deletions tabpfn_client/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def init(use_server=True):
):
print("Your email is not verified. Please verify your email to continue...")
PromptAgent.reverify_email(is_valid_token_set[1], user_auth_handler)
user_auth_handler.set_token(is_valid_token_set[1])
else:
PromptAgent.prompt_welcome()
if not PromptAgent.prompt_terms_and_cond():
Expand Down
46 changes: 42 additions & 4 deletions tabpfn_client/prompt_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,18 +100,28 @@ def prompt_and_set_token(cls, user_auth_handler: "UserAuthenticationClient"):
)
)
additional_info = cls.prompt_add_user_information()
is_created, message = user_auth_handler.set_token_by_registration(
is_created, message, access_token = user_auth_handler.set_token_by_registration(
email, password, password_confirm, validation_link, additional_info
)
if not is_created:
raise RuntimeError("User registration failed: " + str(message) + "\n")

print(
cls.indent(
"Account created successfully! To start using TabPFN please click on the link in the verification email we sent you."
"Account created successfully! To start using TabPFN please enter the secret key in the verification email we sent you."
)
+ "\n"
)
# verify token from email
verified = False
while not verified:
token = input(cls.indent("Verification Token: "))
verified, message = user_auth_handler.verify_email(token)
if not verified:
print("\n" + cls.indent(str(message) + "Please try again!") + "\n")

print(cls.indent("Thank you for verifying your email successfully! Your access token is: ") +
access_token + r" and we have stored it for you in the file in directory: '.\tabpfn\config.'\n\n")

# Login
elif choice == "2":
Expand All @@ -120,12 +130,27 @@ def prompt_and_set_token(cls, user_auth_handler: "UserAuthenticationClient"):
email = input(cls.indent("Please enter your email: "))
password = getpass.getpass(cls.indent("Please enter your password: "))

successful, message = user_auth_handler.set_token_by_login(
successful, message, status_code = user_auth_handler.set_token_by_login(
email, password
)
if successful:
break
print(cls.indent("Login failed: " + message) + "\n")
print(cls.indent("Login failed: " + str(message)) + "\n")
if status_code == 403:
# Verify email
verified = False
while not verified:
token = input(cls.indent("Verification Token: "))
verified, message = user_auth_handler.verify_email(token)
if not verified:
print(
"\n" + cls.indent(str(message) + "Please try again!") + "\n"
)
else:
print(cls.indent("Email verified successfully!") + "\n")
user_auth_handler.set_token_by_login(email, password)
break
break

prompt = "\n".join(
[
Expand Down Expand Up @@ -239,6 +264,19 @@ def reverify_email(
)
+ "\n"
)
# verify token from email
verified = False
while not verified:
token = input(cls.indent("Please enter the correct secret key sent to your email to verify: "))
# get user email from user_auth_handler
verified, message = user_auth_handler.verify_email(token)
if not verified:
print(
"\n" + cls.indent(str(message) + "Please try again!") + "\n"
)
else:
print(cls.indent("Email verified successfully!") + "\n")
break
return

@classmethod
Expand Down
21 changes: 13 additions & 8 deletions tabpfn_client/server_config.yaml
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
## testing
# protocol: "http"
# host: "localhost"
# port: "8080"
# testing
protocol: "http"
host: "localhost"
port: "80"

# production
protocol: "https"
host: "tabpfn-server-wjedmz7r5a-ez.a.run.app"
# host: tabpfn-server-preprod-wjedmz7r5a-ez.a.run.app # preprod
port: "443"
# protocol: "https"
# host: "tabpfn-server-wjedmz7r5a-ez.a.run.app"
# # host: tabpfn-server-preprod-wjedmz7r5a-ez.a.run.app # preprod
# port: "443"
endpoints:
root:
path: "/"
Expand Down Expand Up @@ -39,6 +39,11 @@ endpoints:
methods: [ "POST" ]
description: "Send verifiaction email or for reverification"

verify_email:
path: "/auth/verify_email/"
methods: [ "GET" ]
description: "Verify email"

send_reset_password_email:
path: "/auth/send_reset_password_email/"
methods: [ "POST" ]
Expand Down
14 changes: 9 additions & 5 deletions tabpfn_client/service_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,16 @@ def set_token_by_registration(
)
if access_token is not None:
self.set_token(access_token)
return is_created, message
return is_created, message, access_token

def set_token_by_login(self, email: str, password: str) -> tuple[bool, str]:
access_token, message = self.service_client.login(email, password)
def set_token_by_login(self, email: str, password: str) -> tuple[bool, str, int]:
access_token, message, status_code = self.service_client.login(email, password)

if access_token is None:
return False, message
return False, message, status_code

self.set_token(access_token)
return True, message
return True, message, status_code

def try_reuse_existing_token(self) -> bool | tuple[bool, str]:
if self.service_client.access_token is None:
Expand Down Expand Up @@ -103,6 +103,10 @@ def send_verification_email(self, access_token: str) -> tuple[bool, str]:
sent, message = self.service_client.send_verification_email(access_token)
return sent, message

def verify_email(self, token: str) -> tuple[bool, str]:
verified, message = self.service_client.verify_email(token)
return verified, message


class UserDataClient(ServiceClientWrapper):
"""
Expand Down
6 changes: 4 additions & 2 deletions tabpfn_client/tests/unit/test_prompt_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def test_password_req_to_policy(self):
@patch("getpass.getpass", side_effect=["Password123!", "Password123!"])
@patch(
"builtins.input",
side_effect=["1", "[email protected]", "test", "test", "test", "y"],
side_effect=["1", "[email protected]", "test", "test", "test", "y", "test"],
)
def test_prompt_and_set_token_registration(
self, mock_input, mock_getpass, mock_server
Expand All @@ -30,16 +30,18 @@ def test_prompt_and_set_token_registration(
mock_auth_client.set_token_by_registration.return_value = (
True,
"Registration successful",
"dummy_token"
)
mock_auth_client.validate_email.return_value = (True, "")
mock_auth_client.verify_email.return_value = (True, "Verification successful")
PromptAgent.prompt_and_set_token(user_auth_handler=mock_auth_client)
mock_auth_client.set_token_by_registration.assert_called_once()

@patch("getpass.getpass", side_effect=["password123"])
@patch("builtins.input", side_effect=["2", "[email protected]"])
def test_prompt_and_set_token_login(self, mock_input, mock_getpass):
mock_auth_client = MagicMock()
mock_auth_client.set_token_by_login.return_value = (True, "Login successful")
mock_auth_client.set_token_by_login.return_value = (True, "Login successful", 200)
PromptAgent.prompt_and_set_token(user_auth_handler=mock_auth_client)
mock_auth_client.set_token_by_login.assert_called_once()

Expand Down
4 changes: 2 additions & 2 deletions tabpfn_client/tests/unit/test_service_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_set_token_by_invalid_login(self, mock_server):
401, json={"detail": "Incorrect email or password"}
)
self.assertEqual(
(False, "Incorrect email or password"),
(False, "Incorrect email or password", 401),
UserAuthenticationClient(ServiceClient()).set_token_by_login(
"dummy_email", "dummy_password"
),
Expand Down Expand Up @@ -84,7 +84,7 @@ def test_set_token_by_invalid_registration(self, mock_server):
401, json={"detail": "Password mismatch"}
)
self.assertEqual(
(False, "Password mismatch"),
(False, "Password mismatch", None),
UserAuthenticationClient(ServiceClient()).set_token_by_registration(
"dummy_email",
"dummy_password",
Expand Down

0 comments on commit 7eb9963

Please sign in to comment.