Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added whoami command #1193

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions truss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def version():
return __version__


from truss.api import login, push
from truss.api import login, push, whoami
from truss.build import from_directory, init, kill_all, load

__all__ = ["from_directory", "init", "kill_all", "load", "push", "login"]
__all__ = ["from_directory", "init", "kill_all", "load", "push", "login", "whoami"]
21 changes: 21 additions & 0 deletions truss/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,27 @@ def login(api_key: str):
RemoteFactory.update_remote_config(remote_config)


def whoami(remote: Optional[str] = None):
"""
Returns account information for the current user.
"""
if not remote:
available_remotes = RemoteFactory.get_available_config_names()
if len(available_remotes) == 1:
remote = available_remotes[0]
elif len(available_remotes) == 0:
raise ValueError(
"Please authenticate via truss.login and pass it as an argument."
)
else:
raise ValueError(
"Multiple remotes found. Please pass the remote as an argument."
)

remote_provider = RemoteFactory.create(remote=remote)
return remote_provider.whoami()


def push(
target_directory: str,
remote: Optional[str] = None,
Expand Down
19 changes: 19 additions & 0 deletions truss/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,25 @@ def login(api_key: Optional[str]):
login(api_key)


@truss_cli.command()
@click.option(
"--remote",
type=str,
required=False,
help="Name of the remote in .trussrc to check whoami.",
)
@error_handling
def whoami(remote: Optional[str]):
"""
Shows user information and exit.
"""
from truss.api import whoami

user = whoami(remote)

console.print(f"{user.workspace_name}\{user.user_email}")


@truss_cli.command()
@click.argument("target_directory", required=False, default=os.getcwd())
@click.option(
Expand Down
13 changes: 12 additions & 1 deletion truss/remote/baseten/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from truss.remote.baseten.error import ApiError, RemoteError
from truss.remote.baseten.service import BasetenService, URLConfig
from truss.remote.baseten.utils.transfer import base64_encoded_json_str
from truss.remote.truss_remote import TrussRemote
from truss.remote.truss_remote import RemoteUser, TrussRemote
from truss.truss_config import ModelServer
from truss.truss_handle import TrussHandle
from truss.util.path import is_ignored, load_trussignore_patterns
Expand Down Expand Up @@ -115,6 +115,17 @@ def get_chainlets(
)
]

def whoami(self) -> RemoteUser:
resp = self._api._post_graphql_query(
"query{organization{workspace_name}user{email}}"
)
workspace_name = resp["data"]["organization"]["workspace_name"]
user_email = resp["data"]["user"]["email"]
return RemoteUser(
workspace_name,
user_email,
)

def push( # type: ignore
self,
truss_handle: TrussHandle,
Expand Down
21 changes: 21 additions & 0 deletions truss/remote/truss_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,17 @@
from truss.truss_handle import TrussHandle


class RemoteUser:
"""Class to hold information about the remote user"""

workspace_name: str
user_email: str

def __init__(self, workspace_name: str, user_email: str):
self.workspace_name = workspace_name
self.user_email = user_email


class TrussService(ABC):
"""
Define the abstract base class for a TrussService.
Expand Down Expand Up @@ -209,6 +220,16 @@ def push(self, truss_handle: TrussHandle, **kwargs) -> TrussService:

"""

@abstractmethod
def whoami(self) -> RemoteUser:
"""
Returns account information for the current user.

This method should be implemented in subclasses and return a RemoteUser.


"""

@abstractmethod
def get_service(self, **kwargs) -> TrussService:
"""
Expand Down
5 changes: 4 additions & 1 deletion truss/tests/remote/test_remote_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest
from truss.remote.remote_factory import RemoteFactory
from truss.remote.truss_remote import RemoteConfig, TrussRemote
from truss.remote.truss_remote import RemoteConfig, RemoteUser, TrussRemote

SAMPLE_CONFIG = {"api_key": "test_key", "remote_url": "http://test.com"}

Expand Down Expand Up @@ -41,6 +41,9 @@ def get_service(self, **kwargs):
def sync_truss_to_dev_version_by_name(self, model_name: str, target_directory: str):
raise NotImplementedError

def whoami(self) -> RemoteUser:
return RemoteUser("test_user", "test_email")


def mock_service_config():
return RemoteConfig(
Expand Down
Loading