From cdfac0ac9c73b81d5028a182a8634f8b42320c8e Mon Sep 17 00:00:00 2001 From: Patrick Ogenstad Date: Sat, 26 Oct 2024 11:09:45 +0200 Subject: [PATCH] Support SSO provider config from infrahub.toml --- backend/infrahub/config.py | 54 +++++++++++++++++++++++++++++++++----- 1 file changed, 48 insertions(+), 6 deletions(-) diff --git a/backend/infrahub/config.py b/backend/infrahub/config.py index 5a714c8bc6..74f9cb90dc 100644 --- a/backend/infrahub/config.py +++ b/backend/infrahub/config.py @@ -451,6 +451,14 @@ class SecurityOIDCProvider2(SecurityOIDCSettings): model_config = SettingsConfigDict(env_prefix="INFRAHUB_OIDC_PROVIDER2_") +class SecurityOIDCProviderSettings(BaseModel): + """This class is meant to facilitate configuration of OIDC providers when loading configuration from a infrahub.toml file.""" + + google: Optional[SecurityOIDCGoogle] = Field(default=None) + provider1: Optional[SecurityOIDCProvider1] = Field(default=None) + provider2: Optional[SecurityOIDCProvider2] = Field(default=None) + + class SecurityOAuth2BaseSettings(BaseSettings): """Baseclass for typing""" @@ -490,6 +498,14 @@ class SecurityOAuth2Google(SecurityOAuth2Settings): display_label: str = Field(default="Google") +class SecurityOAuth2ProviderSettings(BaseModel): + """This class is meant to facilitate configuration of OAuth2 providers when loading configuration from a infrahub.toml file.""" + + google: Optional[SecurityOAuth2Google] = Field(default=None) + provider1: Optional[SecurityOAuth2Provider1] = Field(default=None) + provider2: Optional[SecurityOAuth2Provider2] = Field(default=None) + + class MiscellaneousSettings(BaseSettings): model_config = SettingsConfigDict(env_prefix="INFRAHUB_MISC_") print_query_details: bool = False @@ -535,7 +551,9 @@ class SecuritySettings(BaseSettings): default_factory=generate_uuid, description="The secret key used to validate authentication tokens" ) oauth2_providers: list[Oauth2Provider] = Field(default_factory=list, description="The selected OAuth2 providers") + oauth2_provider_settings: SecurityOAuth2ProviderSettings = Field(default_factory=SecurityOAuth2ProviderSettings) oidc_providers: list[OIDCProvider] = Field(default_factory=list, description="The selected OIDC providers") + oidc_provider_settings: SecurityOIDCProviderSettings = Field(default_factory=SecurityOIDCProviderSettings) _oauth2_settings: dict[str, SecurityOAuth2Settings] = PrivateAttr(default_factory=dict) _oidc_settings: dict[str, SecurityOIDCSettings] = PrivateAttr(default_factory=dict) @@ -547,9 +565,21 @@ def check_oauth2_provider_settings(self) -> Self: Oauth2Provider.GOOGLE: SecurityOAuth2Google, } for oauth2_provider in self.oauth2_providers: - provider = mapped_providers[oauth2_provider]() - if isinstance(provider, SecurityOAuth2Settings): - self._oauth2_settings[oauth2_provider.value] = provider + match oauth2_provider: + case Oauth2Provider.GOOGLE: + if self.oauth2_provider_settings.google: + self._oauth2_settings[oauth2_provider.value] = self.oauth2_provider_settings.google + case Oauth2Provider.PROVIDER1: + if self.oauth2_provider_settings.provider1: + self._oauth2_settings[oauth2_provider.value] = self.oauth2_provider_settings.provider1 + case Oauth2Provider.PROVIDER2: + if self.oauth2_provider_settings.provider2: + self._oauth2_settings[oauth2_provider.value] = self.oauth2_provider_settings.provider2 + + if oauth2_provider.value not in self._oauth2_settings: + provider = mapped_providers[oauth2_provider]() + if isinstance(provider, SecurityOAuth2Settings): + self._oauth2_settings[oauth2_provider.value] = provider return self @@ -561,9 +591,21 @@ def check_oidc_provider_settings(self) -> Self: OIDCProvider.PROVIDER2: SecurityOIDCProvider2, } for oidc_provider in self.oidc_providers: - provider = mapped_providers[oidc_provider]() - if isinstance(provider, SecurityOIDCSettings): - self._oidc_settings[oidc_provider.value] = provider + match oidc_provider: + case OIDCProvider.GOOGLE: + if self.oidc_provider_settings.google: + self._oidc_settings[oidc_provider.value] = self.oidc_provider_settings.google + case OIDCProvider.PROVIDER1: + if self.oidc_provider_settings.provider1: + self._oidc_settings[oidc_provider.value] = self.oidc_provider_settings.provider1 + case OIDCProvider.PROVIDER2: + if self.oidc_provider_settings.provider2: + self._oidc_settings[oidc_provider.value] = self.oidc_provider_settings.provider2 + + if oidc_provider.value not in self._oidc_settings: + provider = mapped_providers[oidc_provider]() + if isinstance(provider, SecurityOIDCSettings): + self._oidc_settings[oidc_provider.value] = provider return self