diff --git a/.gitignore b/.gitignore index f24947dc5..0178cf5fe 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,4 @@ node_modules/ .virtualgo boilerplate/lyft/end2end/tmp +dist diff --git a/admin.db b/admin.db new file mode 100644 index 000000000..b858ef3a8 Binary files /dev/null and b/admin.db differ diff --git a/cmd/entrypoints/clusterresource.go b/cmd/entrypoints/clusterresource.go index ab63b3377..4091ce25d 100644 --- a/cmd/entrypoints/clusterresource.go +++ b/cmd/entrypoints/clusterresource.go @@ -3,22 +3,12 @@ package entrypoints import ( "context" - "github.com/flyteorg/flyteadmin/pkg/repositories/errors" + errors2 "github.com/pkg/errors" - "github.com/flyteorg/flyteadmin/pkg/clusterresource/impl" - "github.com/flyteorg/flyteadmin/pkg/clusterresource/interfaces" - execClusterIfaces "github.com/flyteorg/flyteadmin/pkg/executioncluster/interfaces" - "github.com/flyteorg/flyteadmin/pkg/manager/impl/resources" - "github.com/flyteorg/flyteadmin/pkg/repositories" "github.com/flyteorg/flytestdlib/promutils" - "github.com/flyteorg/flyteidl/clients/go/admin" - "github.com/flyteorg/flyteadmin/pkg/clusterresource" - "github.com/flyteorg/flyteadmin/pkg/config" - executioncluster "github.com/flyteorg/flyteadmin/pkg/executioncluster/impl" "github.com/flyteorg/flyteadmin/pkg/runtime" - runtimeInterfaces "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces" "github.com/flyteorg/flytestdlib/logger" "github.com/spf13/cobra" @@ -30,74 +20,40 @@ var parentClusterResourceCmd = &cobra.Command{ Short: "This command administers the ClusterResourceController. Please choose a subcommand.", } -func getClusterResourceController(ctx context.Context, scope promutils.Scope, configuration runtimeInterfaces.Configuration) clusterresource.Controller { - initializationErrorCounter := scope.MustNewCounter( - "flyteclient_initialization_error", - "count of errors encountered initializing a flyte client from kube config") - var listTargetsProvider execClusterIfaces.ListTargetsInterface - var err error - if len(configuration.ClusterConfiguration().GetClusterConfigs()) == 0 { - serverConfig := config.GetConfig() - listTargetsProvider, err = executioncluster.NewInCluster(initializationErrorCounter, serverConfig.KubeConfig, serverConfig.Master) - } else { - listTargetsProvider, err = executioncluster.NewListTargets(initializationErrorCounter, executioncluster.NewExecutionTargetProvider(), configuration.ClusterConfiguration()) - } - if err != nil { - panic(err) - } - - var adminDataProvider interfaces.FlyteAdminDataProvider - if configuration.ClusterResourceConfiguration().IsStandaloneDeployment() { - clientSet, err := admin.ClientSetBuilder().WithConfig(admin.GetConfig(ctx)).Build(ctx) - if err != nil { - panic(err) - } - adminDataProvider = impl.NewAdminServiceDataProvider(clientSet.AdminClient()) - } else { - dbConfig := runtime.NewConfigurationProvider().ApplicationConfiguration().GetDbConfig() - logConfig := logger.GetConfig() - - db, err := repositories.GetDB(ctx, dbConfig, logConfig) - if err != nil { - logger.Fatal(ctx, err) - } - dbScope := scope.NewSubScope("db") - - repo := repositories.NewGormRepo( - db, errors.NewPostgresErrorTransformer(dbScope.NewSubScope("errors")), dbScope) - - adminDataProvider = impl.NewDatabaseAdminDataProvider(repo, configuration, resources.NewResourceManager(repo, configuration.ApplicationConfiguration())) - } - - return clusterresource.NewClusterResourceController(adminDataProvider, listTargetsProvider, scope) -} - var controllerRunCmd = &cobra.Command{ Use: "run", Short: "This command will start a cluster resource controller to periodically sync cluster resources", - Run: func(cmd *cobra.Command, args []string) { + RunE: func(cmd *cobra.Command, args []string) error { ctx := context.Background() configuration := runtime.NewConfigurationProvider() scope := promutils.NewScope(configuration.ApplicationConfiguration().GetTopLevelConfig().MetricsScope).NewSubScope("clusterresource") - clusterResourceController := getClusterResourceController(ctx, scope, configuration) + clusterResourceController, err := clusterresource.NewClusterResourceControllerFromConfig(ctx, scope, configuration) + if err != nil { + return err + } clusterResourceController.Run() logger.Infof(ctx, "ClusterResourceController started running successfully") + return nil }, } var controllerSyncCmd = &cobra.Command{ Use: "sync", Short: "This command will sync cluster resources", - Run: func(cmd *cobra.Command, args []string) { + RunE: func(cmd *cobra.Command, args []string) error { ctx := context.Background() configuration := runtime.NewConfigurationProvider() scope := promutils.NewScope(configuration.ApplicationConfiguration().GetTopLevelConfig().MetricsScope).NewSubScope("clusterresource") - clusterResourceController := getClusterResourceController(ctx, scope, configuration) - err := clusterResourceController.Sync(ctx) + clusterResourceController, err := clusterresource.NewClusterResourceControllerFromConfig(ctx, scope, configuration) + if err != nil { + return err + } + err = clusterResourceController.Sync(ctx) if err != nil { - logger.Fatalf(ctx, "Failed to sync cluster resources [%+v]", err) + return errors2.Wrap(err, "Failed to sync cluster resources ") } logger.Infof(ctx, "ClusterResourceController synced successfully") + return nil }, } diff --git a/cmd/entrypoints/migrate.go b/cmd/entrypoints/migrate.go index df1393f6c..030a39390 100644 --- a/cmd/entrypoints/migrate.go +++ b/cmd/entrypoints/migrate.go @@ -3,12 +3,8 @@ package entrypoints import ( "context" - "github.com/flyteorg/flyteadmin/pkg/repositories" - "github.com/flyteorg/flyteadmin/pkg/repositories/config" - "github.com/flyteorg/flyteadmin/pkg/runtime" - "github.com/flyteorg/flytestdlib/logger" + "github.com/flyteorg/flyteadmin/pkg/server" - "github.com/go-gormigrate/gormigrate/v2" "github.com/spf13/cobra" _ "gorm.io/driver/postgres" // Required to import database driver. ) @@ -22,35 +18,9 @@ var parentMigrateCmd = &cobra.Command{ var migrateCmd = &cobra.Command{ Use: "run", Short: "This command will run all the migrations for the database", - Run: func(cmd *cobra.Command, args []string) { + RunE: func(cmd *cobra.Command, args []string) error { ctx := context.Background() - configuration := runtime.NewConfigurationProvider() - databaseConfig := configuration.ApplicationConfiguration().GetDbConfig() - logConfig := logger.GetConfig() - - db, err := repositories.GetDB(ctx, databaseConfig, logConfig) - if err != nil { - logger.Fatal(ctx, err) - } - sqlDB, err := db.DB() - if err != nil { - logger.Fatal(ctx, err) - } - - defer func(deferCtx context.Context) { - if err = sqlDB.Close(); err != nil { - logger.Fatal(deferCtx, err) - } - }(ctx) - - if err = sqlDB.Ping(); err != nil { - logger.Fatal(ctx, err) - } - m := gormigrate.New(db, gormigrate.DefaultOptions, config.Migrations) - if err = m.Migrate(); err != nil { - logger.Fatalf(ctx, "Could not migrate: %v", err) - } - logger.Infof(ctx, "Migration ran successfully") + return server.Migrate(ctx) }, } @@ -58,36 +28,9 @@ var migrateCmd = &cobra.Command{ var rollbackCmd = &cobra.Command{ Use: "rollback", Short: "This command will rollback one migration", - Run: func(cmd *cobra.Command, args []string) { + RunE: func(cmd *cobra.Command, args []string) error { ctx := context.Background() - configuration := runtime.NewConfigurationProvider() - databaseConfig := configuration.ApplicationConfiguration().GetDbConfig() - logConfig := logger.GetConfig() - - db, err := repositories.GetDB(ctx, databaseConfig, logConfig) - if err != nil { - logger.Fatal(ctx, err) - } - sqlDB, err := db.DB() - if err != nil { - logger.Fatal(ctx, err) - } - defer func(deferCtx context.Context) { - if err = sqlDB.Close(); err != nil { - logger.Fatal(deferCtx, err) - } - }(ctx) - - if err = sqlDB.Ping(); err != nil { - logger.Fatal(ctx, err) - } - - m := gormigrate.New(db, gormigrate.DefaultOptions, config.Migrations) - err = m.RollbackLast() - if err != nil { - logger.Fatalf(ctx, "Could not rollback latest migration: %v", err) - } - logger.Infof(ctx, "Rolled back one migration successfully") + return server.Rollback(ctx) }, } @@ -95,36 +38,9 @@ var rollbackCmd = &cobra.Command{ var seedProjectsCmd = &cobra.Command{ Use: "seed-projects", Short: "Seed projects in the database.", - Run: func(cmd *cobra.Command, args []string) { + RunE: func(cmd *cobra.Command, args []string) error { ctx := context.Background() - configuration := runtime.NewConfigurationProvider() - databaseConfig := configuration.ApplicationConfiguration().GetDbConfig() - logConfig := logger.GetConfig() - - db, err := repositories.GetDB(ctx, databaseConfig, logConfig) - if err != nil { - logger.Fatal(ctx, err) - } - - sqlDB, err := db.DB() - if err != nil { - logger.Fatal(ctx, err) - } - - defer func(deferCtx context.Context) { - if err = sqlDB.Close(); err != nil { - logger.Fatal(deferCtx, err) - } - }(ctx) - - if err = sqlDB.Ping(); err != nil { - logger.Fatal(ctx, err) - } - - if err = config.SeedProjects(db, args); err != nil { - logger.Fatalf(ctx, "Could not add projects to database with err: %v", err) - } - logger.Infof(ctx, "Successfully added projects to database") + return server.SeedProjects(ctx, args) }, } diff --git a/cmd/entrypoints/serve.go b/cmd/entrypoints/serve.go index 2028dfeb9..86ae225bb 100644 --- a/cmd/entrypoints/serve.go +++ b/cmd/entrypoints/serve.go @@ -2,66 +2,38 @@ package entrypoints import ( "context" - "crypto/tls" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" + "github.com/flyteorg/flytestdlib/profutils" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/secretmanager" - - authConfig "github.com/flyteorg/flyteadmin/auth/config" - - "github.com/flyteorg/flyteadmin/auth/authzserver" - - "github.com/gorilla/handlers" - - "github.com/flyteorg/flyteadmin/auth" - "github.com/flyteorg/flyteadmin/auth/interfaces" - grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" - grpcauth "github.com/grpc-ecosystem/go-grpc-middleware/auth" - - "net" - "net/http" _ "net/http/pprof" // Required to serve application. - "strings" - - "github.com/flyteorg/flyteadmin/pkg/server" - "github.com/pkg/errors" - "google.golang.org/grpc/credentials" - "google.golang.org/grpc/health" - "google.golang.org/grpc/health/grpc_health_v1" "github.com/flyteorg/flyteadmin/pkg/common" - flyteService "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service" + "github.com/flyteorg/flyteadmin/pkg/server" "github.com/flyteorg/flytestdlib/logger" - "github.com/grpc-ecosystem/grpc-gateway/runtime" - - "github.com/flyteorg/flyteadmin/pkg/config" - "github.com/flyteorg/flyteadmin/pkg/rpc/adminservice" "github.com/spf13/cobra" + runtimeConfig "github.com/flyteorg/flyteadmin/pkg/runtime" "github.com/flyteorg/flytestdlib/contextutils" "github.com/flyteorg/flytestdlib/promutils/labeled" - grpcPrometheus "github.com/grpc-ecosystem/go-grpc-prometheus" - "google.golang.org/grpc" - "google.golang.org/grpc/reflection" ) -var defaultCorsHeaders = []string{"Content-Type"} - // serveCmd represents the serve command var serveCmd = &cobra.Command{ Use: "serve", Short: "Launches the Flyte admin server", RunE: func(cmd *cobra.Command, args []string) error { ctx := context.Background() - serverConfig := config.GetConfig() - - if serverConfig.Security.Secure { - return serveGatewaySecure(ctx, serverConfig, authConfig.GetConfig()) - } + // Serve profiling endpoints. + cfg := runtimeConfig.NewConfigurationProvider() + go func() { + err := profutils.StartProfilingServerWithDefaultHandlers( + ctx, cfg.ApplicationConfiguration().GetTopLevelConfig().GetProfilerPort(), nil) + if err != nil { + logger.Panicf(ctx, "Failed to Start profiling and Metrics server. Error, %v", err) + } + }() - return serveGatewayInsecure(ctx, serverConfig, authConfig.GetConfig()) + return server.Serve(ctx, nil) }, } @@ -75,322 +47,3 @@ func init() { contextutils.ExecIDKey, contextutils.WorkflowIDKey, contextutils.NodeIDKey, contextutils.TaskIDKey, contextutils.TaskTypeKey, common.RuntimeTypeKey, common.RuntimeVersionKey) } - -func blanketAuthorization(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) ( - resp interface{}, err error) { - - identityContext := auth.IdentityContextFromContext(ctx) - if identityContext.IsEmpty() { - return handler(ctx, req) - } - - if !identityContext.Scopes().Has(auth.ScopeAll) { - return nil, status.Errorf(codes.Unauthenticated, "authenticated user doesn't have required scope") - } - - return handler(ctx, req) -} - -// Creates a new gRPC Server with all the configuration -func newGRPCServer(ctx context.Context, cfg *config.ServerConfig, authCtx interfaces.AuthenticationContext, - opts ...grpc.ServerOption) (*grpc.Server, error) { - // Not yet implemented for streaming - var chainedUnaryInterceptors grpc.UnaryServerInterceptor - if cfg.Security.UseAuth { - logger.Infof(ctx, "Creating gRPC server with authentication") - chainedUnaryInterceptors = grpc_middleware.ChainUnaryServer(grpcPrometheus.UnaryServerInterceptor, - auth.GetAuthenticationCustomMetadataInterceptor(authCtx), - grpcauth.UnaryServerInterceptor(auth.GetAuthenticationInterceptor(authCtx)), - auth.AuthenticationLoggingInterceptor, - blanketAuthorization, - ) - } else { - logger.Infof(ctx, "Creating gRPC server without authentication") - chainedUnaryInterceptors = grpc_middleware.ChainUnaryServer(grpcPrometheus.UnaryServerInterceptor) - } - - serverOpts := []grpc.ServerOption{ - grpc.StreamInterceptor(grpcPrometheus.StreamServerInterceptor), - grpc.UnaryInterceptor(chainedUnaryInterceptors), - } - if cfg.GrpcConfig.MaxMessageSizeBytes > 0 { - serverOpts = append(serverOpts, grpc.MaxRecvMsgSize(cfg.GrpcConfig.MaxMessageSizeBytes)) - } - serverOpts = append(serverOpts, opts...) - grpcServer := grpc.NewServer(serverOpts...) - grpcPrometheus.Register(grpcServer) - flyteService.RegisterAdminServiceServer(grpcServer, adminservice.NewAdminServer(ctx, cfg.KubeConfig, cfg.Master)) - if cfg.Security.UseAuth { - flyteService.RegisterAuthMetadataServiceServer(grpcServer, authCtx.AuthMetadataService()) - flyteService.RegisterIdentityServiceServer(grpcServer, authCtx.IdentityService()) - } - - healthServer := health.NewServer() - healthServer.SetServingStatus("flyteadmin", grpc_health_v1.HealthCheckResponse_SERVING) - grpc_health_v1.RegisterHealthServer(grpcServer, healthServer) - if cfg.GrpcConfig.ServerReflection || cfg.GrpcServerReflection { - reflection.Register(grpcServer) - } - return grpcServer, nil -} - -func GetHandleOpenapiSpec(ctx context.Context) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - swaggerBytes, err := flyteService.Asset("admin.swagger.json") - if err != nil { - logger.Warningf(ctx, "Err %v", err) - w.WriteHeader(http.StatusFailedDependency) - } else { - w.WriteHeader(http.StatusOK) - _, err := w.Write(swaggerBytes) - if err != nil { - logger.Errorf(ctx, "failed to write openAPI information, error: %s", err.Error()) - } - } - } -} - -func healthCheckFunc(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) -} - -func newHTTPServer(ctx context.Context, cfg *config.ServerConfig, authCfg *authConfig.Config, authCtx interfaces.AuthenticationContext, - grpcAddress string, grpcConnectionOpts ...grpc.DialOption) (*http.ServeMux, error) { - - // Register the server that will serve HTTP/REST Traffic - mux := http.NewServeMux() - - // Register healthcheck - mux.HandleFunc("/healthcheck", healthCheckFunc) - - // Register OpenAPI endpoint - // This endpoint will serve the OpenAPI2 spec generated by the swagger protoc plugin, and bundled by go-bindata - mux.HandleFunc("/api/v1/openapi", GetHandleOpenapiSpec(ctx)) - - var gwmuxOptions = make([]runtime.ServeMuxOption, 0) - // This option means that http requests are served with protobufs, instead of json. We always want this. - gwmuxOptions = append(gwmuxOptions, runtime.WithMarshalerOption("application/octet-stream", &runtime.ProtoMarshaller{})) - - if cfg.Security.UseAuth { - // Add HTTP handlers for OIDC endpoints - auth.RegisterHandlers(ctx, mux, authCtx) - - // Add HTTP handlers for OAuth2 endpoints - authzserver.RegisterHandlers(mux, authCtx) - - // This option translates HTTP authorization data (cookies) into a gRPC metadata field - gwmuxOptions = append(gwmuxOptions, runtime.WithMetadata(auth.GetHTTPRequestCookieToMetadataHandler(authCtx))) - - // In an attempt to be able to selectively enforce whether or not authentication is required, we're going to tag - // the requests that come from the HTTP gateway. See the enforceHttp/Grpc options for more information. - gwmuxOptions = append(gwmuxOptions, runtime.WithMetadata(auth.GetHTTPMetadataTaggingHandler())) - } - - // Create the grpc-gateway server with the options specified - gwmux := runtime.NewServeMux(gwmuxOptions...) - - err := flyteService.RegisterAdminServiceHandlerFromEndpoint(ctx, gwmux, grpcAddress, grpcConnectionOpts) - if err != nil { - return nil, errors.Wrap(err, "error registering admin service") - } - - err = flyteService.RegisterAuthMetadataServiceHandlerFromEndpoint(ctx, gwmux, grpcAddress, grpcConnectionOpts) - if err != nil { - return nil, errors.Wrap(err, "error registering auth service") - } - - err = flyteService.RegisterIdentityServiceHandlerFromEndpoint(ctx, gwmux, grpcAddress, grpcConnectionOpts) - if err != nil { - return nil, errors.Wrap(err, "error registering identity service") - } - - mux.Handle("/", gwmux) - - return mux, nil -} - -func serveGatewayInsecure(ctx context.Context, cfg *config.ServerConfig, authCfg *authConfig.Config) error { - logger.Infof(ctx, "Serving Flyte Admin Insecure") - - // This will parse configuration and create the necessary objects for dealing with auth - var authCtx interfaces.AuthenticationContext - var err error - // This code is here to support authentication without SSL. This setup supports a network topology where - // Envoy does the SSL termination. The final hop is made over localhost only on a trusted machine. - // Warning: Running authentication without SSL in any other topology is a severe security flaw. - // See the auth.Config object for additional settings as well. - if cfg.Security.UseAuth { - sm := secretmanager.NewFileEnvSecretManager(secretmanager.GetConfig()) - var oauth2Provider interfaces.OAuth2Provider - var oauth2ResourceServer interfaces.OAuth2ResourceServer - if authCfg.AppAuth.AuthServerType == authConfig.AuthorizationServerTypeSelf { - oauth2Provider, err = authzserver.NewProvider(ctx, authCfg.AppAuth.SelfAuthServer, sm) - if err != nil { - logger.Errorf(ctx, "Error creating authorization server %s", err) - return err - } - - oauth2ResourceServer = oauth2Provider - } else { - oauth2ResourceServer, err = authzserver.NewOAuth2ResourceServer(ctx, authCfg.AppAuth.ExternalAuthServer, authCfg.UserAuth.OpenID.BaseURL) - if err != nil { - logger.Errorf(ctx, "Error creating resource server %s", err) - return err - } - } - - oauth2MetadataProvider := authzserver.NewService(authCfg) - oidcUserInfoProvider := auth.NewUserInfoProvider() - - authCtx, err = auth.NewAuthenticationContext(ctx, sm, oauth2Provider, oauth2ResourceServer, oauth2MetadataProvider, oidcUserInfoProvider, authCfg) - if err != nil { - logger.Errorf(ctx, "Error creating auth context %s", err) - return err - } - } - - grpcServer, err := newGRPCServer(ctx, cfg, authCtx) - if err != nil { - return errors.Wrap(err, "failed to create GRPC server") - } - - logger.Infof(ctx, "Serving GRPC Traffic on: %s", cfg.GetGrpcHostAddress()) - lis, err := net.Listen("tcp", cfg.GetGrpcHostAddress()) - if err != nil { - return errors.Wrapf(err, "failed to listen on GRPC port: %s", cfg.GetGrpcHostAddress()) - } - - go func() { - err := grpcServer.Serve(lis) - logger.Fatalf(ctx, "Failed to create GRPC Server, Err: ", err) - }() - - logger.Infof(ctx, "Starting HTTP/1 Gateway server on %s", cfg.GetHostAddress()) - grpcOptions := []grpc.DialOption{ - grpc.WithInsecure(), - grpc.WithMaxHeaderListSize(common.MaxResponseStatusBytes), - } - if cfg.GrpcConfig.MaxMessageSizeBytes > 0 { - grpcOptions = append(grpcOptions, - grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(cfg.GrpcConfig.MaxMessageSizeBytes))) - } - httpServer, err := newHTTPServer(ctx, cfg, authCfg, authCtx, cfg.GetGrpcHostAddress(), grpcOptions...) - if err != nil { - return err - } - - var handler http.Handler - if cfg.Security.AllowCors { - handler = handlers.CORS( - handlers.AllowCredentials(), - handlers.AllowedOrigins(cfg.Security.AllowedOrigins), - handlers.AllowedHeaders(append(defaultCorsHeaders, cfg.Security.AllowedHeaders...)), - handlers.AllowedMethods([]string{"GET", "POST", "DELETE", "HEAD", "PUT", "PATCH"}), - )(httpServer) - } else { - handler = httpServer - } - - err = http.ListenAndServe(cfg.GetHostAddress(), handler) - if err != nil { - return errors.Wrapf(err, "failed to Start HTTP Server") - } - - return nil -} - -// grpcHandlerFunc returns an http.Handler that delegates to grpcServer on incoming gRPC -// connections or otherHandler otherwise. -// See https://github.com/philips/grpc-gateway-example/blob/master/cmd/serve.go for reference -func grpcHandlerFunc(grpcServer *grpc.Server, otherHandler http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // This is a partial recreation of gRPC's internal checks - if r.ProtoMajor == 2 && strings.Contains(r.Header.Get("Content-Type"), "application/grpc") { - grpcServer.ServeHTTP(w, r) - } else { - otherHandler.ServeHTTP(w, r) - } - }) -} - -func serveGatewaySecure(ctx context.Context, cfg *config.ServerConfig, authCfg *authConfig.Config) error { - certPool, cert, err := server.GetSslCredentials(ctx, cfg.Security.Ssl.CertificateFile, cfg.Security.Ssl.KeyFile) - if err != nil { - return err - } - // This will parse configuration and create the necessary objects for dealing with auth - var authCtx interfaces.AuthenticationContext - if cfg.Security.UseAuth { - sm := secretmanager.NewFileEnvSecretManager(secretmanager.GetConfig()) - var oauth2Provider interfaces.OAuth2Provider - var oauth2ResourceServer interfaces.OAuth2ResourceServer - if authCfg.AppAuth.AuthServerType == authConfig.AuthorizationServerTypeSelf { - oauth2Provider, err = authzserver.NewProvider(ctx, authCfg.AppAuth.SelfAuthServer, sm) - if err != nil { - logger.Errorf(ctx, "Error creating authorization server %s", err) - return err - } - - oauth2ResourceServer = oauth2Provider - } else { - oauth2ResourceServer, err = authzserver.NewOAuth2ResourceServer(ctx, authCfg.AppAuth.ExternalAuthServer, authCfg.UserAuth.OpenID.BaseURL) - if err != nil { - logger.Errorf(ctx, "Error creating resource server %s", err) - return err - } - } - - oauth2MetadataProvider := authzserver.NewService(authCfg) - oidcUserInfoProvider := auth.NewUserInfoProvider() - - authCtx, err = auth.NewAuthenticationContext(ctx, sm, oauth2Provider, oauth2ResourceServer, oauth2MetadataProvider, oidcUserInfoProvider, authCfg) - if err != nil { - logger.Errorf(ctx, "Error creating auth context %s", err) - return err - } - } - - grpcServer, err := newGRPCServer(ctx, cfg, authCtx, - grpc.Creds(credentials.NewServerTLSFromCert(cert))) - if err != nil { - return errors.Wrap(err, "failed to create GRPC server") - } - - // Whatever certificate is used, pass it along for easier development - dialCreds := credentials.NewTLS(&tls.Config{ - ServerName: cfg.GetHostAddress(), - RootCAs: certPool, - }) - serverOpts := []grpc.DialOption{ - grpc.WithTransportCredentials(dialCreds), - } - if cfg.GrpcConfig.MaxMessageSizeBytes > 0 { - serverOpts = append(serverOpts, - grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(cfg.GrpcConfig.MaxMessageSizeBytes))) - } - httpServer, err := newHTTPServer(ctx, cfg, authCfg, authCtx, cfg.GetHostAddress(), serverOpts...) - if err != nil { - return err - } - - conn, err := net.Listen("tcp", cfg.GetHostAddress()) - if err != nil { - panic(err) - } - - srv := &http.Server{ - Addr: cfg.GetHostAddress(), - Handler: grpcHandlerFunc(grpcServer, httpServer), - TLSConfig: &tls.Config{ - Certificates: []tls.Certificate{*cert}, - NextProtos: []string{"h2"}, - }, - } - - err = srv.Serve(tls.NewListener(conn, srv.TLSConfig)) - - if err != nil { - return errors.Wrapf(err, "failed to Start HTTP/2 Server") - } - return nil -} diff --git a/cmd/scheduler/entrypoints/scheduler.go b/cmd/scheduler/entrypoints/scheduler.go index b981a3452..e7ebe9f09 100644 --- a/cmd/scheduler/entrypoints/scheduler.go +++ b/cmd/scheduler/entrypoints/scheduler.go @@ -2,20 +2,13 @@ package entrypoints import ( "context" - "fmt" - "runtime/debug" - - "github.com/flyteorg/flyteadmin/pkg/repositories" - "github.com/flyteorg/flyteadmin/pkg/repositories/errors" "github.com/flyteorg/flyteadmin/pkg/common" "github.com/flyteorg/flyteadmin/pkg/runtime" "github.com/flyteorg/flyteadmin/scheduler" - "github.com/flyteorg/flyteidl/clients/go/admin" "github.com/flyteorg/flytestdlib/contextutils" "github.com/flyteorg/flytestdlib/logger" "github.com/flyteorg/flytestdlib/profutils" - "github.com/flyteorg/flytestdlib/promutils" "github.com/flyteorg/flytestdlib/promutils/labeled" "github.com/spf13/cobra" @@ -27,45 +20,7 @@ var schedulerRunCmd = &cobra.Command{ Short: "This command will start the flyte native scheduler and periodically get new schedules from the db for scheduling", RunE: func(cmd *cobra.Command, args []string) error { ctx := context.Background() - configuration := runtime.NewConfigurationProvider() - applicationConfiguration := configuration.ApplicationConfiguration().GetTopLevelConfig() - schedulerConfiguration := configuration.ApplicationConfiguration().GetSchedulerConfig() - - // Define the schedulerScope for prometheus metrics - schedulerScope := promutils.NewScope(applicationConfiguration.MetricsScope).NewSubScope("flytescheduler") - schedulerPanics := schedulerScope.MustNewCounter("initialization_panic", - "panics encountered initializing the flyte native scheduler") - - defer func() { - if err := recover(); err != nil { - schedulerPanics.Inc() - logger.Fatalf(ctx, fmt.Sprintf("caught panic: %v [%+v]", err, string(debug.Stack()))) - } - }() - - databaseConfig := configuration.ApplicationConfiguration().GetDbConfig() - logConfig := logger.GetConfig() - - db, err := repositories.GetDB(ctx, databaseConfig, logConfig) - if err != nil { - logger.Fatal(ctx, err) - } - dbScope := schedulerScope.NewSubScope("database") - repo := repositories.NewGormRepo( - db, errors.NewPostgresErrorTransformer(schedulerScope.NewSubScope("errors")), dbScope) - - clientSet, err := admin.ClientSetBuilder().WithConfig(admin.GetConfig(ctx)).Build(ctx) - if err != nil { - logger.Fatalf(ctx, "Flyte native scheduler failed to start due to %v", err) - return err - } - adminServiceClient := clientSet.AdminClient() - - scheduleExecutor := scheduler.NewScheduledExecutor(repo, - configuration.ApplicationConfiguration().GetSchedulerConfig().GetWorkflowExecutorConfig(), schedulerScope, adminServiceClient) - - logger.Info(ctx, "Successfully initialized a native flyte scheduler") - + schedulerConfiguration := runtime.NewConfigurationProvider().ApplicationConfiguration().GetSchedulerConfig() // Serve profiling endpoints. go func() { err := profutils.StartProfilingServerWithDefaultHandlers( @@ -74,13 +29,7 @@ var schedulerRunCmd = &cobra.Command{ logger.Panicf(ctx, "Failed to Start profiling and Metrics server. Error, %v", err) } }() - - err = scheduleExecutor.Run(ctx) - if err != nil { - logger.Fatalf(ctx, "Flyte native scheduler failed to start due to %v", err) - return err - } - return nil + return scheduler.StartScheduler(ctx) }, } diff --git a/flyteadmin_config.yaml b/flyteadmin_config.yaml index 08a70756a..914e83ce6 100644 --- a/flyteadmin_config.yaml +++ b/flyteadmin_config.yaml @@ -61,12 +61,8 @@ flyteadmin: - "metadata" - "admin" database: - postgres: - port: 5432 - username: postgres - host: localhost - dbname: postgres - options: "sslmode=disable" + sqlite: + file: admin.db scheduler: eventScheduler: scheme: local diff --git a/pkg/clusterresource/controller.go b/pkg/clusterresource/controller.go index 677d2894e..b53b43339 100644 --- a/pkg/clusterresource/controller.go +++ b/pkg/clusterresource/controller.go @@ -12,6 +12,14 @@ import ( "strings" "time" + impl2 "github.com/flyteorg/flyteadmin/pkg/clusterresource/impl" + "github.com/flyteorg/flyteadmin/pkg/config" + "github.com/flyteorg/flyteadmin/pkg/executioncluster/impl" + "github.com/flyteorg/flyteadmin/pkg/manager/impl/resources" + "github.com/flyteorg/flyteadmin/pkg/repositories" + errors2 "github.com/flyteorg/flyteadmin/pkg/repositories/errors" + admin2 "github.com/flyteorg/flyteidl/clients/go/admin" + "google.golang.org/grpc/status" "github.com/flyteorg/flyteadmin/pkg/clusterresource/interfaces" @@ -633,13 +641,55 @@ func newMetrics(scope promutils.Scope) controllerMetrics { } func NewClusterResourceController(adminDataProvider interfaces.FlyteAdminDataProvider, listTargets executionclusterIfaces.ListTargetsInterface, scope promutils.Scope) Controller { - config := runtime.NewConfigurationProvider() + cfg := runtime.NewConfigurationProvider() return &controller{ adminDataProvider: adminDataProvider, - config: config, + config: cfg, listTargets: listTargets, poller: make(chan struct{}), metrics: newMetrics(scope), appliedTemplates: make(map[string]map[string]time.Time), } } + +func NewClusterResourceControllerFromConfig(ctx context.Context, scope promutils.Scope, configuration runtimeInterfaces.Configuration) (Controller, error) { + initializationErrorCounter := scope.MustNewCounter( + "flyteclient_initialization_error", + "count of errors encountered initializing a flyte client from kube config") + var listTargetsProvider executionclusterIfaces.ListTargetsInterface + var err error + if len(configuration.ClusterConfiguration().GetClusterConfigs()) == 0 { + serverConfig := config.GetConfig() + listTargetsProvider, err = impl.NewInCluster(initializationErrorCounter, serverConfig.KubeConfig, serverConfig.Master) + } else { + listTargetsProvider, err = impl.NewListTargets(initializationErrorCounter, impl.NewExecutionTargetProvider(), configuration.ClusterConfiguration()) + } + if err != nil { + return nil, err + } + + var adminDataProvider interfaces.FlyteAdminDataProvider + if configuration.ClusterResourceConfiguration().IsStandaloneDeployment() { + clientSet, err := admin2.ClientSetBuilder().WithConfig(admin2.GetConfig(ctx)).Build(ctx) + if err != nil { + return nil, err + } + adminDataProvider = impl2.NewAdminServiceDataProvider(clientSet.AdminClient()) + } else { + dbConfig := runtime.NewConfigurationProvider().ApplicationConfiguration().GetDbConfig() + logConfig := logger.GetConfig() + + db, _, err := repositories.GetDB(ctx, dbConfig, logConfig) + if err != nil { + return nil, err + } + dbScope := scope.NewSubScope("db") + + repo := repositories.NewGormRepo( + db, errors2.NewPostgresErrorTransformer(dbScope.NewSubScope("errors")), dbScope) + + adminDataProvider = impl2.NewDatabaseAdminDataProvider(repo, configuration, resources.NewResourceManager(repo, configuration.ApplicationConfiguration())) + } + + return NewClusterResourceController(adminDataProvider, listTargetsProvider, scope), nil +} diff --git a/pkg/manager/impl/execution_manager_test.go b/pkg/manager/impl/execution_manager_test.go index c0c016bc1..faae6e398 100644 --- a/pkg/manager/impl/execution_manager_test.go +++ b/pkg/manager/impl/execution_manager_test.go @@ -372,9 +372,7 @@ func TestCreateExecutionFromWorkflowNode(t *testing.T) { assert.EqualValues(t, input.NodeExecutionIdentifier, parentNodeExecutionID) getNodeExecutionCalled = true return models.NodeExecution{ - BaseModel: models.BaseModel{ - ID: 1, - }, + ID: 1, }, nil }, ) @@ -1197,9 +1195,7 @@ func TestRecoverExecution_RecoveredChildNode(t *testing.T) { assert.True(t, proto.Equal(&parentNodeExecution, &input.NodeExecutionIdentifier)) return models.NodeExecution{ - BaseModel: models.BaseModel{ - ID: parentNodeDatabaseID, - }, + ID: parentNodeDatabaseID, }, nil }) diff --git a/pkg/manager/impl/node_execution_manager_test.go b/pkg/manager/impl/node_execution_manager_test.go index d88899d6a..20ce47a6b 100644 --- a/pkg/manager/impl/node_execution_manager_test.go +++ b/pkg/manager/impl/node_execution_manager_test.go @@ -657,9 +657,7 @@ func TestListNodeExecutionsWithParent(t *testing.T) { repository.NodeExecutionRepo().(*repositoryMocks.MockNodeExecutionRepo).SetGetCallback(func(ctx context.Context, input interfaces.NodeExecutionResource) (execution models.NodeExecution, e error) { assert.Equal(t, "parent_1", input.NodeExecutionIdentifier.NodeId) return models.NodeExecution{ - BaseModel: models.BaseModel{ - ID: parentID, - }, + ID: parentID, }, nil }) repository.NodeExecutionRepo().(*repositoryMocks.MockNodeExecutionRepo).SetListCallback( diff --git a/pkg/manager/impl/task_execution_manager_test.go b/pkg/manager/impl/task_execution_manager_test.go index 13d11c9d6..a30bb34b9 100644 --- a/pkg/manager/impl/task_execution_manager_test.go +++ b/pkg/manager/impl/task_execution_manager_test.go @@ -540,9 +540,7 @@ func TestGetTaskExecution(t *testing.T) { Closure: closureBytes, ChildNodeExecution: []models.NodeExecution{ { - BaseModel: models.BaseModel{ - ID: uint(2), - }, + ID: uint(2), }, }, }, nil diff --git a/pkg/repositories/config/migration_models.go b/pkg/repositories/config/migration_models.go index 184de5f5c..c287dab3a 100644 --- a/pkg/repositories/config/migration_models.go +++ b/pkg/repositories/config/migration_models.go @@ -13,21 +13,17 @@ import ( */ type TaskKey struct { - Project string `gorm:"primary_key"` - Domain string `gorm:"primary_key"` - Name string `gorm:"primary_key"` - Version string `gorm:"primary_key"` -} - -type ExecutionKey struct { - Project string `gorm:"primary_key;column:execution_project"` - Domain string `gorm:"primary_key;column:execution_domain"` - Name string `gorm:"primary_key;column:execution_name"` + Project string `gorm:"uniqueIndex:primary_task_exec_index"` + Domain string `gorm:"uniqueIndex:primary_task_exec_index"` + Name string `gorm:"uniqueIndex:primary_task_exec_index"` + Version string `gorm:"uniqueIndex:primary_task_exec_index"` } type NodeExecutionKey struct { - ExecutionKey - NodeID string `gorm:"primary_key;index"` + Project string `gorm:"uniqueIndex:primary_node_exec_index;column:execution_project"` + Domain string `gorm:"uniqueIndex:primary_node_exec_index;column:execution_domain"` + Name string `gorm:"uniqueIndex:primary_node_exec_index;column:execution_name"` + NodeID string `gorm:"uniqueIndex:primary_node_exec_index;index"` } type NodeExecution struct { @@ -45,23 +41,37 @@ type NodeExecution struct { // Prefixed with NodeExecution to avoid clashes with gorm.Model UpdatedAt NodeExecutionUpdatedAt *time.Time Duration time.Duration + // Metadata about the node execution. + NodeExecutionMetadata []byte + // Parent that spawned this node execution - value is empty for executions at level 0 + ParentID *uint `sql:"default:null" gorm:"index"` + // List of child node executions - for cases like Dynamic task, sub workflow, etc + ChildNodeExecutions []NodeExecution `gorm:"foreignKey:ParentID;references:ID"` // The task execution (if any) which launched this node execution. - ParentTaskExecutionID uint `sql:"default:null" gorm:"index"` + // TO BE DEPRECATED - as we have now introduced ParentID + ParentTaskExecutionID *uint `sql:"default:null" gorm:"index"` // The workflow execution (if any) which this node execution launched + // NOTE: LaunchedExecution[foreignkey:ParentNodeExecutionID] refers to Workflow execution launched and is different from ParentID LaunchedExecution models.Execution `gorm:"foreignKey:ParentNodeExecutionID;references:ID"` + // Execution Error Kind. nullable, can be one of core.ExecutionError_ErrorKind + ErrorKind *string `gorm:"index"` + // Execution Error Code nullable. string value, but finite set determined by the execution engine and plugins + ErrorCode *string + // If the node is of Type Task, this should always exist for a successful execution, indicating the cache status for the execution + CacheStatus *string // In the case of dynamic workflow nodes, the remote closure is uploaded to the path specified here. DynamicWorkflowRemoteClosureReference string } type TaskExecutionKey struct { TaskKey - Project string `gorm:"primary_key;column:execution_project;index:idx_task_executions_exec"` - Domain string `gorm:"primary_key;column:execution_domain;index:idx_task_executions_exec"` - Name string `gorm:"primary_key;column:execution_name;index:idx_task_executions_exec"` - NodeID string `gorm:"primary_key;index:idx_task_executions_exec;index"` + Project string `gorm:"uniqueIndex:primary_task_exec_index;column:execution_project;index:idx_task_executions_exec"` + Domain string `gorm:"uniqueIndex:primary_task_exec_index;column:execution_domain;index:idx_task_executions_exec"` + Name string `gorm:"uniqueIndex:primary_task_exec_index;column:execution_name;index:idx_task_executions_exec"` + NodeID string `gorm:"uniqueIndex:primary_task_exec_index;index:idx_task_executions_exec;index"` // *IMPORTANT* This is a pointer to an int in order to allow setting an empty ("0") value according to gorm convention. // Because RetryAttempt is part of the TaskExecution primary key is should *never* be null. - RetryAttempt *uint32 `gorm:"primary_key;AUTO_INCREMENT:FALSE"` + RetryAttempt *uint32 `gorm:"uniqueIndex:primary_task_exec_index;AUTO_INCREMENT:FALSE"` } type TaskExecution struct { @@ -84,3 +94,24 @@ type TaskExecution struct { // The child node executions (if any) launched by this task execution. ChildNodeExecution []NodeExecution `gorm:"foreignkey:ParentTaskExecutionID;references:ID"` } + +type ExecutionEvent struct { + models.BaseModel + Project string `gorm:"uniqueIndex:primary_ee_index;column:execution_project"` + Domain string `gorm:"uniqueIndex:primary_ee_index;column:execution_domain"` + Name string `gorm:"uniqueIndex:primary_ee_index;column:execution_name"` + RequestID string `valid:"length(0|255)"` + OccurredAt time.Time + Phase string `gorm:"uniqueIndex:primary_ee_index"` +} + +type NodeExecutionEvent struct { + models.BaseModel + Project string `gorm:"uniqueIndex:primary_nee_index;column:execution_project"` + Domain string `gorm:"uniqueIndex:primary_nee_index;column:execution_domain"` + Name string `gorm:"uniqueIndex:primary_nee_index;column:execution_name"` + NodeID string `gorm:"uniqueIndex:primary_nee_index;index"` + RequestID string + OccurredAt time.Time + Phase string `gorm:"uniqueIndex:primary_nee_index"` +} diff --git a/pkg/repositories/config/migrations.go b/pkg/repositories/config/migrations.go index 0fba3bb03..dbed0fb01 100644 --- a/pkg/repositories/config/migrations.go +++ b/pkg/repositories/config/migrations.go @@ -17,357 +17,434 @@ var ( "schedule_entities_snapshots", "task_executions", "tasks", "workflows"} ) -var Migrations = []*gormigrate.Migration{ +type MigrationOptions struct { + IgnoreForSqlite bool +} + +type Migration struct { + gormigrate.Migration + Options MigrationOptions +} + +var Migrations = []*Migration{ // Create projects table. { - ID: "2019-05-22-projects", - Migrate: func(tx *gorm.DB) error { - return tx.AutoMigrate(&models.Project{}) - }, - Rollback: func(tx *gorm.DB) error { - return tx.Migrator().DropTable("projects") + Migration: gormigrate.Migration{ + ID: "2019-05-22-projects", + Migrate: func(tx *gorm.DB) error { + return tx.AutoMigrate(&models.Project{}) + }, + Rollback: func(tx *gorm.DB) error { + return tx.Migrator().DropTable("projects") + }, }, }, // Create Task { - ID: "2018-05-23-tasks", - Migrate: func(tx *gorm.DB) error { - // The gormigrate library recommends that we copy the actual struct into here for record-keeping but after - // some internal discussion we've decided that that's not necessary. Just a history of what we've touched - // when should be sufficient. - return tx.AutoMigrate(&models.Task{}) - }, - Rollback: func(tx *gorm.DB) error { - return tx.Migrator().DropTable("tasks") + Migration: gormigrate.Migration{ + ID: "2018-05-23-tasks", + Migrate: func(tx *gorm.DB) error { + // The gormigrate library recommends that we copy the actual struct into here for record-keeping but after + // some internal discussion we've decided that that's not necessary. Just a history of what we've touched + // when should be sufficient. + return tx.AutoMigrate(&models.Task{}) + }, + Rollback: func(tx *gorm.DB) error { + return tx.Migrator().DropTable("tasks") + }, }, }, // Create Workflow { - ID: "2018-05-23-workflows", - Migrate: func(tx *gorm.DB) error { - return tx.AutoMigrate(&models.Workflow{}) - }, - Rollback: func(tx *gorm.DB) error { - return tx.Migrator().DropTable("workflows") + Migration: gormigrate.Migration{ + ID: "2018-05-23-workflows", + Migrate: func(tx *gorm.DB) error { + return tx.AutoMigrate(&models.Workflow{}) + }, + Rollback: func(tx *gorm.DB) error { + return tx.Migrator().DropTable("workflows") + }, }, }, // Create Launch Plan table { - ID: "2019-05-23-lp", - Migrate: func(tx *gorm.DB) error { - return tx.AutoMigrate(&models.LaunchPlan{}) - }, - Rollback: func(tx *gorm.DB) error { - return tx.Migrator().DropTable("launch_plans") + Migration: gormigrate.Migration{ + ID: "2019-05-23-lp", + Migrate: func(tx *gorm.DB) error { + return tx.AutoMigrate(&models.LaunchPlan{}) + }, + Rollback: func(tx *gorm.DB) error { + return tx.Migrator().DropTable("launch_plans") + }, }, }, // Create executions table { - ID: "2019-05-23-executions", - Migrate: func(tx *gorm.DB) error { - return tx.AutoMigrate(&models.Execution{}) - }, - Rollback: func(tx *gorm.DB) error { - return tx.Migrator().DropTable("executions") + Migration: gormigrate.Migration{ + ID: "2019-05-23-executions", + Migrate: func(tx *gorm.DB) error { + return tx.AutoMigrate(&models.Execution{}) + }, + Rollback: func(tx *gorm.DB) error { + return tx.Migrator().DropTable("executions") + }, }, }, // Create executions events table { - ID: "2019-01-29-executions-events", - Migrate: func(tx *gorm.DB) error { - return tx.AutoMigrate(&models.ExecutionEvent{}) - }, - Rollback: func(tx *gorm.DB) error { - return tx.Migrator().DropTable("executions_events") + Migration: gormigrate.Migration{ + ID: "2019-01-29-executions-events", + Migrate: func(tx *gorm.DB) error { + return tx.AutoMigrate(&ExecutionEvent{}) + }, + Rollback: func(tx *gorm.DB) error { + return tx.Migrator().DropTable("executions_events") + }, }, }, // Create node executions table { - ID: "2019-04-17-node-executions", - Migrate: func(tx *gorm.DB) error { - return tx.AutoMigrate(&NodeExecution{}) - }, - Rollback: func(tx *gorm.DB) error { - return tx.Migrator().DropTable("node_executions") + Migration: gormigrate.Migration{ + ID: "2019-04-17-node-executions", + Migrate: func(tx *gorm.DB) error { + return tx.AutoMigrate(&NodeExecution{}) + }, + Rollback: func(tx *gorm.DB) error { + return tx.Migrator().DropTable("node_executions") + }, }, }, // Create node executions events table { - ID: "2019-01-29-node-executions-events", - Migrate: func(tx *gorm.DB) error { - return tx.AutoMigrate(&models.NodeExecutionEvent{}) - }, - Rollback: func(tx *gorm.DB) error { - return tx.Migrator().DropTable("node_executions_events") + Migration: gormigrate.Migration{ + ID: "2019-01-29-node-executions-events", + Migrate: func(tx *gorm.DB) error { + return tx.AutoMigrate(&NodeExecutionEvent{}) + }, + Rollback: func(tx *gorm.DB) error { + return tx.Migrator().DropTable("node_executions_events") + }, }, }, // Create task executions table { - ID: "2019-03-16-task-executions", - Migrate: func(tx *gorm.DB) error { - return tx.AutoMigrate(&TaskExecution{}) - }, - Rollback: func(tx *gorm.DB) error { - return tx.Migrator().DropTable("task_executions") + Migration: gormigrate.Migration{ + ID: "2019-03-16-task-executions", + Migrate: func(tx *gorm.DB) error { + return tx.AutoMigrate(&TaskExecution{}) + }, + Rollback: func(tx *gorm.DB) error { + return tx.Migrator().DropTable("task_executions") + }, }, }, // Update node executions with null parent values { - ID: "2019-04-17-node-executions-backfill", - Migrate: func(tx *gorm.DB) error { - return tx.Exec("update node_executions set parent_task_execution_id = NULL where parent_task_execution_id = 0").Error + Migration: gormigrate.Migration{ + ID: "2019-04-17-node-executions-backfill", + Migrate: func(tx *gorm.DB) error { + return tx.Exec("update node_executions set parent_task_execution_id = NULL where parent_task_execution_id = 0").Error + }, }, }, // Update executions table to add cluster { - ID: "2019-09-27-executions", - Migrate: func(tx *gorm.DB) error { - return tx.AutoMigrate(&models.Execution{}) - }, - Rollback: func(tx *gorm.DB) error { - return tx.Exec("ALTER TABLE executions DROP COLUMN IF EXISTS cluster").Error + Migration: gormigrate.Migration{ + ID: "2019-09-27-executions", + Migrate: func(tx *gorm.DB) error { + return tx.AutoMigrate(&models.Execution{}) + }, + Rollback: func(tx *gorm.DB) error { + return tx.Exec("ALTER TABLE executions DROP COLUMN IF EXISTS cluster").Error + }, }, }, // Update projects table to add description column { - ID: "2019-10-09-project-description", - Migrate: func(tx *gorm.DB) error { - return tx.AutoMigrate(&models.Project{}) - }, - Rollback: func(tx *gorm.DB) error { - return tx.Exec("ALTER TABLE projects DROP COLUMN IF EXISTS description").Error + Migration: gormigrate.Migration{ + ID: "2019-10-09-project-description", + Migrate: func(tx *gorm.DB) error { + return tx.AutoMigrate(&models.Project{}) + }, + Rollback: func(tx *gorm.DB) error { + return tx.Exec("ALTER TABLE projects DROP COLUMN IF EXISTS description").Error + }, }, }, // Add offloaded URIs to table { - ID: "2019-10-15-offload-inputs", - Migrate: func(tx *gorm.DB) error { - return tx.AutoMigrate(&models.Execution{}) - }, - Rollback: func(tx *gorm.DB) error { - return tx.Exec("ALTER TABLE executions DROP COLUMN IF EXISTS InputsURI, DROP COLUMN IF EXISTS UserInputsURI").Error + Migration: gormigrate.Migration{ + ID: "2019-10-15-offload-inputs", + Migrate: func(tx *gorm.DB) error { + return tx.AutoMigrate(&models.Execution{}) + }, + Rollback: func(tx *gorm.DB) error { + return tx.Exec("ALTER TABLE executions DROP COLUMN IF EXISTS InputsURI, DROP COLUMN IF EXISTS UserInputsURI").Error + }, }, }, // Create named_entity_metadata table. { - ID: "2019-11-05-named-entity-metadata", - Migrate: func(tx *gorm.DB) error { - return tx.AutoMigrate(&models.NamedEntityMetadata{}) - }, - Rollback: func(tx *gorm.DB) error { - return tx.Migrator().DropTable("named_entity_metadata") + Migration: gormigrate.Migration{ + ID: "2019-11-05-named-entity-metadata", + Migrate: func(tx *gorm.DB) error { + return tx.AutoMigrate(&models.NamedEntityMetadata{}) + }, + Rollback: func(tx *gorm.DB) error { + return tx.Migrator().DropTable("named_entity_metadata") + }, }, }, // Add ProjectAttributes with custom resource attributes. { - ID: "2020-01-10-resource", - Migrate: func(tx *gorm.DB) error { - return tx.AutoMigrate(&models.Resource{}) - }, - Rollback: func(tx *gorm.DB) error { - return tx.Migrator().DropTable("resources") + Migration: gormigrate.Migration{ + ID: "2020-01-10-resource", + Migrate: func(tx *gorm.DB) error { + return tx.AutoMigrate(&models.Resource{}) + }, + Rollback: func(tx *gorm.DB) error { + return tx.Migrator().DropTable("resources") + }, }, }, // Add Type to Task model. { - ID: "2020-03-17-task-type", - Migrate: func(tx *gorm.DB) error { - return tx.AutoMigrate(&models.Task{}) - }, - Rollback: func(tx *gorm.DB) error { - return tx.Exec("ALTER TABLE tasks DROP COLUMN IF EXISTS type").Error + Migration: gormigrate.Migration{ + ID: "2020-03-17-task-type", + Migrate: func(tx *gorm.DB) error { + return tx.AutoMigrate(&models.Task{}) + }, + Rollback: func(tx *gorm.DB) error { + return tx.Exec("ALTER TABLE tasks DROP COLUMN IF EXISTS type").Error + }, }, }, // Add state to name entity model { - ID: "2020-04-03-named-entity-state", - Migrate: func(tx *gorm.DB) error { - return tx.AutoMigrate(&models.NamedEntityMetadata{}) - }, - Rollback: func(tx *gorm.DB) error { - return tx.Table("named_entity_metadata").Migrator().DropColumn(&models.NamedEntityMetadata{}, "state") + Migration: gormigrate.Migration{ + ID: "2020-04-03-named-entity-state", + Migrate: func(tx *gorm.DB) error { + return tx.AutoMigrate(&models.NamedEntityMetadata{}) + }, + Rollback: func(tx *gorm.DB) error { + return tx.Table("named_entity_metadata").Migrator().DropColumn(&models.NamedEntityMetadata{}, "state") + }, }, }, // Set default state value for workflow model { - ID: "2020-04-03-named-entity-state-default", - Migrate: func(tx *gorm.DB) error { - return tx.Exec("UPDATE named_entity_metadata SET state = 0").Error - }, - Rollback: func(tx *gorm.DB) error { - return tx.Exec("UPDATE named_entity_metadata set state = NULL").Error + Migration: gormigrate.Migration{ + ID: "2020-04-03-named-entity-state-default", + Migrate: func(tx *gorm.DB) error { + return tx.Exec("UPDATE named_entity_metadata SET state = 0").Error + }, + Rollback: func(tx *gorm.DB) error { + return tx.Exec("UPDATE named_entity_metadata set state = NULL").Error + }, }, }, // Modify the workflows table, if necessary { - ID: "2020-04-03-workflow-state", - Migrate: func(tx *gorm.DB) error { - return tx.Exec("ALTER TABLE workflows DROP COLUMN IF EXISTS state").Error - }, - Rollback: func(tx *gorm.DB) error { - return tx.Exec("ALTER TABLE workflows ADD COLUMN IF NOT EXISTS state integer;").Error + Migration: gormigrate.Migration{ + ID: "2020-04-03-workflow-state", + Migrate: func(tx *gorm.DB) error { + return tx.Exec("ALTER TABLE workflows DROP COLUMN IF EXISTS state").Error + }, + Rollback: func(tx *gorm.DB) error { + return tx.Exec("ALTER TABLE workflows ADD COLUMN IF NOT EXISTS state integer;").Error + }, }, + Options: MigrationOptions{IgnoreForSqlite: true}, }, // Modify the executions & node_execution table, if necessary { - ID: "2020-04-29-executions", - Migrate: func(tx *gorm.DB) error { - return tx.AutoMigrate(&models.Execution{}, &models.NodeExecution{}) - }, - Rollback: func(tx *gorm.DB) error { - if err := tx.Model(&models.Execution{}).Migrator().DropColumn(&models.Execution{}, "error_code"); err != nil { - return err - } - if err := tx.Model(&models.Execution{}).Migrator().DropColumn(&models.Execution{}, "error_kind"); err != nil { - return err - } - if err := tx.Model(&models.NodeExecution{}).Migrator().DropColumn(&models.NodeExecution{}, "error_code"); err != nil { - return err - } - if err := tx.Model(&models.NodeExecution{}).Migrator().DropColumn(&models.NodeExecution{}, "error_kind"); err != nil { - return err - } - return nil + Migration: gormigrate.Migration{ + ID: "2020-04-29-executions", + Migrate: func(tx *gorm.DB) error { + return tx.AutoMigrate(&models.Execution{}, &NodeExecution{}) + }, + Rollback: func(tx *gorm.DB) error { + if err := tx.Model(&models.Execution{}).Migrator().DropColumn(&models.Execution{}, "error_code"); err != nil { + return err + } + if err := tx.Model(&models.Execution{}).Migrator().DropColumn(&models.Execution{}, "error_kind"); err != nil { + return err + } + if err := tx.Model(&NodeExecution{}).Migrator().DropColumn(&NodeExecution{}, "error_code"); err != nil { + return err + } + if err := tx.Model(&NodeExecution{}).Migrator().DropColumn(&NodeExecution{}, "error_kind"); err != nil { + return err + } + return nil + }, }, }, // Add TaskID to Execution model. { - ID: "2020-04-14-task-type", - Migrate: func(tx *gorm.DB) error { - return tx.AutoMigrate(&models.Execution{}) - }, - Rollback: func(tx *gorm.DB) error { - return tx.Exec("ALTER TABLE executions DROP COLUMN IF EXISTS task_id").Error + Migration: gormigrate.Migration{ + ID: "2020-04-14-task-type", + Migrate: func(tx *gorm.DB) error { + return tx.AutoMigrate(&models.Execution{}) + }, + Rollback: func(tx *gorm.DB) error { + return tx.Exec("ALTER TABLE executions DROP COLUMN IF EXISTS task_id").Error + }, }, }, // NodeExecutions table has CacheStatus for Task nodes { - ID: "2020-07-27-cachestatus", - Migrate: func(tx *gorm.DB) error { - return tx.AutoMigrate(&models.NodeExecution{}) - }, - Rollback: func(tx *gorm.DB) error { - return tx.Model(&models.NodeExecution{}).Migrator().DropColumn(&models.NodeExecution{}, "cache_status") + Migration: gormigrate.Migration{ + ID: "2020-07-27-cachestatus", + Migrate: func(tx *gorm.DB) error { + return tx.AutoMigrate(&NodeExecution{}) + }, + Rollback: func(tx *gorm.DB) error { + return tx.Model(&NodeExecution{}).Migrator().DropColumn(&NodeExecution{}, "cache_status") + }, }, }, { - ID: "2020-07-31-node-execution", - Migrate: func(tx *gorm.DB) error { - return tx.AutoMigrate(&models.NodeExecution{}) - }, - Rollback: func(tx *gorm.DB) error { - if err := tx.Model(&models.NodeExecution{}).Migrator().DropColumn(&models.NodeExecution{}, "parent_id"); err != nil { - return err - } - return tx.Model(&models.NodeExecution{}).Migrator().DropColumn(&models.NodeExecution{}, "node_execution_metadata") + Migration: gormigrate.Migration{ + ID: "2020-07-31-node-execution", + Migrate: func(tx *gorm.DB) error { + return tx.AutoMigrate(&NodeExecution{}) + }, + Rollback: func(tx *gorm.DB) error { + if err := tx.Model(&NodeExecution{}).Migrator().DropColumn(&NodeExecution{}, "parent_id"); err != nil { + return err + } + return tx.Model(&NodeExecution{}).Migrator().DropColumn(&NodeExecution{}, "node_execution_metadata") + }, }, }, { - ID: "2020-08-17-labels-addition", - Migrate: func(tx *gorm.DB) error { - return tx.AutoMigrate(&models.Project{}) - }, - Rollback: func(tx *gorm.DB) error { - return tx.Model(&models.Project{}).Migrator().DropColumn(&models.Project{}, "labels") + Migration: gormigrate.Migration{ + ID: "2020-08-17-labels-addition", + Migrate: func(tx *gorm.DB) error { + return tx.AutoMigrate(&models.Project{}) + }, + Rollback: func(tx *gorm.DB) error { + return tx.Model(&models.Project{}).Migrator().DropColumn(&models.Project{}, "labels") + }, }, }, { - ID: "2020-09-01-task-exec-idx", - Migrate: func(tx *gorm.DB) error { - return tx.AutoMigrate(&TaskExecution{}) - }, - Rollback: func(tx *gorm.DB) error { - return tx.Model(&TaskExecution{}).Migrator().DropIndex(&TaskExecution{}, "idx_task_executions_exec") + Migration: gormigrate.Migration{ + ID: "2020-09-01-task-exec-idx", + Migrate: func(tx *gorm.DB) error { + return tx.AutoMigrate(&TaskExecution{}) + }, + Rollback: func(tx *gorm.DB) error { + return tx.Model(&TaskExecution{}).Migrator().DropIndex(&TaskExecution{}, "idx_task_executions_exec") + }, }, }, { - ID: "2020-11-03-project-state-addition", - Migrate: func(tx *gorm.DB) error { - return tx.AutoMigrate(&models.Project{}) - }, - Rollback: func(tx *gorm.DB) error { - return tx.Model(&models.Project{}).Migrator().DropColumn(&models.Project{}, "state") + Migration: gormigrate.Migration{ + ID: "2020-11-03-project-state-addition", + Migrate: func(tx *gorm.DB) error { + return tx.AutoMigrate(&models.Project{}) + }, + Rollback: func(tx *gorm.DB) error { + return tx.Model(&models.Project{}).Migrator().DropColumn(&models.Project{}, "state") + }, }, }, { - ID: "2020-11-03-project-state-default", - Migrate: func(tx *gorm.DB) error { - return tx.Exec("UPDATE projects set state = 0").Error - }, - Rollback: func(tx *gorm.DB) error { - return tx.Exec("UPDATE projects set state = NULL").Error + Migration: gormigrate.Migration{ + ID: "2020-11-03-project-state-default", + Migrate: func(tx *gorm.DB) error { + return tx.Exec("UPDATE projects set state = 0").Error + }, + Rollback: func(tx *gorm.DB) error { + return tx.Exec("UPDATE projects set state = NULL").Error + }, }, }, { - ID: "2021-01-22-execution-user", - Migrate: func(tx *gorm.DB) error { - return tx.AutoMigrate(&models.Execution{}) - }, - Rollback: func(tx *gorm.DB) error { - return tx.Model(&models.Execution{}).Migrator().DropColumn(&models.Execution{}, "user") + Migration: gormigrate.Migration{ + ID: "2021-01-22-execution-user", + Migrate: func(tx *gorm.DB) error { + return tx.AutoMigrate(&models.Execution{}) + }, + Rollback: func(tx *gorm.DB) error { + return tx.Model(&models.Execution{}).Migrator().DropColumn(&models.Execution{}, "user") + }, }, }, { - ID: "2021-04-19-node-execution_dynamic-workflow", - Migrate: func(tx *gorm.DB) error { - return tx.AutoMigrate(&models.NodeExecution{}) - }, - Rollback: func(tx *gorm.DB) error { - return tx.Model(&models.NodeExecution{}).Migrator().DropColumn(&models.NodeExecution{}, "dynamic_workflow_remote_closure_reference") + Migration: gormigrate.Migration{ + ID: "2021-04-19-node-execution_dynamic-workflow", + Migrate: func(tx *gorm.DB) error { + return tx.AutoMigrate(&NodeExecution{}) + }, + Rollback: func(tx *gorm.DB) error { + return tx.Model(&NodeExecution{}).Migrator().DropColumn(&NodeExecution{}, "dynamic_workflow_remote_closure_reference") + }, }, }, { - ID: "2021-07-22-schedulable_entities", - Migrate: func(tx *gorm.DB) error { - return tx.AutoMigrate(&schedulerModels.SchedulableEntity{}) - }, - Rollback: func(tx *gorm.DB) error { - return tx.Migrator().DropTable(&schedulerModels.SchedulableEntity{}, "schedulable_entities") + Migration: gormigrate.Migration{ + ID: "2021-07-22-schedulable_entities", + Migrate: func(tx *gorm.DB) error { + return tx.AutoMigrate(&schedulerModels.SchedulableEntity{}) + }, + Rollback: func(tx *gorm.DB) error { + return tx.Migrator().DropTable(&schedulerModels.SchedulableEntity{}, "schedulable_entities") + }, }, }, { - ID: "2021-08-05-schedulable_entities_snapshot", - Migrate: func(tx *gorm.DB) error { - return tx.AutoMigrate(&schedulerModels.ScheduleEntitiesSnapshot{}) - }, - Rollback: func(tx *gorm.DB) error { - return tx.Migrator().DropTable(&schedulerModels.ScheduleEntitiesSnapshot{}, "schedulable_entities_snapshot") + Migration: gormigrate.Migration{ + ID: "2021-08-05-schedulable_entities_snapshot", + Migrate: func(tx *gorm.DB) error { + return tx.AutoMigrate(&schedulerModels.ScheduleEntitiesSnapshot{}) + }, + Rollback: func(tx *gorm.DB) error { + return tx.Migrator().DropTable(&schedulerModels.ScheduleEntitiesSnapshot{}, "schedulable_entities_snapshot") + }, }, }, // For any new table, Please use the following pattern due to a bug // in the postgres gorm layer https://github.com/go-gorm/postgres/issues/65 { - ID: "2022-01-11-id-to-bigint", - Migrate: func(tx *gorm.DB) error { - db, err := tx.DB() - if err != nil { - return err - } - return alterTableColumnType(db, "id", "bigint") - }, - Rollback: func(tx *gorm.DB) error { - db, err := tx.DB() - if err != nil { - return err - } - return alterTableColumnType(db, "id", "int") - }, + Migration: gormigrate.Migration{ + ID: "2022-01-11-id-to-bigint", + Migrate: func(tx *gorm.DB) error { + db, err := tx.DB() + if err != nil { + return err + } + return alterTableColumnType(db, "id", "bigint") + }, + Rollback: func(tx *gorm.DB) error { + db, err := tx.DB() + if err != nil { + return err + } + return alterTableColumnType(db, "id", "int") + }, + }, + Options: MigrationOptions{IgnoreForSqlite: true}, }, // Add state to execution model. { - ID: "2022-01-11-execution-state", - Migrate: func(tx *gorm.DB) error { - return tx.AutoMigrate(&models.Execution{}) - }, - Rollback: func(tx *gorm.DB) error { - return tx.Table("execution").Migrator().DropColumn(&models.Execution{}, "state") + Migration: gormigrate.Migration{ + ID: "2022-01-11-execution-state", + Migrate: func(tx *gorm.DB) error { + return tx.AutoMigrate(&models.Execution{}) + }, + Rollback: func(tx *gorm.DB) error { + return tx.Table("execution").Migrator().DropColumn(&models.Execution{}, "state") + }, }, }, } diff --git a/pkg/repositories/database.go b/pkg/repositories/database.go index 61c161696..a12fec265 100644 --- a/pkg/repositories/database.go +++ b/pkg/repositories/database.go @@ -9,9 +9,10 @@ import ( "strings" repoErrors "github.com/flyteorg/flyteadmin/pkg/repositories/errors" + "gorm.io/driver/sqlite" + runtimeInterfaces "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces" "github.com/flyteorg/flytestdlib/logger" - "github.com/jackc/pgconn" "gorm.io/driver/postgres" "gorm.io/gorm" @@ -21,6 +22,13 @@ import ( const pqInvalidDBCode = "3D000" const defaultDB = "postgres" +type DatabaseType int64 + +const ( + DatabaseTypePostgres DatabaseType = iota + DatabaseTypeSqlite +) + // getGormLogLevel converts between the flytestdlib configured log level to the equivalent gorm log level. func getGormLogLevel(ctx context.Context, logConfig *logger.Config) gormLogger.LogLevel { if logConfig == nil { @@ -66,7 +74,7 @@ func resolvePassword(ctx context.Context, passwordVal, passwordPath string) stri } // Produces the DSN (data source name) for opening a postgres db connection. -func getPostgresDsn(ctx context.Context, pgConfig runtimeInterfaces.PostgresConfig) string { +func getPostgresDsn(ctx context.Context, pgConfig *runtimeInterfaces.PostgresConfig) string { password := resolvePassword(ctx, pgConfig.Password, pgConfig.PasswordPath) if len(password) == 0 { // The password-less case is included for development environments. @@ -79,8 +87,7 @@ func getPostgresDsn(ctx context.Context, pgConfig runtimeInterfaces.PostgresConf // GetDB uses the dbConfig to create gorm DB object. If the db doesn't exist for the dbConfig then a new one is created // using the default db for the provider. eg : postgres has default dbName as postgres -func GetDB(ctx context.Context, dbConfig *runtimeInterfaces.DbConfig, logConfig *logger.Config) ( - gormDb *gorm.DB, err error) { +func GetDB(ctx context.Context, dbConfig *runtimeInterfaces.DbConfig, logConfig *logger.Config) (*gorm.DB, DatabaseType, error) { if dbConfig == nil { panic("Cannot initialize database repository from empty db config") } @@ -89,17 +96,29 @@ func GetDB(ctx context.Context, dbConfig *runtimeInterfaces.DbConfig, logConfig DisableForeignKeyConstraintWhenMigrating: !dbConfig.EnableForeignKeyConstraintWhenMigrating, } - // TODO: add other gorm-supported db type handling in further case blocks. + var gormDb *gorm.DB + var err error + var databaseType DatabaseType + switch { - // TODO: Figure out a better proxy for a non-empty postgres config - case len(dbConfig.PostgresConfig.Host) > 0 || len(dbConfig.PostgresConfig.User) > 0 || len(dbConfig.PostgresConfig.DbName) > 0: + case dbConfig.SQLiteConfig != nil: + if dbConfig.SQLiteConfig.File == "" { + return nil, databaseType, fmt.Errorf("illegal sqlite database configuration. `file` is a required parameter and should be a path") + } + logger.Debugf(ctx, "Opening sqlite db connection") + gormDb, err = gorm.Open(sqlite.Open(dbConfig.SQLiteConfig.File), gormConfig) + if err != nil { + return nil, databaseType, err + } + databaseType = DatabaseTypeSqlite + case dbConfig.PostgresConfig != nil && (len(dbConfig.PostgresConfig.Host) > 0 || len(dbConfig.PostgresConfig.User) > 0 || len(dbConfig.PostgresConfig.DbName) > 0): gormDb, err = createPostgresDbIfNotExists(ctx, gormConfig, dbConfig.PostgresConfig) if err != nil { - return nil, err + return nil, databaseType, err } - + databaseType = DatabaseTypePostgres case len(dbConfig.DeprecatedHost) > 0 || len(dbConfig.DeprecatedUser) > 0 || len(dbConfig.DeprecatedDbName) > 0: - pgConfig := runtimeInterfaces.PostgresConfig{ + pgConfig := &runtimeInterfaces.PostgresConfig{ Host: dbConfig.DeprecatedHost, Port: dbConfig.DeprecatedPort, DbName: dbConfig.DeprecatedDbName, @@ -111,18 +130,19 @@ func GetDB(ctx context.Context, dbConfig *runtimeInterfaces.DbConfig, logConfig } gormDb, err = createPostgresDbIfNotExists(ctx, gormConfig, pgConfig) if err != nil { - return nil, err + return nil, databaseType, err } + databaseType = DatabaseTypePostgres default: - panic(fmt.Sprintf("Unrecognized database config %v", dbConfig)) + return nil, databaseType, fmt.Errorf("unrecognized database config, %v. Supported only postgres and sqlite", dbConfig) } // Setup connection pool settings - return gormDb, setupDbConnectionPool(gormDb, dbConfig) + return gormDb, databaseType, setupDbConnectionPool(gormDb, dbConfig) } // Creates DB if it doesn't exist for the passed in config -func createPostgresDbIfNotExists(ctx context.Context, gormConfig *gorm.Config, pgConfig runtimeInterfaces.PostgresConfig) (*gorm.DB, error) { +func createPostgresDbIfNotExists(ctx context.Context, gormConfig *gorm.Config, pgConfig *runtimeInterfaces.PostgresConfig) (*gorm.DB, error) { dialector := postgres.Open(getPostgresDsn(ctx, pgConfig)) gormDb, err := gorm.Open(dialector, gormConfig) diff --git a/pkg/repositories/database_test.go b/pkg/repositories/database_test.go index 5b1de474e..27b53aaac 100644 --- a/pkg/repositories/database_test.go +++ b/pkg/repositories/database_test.go @@ -4,6 +4,7 @@ import ( "context" "io/ioutil" "os" + "path" "path/filepath" "testing" "time" @@ -58,7 +59,7 @@ func TestResolvePassword(t *testing.T) { } func TestGetPostgresDsn(t *testing.T) { - pgConfig := runtimeInterfaces.PostgresConfig{ + pgConfig := &runtimeInterfaces.PostgresConfig{ Host: "localhost", Port: 5432, DbName: "postgres", @@ -143,3 +144,25 @@ func TestSetupDbConnectionPool(t *testing.T) { assert.NotNil(t, err) }) } + +func TestGetDB(t *testing.T) { + ctx := context.TODO() + + t.Run("missing DB Config", func(t *testing.T) { + _, _, err := GetDB(ctx, &runtimeInterfaces.DbConfig{}, &logger.Config{}) + assert.Error(t, err) + }) + + t.Run("sqlite config", func(t *testing.T) { + dbFile := path.Join(t.TempDir(), "admin.db") + db, _, err := GetDB(ctx, &runtimeInterfaces.DbConfig{ + SQLiteConfig: &runtimeInterfaces.SQLiteConfig{ + File: dbFile, + }, + }, &logger.Config{}) + assert.NoError(t, err) + assert.NotNil(t, db) + assert.FileExists(t, dbFile) + assert.Equal(t, "sqlite", db.Name()) + }) +} diff --git a/pkg/repositories/gormimpl/node_execution_repo_test.go b/pkg/repositories/gormimpl/node_execution_repo_test.go index a2679a919..c74d712cc 100644 --- a/pkg/repositories/gormimpl/node_execution_repo_test.go +++ b/pkg/repositories/gormimpl/node_execution_repo_test.go @@ -65,7 +65,7 @@ func TestUpdateNodeExecution(t *testing.T) { nodeExecutionQuery.WithQuery(`UPDATE "node_executions" SET "id"=$1,"updated_at"=$2,"execution_project"=$3,"execution_domain"=$4,"execution_name"=$5,"node_id"=$6,"phase"=$7,"input_uri"=$8,"closure"=$9,"started_at"=$10,"node_execution_created_at"=$11,"node_execution_updated_at"=$12,"duration"=$13 WHERE "execution_project" = $14 AND "execution_domain" = $15 AND "execution_name" = $16 AND "node_id" = $17`) err := nodeExecutionRepo.Update(context.Background(), &models.NodeExecution{ - BaseModel: models.BaseModel{ID: 1}, + ID: 1, NodeExecutionKey: models.NodeExecutionKey{ NodeID: "1", ExecutionKey: models.ExecutionKey{ @@ -381,9 +381,7 @@ func TestNodeExecutionExists(t *testing.T) { Name: "1", }, }, - BaseModel: models.BaseModel{ - ID: id, - }, + ID: id, Phase: nodePhase, Closure: []byte("closure"), } diff --git a/pkg/repositories/models/base_model.go b/pkg/repositories/models/base_model.go index bfc5776ce..ea26447b7 100644 --- a/pkg/repositories/models/base_model.go +++ b/pkg/repositories/models/base_model.go @@ -6,7 +6,7 @@ import "time" // This is nearly identical to http://doc.gorm.io/models.html#conventions except that flyteadmin models define their // own primary keys rather than use the ID as the primary key type BaseModel struct { - ID uint `gorm:"index;autoIncrement"` + ID uint `gorm:"index;autoIncrement;primary_key"` CreatedAt time.Time UpdatedAt time.Time DeletedAt *time.Time `gorm:"index"` diff --git a/pkg/repositories/models/execution.go b/pkg/repositories/models/execution.go index fa429226b..00e44a10a 100644 --- a/pkg/repositories/models/execution.go +++ b/pkg/repositories/models/execution.go @@ -11,9 +11,9 @@ import ( // Execution primary key type ExecutionKey struct { - Project string `gorm:"primary_key;column:execution_project" valid:"length(0|255)"` - Domain string `gorm:"primary_key;column:execution_domain" valid:"length(0|255)"` - Name string `gorm:"primary_key;column:execution_name" valid:"length(0|255)"` + Project string `gorm:"uniqueIndex:primary_exec_index;column:execution_project" valid:"length(0|255)"` + Domain string `gorm:"uniqueIndex:primary_exec_index;column:execution_domain" valid:"length(0|255)"` + Name string `gorm:"uniqueIndex:primary_exec_index;column:execution_name" valid:"length(0|255)"` } // Database model to encapsulate a (workflow) execution. diff --git a/pkg/repositories/models/execution_event.go b/pkg/repositories/models/execution_event.go index b6e18decc..117f8bb22 100644 --- a/pkg/repositories/models/execution_event.go +++ b/pkg/repositories/models/execution_event.go @@ -9,5 +9,5 @@ type ExecutionEvent struct { ExecutionKey RequestID string `valid:"length(0|255)"` OccurredAt time.Time - Phase string `gorm:"primary_key"` + Phase string `gorm:"uniqueIndex:primary_exec_event_index"` } diff --git a/pkg/repositories/models/launch_plan.go b/pkg/repositories/models/launch_plan.go index 2b2787518..0920d55d9 100644 --- a/pkg/repositories/models/launch_plan.go +++ b/pkg/repositories/models/launch_plan.go @@ -2,10 +2,10 @@ package models // Launch plan primary key type LaunchPlanKey struct { - Project string `gorm:"primary_key;index:lp_project_domain_name_idx,lp_project_domain_idx" valid:"length(0|255)"` - Domain string `gorm:"primary_key;index:lp_project_domain_name_idx,lp_project_domain_idx" valid:"length(0|255)"` - Name string `gorm:"primary_key;index:lp_project_domain_name_idx" valid:"length(0|255)"` - Version string `gorm:"primary_key" valid:"length(0|255)"` + Project string `gorm:"uniqueIndex:primary_lp_index;index:lp_project_domain_name_idx,lp_project_domain_idx" valid:"length(0|255)"` + Domain string `gorm:"uniqueIndex:primary_lp_index;index:lp_project_domain_name_idx,lp_project_domain_idx" valid:"length(0|255)"` + Name string `gorm:"uniqueIndex:primary_lp_index;index:lp_project_domain_name_idx" valid:"length(0|255)"` + Version string `gorm:"uniqueIndex:primary_lp_index" valid:"length(0|255)"` } type LaunchPlanScheduleType string diff --git a/pkg/repositories/models/named_entity.go b/pkg/repositories/models/named_entity.go index 966676d0c..86bccae72 100644 --- a/pkg/repositories/models/named_entity.go +++ b/pkg/repositories/models/named_entity.go @@ -6,10 +6,10 @@ import ( // NamedEntityMetadata primary key type NamedEntityMetadataKey struct { - ResourceType core.ResourceType `gorm:"primary_key;index:named_entity_metadata_type_project_domain_name_idx" valid:"length(0|255)"` - Project string `gorm:"primary_key;index:named_entity_metadata_type_project_domain_name_idx" valid:"length(0|255)"` - Domain string `gorm:"primary_key;index:named_entity_metadata_type_project_domain_name_idx" valid:"length(0|255)"` - Name string `gorm:"primary_key;index:named_entity_metadata_type_project_domain_name_idx" valid:"length(0|255)"` + ResourceType core.ResourceType `gorm:"uniqueIndex:primary_ne_index;index:named_entity_metadata_type_project_domain_name_idx" valid:"length(0|255)"` + Project string `gorm:"uniqueIndex:primary_ne_index;index:named_entity_metadata_type_project_domain_name_idx" valid:"length(0|255)"` + Domain string `gorm:"uniqueIndex:primary_ne_index;index:named_entity_metadata_type_project_domain_name_idx" valid:"length(0|255)"` + Name string `gorm:"uniqueIndex:primary_ne_index;index:named_entity_metadata_type_project_domain_name_idx" valid:"length(0|255)"` } // Fields to be composed into any named entity diff --git a/pkg/repositories/models/node_execution_event.go b/pkg/repositories/models/node_execution_event.go index cd362c8ff..5310756a7 100644 --- a/pkg/repositories/models/node_execution_event.go +++ b/pkg/repositories/models/node_execution_event.go @@ -9,5 +9,5 @@ type NodeExecutionEvent struct { NodeExecutionKey RequestID string OccurredAt time.Time - Phase string `gorm:"primary_key"` + Phase string `gorm:"uniqueIndex:primary_nee_index"` } diff --git a/pkg/repositories/models/project.go b/pkg/repositories/models/project.go index a5feedb27..8093c189c 100644 --- a/pkg/repositories/models/project.go +++ b/pkg/repositories/models/project.go @@ -2,7 +2,7 @@ package models type Project struct { BaseModel - Identifier string `gorm:"primary_key"` + Identifier string `gorm:"uniqueIndex:project_index"` Name string `valid:"length(0|255)"` // Human-readable name, not a unique identifier. Description string `gorm:"type:varchar(300)"` Labels []byte diff --git a/pkg/repositories/models/resource.go b/pkg/repositories/models/resource.go index 02da2fbe3..4512038c3 100644 --- a/pkg/repositories/models/resource.go +++ b/pkg/repositories/models/resource.go @@ -1,6 +1,8 @@ package models -import "time" +import ( + "time" +) type ResourcePriority int32 diff --git a/pkg/repositories/models/task.go b/pkg/repositories/models/task.go index 093bd29ec..71f74cd3f 100644 --- a/pkg/repositories/models/task.go +++ b/pkg/repositories/models/task.go @@ -5,10 +5,10 @@ package models // Task primary key type TaskKey struct { - Project string `gorm:"primary_key;index:task_project_domain_name_idx;index:task_project_domain_idx" valid:"length(0|255)"` - Domain string `gorm:"primary_key;index:task_project_domain_name_idx;index:task_project_domain_idx" valid:"length(0|255)"` - Name string `gorm:"primary_key;index:task_project_domain_name_idx" valid:"length(0|255)"` - Version string `gorm:"primary_key" valid:"length(0|255)"` + Project string `gorm:"uniqueIndex:primary_task_index;index:task_project_domain_name_idx;index:task_project_domain_idx" valid:"length(0|255)"` + Domain string `gorm:"uniqueIndex:primary_task_index;index:task_project_domain_name_idx;index:task_project_domain_idx" valid:"length(0|255)"` + Name string `gorm:"uniqueIndex:primary_task_index;index:task_project_domain_name_idx" valid:"length(0|255)"` + Version string `gorm:"uniqueIndex:primary_task_index" valid:"length(0|255)"` } // Database model to encapsulate a task. diff --git a/pkg/repositories/models/workflow.go b/pkg/repositories/models/workflow.go index 5f50379b1..6ed810fd2 100644 --- a/pkg/repositories/models/workflow.go +++ b/pkg/repositories/models/workflow.go @@ -2,10 +2,10 @@ package models // Workflow primary key type WorkflowKey struct { - Project string `gorm:"primary_key;index:workflow_project_domain_name_idx;index:workflow_project_domain_idx" valid:"length(0|255)"` - Domain string `gorm:"primary_key;index:workflow_project_domain_name_idx;index:workflow_project_domain_idx" valid:"length(0|255)"` - Name string `gorm:"primary_key;index:workflow_project_domain_name_idx" valid:"length(0|255)"` - Version string `gorm:"primary_key"` + Project string `gorm:"uniqueIndex:primary_workflow_index;index:workflow_project_domain_name_idx;index:workflow_project_domain_idx" valid:"length(0|255)"` + Domain string `gorm:"uniqueIndex:primary_workflow_index;index:workflow_project_domain_name_idx;index:workflow_project_domain_idx" valid:"length(0|255)"` + Name string `gorm:"uniqueIndex:primary_workflow_index;index:workflow_project_domain_name_idx" valid:"length(0|255)"` + Version string `gorm:"uniqueIndex:primary_workflow_index"` } // Database model to encapsulate a workflow. diff --git a/pkg/rpc/adminservice/base.go b/pkg/rpc/adminservice/base.go index 891f8f9a5..c7e6584e5 100644 --- a/pkg/rpc/adminservice/base.go +++ b/pkg/rpc/adminservice/base.go @@ -24,7 +24,6 @@ import ( "github.com/flyteorg/flyteadmin/pkg/workflowengine" workflowengineImpl "github.com/flyteorg/flyteadmin/pkg/workflowengine/impl" "github.com/flyteorg/flytestdlib/logger" - "github.com/flyteorg/flytestdlib/profutils" "github.com/flyteorg/flytestdlib/promutils" "github.com/flyteorg/flytestdlib/storage" "github.com/golang/protobuf/proto" @@ -76,7 +75,7 @@ func NewAdminServer(ctx context.Context, kubeConfig, master string) *AdminServic databaseConfig := configuration.ApplicationConfiguration().GetDbConfig() logConfig := logger.GetConfig() - db, err := repositories.GetDB(ctx, databaseConfig, logConfig) + db, _, err := repositories.GetDB(ctx, databaseConfig, logConfig) if err != nil { logger.Fatal(ctx, err) } @@ -155,15 +154,6 @@ func NewAdminServer(ctx context.Context, kubeConfig, master string) *AdminServic scheduledWorkflowExecutor.Run() }() - // Serve profiling endpoints. - go func() { - err := profutils.StartProfilingServerWithDefaultHandlers( - ctx, applicationConfiguration.GetProfilerPort(), nil) - if err != nil { - logger.Panicf(ctx, "Failed to Start profiling and Metrics server. Error, %v", err) - } - }() - nodeExecutionEventWriter := eventWriter.NewNodeExecutionEventWriter(repo, applicationConfiguration.GetAsyncEventsBufferSize()) go func() { nodeExecutionEventWriter.Run() diff --git a/pkg/runtime/interfaces/application_configuration.go b/pkg/runtime/interfaces/application_configuration.go index 2eca76c81..96f5a7ad5 100644 --- a/pkg/runtime/interfaces/application_configuration.go +++ b/pkg/runtime/interfaces/application_configuration.go @@ -22,7 +22,14 @@ type DbConfig struct { MaxIdleConnections int `json:"maxIdleConnections" pflag:",maxIdleConnections sets the maximum number of connections in the idle connection pool."` MaxOpenConnections int `json:"maxOpenConnections" pflag:",maxOpenConnections sets the maximum number of open connections to the database."` ConnMaxLifeTime config.Duration `json:"connMaxLifeTime" pflag:",sets the maximum amount of time a connection may be reused"` - PostgresConfig PostgresConfig `json:"postgres"` + PostgresConfig *PostgresConfig `json:"postgres,omitempty"` + SQLiteConfig *SQLiteConfig `json:"sqlite,omitempty"` +} + +// SQLiteConfig can be used to configure +type SQLiteConfig struct { + File string `json:"file" pflag:",The path to the file (existing or new) where the DB should be created / stored. If existing, then this will be re-used, else a new will be created"` + Debug bool `json:"debug" pflag:" Whether or not to start the database connection with debug mode enabled."` } // PostgresConfig includes specific config options for opening a connection to a postgres database. @@ -38,7 +45,7 @@ type PostgresConfig struct { Debug bool `json:"debug" pflag:" Whether or not to start the database connection with debug mode enabled."` } -// This configuration is the base configuration to start admin +// ApplicationConfig is the base configuration to start admin type ApplicationConfig struct { // The RoleName key inserted as an annotation (https://kubernetes.io/docs/concepts/overview/working-with-objects/annotations/) // in Flyte Workflow CRDs created in the CreateExecution flow. The corresponding role value is defined in the diff --git a/pkg/server/initialize.go b/pkg/server/initialize.go new file mode 100644 index 000000000..baa872f09 --- /dev/null +++ b/pkg/server/initialize.go @@ -0,0 +1,88 @@ +package server + +import ( + "context" + "fmt" + + "github.com/flyteorg/flyteadmin/pkg/repositories" + "github.com/flyteorg/flyteadmin/pkg/repositories/config" + "github.com/flyteorg/flyteadmin/pkg/runtime" + "github.com/flyteorg/flytestdlib/logger" + "github.com/go-gormigrate/gormigrate/v2" + "gorm.io/gorm" +) + +func withDB(ctx context.Context, do func(db *gorm.DB, dbType repositories.DatabaseType) error) error { + configuration := runtime.NewConfigurationProvider() + databaseConfig := configuration.ApplicationConfiguration().GetDbConfig() + logConfig := logger.GetConfig() + + db, dbType, err := repositories.GetDB(ctx, databaseConfig, logConfig) + if err != nil { + logger.Fatal(ctx, err) + } + + sqlDB, err := db.DB() + if err != nil { + logger.Fatal(ctx, err) + } + + defer func(deferCtx context.Context) { + if err = sqlDB.Close(); err != nil { + logger.Fatal(deferCtx, err) + } + }(ctx) + + if err = sqlDB.Ping(); err != nil { + return err + } + + return do(db, dbType) +} + +func getMigrations(dbType repositories.DatabaseType) []*gormigrate.Migration { + var migrations = make([]*gormigrate.Migration, 0, len(config.Migrations)) + for _, migration := range config.Migrations { + if dbType == repositories.DatabaseTypeSqlite && migration.Options.IgnoreForSqlite { + continue + } + migrations = append(migrations, &migration.Migration) + } + return migrations +} + +// Migrate runs all configured migrations +func Migrate(ctx context.Context) error { + return withDB(ctx, func(db *gorm.DB, dbType repositories.DatabaseType) error { + m := gormigrate.New(db.Debug(), gormigrate.DefaultOptions, getMigrations(dbType)) + if err := m.Migrate(); err != nil { + return fmt.Errorf("database migration failed: %v", err) + } + logger.Infof(ctx, "Migration ran successfully") + return nil + }) +} + +// Rollback rolls back the last migration +func Rollback(ctx context.Context) error { + return withDB(ctx, func(db *gorm.DB, dbType repositories.DatabaseType) error { + m := gormigrate.New(db, gormigrate.DefaultOptions, getMigrations(dbType)) + err := m.RollbackLast() + if err != nil { + return fmt.Errorf("could not rollback latest migration: %v", err) + } + logger.Infof(ctx, "Rolled back one migration successfully") + return nil + }) +} + +// SeedProjects creates a set of given projects in the DB +func SeedProjects(ctx context.Context, projects []string) error { + return withDB(ctx, func(db *gorm.DB, _ repositories.DatabaseType) error { + if err := config.SeedProjects(db, projects); err != nil { + return fmt.Errorf("could not add projects to database with err: %v", err) + } + logger.Infof(ctx, "Successfully added projects to database") + return nil + }) +} diff --git a/pkg/server/service.go b/pkg/server/service.go new file mode 100644 index 000000000..41ce92491 --- /dev/null +++ b/pkg/server/service.go @@ -0,0 +1,366 @@ +package server + +import ( + "context" + "crypto/tls" + "net" + "net/http" + "strings" + + "github.com/flyteorg/flyteadmin/auth" + "github.com/flyteorg/flyteadmin/auth/authzserver" + authConfig "github.com/flyteorg/flyteadmin/auth/config" + "github.com/flyteorg/flyteadmin/auth/interfaces" + "github.com/flyteorg/flyteadmin/pkg/common" + "github.com/flyteorg/flyteadmin/pkg/config" + "github.com/flyteorg/flyteadmin/pkg/rpc/adminservice" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/secretmanager" + "github.com/flyteorg/flytestdlib/logger" + "github.com/gorilla/handlers" + grpcmiddleware "github.com/grpc-ecosystem/go-grpc-middleware" + grpcauth "github.com/grpc-ecosystem/go-grpc-middleware/auth" + grpcprometheus "github.com/grpc-ecosystem/go-grpc-prometheus" + "github.com/grpc-ecosystem/grpc-gateway/runtime" + "github.com/pkg/errors" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/health" + "google.golang.org/grpc/health/grpc_health_v1" + "google.golang.org/grpc/reflection" + "google.golang.org/grpc/status" +) + +var defaultCorsHeaders = []string{"Content-Type"} + +// Serve starts a server and blocks the calling goroutine +func Serve(ctx context.Context, additionalHandlers map[string]func(http.ResponseWriter, *http.Request)) error { + serverConfig := config.GetConfig() + + if serverConfig.Security.Secure { + return serveGatewaySecure(ctx, serverConfig, authConfig.GetConfig(), additionalHandlers) + } + + return serveGatewayInsecure(ctx, serverConfig, authConfig.GetConfig(), additionalHandlers) +} + +func blanketAuthorization(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) ( + resp interface{}, err error) { + + identityContext := auth.IdentityContextFromContext(ctx) + if identityContext.IsEmpty() { + return handler(ctx, req) + } + + if !identityContext.Scopes().Has(auth.ScopeAll) { + return nil, status.Errorf(codes.Unauthenticated, "authenticated user doesn't have required scope") + } + + return handler(ctx, req) +} + +// Creates a new gRPC Server with all the configuration +func newGRPCServer(ctx context.Context, cfg *config.ServerConfig, authCtx interfaces.AuthenticationContext, + opts ...grpc.ServerOption) *grpc.Server { + // Not yet implemented for streaming + var chainedUnaryInterceptors grpc.UnaryServerInterceptor + if cfg.Security.UseAuth { + logger.Infof(ctx, "Creating gRPC server with authentication") + chainedUnaryInterceptors = grpcmiddleware.ChainUnaryServer(grpcprometheus.UnaryServerInterceptor, + auth.GetAuthenticationCustomMetadataInterceptor(authCtx), + grpcauth.UnaryServerInterceptor(auth.GetAuthenticationInterceptor(authCtx)), + auth.AuthenticationLoggingInterceptor, + blanketAuthorization, + ) + } else { + logger.Infof(ctx, "Creating gRPC server without authentication") + chainedUnaryInterceptors = grpcmiddleware.ChainUnaryServer(grpcprometheus.UnaryServerInterceptor) + } + + serverOpts := []grpc.ServerOption{ + grpc.StreamInterceptor(grpcprometheus.StreamServerInterceptor), + grpc.UnaryInterceptor(chainedUnaryInterceptors), + } + if cfg.GrpcConfig.MaxMessageSizeBytes > 0 { + serverOpts = append(serverOpts, grpc.MaxRecvMsgSize(cfg.GrpcConfig.MaxMessageSizeBytes)) + } + serverOpts = append(serverOpts, opts...) + grpcServer := grpc.NewServer(serverOpts...) + grpcprometheus.Register(grpcServer) + service.RegisterAdminServiceServer(grpcServer, adminservice.NewAdminServer(ctx, cfg.KubeConfig, cfg.Master)) + if cfg.Security.UseAuth { + service.RegisterAuthMetadataServiceServer(grpcServer, authCtx.AuthMetadataService()) + service.RegisterIdentityServiceServer(grpcServer, authCtx.IdentityService()) + } + + healthServer := health.NewServer() + healthServer.SetServingStatus("flyteadmin", grpc_health_v1.HealthCheckResponse_SERVING) + grpc_health_v1.RegisterHealthServer(grpcServer, healthServer) + if cfg.GrpcConfig.ServerReflection || cfg.GrpcServerReflection { + reflection.Register(grpcServer) + } + return grpcServer +} + +func GetHandleOpenapiSpec(ctx context.Context) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + swaggerBytes, err := service.Asset("admin.swagger.json") + if err != nil { + logger.Warningf(ctx, "Err %v", err) + w.WriteHeader(http.StatusFailedDependency) + } else { + w.WriteHeader(http.StatusOK) + _, err := w.Write(swaggerBytes) + if err != nil { + logger.Errorf(ctx, "failed to write openAPI information, error: %s", err.Error()) + } + } + } +} + +func healthCheckFunc(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) +} + +func newHTTPServer(ctx context.Context, cfg *config.ServerConfig, _ *authConfig.Config, authCtx interfaces.AuthenticationContext, + additionalHandlers map[string]func(http.ResponseWriter, *http.Request), + grpcAddress string, grpcConnectionOpts ...grpc.DialOption) (*http.ServeMux, error) { + + // Register the server that will serve HTTP/REST Traffic + mux := http.NewServeMux() + + // Add any additional handlers that have been passed in for the main HTTP server + for p, f := range additionalHandlers { + mux.HandleFunc(p, f) + } + + // Register healthcheck + mux.HandleFunc("/healthcheck", healthCheckFunc) + + // Register OpenAPI endpoint + // This endpoint will serve the OpenAPI2 spec generated by the swagger protoc plugin, and bundled by go-bindata + mux.HandleFunc("/api/v1/openapi", GetHandleOpenapiSpec(ctx)) + + var gwmuxOptions = make([]runtime.ServeMuxOption, 0) + // This option means that http requests are served with protobufs, instead of json. We always want this. + gwmuxOptions = append(gwmuxOptions, runtime.WithMarshalerOption("application/octet-stream", &runtime.ProtoMarshaller{})) + + if cfg.Security.UseAuth { + // Add HTTP handlers for OIDC endpoints + auth.RegisterHandlers(ctx, mux, authCtx) + + // Add HTTP handlers for OAuth2 endpoints + authzserver.RegisterHandlers(mux, authCtx) + + // This option translates HTTP authorization data (cookies) into a gRPC metadata field + gwmuxOptions = append(gwmuxOptions, runtime.WithMetadata(auth.GetHTTPRequestCookieToMetadataHandler(authCtx))) + + // In an attempt to be able to selectively enforce whether or not authentication is required, we're going to tag + // the requests that come from the HTTP gateway. See the enforceHttp/Grpc options for more information. + gwmuxOptions = append(gwmuxOptions, runtime.WithMetadata(auth.GetHTTPMetadataTaggingHandler())) + } + + // Create the grpc-gateway server with the options specified + gwmux := runtime.NewServeMux(gwmuxOptions...) + + err := service.RegisterAdminServiceHandlerFromEndpoint(ctx, gwmux, grpcAddress, grpcConnectionOpts) + if err != nil { + return nil, errors.Wrap(err, "error registering admin service") + } + + err = service.RegisterAuthMetadataServiceHandlerFromEndpoint(ctx, gwmux, grpcAddress, grpcConnectionOpts) + if err != nil { + return nil, errors.Wrap(err, "error registering auth service") + } + + err = service.RegisterIdentityServiceHandlerFromEndpoint(ctx, gwmux, grpcAddress, grpcConnectionOpts) + if err != nil { + return nil, errors.Wrap(err, "error registering identity service") + } + + mux.Handle("/", gwmux) + + return mux, nil +} + +func serveGatewayInsecure(ctx context.Context, cfg *config.ServerConfig, authCfg *authConfig.Config, additionalHandlers map[string]func(http.ResponseWriter, *http.Request)) error { + logger.Infof(ctx, "Serving Flyte Admin Insecure") + + // This will parse configuration and create the necessary objects for dealing with auth + var authCtx interfaces.AuthenticationContext + var err error + // This code is here to support authentication without SSL. This setup supports a network topology where + // Envoy does the SSL termination. The final hop is made over localhost only on a trusted machine. + // Warning: Running authentication without SSL in any other topology is a severe security flaw. + // See the auth.Config object for additional settings as well. + if cfg.Security.UseAuth { + sm := secretmanager.NewFileEnvSecretManager(secretmanager.GetConfig()) + var oauth2Provider interfaces.OAuth2Provider + var oauth2ResourceServer interfaces.OAuth2ResourceServer + if authCfg.AppAuth.AuthServerType == authConfig.AuthorizationServerTypeSelf { + oauth2Provider, err = authzserver.NewProvider(ctx, authCfg.AppAuth.SelfAuthServer, sm) + if err != nil { + logger.Errorf(ctx, "Error creating authorization server %s", err) + return err + } + + oauth2ResourceServer = oauth2Provider + } else { + oauth2ResourceServer, err = authzserver.NewOAuth2ResourceServer(ctx, authCfg.AppAuth.ExternalAuthServer, authCfg.UserAuth.OpenID.BaseURL) + if err != nil { + logger.Errorf(ctx, "Error creating resource server %s", err) + return err + } + } + + oauth2MetadataProvider := authzserver.NewService(authCfg) + oidcUserInfoProvider := auth.NewUserInfoProvider() + + authCtx, err = auth.NewAuthenticationContext(ctx, sm, oauth2Provider, oauth2ResourceServer, oauth2MetadataProvider, oidcUserInfoProvider, authCfg) + if err != nil { + logger.Errorf(ctx, "Error creating auth context %s", err) + return err + } + } + + grpcServer := newGRPCServer(ctx, cfg, authCtx) + + logger.Infof(ctx, "Serving GRPC Traffic on: %s", cfg.GetGrpcHostAddress()) + lis, err := net.Listen("tcp", cfg.GetGrpcHostAddress()) + if err != nil { + return errors.Wrapf(err, "failed to listen on GRPC port: %s", cfg.GetGrpcHostAddress()) + } + + go func() { + err := grpcServer.Serve(lis) + logger.Fatalf(ctx, "Failed to create GRPC Server, Err: ", err) + }() + + logger.Infof(ctx, "Starting HTTP/1 Gateway server on %s", cfg.GetHostAddress()) + grpcOptions := []grpc.DialOption{ + grpc.WithInsecure(), + grpc.WithMaxHeaderListSize(common.MaxResponseStatusBytes), + } + if cfg.GrpcConfig.MaxMessageSizeBytes > 0 { + grpcOptions = append(grpcOptions, + grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(cfg.GrpcConfig.MaxMessageSizeBytes))) + } + httpServer, err := newHTTPServer(ctx, cfg, authCfg, authCtx, additionalHandlers, cfg.GetGrpcHostAddress(), grpcOptions...) + if err != nil { + return err + } + + var handler http.Handler + if cfg.Security.AllowCors { + handler = handlers.CORS( + handlers.AllowCredentials(), + handlers.AllowedOrigins(cfg.Security.AllowedOrigins), + handlers.AllowedHeaders(append(defaultCorsHeaders, cfg.Security.AllowedHeaders...)), + handlers.AllowedMethods([]string{"GET", "POST", "DELETE", "HEAD", "PUT", "PATCH"}), + )(httpServer) + } else { + handler = httpServer + } + + err = http.ListenAndServe(cfg.GetHostAddress(), handler) + if err != nil { + return errors.Wrapf(err, "failed to Start HTTP Server") + } + + return nil +} + +// grpcHandlerFunc returns an http.Handler that delegates to grpcServer on incoming gRPC +// connections or otherHandler otherwise. +// See https://github.com/philips/grpc-gateway-example/blob/master/cmd/serve.go for reference +func grpcHandlerFunc(grpcServer *grpc.Server, otherHandler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // This is a partial recreation of gRPC's internal checks + if r.ProtoMajor == 2 && strings.Contains(r.Header.Get("Content-Type"), "application/grpc") { + grpcServer.ServeHTTP(w, r) + } else { + otherHandler.ServeHTTP(w, r) + } + }) +} + +func serveGatewaySecure(ctx context.Context, cfg *config.ServerConfig, authCfg *authConfig.Config, additionalHandlers map[string]func(http.ResponseWriter, *http.Request)) error { + certPool, cert, err := GetSslCredentials(ctx, cfg.Security.Ssl.CertificateFile, cfg.Security.Ssl.KeyFile) + if err != nil { + return err + } + // This will parse configuration and create the necessary objects for dealing with auth + var authCtx interfaces.AuthenticationContext + if cfg.Security.UseAuth { + sm := secretmanager.NewFileEnvSecretManager(secretmanager.GetConfig()) + var oauth2Provider interfaces.OAuth2Provider + var oauth2ResourceServer interfaces.OAuth2ResourceServer + if authCfg.AppAuth.AuthServerType == authConfig.AuthorizationServerTypeSelf { + oauth2Provider, err = authzserver.NewProvider(ctx, authCfg.AppAuth.SelfAuthServer, sm) + if err != nil { + logger.Errorf(ctx, "Error creating authorization server %s", err) + return err + } + + oauth2ResourceServer = oauth2Provider + } else { + oauth2ResourceServer, err = authzserver.NewOAuth2ResourceServer(ctx, authCfg.AppAuth.ExternalAuthServer, authCfg.UserAuth.OpenID.BaseURL) + if err != nil { + logger.Errorf(ctx, "Error creating resource server %s", err) + return err + } + } + + oauth2MetadataProvider := authzserver.NewService(authCfg) + oidcUserInfoProvider := auth.NewUserInfoProvider() + + authCtx, err = auth.NewAuthenticationContext(ctx, sm, oauth2Provider, oauth2ResourceServer, oauth2MetadataProvider, oidcUserInfoProvider, authCfg) + if err != nil { + logger.Errorf(ctx, "Error creating auth context %s", err) + return err + } + } + + grpcServer := newGRPCServer(ctx, cfg, authCtx, grpc.Creds(credentials.NewServerTLSFromCert(cert))) + + // Whatever certificate is used, pass it along for easier development + // #nosec G402 + dialCreds := credentials.NewTLS(&tls.Config{ + ServerName: cfg.GetHostAddress(), + RootCAs: certPool, + }) + serverOpts := []grpc.DialOption{ + grpc.WithTransportCredentials(dialCreds), + } + if cfg.GrpcConfig.MaxMessageSizeBytes > 0 { + serverOpts = append(serverOpts, + grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(cfg.GrpcConfig.MaxMessageSizeBytes))) + } + httpServer, err := newHTTPServer(ctx, cfg, authCfg, authCtx, additionalHandlers, cfg.GetHostAddress(), serverOpts...) + if err != nil { + return err + } + + conn, err := net.Listen("tcp", cfg.GetHostAddress()) + if err != nil { + panic(err) + } + + srv := &http.Server{ + Addr: cfg.GetHostAddress(), + Handler: grpcHandlerFunc(grpcServer, httpServer), + // #nosec G402 + TLSConfig: &tls.Config{ + Certificates: []tls.Certificate{*cert}, + NextProtos: []string{"h2"}, + }, + } + + err = srv.Serve(tls.NewListener(conn, srv.TLSConfig)) + + if err != nil { + return errors.Wrapf(err, "failed to Start HTTP/2 Server") + } + return nil +} diff --git a/scheduler/start.go b/scheduler/start.go new file mode 100644 index 000000000..77c9f79f3 --- /dev/null +++ b/scheduler/start.go @@ -0,0 +1,62 @@ +package scheduler + +import ( + "context" + "fmt" + "runtime/debug" + + "github.com/flyteorg/flyteadmin/pkg/repositories" + "github.com/flyteorg/flyteadmin/pkg/repositories/errors" + "github.com/flyteorg/flyteadmin/pkg/runtime" + "github.com/flyteorg/flyteidl/clients/go/admin" + "github.com/flyteorg/flytestdlib/logger" + "github.com/flyteorg/flytestdlib/promutils" +) + +// StartScheduler creates and starts a new scheduler instance. This is a blocking call and will block the calling go-routine +func StartScheduler(ctx context.Context) error { + configuration := runtime.NewConfigurationProvider() + applicationConfiguration := configuration.ApplicationConfiguration().GetTopLevelConfig() + + // Define the schedulerScope for prometheus metrics + schedulerScope := promutils.NewScope(applicationConfiguration.MetricsScope).NewSubScope("flytescheduler") + schedulerPanics := schedulerScope.MustNewCounter("initialization_panic", + "panics encountered initializing the flyte native scheduler") + + defer func() { + if err := recover(); err != nil { + schedulerPanics.Inc() + logger.Fatalf(ctx, fmt.Sprintf("caught panic: %v [%+v]", err, string(debug.Stack()))) + } + }() + + databaseConfig := configuration.ApplicationConfiguration().GetDbConfig() + logConfig := logger.GetConfig() + + db, _, err := repositories.GetDB(ctx, databaseConfig, logConfig) + if err != nil { + logger.Fatal(ctx, err) + } + dbScope := schedulerScope.NewSubScope("database") + repo := repositories.NewGormRepo( + db, errors.NewPostgresErrorTransformer(schedulerScope.NewSubScope("errors")), dbScope) + + clientSet, err := admin.ClientSetBuilder().WithConfig(admin.GetConfig(ctx)).Build(ctx) + if err != nil { + logger.Fatalf(ctx, "Flyte native scheduler failed to start due to %v", err) + return err + } + adminServiceClient := clientSet.AdminClient() + + scheduleExecutor := NewScheduledExecutor(repo, + configuration.ApplicationConfiguration().GetSchedulerConfig().GetWorkflowExecutorConfig(), schedulerScope, adminServiceClient) + + logger.Info(ctx, "Successfully initialized a native flyte scheduler") + + err = scheduleExecutor.Run(ctx) + if err != nil { + logger.Fatalf(ctx, "Flyte native scheduler failed to start due to %v", err) + return err + } + return nil +} diff --git a/tests/bootstrap.go b/tests/bootstrap.go index 57ca36416..44ca517ca 100644 --- a/tests/bootstrap.go +++ b/tests/bootstrap.go @@ -23,7 +23,7 @@ var adminScope = promutils.NewScope("flyteadmin") func getDbConfig() *runtimeInterfaces.DbConfig { return &runtimeInterfaces.DbConfig{ - PostgresConfig: runtimeInterfaces.PostgresConfig{ + PostgresConfig: &runtimeInterfaces.PostgresConfig{ Host: "postgres", Port: 5432, DbName: "postgres", @@ -34,7 +34,7 @@ func getDbConfig() *runtimeInterfaces.DbConfig { func getLocalDbConfig() *runtimeInterfaces.DbConfig { return &runtimeInterfaces.DbConfig{ - PostgresConfig: runtimeInterfaces.PostgresConfig{ + PostgresConfig: &runtimeInterfaces.PostgresConfig{ Host: "localhost", Port: 5432, DbName: "flyteadmin",