Skip to content

Commit

Permalink
Support SSO provider config from infrahub.toml
Browse files Browse the repository at this point in the history
  • Loading branch information
ogenstad committed Oct 26, 2024
1 parent 57f3889 commit cdfac0a
Showing 1 changed file with 48 additions and 6 deletions.
54 changes: 48 additions & 6 deletions backend/infrahub/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,14 @@ class SecurityOIDCProvider2(SecurityOIDCSettings):
model_config = SettingsConfigDict(env_prefix="INFRAHUB_OIDC_PROVIDER2_")


class SecurityOIDCProviderSettings(BaseModel):
"""This class is meant to facilitate configuration of OIDC providers when loading configuration from a infrahub.toml file."""

google: Optional[SecurityOIDCGoogle] = Field(default=None)
provider1: Optional[SecurityOIDCProvider1] = Field(default=None)
provider2: Optional[SecurityOIDCProvider2] = Field(default=None)


class SecurityOAuth2BaseSettings(BaseSettings):
"""Baseclass for typing"""

Expand Down Expand Up @@ -490,6 +498,14 @@ class SecurityOAuth2Google(SecurityOAuth2Settings):
display_label: str = Field(default="Google")


class SecurityOAuth2ProviderSettings(BaseModel):
"""This class is meant to facilitate configuration of OAuth2 providers when loading configuration from a infrahub.toml file."""

google: Optional[SecurityOAuth2Google] = Field(default=None)
provider1: Optional[SecurityOAuth2Provider1] = Field(default=None)
provider2: Optional[SecurityOAuth2Provider2] = Field(default=None)


class MiscellaneousSettings(BaseSettings):
model_config = SettingsConfigDict(env_prefix="INFRAHUB_MISC_")
print_query_details: bool = False
Expand Down Expand Up @@ -535,7 +551,9 @@ class SecuritySettings(BaseSettings):
default_factory=generate_uuid, description="The secret key used to validate authentication tokens"
)
oauth2_providers: list[Oauth2Provider] = Field(default_factory=list, description="The selected OAuth2 providers")
oauth2_provider_settings: SecurityOAuth2ProviderSettings = Field(default_factory=SecurityOAuth2ProviderSettings)
oidc_providers: list[OIDCProvider] = Field(default_factory=list, description="The selected OIDC providers")
oidc_provider_settings: SecurityOIDCProviderSettings = Field(default_factory=SecurityOIDCProviderSettings)
_oauth2_settings: dict[str, SecurityOAuth2Settings] = PrivateAttr(default_factory=dict)
_oidc_settings: dict[str, SecurityOIDCSettings] = PrivateAttr(default_factory=dict)

Expand All @@ -547,9 +565,21 @@ def check_oauth2_provider_settings(self) -> Self:
Oauth2Provider.GOOGLE: SecurityOAuth2Google,
}
for oauth2_provider in self.oauth2_providers:
provider = mapped_providers[oauth2_provider]()
if isinstance(provider, SecurityOAuth2Settings):
self._oauth2_settings[oauth2_provider.value] = provider
match oauth2_provider:
case Oauth2Provider.GOOGLE:
if self.oauth2_provider_settings.google:
self._oauth2_settings[oauth2_provider.value] = self.oauth2_provider_settings.google
case Oauth2Provider.PROVIDER1:
if self.oauth2_provider_settings.provider1:
self._oauth2_settings[oauth2_provider.value] = self.oauth2_provider_settings.provider1
case Oauth2Provider.PROVIDER2:
if self.oauth2_provider_settings.provider2:
self._oauth2_settings[oauth2_provider.value] = self.oauth2_provider_settings.provider2

if oauth2_provider.value not in self._oauth2_settings:
provider = mapped_providers[oauth2_provider]()
if isinstance(provider, SecurityOAuth2Settings):
self._oauth2_settings[oauth2_provider.value] = provider

return self

Expand All @@ -561,9 +591,21 @@ def check_oidc_provider_settings(self) -> Self:
OIDCProvider.PROVIDER2: SecurityOIDCProvider2,
}
for oidc_provider in self.oidc_providers:
provider = mapped_providers[oidc_provider]()
if isinstance(provider, SecurityOIDCSettings):
self._oidc_settings[oidc_provider.value] = provider
match oidc_provider:
case OIDCProvider.GOOGLE:
if self.oidc_provider_settings.google:
self._oidc_settings[oidc_provider.value] = self.oidc_provider_settings.google
case OIDCProvider.PROVIDER1:
if self.oidc_provider_settings.provider1:
self._oidc_settings[oidc_provider.value] = self.oidc_provider_settings.provider1
case OIDCProvider.PROVIDER2:
if self.oidc_provider_settings.provider2:
self._oidc_settings[oidc_provider.value] = self.oidc_provider_settings.provider2

if oidc_provider.value not in self._oidc_settings:
provider = mapped_providers[oidc_provider]()
if isinstance(provider, SecurityOIDCSettings):
self._oidc_settings[oidc_provider.value] = provider

return self

Expand Down

0 comments on commit cdfac0a

Please sign in to comment.