diff --git a/context.go b/context.go index 81e5e20..3b985ab 100644 --- a/context.go +++ b/context.go @@ -135,18 +135,28 @@ func (ps *Set) GetAll() []*Plugin { // GetByType should be used. If only one is expected, then to switch plugins, // disable or remove the unused plugins of the same type. func (i *InitContext) GetSingle(t Type) (interface{}, error) { - pt, ok := i.plugins.byTypeAndID[t] - if !ok || len(pt) == 0 { - return nil, fmt.Errorf("no plugins registered for %s: %w", t, ErrPluginNotFound) - } - if len(pt) > 1 { - return nil, fmt.Errorf("multiple plugins registered for %s: %w", t, ErrPluginMultipleInstances) + var ( + found bool + instance interface{} + ) + for _, v := range i.plugins.byTypeAndID[t] { + i, err := v.Instance() + if err != nil { + if IsSkipPlugin(err) { + continue + } + return i, err + } + if found { + return nil, fmt.Errorf("multiple plugins registered for %s: %w", t, ErrPluginMultipleInstances) + } + instance = i + found = true } - var p *Plugin - for _, v := range pt { - p = v + if !found { + return nil, fmt.Errorf("no plugins registered for %s: %w", t, ErrPluginNotFound) } - return p.Instance() + return instance, nil } // Plugins returns plugin set @@ -170,19 +180,20 @@ func (i *InitContext) GetByID(t Type, id string) (interface{}, error) { // GetByType returns all plugins with the specific type. func (i *InitContext) GetByType(t Type) (map[string]interface{}, error) { - pt, ok := i.plugins.byTypeAndID[t] - if !ok { - return nil, fmt.Errorf("no plugins registered for %s: %w", t, ErrPluginNotFound) - } - - pi := make(map[string]interface{}, len(pt)) - for id, p := range pt { + pi := map[string]interface{}{} + for id, p := range i.plugins.byTypeAndID[t] { i, err := p.Instance() if err != nil { + if IsSkipPlugin(err) { + continue + } return nil, err } pi[id] = i } + if len(pi) == 0 { + return nil, fmt.Errorf("no plugins registered for %s: %w", t, ErrPluginNotFound) + } return pi, nil } diff --git a/plugin_test.go b/plugin_test.go index 8238009..185eda5 100644 --- a/plugin_test.go +++ b/plugin_test.go @@ -17,6 +17,8 @@ package plugin import ( + "errors" + "fmt" "testing" ) @@ -377,3 +379,136 @@ func TestPluginGraph(t *testing.T) { cmpOrdered(t, ordered, testcase.expectedURI) } } + +func TestGetPlugins(t *testing.T) { + otherError := fmt.Errorf("other error") + plugins := NewPluginSet() + for _, p := range []*Plugin{ + testPlugin("type1", "id1", "id1", nil), + testPlugin("type1", "id2", "id2", ErrSkipPlugin), + testPlugin("type2", "id3", "id3", ErrSkipPlugin), + testPlugin("type3", "id4", "id4", nil), + testPlugin("type4", "id5", "id5", nil), + testPlugin("type4", "id6", "id6", nil), + testPlugin("type5", "id7", "id7", otherError), + } { + plugins.Add(p) + } + + ic := InitContext{ + plugins: plugins, + } + + for _, tc := range []struct { + pluginType string + err error + }{ + {"type1", nil}, + {"type2", ErrPluginNotFound}, + {"type3", nil}, + {"type4", ErrPluginMultipleInstances}, + {"type5", otherError}, + } { + t.Run("GetSingle", func(t *testing.T) { + instance, err := ic.GetSingle(Type(tc.pluginType)) + if err != nil { + if tc.err == nil { + t.Fatalf("unexpected error %v", err) + } else if !errors.Is(err, tc.err) { + t.Fatalf("unexpected error %v, expected %v", err, tc.err) + } + return + } else if tc.err != nil { + t.Fatalf("expected error %v, got no error", tc.err) + } + _, ok := instance.(string) + if !ok { + t.Fatalf("unexpected instance value %v", instance) + } + }) + } + + for _, tc := range []struct { + pluginType string + expected []string + err error + }{ + {"type1", []string{"id1"}, nil}, + {"type2", nil, ErrPluginNotFound}, + {"type3", []string{"id4"}, nil}, + {"type4", []string{"id5", "id6"}, nil}, + {"type5", nil, otherError}, + } { + t.Run("GetByType", func(t *testing.T) { + m, err := ic.GetByType(Type(tc.pluginType)) + if err != nil { + if tc.err == nil { + t.Fatalf("unexpected error %v", err) + } else if !errors.Is(err, tc.err) { + t.Fatalf("unexpected error %v, expected %v", err, tc.err) + } + return + } else if tc.err != nil { + t.Fatalf("expected error %v, got no error", tc.err) + } + + if len(m) != len(tc.expected) { + t.Fatalf("unexpected result %v, expected %v", m, tc.expected) + } + for _, v := range tc.expected { + instance, ok := m[v] + if !ok { + t.Errorf("missing value for %q", v) + continue + } + if instance.(string) != v { + t.Errorf("unexpected value %v, expected %v", instance, v) + } + } + }) + } + + for _, tc := range []struct { + pluginType string + id string + err error + }{ + {"type1", "id1", nil}, + {"type1", "id2", ErrSkipPlugin}, + {"type2", "id3", ErrSkipPlugin}, + {"type3", "id4", nil}, + {"type4", "id5", nil}, + {"type4", "id6", nil}, + {"type5", "id7", otherError}, + } { + t.Run("GetByID", func(t *testing.T) { + instance, err := ic.GetByID(Type(tc.pluginType), tc.id) + if err != nil { + if tc.err == nil { + t.Fatalf("unexpected error %v", err) + } else if !errors.Is(err, tc.err) { + t.Fatalf("unexpected error %v, expected %v", err, tc.err) + } + return + } else if tc.err != nil { + t.Fatalf("expected error %v, got no error", tc.err) + } + + if instance.(string) != tc.id { + t.Errorf("unexpected value %v, expected %v", instance, tc.id) + } + }) + } + +} + +func testPlugin(t Type, id string, i interface{}, err error) *Plugin { + return &Plugin{ + Registration: Registration{ + Type: t, + ID: id, + }, + instance: i, + err: err, + } +}