diff --git a/backend/config/config.go b/backend/config/config.go index f40abfa96..66a5fc10a 100644 --- a/backend/config/config.go +++ b/backend/config/config.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "log" + "net/url" "strings" "time" @@ -15,6 +16,7 @@ import ( "github.com/knadh/koanf/providers/file" zeroLogger "github.com/rs/zerolog/log" "github.com/teamhanko/hanko/backend/ee/saml/config" + "github.com/teamhanko/hanko/backend/persistence/models" "golang.org/x/exp/slices" ) @@ -25,6 +27,7 @@ type Config struct { Smtp SMTP `yaml:"smtp" json:"smtp,omitempty" koanf:"smtp"` EmailDelivery EmailDelivery `yaml:"email_delivery" json:"email_delivery,omitempty" koanf:"email_delivery" split_words:"true"` Passcode Passcode `yaml:"passcode" json:"passcode" koanf:"passcode"` + Passlink Passlink `yaml:"passlink" json:"passlink,omitempty" koanf:"passlink"` Password Password `yaml:"password" json:"password,omitempty" koanf:"password"` Database Database `yaml:"database" json:"database" koanf:"database"` Secrets Secrets `yaml:"secrets" json:"secrets" koanf:"secrets"` @@ -135,6 +138,16 @@ func DefaultConfig() *Config { Password: Password{ MinPasswordLength: 8, }, + Passlink: Passlink{ + Enabled: false, + URL: "http://localhost:8888", + TTL: 300, // 5 minutes + Strictness: models.PasslinkStrictnessNone, + Email: Email{ + FromAddress: "passcode@hanko.io", + FromName: "Hanko", + }, + }, Database: Database{ Database: "hanko", }, @@ -167,6 +180,10 @@ func DefaultConfig() *Config { Tokens: 3, Interval: 1 * time.Minute, }, + PasslinkLimits: RateLimits{ + Tokens: 3, + Interval: 1 * time.Minute, + }, TokenLimits: RateLimits{ Tokens: 3, Interval: 1 * time.Minute, @@ -419,9 +436,36 @@ func (p *Passcode) Validate() error { return nil } +type Passlink struct { + Enabled bool `yaml:"enabled" json:"enabled,omitempty" koanf:"enabled" jsonschema:"default=false"` + URL string `yaml:"url" json:"url,omitempty" koanf:"url"` + TTL int `yaml:"ttl" json:"ttl,omitempty" koanf:"ttl" jsonschema:"default=300"` + Email Email `yaml:"email" json:"email,omitempty" koanf:"email"` + Strictness models.PasslinkStrictness `yaml:"strictness" json:"strictness,omitempty" koanf:"strictness" jsonschema:"default=none,enum=browser,enum=device,enum=none"` +} + +func (p *Passlink) Validate() error { + err := p.Email.Validate() + if err != nil { + return fmt.Errorf("failed to validate email settings: %w", err) + } + if len(strings.TrimSpace(p.URL)) == 0 { + return errors.New("url must not be empty") + } + if url, err := url.Parse(p.URL); err != nil { + return fmt.Errorf("failed to parse url: %w", err) + } else if url.Scheme == "" || url.Host == "" { + return errors.New("url must be a valid URL") + } + if !p.Strictness.Valid() { + return fmt.Errorf("invalid passlink strictness: %s", p.Strictness) + } + return nil +} + // Database connection settings type Database struct { - Database string `yaml:"database" json:"database,omitempty" koanf:"database" jsonschema:"default=hanko" jsonschema:"oneof_required=config"` + Database string `yaml:"database" json:"database,omitempty" koanf:"database" jsonschema:"oneof_required=config,default=hanko"` User string `yaml:"user" json:"user,omitempty" koanf:"user" jsonschema:"oneof_required=config"` Password string `yaml:"password" json:"password,omitempty" koanf:"password" jsonschema:"oneof_required=config"` Host string `yaml:"host" json:"host,omitempty" koanf:"host" jsonschema:"oneof_required=config"` @@ -527,6 +571,7 @@ type RateLimiter struct { Redis *RedisConfig `yaml:"redis_config" json:"redis_config,omitempty" koanf:"redis_config"` PasscodeLimits RateLimits `yaml:"passcode_limits" json:"passcode_limits,omitempty" koanf:"passcode_limits" split_words:"true"` PasswordLimits RateLimits `yaml:"password_limits" json:"password_limits,omitempty" koanf:"password_limits" split_words:"true"` + PasslinkLimits RateLimits `yaml:"passlink_limits" json:"passlink_limits,omitempty" koanf:"passlink_limits" split_words:"true"` TokenLimits RateLimits `yaml:"token_limits" json:"token_limits,omitempty" koanf:"token_limits" split_words:"true"` } @@ -539,7 +584,7 @@ type RateLimiterStoreType string const ( RATE_LIMITER_STORE_IN_MEMORY RateLimiterStoreType = "in_memory" - RATE_LIMITER_STORE_REDIS = "redis" + RATE_LIMITER_STORE_REDIS RateLimiterStoreType = "redis" ) func (r *RateLimiter) Validate() error { @@ -673,7 +718,7 @@ func (p *ThirdPartyProviders) HasEnabled() bool { func (p *ThirdPartyProviders) Get(provider string) *ThirdPartyProvider { s := structs.New(p) for _, field := range s.Fields() { - if strings.ToLower(field.Name()) == strings.ToLower(provider) { + if strings.EqualFold(field.Name(), provider) { p := field.Value().(ThirdPartyProvider) return &p } diff --git a/backend/crypto/passlink.go b/backend/crypto/passlink.go new file mode 100644 index 000000000..85fb383cd --- /dev/null +++ b/backend/crypto/passlink.go @@ -0,0 +1,28 @@ +package crypto + +import ( + "crypto/rand" + "encoding/hex" + "log" +) + +// PasslinkGenerator will generate a random passlink token +type PasslinkGenerator interface { + Generate() (string, error) +} + +type passlinkGenerator struct { +} + +func NewPasslinkGenerator() PasslinkGenerator { + return &passlinkGenerator{} +} + +func (g *passlinkGenerator) Generate() (string, error) { + bytes := make([]byte, 32) + _, err := rand.Read(bytes) + if err != nil { + log.Fatal(err) + } + return hex.EncodeToString(bytes), nil +} diff --git a/backend/dto/config.go b/backend/dto/config.go index a1efe0497..8108a00eb 100644 --- a/backend/dto/config.go +++ b/backend/dto/config.go @@ -9,6 +9,7 @@ import ( // PublicConfig is the part of the configuration that will be shared with the frontend type PublicConfig struct { Password config.Password `json:"password"` + Passlink bool `json:"passlink"` Emails config.Emails `json:"emails"` Providers []string `json:"providers"` Account config.Account `json:"account"` @@ -19,6 +20,7 @@ type PublicConfig struct { func FromConfig(config config.Config) PublicConfig { return PublicConfig{ Password: config.Password, + Passlink: config.Passlink.Enabled, Emails: config.Emails, Providers: GetEnabledProviders(config.ThirdParty.Providers), Account: config.Account, diff --git a/backend/dto/passlink.go b/backend/dto/passlink.go new file mode 100644 index 000000000..533bd47aa --- /dev/null +++ b/backend/dto/passlink.go @@ -0,0 +1,22 @@ +package dto + +import ( + "time" +) + +type PasslinkFinishRequest struct { + ID string `json:"id" validate:"required,uuid4"` + Token string `json:"token" validate:"required"` +} + +type PasslinkInitRequest struct { + UserID string `json:"user_id" validate:"required,uuid4"` + EmailID *string `json:"email_id"` + RedirectPath string `json:"redirect_path" validate:"required"` +} + +type PasslinkReturn struct { + ID string `json:"id"` + CreatedAt time.Time `json:"created_at"` + UserID string `json:"user_id"` +} diff --git a/backend/dto/webhook/email.go b/backend/dto/webhook/email.go index 40d49237a..c5eaac56b 100644 --- a/backend/dto/webhook/email.go +++ b/backend/dto/webhook/email.go @@ -1,5 +1,7 @@ package webhook +import "github.com/teamhanko/hanko/backend/persistence/models" + type EmailSend struct { Subject string `json:"subject"` // subject BodyPlain string `json:"body_plain"` // used for string templates @@ -19,8 +21,20 @@ type PasscodeData struct { ValidUntil int64 `json:"valid_until"` // UnixTimestamp } +type PasslinkData struct { + ServiceName string `json:"service_name"` + Token string `json:"token"` + URL string `json:"url"` + TTL int `json:"ttl"` + ValidUntil int64 `json:"valid_until"` // UnixTimestamp + RedirectPath string `json:"redirect_path"` + RetryLimit int `json:"retry_limit"` + Strictness models.PasslinkStrictness `json:"strictness"` +} + type EmailType string var ( EmailTypePasscode EmailType = "passcode" + EmailTypePasslink EmailType = "passlink" ) diff --git a/backend/go.mod b/backend/go.mod index 38f357712..d40efd8f8 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -1,6 +1,8 @@ module github.com/teamhanko/hanko/backend -go 1.20 +go 1.21 + +toolchain go1.22.3 require ( github.com/brianvoe/gofakeit/v6 v6.28.0 diff --git a/backend/go.sum b/backend/go.sum index 6ce5a85fa..53e1876ce 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -6,6 +6,7 @@ github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25 github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/toml v1.3.2 h1:o7IhLm0Msx3BaB+n3Ag7L8EVlByGnpq14C4YWiu/gL8= +github.com/BurntSushi/toml v1.3.2/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ= github.com/ClickHouse/ch-go v0.55.0 h1:jw4Tpx887YXrkyL5DfgUome/po8MLz92nz2heOQ6RjQ= github.com/ClickHouse/ch-go v0.55.0/go.mod h1:kQT2f+yp2p+sagQA/7kS6G3ukym+GQ5KAu1kuFAFDiU= github.com/ClickHouse/clickhouse-go/v2 v2.9.1 h1:IeE2bwVvAba7Yw5ZKu98bKI4NpDmykEy6jUaQdJJCk8= @@ -76,12 +77,14 @@ github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46t github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY= +github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 h1:8UrgZ3GkP4i/CLijOJx79Yu+etlyjdBU4sfcs2WYQMs= github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0/go.mod h1:v57UDF4pDQJcEfFUCRop3lJL149eHGSe9Jvczhzjo/0= github.com/denisenkom/go-mssqldb v0.12.3 h1:pBSGx9Tq67pBOTLmxNuirNTeB8Vjmf886Kx+8Y+8shw= +github.com/denisenkom/go-mssqldb v0.12.3/go.mod h1:k0mtMFOnU+AihqFxPMiF05rtiDrorD1Vrm1KEz5hxDo= github.com/docker/cli v23.0.1+incompatible h1:LRyWITpGzl2C9e9uGxzisptnxAn1zfZKXy13Ul2Q5oM= github.com/docker/cli v23.0.1+incompatible/go.mod h1:JLrzqnKDaYBop7H2jaqPtU4hHvMKP+vjCwu2uszcLI8= github.com/docker/docker v24.0.9+incompatible h1:HPGzNmwfLZWdxHqK9/II92pyi1EpYKsAqcl4G0Of9v0= @@ -122,6 +125,7 @@ github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9 github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= +github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= @@ -184,7 +188,9 @@ github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzq github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe h1:lXe2qZdvpiX5WZkZR4hgp4KJVfY3nMkvmwbVkpv1rVY= +github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A= +github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EOqtpKwwwHI= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= @@ -217,6 +223,7 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-tpm v0.9.0 h1:sQF6YqWMi+SCXpsmS3fd21oPy/vSddwZry4JnmltHVk= github.com/google/go-tpm v0.9.0/go.mod h1:FkNVkc6C+IsvDI9Jw1OveJmxGZUUaKxtrpOS47QWKfU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= @@ -367,6 +374,7 @@ github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfn github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= @@ -855,6 +863,7 @@ gopkg.in/yaml.v3 v3.0.0/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gotest.tools/v3 v3.3.0 h1:MfDY1b1/0xN1CyMlQDac0ziEy9zJQd9CXBRRDHw2jJo= +gotest.tools/v3 v3.3.0/go.mod h1:Mcr9QNxkg0uMvy/YElmo4SpXgJKWgQvYrT7Kw5RzJ1A= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= diff --git a/backend/handler/passlink.go b/backend/handler/passlink.go new file mode 100644 index 000000000..183b65eb9 --- /dev/null +++ b/backend/handler/passlink.go @@ -0,0 +1,523 @@ +package handler + +import ( + "errors" + "fmt" + "net/http" + "net/url" + "strings" + "time" + + "github.com/gobuffalo/pop/v6" + "github.com/gofrs/uuid" + "github.com/labstack/echo/v4" + "github.com/lestrrat-go/jwx/v2/jwt" + "github.com/rs/zerolog/log" + "github.com/sethvargo/go-limiter" + auditlog "github.com/teamhanko/hanko/backend/audit_log" + "github.com/teamhanko/hanko/backend/config" + "github.com/teamhanko/hanko/backend/crypto" + "github.com/teamhanko/hanko/backend/dto" + "github.com/teamhanko/hanko/backend/dto/webhook" + "github.com/teamhanko/hanko/backend/mail" + "github.com/teamhanko/hanko/backend/persistence" + "github.com/teamhanko/hanko/backend/persistence/models" + "github.com/teamhanko/hanko/backend/rate_limiter" + "github.com/teamhanko/hanko/backend/session" + "github.com/teamhanko/hanko/backend/webhooks/events" + "github.com/teamhanko/hanko/backend/webhooks/utils" + "golang.org/x/crypto/bcrypt" + "gopkg.in/gomail.v2" +) + +// TODO: garbage collect passlinks + +type PasslinkHandler struct { + mailer mail.Mailer + renderer *mail.Renderer + passlinkGenerator crypto.PasslinkGenerator + persister persistence.Persister + emailConfig config.Email + serviceConfig config.Service + URL string + TTL int + Strictness models.PasslinkStrictness + sessionManager session.Manager + cfg *config.Config + auditLogger auditlog.Logger + rateLimiter limiter.Store +} + +func NewPasslinkHandler(cfg *config.Config, persister persistence.Persister, sessionManager session.Manager, mailer mail.Mailer, auditLogger auditlog.Logger) (*PasslinkHandler, error) { + renderer, err := mail.NewRenderer() + if err != nil { + return nil, fmt.Errorf("failed to create new renderer: %w", err) + } + var rateLimiter limiter.Store + if cfg.RateLimiter.Enabled { + rateLimiter = rate_limiter.NewRateLimiter(cfg.RateLimiter, cfg.RateLimiter.PasslinkLimits) + } + return &PasslinkHandler{ + mailer: mailer, + renderer: renderer, + passlinkGenerator: crypto.NewPasslinkGenerator(), + persister: persister, + emailConfig: cfg.Passlink.Email, + serviceConfig: cfg.Service, + URL: cfg.Passlink.URL, + TTL: cfg.Passlink.TTL, + Strictness: cfg.Passlink.Strictness, + sessionManager: sessionManager, + cfg: cfg, + auditLogger: auditLogger, + rateLimiter: rateLimiter, + }, nil +} + +func (h *PasslinkHandler) Init(c echo.Context) error { + + var body dto.PasslinkInitRequest + if err := (&echo.DefaultBinder{}).BindBody(c, &body); err != nil { + return dto.ToHttpError(err) + } + + if err := c.Validate(body); err != nil { + return dto.ToHttpError(err) + } + + userId, err := uuid.FromString(body.UserID) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, "failed to parse userId as uuid").SetInternal(err) + } + + user, err := h.persister.GetUserPersister().Get(userId) + if err != nil { + return fmt.Errorf("failed to get user: %w", err) + } + if user == nil { + err = h.auditLogger.Create(c, models.AuditLogPasslinkLoginInitFailed, nil, fmt.Errorf("unknown user")) + if err != nil { + return fmt.Errorf("failed to create audit log: %w", err) + } + return echo.NewHTTPError(http.StatusBadRequest).SetInternal(errors.New("user not found")) + } + + if h.rateLimiter != nil { + err := rate_limiter.Limit(h.rateLimiter, userId, c) + if err != nil { + return err + } + } + + var emailId uuid.UUID + if body.EmailID != nil { + emailId, err = uuid.FromString(*body.EmailID) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, "failed to parse emailId as uuid").SetInternal(err) + } + } + + // Determine where to send the passlink + var email *models.Email + if !emailId.IsNil() { + // Send the passlink to the specified email address + email, err = h.persister.GetEmailPersister().Get(emailId) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, "failed to get email by id").SetInternal(err) + } + if email == nil { + return echo.NewHTTPError(http.StatusBadRequest, "the specified emailId is not available") + } + } else if e := user.Emails.GetPrimary(); e != nil { + // Send the passlink to the primary email address + email = e + } else { + // Workaround to support hanko element versions before v0.1.0-alpha: + // If user has no primary email, check if a cookie with an email id is present + emailIdCookie, err := c.Cookie("hanko_email_id") + if err != nil { + return fmt.Errorf("failed to get email id cookie: %w", err) + } + + if emailIdCookie != nil && emailIdCookie.Value != "" { + emailId, err = uuid.FromString(emailIdCookie.Value) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, "failed to parse emailId as uuid").SetInternal(err) + } + email, err = h.persister.GetEmailPersister().Get(emailId) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, "failed to get email by id").SetInternal(err) + } + if email == nil { + return echo.NewHTTPError(http.StatusBadRequest, "the specified emailId is not available") + } + } else { + // Can't determine email address to which the passlink should be sent to + return echo.NewHTTPError(http.StatusBadRequest, "an emailId needs to be specified") + } + } + + sessionToken := h.GetSessionToken(c) + if sessionToken != nil && sessionToken.Subject() != user.ID.String() { + // if the user is logged in and the requested user in the body does not match the user from the session then sending and finalizing passlinks is not allowed + return echo.NewHTTPError(http.StatusForbidden).SetInternal(errors.New("session.userId does not match requested userId")) + } + + if email.User != nil && email.User.ID.String() != user.ID.String() { + return echo.NewHTTPError(http.StatusForbidden).SetInternal(errors.New("email address is assigned to another user")) + } + + redirectPath := "/" + if strings.HasPrefix(body.RedirectPath, "/") { + redirectPath = body.RedirectPath + } + + now := time.Now().UTC() + id, err := uuid.NewV4() + if err != nil { + return fmt.Errorf("failed to create passlinkId: %w", err) + } + token, err := h.passlinkGenerator.Generate() + if err != nil { + return fmt.Errorf("failed to generate passlink: %w", err) + } + tokenHashed, err := bcrypt.GenerateFromPassword([]byte(token), 12) + if err != nil { + return fmt.Errorf("failed to hash passlink: %w", err) + } + + passlinkModel := models.Passlink{ + ID: id, + UserId: userId, + EmailID: email.ID, + Strictness: h.Strictness.String(), + IP: c.RealIP(), + TTL: h.TTL, + LoginCount: 0, + Reusable: false, + Token: string(tokenHashed), + CreatedAt: now, + UpdatedAt: now, + } + + redirectURL, err := h.createRedirectURL(c, id, token, redirectPath) + if err != nil { + return fmt.Errorf("failed to create passlink redirect URL: %w", err) + } + + err = h.persister.GetPasslinkPersister().Create(passlinkModel) + if err != nil { + return fmt.Errorf("failed to store passlink: %w", err) + } + + durationTTL := time.Duration(h.TTL) * time.Second + data := map[string]interface{}{ + "ServiceName": h.serviceConfig.Name, + "Token": token, + "URL": redirectURL, + "TTL": fmt.Sprintf("%.0f", durationTTL.Minutes()), + } + + lang := c.Request().Header.Get("Accept-Language") + subject := h.renderer.Translate(lang, "email_subject_login_passlink", data) + bodyPlain, err := h.renderer.Render("passlinkLoginTextMail", lang, data) + if err != nil { + return fmt.Errorf("failed to render email template: %w", err) + } + + webhookData := webhook.EmailSend{ + Subject: subject, + BodyPlain: bodyPlain, + ToEmailAddress: email.Address, + DeliveredByHanko: true, + AcceptLanguage: lang, + Type: webhook.EmailTypePasslink, + Data: webhook.PasslinkData{ + ServiceName: h.cfg.Service.Name, + Token: token, + URL: redirectURL, + TTL: h.TTL, + ValidUntil: passlinkModel.CreatedAt.Add(time.Duration(h.TTL) * time.Second).UTC().Unix(), + RedirectPath: redirectPath, + RetryLimit: 1, + Strictness: h.Strictness, + }, + } + + if h.cfg.EmailDelivery.Enabled { + message := gomail.NewMessage() + message.SetAddressHeader("To", email.Address, "") + message.SetAddressHeader("From", h.emailConfig.FromAddress, h.emailConfig.FromName) + + message.SetHeader("Subject", subject) + + message.SetBody("text/plain", bodyPlain) + + err = h.mailer.Send(message) + if err != nil { + return fmt.Errorf("failed to send passlink: %w", err) + } + + err = utils.TriggerWebhooks(c, events.EmailSend, webhookData) + + if err != nil { + log.Warn().Err(err).Msg("failed to trigger webhook") + } + } else { + webhookData.DeliveredByHanko = false + err = utils.TriggerWebhooks(c, events.EmailSend, webhookData) + + if err != nil { + return fmt.Errorf(fmt.Sprintf("failed to trigger webhook: %s", err)) + } + } + + err = h.auditLogger.Create(c, models.AuditLogPasslinkLoginInitSucceeded, user, nil) + if err != nil { + return fmt.Errorf("failed to create audit log: %w", err) + } + + // TODO: set cookie based on the passlink strictness + + return c.JSON(http.StatusOK, dto.PasslinkReturn{ + ID: id.String(), + CreatedAt: passlinkModel.CreatedAt, + UserID: userId.String(), + }) +} + +func (h *PasslinkHandler) Finish(c echo.Context) error { + startTime := time.Now().UTC() + var body dto.PasslinkFinishRequest + if err := (&echo.DefaultBinder{}).BindBody(c, &body); err != nil { + return dto.ToHttpError(err) + } + + if err := c.Validate(body); err != nil { + return dto.ToHttpError(err) + } + + passlinkID, err := uuid.FromString(body.ID) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, "failed to parse passlinkId as uuid").SetInternal(err) + } + + // only if an internal server error occurs the transaction should be rolled back + var businessError error + transactionError := h.persister.Transaction(func(tx *pop.Connection) error { + passlinkPersister := h.persister.GetPasslinkPersisterWithConnection(tx) + userPersister := h.persister.GetUserPersisterWithConnection(tx) + emailPersister := h.persister.GetEmailPersisterWithConnection(tx) + primaryEmailPersister := h.persister.GetPrimaryEmailPersisterWithConnection(tx) + passlink, err := passlinkPersister.Get(passlinkID) + if err != nil { + return fmt.Errorf("failed to get passlink: %w", err) + } + if passlink == nil { + err = h.auditLogger.CreateWithConnection(tx, c, models.AuditLogPasslinkLoginFinalFailed, nil, fmt.Errorf("unknown passlink")) + if err != nil { + return fmt.Errorf("failed to create audit log: %w", err) + } + businessError = echo.NewHTTPError(http.StatusUnauthorized, "passlink not found") + return nil + } + + user, err := userPersister.Get(passlink.UserId) + if err != nil { + return fmt.Errorf("failed to get user: %w", err) + } + + lastVerificationTime := passlink.CreatedAt.Add(time.Duration(passlink.TTL) * time.Second) + if lastVerificationTime.Before(startTime) { + err = passlinkPersister.Delete(*passlink) + if err != nil { + return fmt.Errorf("failed to delete passlink: %w", err) + } + + err = h.auditLogger.CreateWithConnection(tx, c, models.AuditLogPasslinkLoginFinalFailed, user, fmt.Errorf("timed out passlink: createdAt: %s -> lastVerificationTime: %s", passlink.CreatedAt, lastVerificationTime)) + if err != nil { + return fmt.Errorf("failed to create audit log: %w", err) + } + businessError = echo.NewHTTPError(http.StatusRequestTimeout, "passlink request timed out").SetInternal(fmt.Errorf("createdAt: %s -> lastVerificationTime: %s", passlink.CreatedAt, lastVerificationTime)) // TODO: maybe we should use BadRequest, because RequestTimeout might be too technical and can refer to different error + return nil + } + + // TODO: handle passlink strictness + // TODO: check IP address if strictness is device + + err = bcrypt.CompareHashAndPassword([]byte(passlink.Token), []byte(body.Token)) + if err != nil { + err = passlinkPersister.Delete(*passlink) + if err != nil { + return fmt.Errorf("failed to delete passlink: %w", err) + } + err = h.auditLogger.CreateWithConnection(tx, c, models.AuditLogPasslinkLoginFinalFailed, user, fmt.Errorf("invalid token")) + if err != nil { + return fmt.Errorf("failed to create audit log: %w", err) + } + businessError = echo.NewHTTPError(http.StatusForbidden, "invalid token") + return nil + } + + // a passlink is valid only once, except it is explicitly marked as reusable + // a reusable passlink token is a security risk, but might be useful to authenticate a again and again from same link (e.g. link in a newsletter) + if passlink.Reusable { + passlink.LoginCount += 1 + + err = passlinkPersister.Update(*passlink) + if err != nil { + return fmt.Errorf("failed to update passlink: %w", err) + } + } else { + err = passlinkPersister.Delete(*passlink) + if err != nil { + return fmt.Errorf("failed to delete passlink: %w", err) + } + } + + if passlink.Email.User != nil && passlink.Email.User.ID.String() != user.ID.String() { + return echo.NewHTTPError(http.StatusForbidden, "email address has been claimed by another user") + } + + emailExistsForUser := false + for _, email := range user.Emails { + emailExistsForUser = email.ID == passlink.Email.ID + if emailExistsForUser { + break + } + } + + existingSessionToken := h.GetSessionToken(c) + // return forbidden when none of these cases matches + if !((existingSessionToken == nil && emailExistsForUser) || // normal login: when user logs in and the email used is associated with the user + (existingSessionToken == nil && len(user.Emails) == 0) || // register: when user register and the user has no emails + (existingSessionToken != nil && existingSessionToken.Subject() == user.ID.String())) { // add email through profile: when the user adds an email while having a session and the userIds requested in the passlink and the one in the session matches + return echo.NewHTTPError(http.StatusForbidden).SetInternal(errors.New("passlink finalization not allowed")) + } + + wasUnverified := false + hasEmails := len(user.Emails) >= 1 // check if we need to trigger a UserCreate webhook or a UserEmailCreate one + + if !passlink.Email.Verified { + wasUnverified = true + + // Update email verified status and assign the email address to the user. + passlink.Email.Verified = true + passlink.Email.UserID = &user.ID + + err = emailPersister.Update(passlink.Email) + if err != nil { + return fmt.Errorf("failed to update the email verified status: %w", err) + } + + if user.Emails.GetPrimary() == nil { + primaryEmail := models.NewPrimaryEmail(passlink.Email.ID, user.ID) + err = primaryEmailPersister.Create(*primaryEmail) + if err != nil { + return fmt.Errorf("failed to create primary email: %w", err) + } + + user.Emails = models.Emails{passlink.Email} + user.Emails.SetPrimary(primaryEmail) + err = h.auditLogger.CreateWithConnection(tx, c, models.AuditLogPrimaryEmailChanged, user, nil) + if err != nil { + return fmt.Errorf("failed to create audit log: %w", err) + } + } + + err = h.auditLogger.CreateWithConnection(tx, c, models.AuditLogEmailVerified, user, nil) + if err != nil { + return fmt.Errorf("failed to create audit log: %w", err) + } + } + + var emailJwt *dto.EmailJwt + if e := user.Emails.GetPrimary(); e != nil { + emailJwt = dto.JwtFromEmailModel(e) + } + + token, err := h.sessionManager.GenerateJWT(passlink.UserId, emailJwt) + if err != nil { + return fmt.Errorf("failed to generate jwt: %w", err) + } + + cookie, err := h.sessionManager.GenerateCookie(token) + if err != nil { + return fmt.Errorf("failed to create session token: %w", err) + } + + c.Response().Header().Set("X-Session-Lifetime", fmt.Sprintf("%d", cookie.MaxAge)) + + if h.cfg.Session.EnableAuthTokenHeader { + c.Response().Header().Set("X-Auth-Token", token) + } else { + c.SetCookie(cookie) + } + + err = h.auditLogger.CreateWithConnection(tx, c, models.AuditLogPasslinkLoginFinalSucceeded, user, nil) + if err != nil { + return fmt.Errorf("failed to create audit log: %w", err) + } + + // notify about email verification result. Last step to prevent a trigger and rollback scenario + if h.cfg.Emails.RequireVerification && wasUnverified { + var evt events.Event + + if hasEmails { + evt = events.UserEmailCreate + } else { + evt = events.UserCreate + } + + utils.NotifyUserChange(c, tx, h.persister, evt, user.ID) + } + + return c.JSON(http.StatusOK, dto.PasslinkReturn{ + ID: passlink.ID.String(), + CreatedAt: passlink.CreatedAt, + UserID: passlink.UserId.String(), + }) + }) + + if businessError != nil { + return businessError + } + + return transactionError +} + +func (h *PasslinkHandler) GetSessionToken(c echo.Context) jwt.Token { + var token jwt.Token + sessionCookie, _ := c.Cookie("hanko") + // we don't need to check the error, because when the cookie can not be found, the user is not logged in + if sessionCookie != nil { + token, _ = h.sessionManager.Verify(sessionCookie.Value) + // we don't need to check the error, because when the token is not returned, the user is not logged in + } + + if token == nil { + authorizationHeader := c.Request().Header.Get("Authorization") + sessionToken := strings.TrimPrefix(authorizationHeader, "Bearer") + if strings.TrimSpace(sessionToken) != "" { + token, _ = h.sessionManager.Verify(sessionToken) + } + } + + return token +} + +func (h *PasslinkHandler) createRedirectURL(c echo.Context, id uuid.UUID, token string, path string) (string, error) { + redirect, err := url.Parse(h.URL) + if err != nil { + return "", fmt.Errorf("failed to parse URL for passlink finalization: %w", err) + } + + redirect.Path = path + + queryValues := redirect.Query() + queryValues.Add("plid", id.String()) + queryValues.Add("pltk", token) + redirect.RawQuery = queryValues.Encode() + + return redirect.String(), nil +} diff --git a/backend/handler/password.go b/backend/handler/password.go index bb638a7af..a9d2f3212 100644 --- a/backend/handler/password.go +++ b/backend/handler/password.go @@ -3,12 +3,15 @@ package handler import ( "errors" "fmt" + "net/http" + "unicode/utf8" + "github.com/gobuffalo/pop/v6" "github.com/gofrs/uuid" "github.com/labstack/echo/v4" "github.com/lestrrat-go/jwx/v2/jwt" "github.com/sethvargo/go-limiter" - "github.com/teamhanko/hanko/backend/audit_log" + auditlog "github.com/teamhanko/hanko/backend/audit_log" "github.com/teamhanko/hanko/backend/config" "github.com/teamhanko/hanko/backend/dto" "github.com/teamhanko/hanko/backend/persistence" @@ -16,8 +19,6 @@ import ( "github.com/teamhanko/hanko/backend/rate_limiter" "github.com/teamhanko/hanko/backend/session" "golang.org/x/crypto/bcrypt" - "net/http" - "unicode/utf8" ) type PasswordHandler struct { diff --git a/backend/handler/public_router.go b/backend/handler/public_router.go index ed750d030..0d89f4457 100644 --- a/backend/handler/public_router.go +++ b/backend/handler/public_router.go @@ -2,10 +2,11 @@ package handler import ( "fmt" + "github.com/labstack/echo/v4" "github.com/labstack/echo/v4/middleware" "github.com/sethvargo/go-limiter/httplimit" - "github.com/teamhanko/hanko/backend/audit_log" + auditlog "github.com/teamhanko/hanko/backend/audit_log" "github.com/teamhanko/hanko/backend/config" "github.com/teamhanko/hanko/backend/crypto/jwk" "github.com/teamhanko/hanko/backend/dto" @@ -109,11 +110,24 @@ func NewPublicRouter(cfg *config.Config, persister persistence.Persister, promet if err != nil { panic(fmt.Errorf("failed to create public webauthn handler: %w", err)) } + passcodeHandler, err := NewPasscodeHandler(cfg, persister, sessionManager, mailer, auditLogger) if err != nil { panic(fmt.Errorf("failed to create public passcode handler: %w", err)) } + if cfg.Passlink.Enabled { + passlinkHandler, err := NewPasslinkHandler(cfg, persister, sessionManager, mailer, auditLogger) + if err != nil { + panic(fmt.Errorf("failed to create public passlink handler: %w", err)) + } + + passlink := g.Group("/passlink") + passlinkLogin := passlink.Group("/login", webhookMiddlware) + passlinkLogin.POST("/initialize", passlinkHandler.Init).Name = "passlink_login_initialize" + passlinkLogin.POST("/finalize", passlinkHandler.Finish).Name = "passlink_login_finalize" + } + health := e.Group("/health") health.GET("/alive", healthHandler.Alive) health.GET("/ready", healthHandler.Ready) diff --git a/backend/mail/locales/passcode.en.yaml b/backend/mail/locales/passcode.en.yaml index 446eb1fcb..e629d1f1d 100644 --- a/backend/mail/locales/passcode.en.yaml +++ b/backend/mail/locales/passcode.en.yaml @@ -1,9 +1,18 @@ login_text: description: "The sign in content of the text email." other: "Enter the following passcode on your login screen:" +passlink_login_text: + description: "The sign in content of the text email." + other: "Click the link below to securely log into your account at {{ .ServiceName }}:" ttl_text: description: "The length how long the passcode is valid." other: "The passcode is valid for {{ .TTL }} minutes." +passlink_ttl_text: + description: "The length how long the passcode is valid." + other: "This link is valid for {{ .TTL }} minutes and can only be used once. If you did not request this link, please ignore this email." email_subject_login: description: "" other: "Use passcode {{ .Code }} to sign in to {{ .ServiceName }}" +email_subject_login_passlink: + description: "" + other: "Confirm your sign in request to {{ .ServiceName }}" diff --git a/backend/mail/templates/passlink-login.tmpl b/backend/mail/templates/passlink-login.tmpl new file mode 100644 index 000000000..7c21a7c1f --- /dev/null +++ b/backend/mail/templates/passlink-login.tmpl @@ -0,0 +1,7 @@ +{{define "passlinkLoginTextMail"}} +{{t "passlink_login_text" .}} + +{{ .URL }} + +{{t "passlink_ttl_text" .}} +{{end}} diff --git a/backend/persistence/migrations/20240522233121_create_passlinks.down.fizz b/backend/persistence/migrations/20240522233121_create_passlinks.down.fizz new file mode 100644 index 000000000..6056dc5ef --- /dev/null +++ b/backend/persistence/migrations/20240522233121_create_passlinks.down.fizz @@ -0,0 +1 @@ +drop_table("passlinks") diff --git a/backend/persistence/migrations/20240522233121_create_passlinks.up.fizz b/backend/persistence/migrations/20240522233121_create_passlinks.up.fizz new file mode 100644 index 000000000..0c01c14b4 --- /dev/null +++ b/backend/persistence/migrations/20240522233121_create_passlinks.up.fizz @@ -0,0 +1,14 @@ +create_table("passlinks") { + t.Column("id", "uuid", {primary: true}) + t.Column("user_id", "uuid", {}) + t.Column("email_id", "uuid", {null: true}) + t.Column("ttl", "integer", {}) + t.Column("strictness", "string", {}) + t.Column("ip", "string", {}) + t.Column("token", "string", {}) + t.Column("login_count", "integer", {}) + t.Column("reusable", "bool", {}) + t.Timestamps() + t.ForeignKey("user_id", {"users": ["id"]}, {"on_delete": "cascade", "on_update": "cascade"}) + t.ForeignKey("email_id", {"emails": ["id"]}, {"on_delete": "cascade", "on_update": "cascade"}) +} diff --git a/backend/persistence/models/audit_log.go b/backend/persistence/models/audit_log.go index a8a6d04de..bbc24e58c 100644 --- a/backend/persistence/models/audit_log.go +++ b/backend/persistence/models/audit_log.go @@ -1,8 +1,9 @@ package models import ( - "github.com/gofrs/uuid" "time" + + "github.com/gofrs/uuid" ) type AuditLog struct { @@ -36,6 +37,11 @@ var ( AuditLogPasscodeLoginFinalSucceeded AuditLogType = "passcode_login_final_succeeded" AuditLogPasscodeLoginFinalFailed AuditLogType = "passcode_login_final_failed" + AuditLogPasslinkLoginInitSucceeded AuditLogType = "passlink_login_init_succeeded" + AuditLogPasslinkLoginInitFailed AuditLogType = "passlink_login_init_failed" + AuditLogPasslinkLoginFinalSucceeded AuditLogType = "passlink_login_final_succeeded" + AuditLogPasslinkLoginFinalFailed AuditLogType = "passlink_login_final_failed" + AuditLogWebAuthnRegistrationInitSucceeded AuditLogType = "webauthn_registration_init_succeeded" AuditLogWebAuthnRegistrationInitFailed AuditLogType = "webauthn_registration_init_failed" AuditLogWebAuthnRegistrationFinalSucceeded AuditLogType = "webauthn_registration_final_succeeded" diff --git a/backend/persistence/models/passlink.go b/backend/persistence/models/passlink.go new file mode 100644 index 000000000..6548d490e --- /dev/null +++ b/backend/persistence/models/passlink.go @@ -0,0 +1,68 @@ +package models + +import ( + "time" + + "github.com/gobuffalo/pop/v6" + "github.com/gobuffalo/validate/v3" + "github.com/gobuffalo/validate/v3/validators" + "github.com/gofrs/uuid" +) + +// Passlink is used by pop to map your passlink database table to your go code. +type Passlink struct { + ID uuid.UUID `db:"id"` + UserId uuid.UUID `db:"user_id"` + EmailID uuid.UUID `db:"email_id"` + TTL int `db:"ttl"` // in seconds + Strictness string `db:"strictness"` + IP string `db:"ip"` + Token string `db:"token"` + LoginCount int `db:"login_count"` + Reusable bool `db:"reusable"` // by default a passlink can only used once, if reusable is set true, it can be used to authenticate the user multiple times by clicking the same link (e.g. in a newsletter) + CreatedAt time.Time `db:"created_at"` + UpdatedAt time.Time `db:"updated_at"` + Email Email `belongs_to:"email"` +} + +// Validate gets run every time you call a "pop.Validate*" (pop.ValidateAndSave, pop.ValidateAndCreate, pop.ValidateAndUpdate) method. +func (passlink *Passlink) Validate(tx *pop.Connection) (*validate.Errors, error) { + tests := []validate.Validator{ + &validators.UUIDIsPresent{Name: "ID", Field: passlink.ID}, + &validators.UUIDIsPresent{Name: "UserID", Field: passlink.UserId}, + &validators.StringLengthInRange{Name: "Token", Field: passlink.Token, Min: 16}, + &validators.TimeIsPresent{Name: "CreatedAt", Field: passlink.CreatedAt}, + &validators.TimeIsPresent{Name: "UpdatedAt", Field: passlink.UpdatedAt}, + } + return validate.Validate(tests...), nil +} + +type PasslinkStrictness string + +const ( + PasslinkStrictnessBrowser PasslinkStrictness = "browser" // only allow passlink usage in the same browser based on a session cookie + PasslinkStrictnessDevice PasslinkStrictness = "device" // only allow passlink usage on the same device based on the ip address + PasslinkStrictnessNone PasslinkStrictness = "" // no strictness, allow passlink usage from any device +) + +// AllPasslinkStrictness represents the list of all valid types +var AllPasslinkStrictness = []PasslinkStrictness{ + PasslinkStrictnessBrowser, + PasslinkStrictnessDevice, + PasslinkStrictnessNone, +} + +// String returns the string representation +func (ps PasslinkStrictness) String() string { + return string(ps) +} + +// Valid check if the given value is included +func (ps PasslinkStrictness) Valid() bool { + for _, v := range AllPasslinkStrictness { + if v == ps { + return true + } + } + return false +} diff --git a/backend/persistence/passlink_persister.go b/backend/persistence/passlink_persister.go new file mode 100644 index 000000000..4a07d4cea --- /dev/null +++ b/backend/persistence/passlink_persister.go @@ -0,0 +1,74 @@ +package persistence + +import ( + "database/sql" + "errors" + "fmt" + + "github.com/gobuffalo/pop/v6" + "github.com/gofrs/uuid" + "github.com/teamhanko/hanko/backend/persistence/models" +) + +type PasslinkPersister interface { + Get(uuid.UUID) (*models.Passlink, error) + Create(models.Passlink) error + Update(models.Passlink) error + Delete(models.Passlink) error +} + +type passlinkPersister struct { + db *pop.Connection +} + +func NewPasslinkPersister(db *pop.Connection) PasslinkPersister { + return &passlinkPersister{db: db} +} + +func (p *passlinkPersister) Get(id uuid.UUID) (*models.Passlink, error) { + passlink := models.Passlink{} + err := p.db.EagerPreload("Email.User").Find(&passlink, id) + if err != nil && errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + if err != nil { + return nil, fmt.Errorf("failed to get passlink: %w", err) + } + + return &passlink, nil +} + +func (p *passlinkPersister) Create(passlink models.Passlink) error { + vErr, err := p.db.ValidateAndCreate(&passlink) + if err != nil { + return fmt.Errorf("failed to store passlink: %w", err) + } + + if vErr != nil && vErr.HasAny() { + return fmt.Errorf("passlink object validation failed: %w", vErr) + } + + return nil +} + +func (p *passlinkPersister) Update(passlink models.Passlink) error { + vErr, err := p.db.ValidateAndUpdate(&passlink) + if err != nil { + return fmt.Errorf("failed to update passlink: %w", err) + } + + if vErr != nil && vErr.HasAny() { + return fmt.Errorf("passlink object validation failed: %w", vErr) + } + + return nil +} + +func (p *passlinkPersister) Delete(passlink models.Passlink) error { + err := p.db.Destroy(&passlink) + if err != nil { + return fmt.Errorf("failed to delete passlink: %w", err) + } + + return nil +} diff --git a/backend/persistence/persister.go b/backend/persistence/persister.go index 47e2e7f83..d12542653 100644 --- a/backend/persistence/persister.go +++ b/backend/persistence/persister.go @@ -2,6 +2,7 @@ package persistence import ( "embed" + "github.com/gobuffalo/pop/v6" "github.com/teamhanko/hanko/backend/config" ) @@ -23,6 +24,8 @@ type Persister interface { GetUserPersisterWithConnection(tx *pop.Connection) UserPersister GetPasscodePersister() PasscodePersister GetPasscodePersisterWithConnection(tx *pop.Connection) PasscodePersister + GetPasslinkPersister() PasslinkPersister + GetPasslinkPersisterWithConnection(tx *pop.Connection) PasslinkPersister GetPasswordCredentialPersister() PasswordCredentialPersister GetPasswordCredentialPersisterWithConnection(tx *pop.Connection) PasswordCredentialPersister GetWebauthnCredentialPersister() WebauthnCredentialPersister @@ -142,6 +145,14 @@ func (p *persister) GetPasscodePersisterWithConnection(tx *pop.Connection) Passc return NewPasscodePersister(tx) } +func (p *persister) GetPasslinkPersister() PasslinkPersister { + return NewPasslinkPersister(p.DB) +} + +func (p *persister) GetPasslinkPersisterWithConnection(tx *pop.Connection) PasslinkPersister { + return NewPasslinkPersister(tx) +} + func (p *persister) GetPasswordCredentialPersister() PasswordCredentialPersister { return NewPasswordCredentialPersister(p.DB) } diff --git a/backend/test/passlink_persister.go b/backend/test/passlink_persister.go new file mode 100644 index 000000000..be5c8d7bb --- /dev/null +++ b/backend/test/passlink_persister.go @@ -0,0 +1,54 @@ +package test + +import ( + "github.com/gofrs/uuid" + "github.com/teamhanko/hanko/backend/persistence" + "github.com/teamhanko/hanko/backend/persistence/models" +) + +func NewPasslinkPersister(init []models.Passlink) persistence.PasslinkPersister { + return &passlinkPersister{append([]models.Passlink{}, init...)} +} + +type passlinkPersister struct { + passlinks []models.Passlink +} + +func (p *passlinkPersister) Get(id uuid.UUID) (*models.Passlink, error) { + var found *models.Passlink + for _, data := range p.passlinks { + if data.ID == id { + d := data + found = &d + } + } + return found, nil +} + +func (p *passlinkPersister) Create(passlink models.Passlink) error { + p.passlinks = append(p.passlinks, passlink) + return nil +} + +func (p *passlinkPersister) Update(passlink models.Passlink) error { + for i, data := range p.passlinks { + if data.ID == passlink.ID { + p.passlinks[i] = passlink + } + } + return nil +} + +func (p *passlinkPersister) Delete(passlink models.Passlink) error { + index := -1 + for i, data := range p.passlinks { + if data.ID == passlink.ID { + index = i + } + } + if index > -1 { + p.passlinks = append(p.passlinks[:index], p.passlinks[index+1:]...) + } + + return nil +} diff --git a/backend/test/persister.go b/backend/test/persister.go index 603c59a01..9c485b358 100644 --- a/backend/test/persister.go +++ b/backend/test/persister.go @@ -10,6 +10,7 @@ import ( func NewPersister( user []models.User, passcodes []models.Passcode, + passlinks []models.Passlink, jwks []models.Jwk, credentials []models.WebauthnCredential, sessionData []models.WebauthnSessionData, @@ -27,6 +28,7 @@ func NewPersister( return &persister{ userPersister: NewUserPersister(user), passcodePersister: NewPasscodePersister(passcodes), + passlinkPersister: NewPasslinkPersister(passlinks), jwkPersister: NewJwkPersister(jwks), webauthnCredentialPersister: NewWebauthnCredentialPersister(credentials), webauthnSessionDataPersister: NewWebauthnSessionDataPersister(sessionData), @@ -45,6 +47,7 @@ func NewPersister( type persister struct { userPersister persistence.UserPersister passcodePersister persistence.PasscodePersister + passlinkPersister persistence.PasslinkPersister jwkPersister persistence.JwkPersister webauthnCredentialPersister persistence.WebauthnCredentialPersister webauthnSessionDataPersister persistence.WebauthnSessionDataPersister @@ -91,6 +94,14 @@ func (p *persister) GetPasscodePersisterWithConnection(tx *pop.Connection) persi return p.passcodePersister } +func (p *persister) GetPasslinkPersister() persistence.PasslinkPersister { + return p.passlinkPersister +} + +func (p *persister) GetPasslinkPersisterWithConnection(tx *pop.Connection) persistence.PasslinkPersister { + return p.passlinkPersister +} + func (p *persister) GetWebauthnCredentialPersister() persistence.WebauthnCredentialPersister { return p.webauthnCredentialPersister }