diff --git a/v2/backend.go b/v2/backend.go index cccb2c36..ed25de43 100644 --- a/v2/backend.go +++ b/v2/backend.go @@ -55,6 +55,18 @@ type Backend interface { // wasn't found or another error if any occurred. key won't contain the // namespace prefix. Fetch(ctx context.Context, key string) (any, error) + + // Validate validates the given values and returns the index of the "best" + // value or an error and -1 if all values are invalid. If the method is used + // with a single value, it will return 0 and no error if it is valid or an + // error and -1 if it is invalid. For multiple values, it will select the + // "best" value based on user-defined logic and return its index in the + // original values list. If we receive a request for /ipns/$binary_id, the + // key parameter will be set to $binary_id. Decisions about which value is + // the "best" from the given list must be stable. So if there are multiple + // equally good values, the implementation must always return the same + // index - for example, always the first good or last good value. + Validate(ctx context.Context, key string, values ...any) (int, error) } // NewBackendIPNS initializes a new backend for the "ipns" namespace that can diff --git a/v2/backend_provider.go b/v2/backend_provider.go index 3be9d88a..703fa1ee 100644 --- a/v2/backend_provider.go +++ b/v2/backend_provider.go @@ -1,11 +1,11 @@ package dht import ( + "bytes" "context" "encoding/binary" "fmt" "io" - "path" "strings" "sync" "time" @@ -226,13 +226,50 @@ func (p *ProvidersBackend) Fetch(ctx context.Context, key string) (any, error) { out.addProvider(addrInfo, rec.expiry) } - if len(out.providers) > 0 { + if len(out.providers) == 0 { + return nil, ds.ErrNotFound + } else { p.cache.Add(qKey.String(), *out) } return out, nil } +// Validate verifies that the given values are of type [peer.AddrInfo]. Then it +// decides based on the number of attached multi addresses which value is +// "better" than the other. If there is a tie, Validate will return the index +// of the earliest occurrence. +func (p *ProvidersBackend) Validate(ctx context.Context, key string, values ...any) (int, error) { + // short circuit if it's just a single value + if len(values) == 1 { + _, ok := values[0].(peer.AddrInfo) + if !ok { + return -1, fmt.Errorf("invalid type %T", values[0]) + } + return 0, nil + } + + bestIdx := -1 + for i, value := range values { + addrInfo, ok := value.(peer.AddrInfo) + if !ok { + continue + } + + if bestIdx == -1 { + bestIdx = i + } else if len(values[bestIdx].(peer.AddrInfo).Addrs) < len(addrInfo.Addrs) { + bestIdx = i + } + } + + if bestIdx == -1 { + return -1, fmt.Errorf("no value of correct type") + } + + return bestIdx, nil +} + // Close is here to implement the [io.Closer] interface. This will get called // when the [DHT] "shuts down"/closes. func (p *ProvidersBackend) Close() error { @@ -431,5 +468,16 @@ func newDatastoreKey(namespace string, binStrs ...string) ds.Key { for i, bin := range binStrs { elems[i+1] = base32.RawStdEncoding.EncodeToString([]byte(bin)) } - return ds.NewKey("/" + path.Join(elems...)) + + return ds.NewKey("/" + strings.Join(elems, "/")) +} + +// newRoutingKey uses the given namespace and binary string key and constructs +// a new string of the format: /$namespace/$binStr +func newRoutingKey(namespace string, binStr string) string { + buf := make([]byte, 0, 2+len(namespace)+len(binStr)) + buffer := bytes.NewBuffer(buf) + buffer.WriteString("/" + namespace + "/") + buffer.Write([]byte(binStr)) + return buffer.String() } diff --git a/v2/backend_provider_test.go b/v2/backend_provider_test.go index d3ab465d..37e01368 100644 --- a/v2/backend_provider_test.go +++ b/v2/backend_provider_test.go @@ -10,9 +10,13 @@ import ( "github.com/benbjohnson/clock" ds "github.com/ipfs/go-datastore" "github.com/libp2p/go-libp2p" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/multiformats/go-multiaddr" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/exp/slog" + + "github.com/libp2p/go-libp2p-kad-dht/v2/internal/kadtest" ) var devnull = slog.New(slog.NewTextHandler(io.Discard, nil)) @@ -115,3 +119,62 @@ func TestProvidersBackend_GarbageCollection_lifecycle_thread_safe(t *testing.T) assert.Nil(t, b.gcCancel) assert.Nil(t, b.gcDone) } + +func TestProvidersBackend_Validate(t *testing.T) { + ctx := kadtest.CtxShort(t) + + b := newBackendProvider(t, nil) + + pid := newPeerID(t) + peer1 := peer.AddrInfo{ID: pid, Addrs: make([]multiaddr.Multiaddr, 0)} + peer2 := peer.AddrInfo{ID: pid, Addrs: make([]multiaddr.Multiaddr, 1)} + peer3 := peer.AddrInfo{ID: pid, Addrs: make([]multiaddr.Multiaddr, 2)} + + t.Run("no values", func(t *testing.T) { + idx, err := b.Validate(ctx, "some-key") + assert.Error(t, err) + assert.Equal(t, -1, idx) + }) + + t.Run("nil value", func(t *testing.T) { + idx, err := b.Validate(ctx, "some-key", nil) + assert.Error(t, err) + assert.Equal(t, -1, idx) + }) + + t.Run("nil values", func(t *testing.T) { + idx, err := b.Validate(ctx, "some-key", nil, nil) + assert.Error(t, err) + assert.Equal(t, -1, idx) + }) + + t.Run("single valid value", func(t *testing.T) { + idx, err := b.Validate(ctx, "some-key", peer1) + assert.NoError(t, err) + assert.Equal(t, 0, idx) + }) + + t.Run("increasing better values", func(t *testing.T) { + idx, err := b.Validate(ctx, "some-key", peer1, peer2, peer3) + assert.NoError(t, err) + assert.Equal(t, 2, idx) + }) + + t.Run("mixed better values", func(t *testing.T) { + idx, err := b.Validate(ctx, "some-key", peer1, peer3, peer2) + assert.NoError(t, err) + assert.Equal(t, 1, idx) + }) + + t.Run("mixed invalid values", func(t *testing.T) { + idx, err := b.Validate(ctx, "some-key", peer1, nil, peer2, nil) + assert.NoError(t, err) + assert.Equal(t, 2, idx) + }) + + t.Run("identically good values", func(t *testing.T) { + idx, err := b.Validate(ctx, "some-key", peer1, peer1) + assert.NoError(t, err) + assert.Equal(t, 0, idx) + }) +} diff --git a/v2/backend_record.go b/v2/backend_record.go index ba4a94ba..e0c4284d 100644 --- a/v2/backend_record.go +++ b/v2/backend_record.go @@ -131,6 +131,57 @@ func (r *RecordBackend) Fetch(ctx context.Context, key string) (any, error) { return rec, nil } +func (r *RecordBackend) Validate(ctx context.Context, key string, values ...any) (int, error) { + k := newRoutingKey(r.namespace, key) + + // short circuit if it's just a single value + if len(values) == 1 { + data, ok := values[0].([]byte) + if !ok { + return -1, fmt.Errorf("value not byte slice") + } + + if err := r.validator.Validate(k, data); err != nil { + return -1, err + } + + return 0, nil + } + + // In case there are invalid values in the slice, we still want to return + // the index in the original list of values. The Select method below will + // return the index of the "best" value in the slice of valid values. This + // slice can have a different length and therefore that method will return + // an index that doesn't match the values slice that's passed into this + // method. origIdx stores the original index + origIdx := map[int]int{} + validValues := [][]byte{} + for i, value := range values { + data, ok := value.([]byte) + if !ok { + continue + } + + if err := r.validator.Validate(k, data); err != nil { + continue + } + + origIdx[len(validValues)] = i + validValues = append(validValues, data) + } + + if len(validValues) == 0 { + return -1, fmt.Errorf("no valid values") + } + + sel, err := r.validator.Select(k, validValues) + if err != nil { + return -1, err + } + + return origIdx[sel], nil +} + // shouldReplaceExistingRecord returns true if the given record should replace any // existing one in the local datastore. It queries the datastore, unmarshalls // the record, validates it, and compares it to the incoming record. If the diff --git a/v2/backend_record_test.go b/v2/backend_record_test.go new file mode 100644 index 00000000..eb772a3a --- /dev/null +++ b/v2/backend_record_test.go @@ -0,0 +1,114 @@ +package dht + +import ( + "fmt" + "strconv" + "strings" + "testing" + + record "github.com/libp2p/go-libp2p-record" + "github.com/stretchr/testify/assert" + + "github.com/libp2p/go-libp2p-kad-dht/v2/internal/kadtest" +) + +// testValidator is a validator that considers all values valid that have a +// "valid-" prefix. Then the suffix will determine which value is better. For +// example, "valid-2" is better than "valid-1". +type testValidator struct{} + +var _ record.Validator = (*testValidator)(nil) + +func (t testValidator) Validate(key string, value []byte) error { + if strings.HasPrefix(string(value), "valid-") { + return nil + } + return fmt.Errorf("invalid value") +} + +func (t testValidator) Select(key string, values [][]byte) (int, error) { + idx := -1 + best := -1 + for i, val := range values { + if !strings.HasPrefix(string(val), "valid-") { + continue + } + newBest, err := strconv.Atoi(string(val)[6:]) + if err != nil { + continue + } + if newBest > best { + idx = i + best = newBest + } + } + + if idx == -1 { + return idx, fmt.Errorf("no valid value") + } + + return idx, nil +} + +func TestRecordBackend_Validate(t *testing.T) { + ctx := kadtest.CtxShort(t) + + b := &RecordBackend{ + namespace: "test", + validator: &testValidator{}, + } + + t.Run("no values", func(t *testing.T) { + idx, err := b.Validate(ctx, "some-key") + assert.Error(t, err) + assert.Equal(t, -1, idx) + }) + + t.Run("nil value", func(t *testing.T) { + idx, err := b.Validate(ctx, "some-key", nil) + assert.Error(t, err) + assert.Equal(t, -1, idx) + }) + + t.Run("nil values", func(t *testing.T) { + idx, err := b.Validate(ctx, "some-key", nil, nil) + assert.Error(t, err) + assert.Equal(t, -1, idx) + }) + + t.Run("single valid value", func(t *testing.T) { + idx, err := b.Validate(ctx, "some-key", []byte("valid-0")) + assert.NoError(t, err) + assert.Equal(t, 0, idx) + }) + + t.Run("increasing better values", func(t *testing.T) { + idx, err := b.Validate(ctx, "some-key", []byte("valid-0"), []byte("valid-1"), []byte("valid-2")) + assert.NoError(t, err) + assert.Equal(t, 2, idx) + }) + + t.Run("mixed better values", func(t *testing.T) { + idx, err := b.Validate(ctx, "some-key", []byte("valid-0"), []byte("valid-2"), []byte("valid-1")) + assert.NoError(t, err) + assert.Equal(t, 1, idx) + }) + + t.Run("mixed invalid values", func(t *testing.T) { + idx, err := b.Validate(ctx, "some-key", []byte("valid-0"), []byte("invalid"), []byte("valid-2"), []byte("invalid")) + assert.NoError(t, err) + assert.Equal(t, 2, idx) + }) + + t.Run("only invalid values", func(t *testing.T) { + idx, err := b.Validate(ctx, "some-key", []byte("invalid"), nil) + assert.Error(t, err) + assert.Equal(t, -1, idx) + }) + + t.Run("identically good values", func(t *testing.T) { + idx, err := b.Validate(ctx, "some-key", []byte("valid-0"), []byte("valid-0")) + assert.NoError(t, err) + assert.Equal(t, 0, idx) + }) +} diff --git a/v2/backend_trace.go b/v2/backend_trace.go index 72335c35..f09dfd0d 100644 --- a/v2/backend_trace.go +++ b/v2/backend_trace.go @@ -57,6 +57,21 @@ func (t *tracedBackend) Fetch(ctx context.Context, key string) (any, error) { return result, err } +func (t *tracedBackend) Validate(ctx context.Context, key string, values ...any) (int, error) { + ctx, span := t.tracer.Start(ctx, "Validate", t.traceAttributes(key)) + defer span.End() + + idx, err := t.backend.Validate(ctx, key, values...) + if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) + } else { + span.SetAttributes(attribute.Int("idx", idx)) + } + + return idx, err +} + // traceAttributes is a helper to build the trace attributes. func (t *tracedBackend) traceAttributes(key string) trace.SpanStartEventOption { return trace.WithAttributes(attribute.String("namespace", t.namespace), attribute.String("key", key)) diff --git a/v2/config.go b/v2/config.go index 696b8dfe..b7314ecf 100644 --- a/v2/config.go +++ b/v2/config.go @@ -269,6 +269,15 @@ func (c *Config) Validate() error { } } + for _, bp := range c.BootstrapPeers { + if len(bp.Addrs) == 0 { + return &ConfigurationError{ + Component: "Config", + Err: fmt.Errorf("bootstrap peer with no address"), + } + } + } + if c.ProtocolID == "" { return &ConfigurationError{ Component: "Config", @@ -378,6 +387,14 @@ type QueryConfig struct { // RequestTimeout defines the time to wait before terminating a request to a node that has not responded. RequestTimeout time.Duration + + // DefaultQuorum specifies the minimum number of identical responses before + // a SearchValue/GetValue operation returns. The responses must not only be + // identical, but the responses must also correspond to the "best" records + // we have observed in the network during the SearchValue/GetValue + // operation. A DefaultQuorum of 0 means that we search the network until + // we have exhausted the keyspace. + DefaultQuorum int } // DefaultQueryConfig returns the default query configuration options for a DHT. @@ -387,6 +404,7 @@ func DefaultQueryConfig() *QueryConfig { Timeout: 5 * time.Minute, // MAGIC RequestConcurrency: 3, // MAGIC RequestTimeout: time.Minute, // MAGIC + DefaultQuorum: 0, // MAGIC } } @@ -419,5 +437,12 @@ func (cfg *QueryConfig) Validate() error { } } + if cfg.DefaultQuorum < 0 { + return &ConfigurationError{ + Component: "QueryConfig", + Err: fmt.Errorf("default quorum must not be negative"), + } + } + return nil } diff --git a/v2/config_test.go b/v2/config_test.go index ad84b8d4..5b49f9eb 100644 --- a/v2/config_test.go +++ b/v2/config_test.go @@ -5,7 +5,9 @@ import ( "time" "github.com/libp2p/go-libp2p/core/peer" + ma "github.com/multiformats/go-multiaddr" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestConfig_Validate(t *testing.T) { @@ -140,6 +142,25 @@ func TestConfig_Validate(t *testing.T) { cfg.BootstrapPeers = []peer.AddrInfo{} assert.Error(t, cfg.Validate()) }) + + t.Run("bootstrap peers no addresses", func(t *testing.T) { + cfg := DefaultConfig() + cfg.BootstrapPeers = []peer.AddrInfo{ + {ID: newPeerID(t), Addrs: []ma.Multiaddr{}}, + } + assert.Error(t, cfg.Validate()) + }) + + t.Run("bootstrap peers mixed no addresses", func(t *testing.T) { + cfg := DefaultConfig() + maddr, err := ma.NewMultiaddr("/ip4/127.0.0.1/tcp/1234") + require.NoError(t, err) + cfg.BootstrapPeers = []peer.AddrInfo{ + {ID: newPeerID(t), Addrs: []ma.Multiaddr{}}, + {ID: newPeerID(t), Addrs: []ma.Multiaddr{maddr}}, + } + assert.Error(t, cfg.Validate()) // still an error + }) } func TestQueryConfig_Validate(t *testing.T) { @@ -183,4 +204,13 @@ func TestQueryConfig_Validate(t *testing.T) { cfg.RequestTimeout = -1 assert.Error(t, cfg.Validate()) }) + + t.Run("negative default quorum", func(t *testing.T) { + cfg := DefaultQueryConfig() + + cfg.DefaultQuorum = 0 + assert.NoError(t, cfg.Validate()) + cfg.DefaultQuorum = -1 + assert.Error(t, cfg.Validate()) + }) } diff --git a/v2/dht.go b/v2/dht.go index b9ec9993..5a2dcb03 100644 --- a/v2/dht.go +++ b/v2/dht.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "sync" + "sync/atomic" "time" "github.com/ipfs/go-datastore/trace" @@ -55,6 +56,9 @@ type DHT struct { // tele holds a reference to a telemetry struct tele *Telemetry + + // indicates whether this DHT instance was stopped ([DHT.Close] was called). + stopped atomic.Bool } // New constructs a new [DHT] for the given underlying host and with the given @@ -114,7 +118,12 @@ func New(h host.Host, cfg *Config) (*DHT, error) { coordCfg.MeterProvider = cfg.MeterProvider coordCfg.TracerProvider = cfg.TracerProvider - d.kad, err = coord.NewCoordinator(kadt.PeerID(d.host.ID()), &router{host: h, ProtocolID: cfg.ProtocolID}, d.rt, coordCfg) + rtr := &router{ + host: h, + protocolID: cfg.ProtocolID, + tracer: d.tele.Tracer, + } + d.kad, err = coord.NewCoordinator(kadt.PeerID(d.host.ID()), rtr, d.rt, coordCfg) if err != nil { return nil, fmt.Errorf("new coordinator: %w", err) } @@ -201,6 +210,10 @@ func (d *DHT) initAminoBackends() (map[string]Backend, error) { // Close cleans up all resources associated with this DHT. func (d *DHT) Close() error { + if d.stopped.Swap(true) { + return nil + } + if err := d.sub.Close(); err != nil { d.log.With("err", err).Debug("failed closing event bus subscription") } diff --git a/v2/dht_test.go b/v2/dht_test.go index 0ee635df..ad53e2f7 100644 --- a/v2/dht_test.go +++ b/v2/dht_test.go @@ -1,6 +1,7 @@ package dht import ( + "sync" "testing" "time" @@ -103,3 +104,17 @@ func TestAddAddresses(t *testing.T) { _, err = local.kad.GetNode(ctx, kadt.PeerID(remote.host.ID())) require.NoError(t, err) } + +func TestDHT_Close_idempotent(t *testing.T) { + d := newTestDHT(t) + + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + assert.NoError(t, d.Close()) + wg.Done() + }() + } + wg.Wait() +} diff --git a/v2/handlers.go b/v2/handlers.go index 74e55a2b..aebd9a68 100644 --- a/v2/handlers.go +++ b/v2/handlers.go @@ -83,6 +83,8 @@ func (d *DHT) handlePutValue(ctx context.Context, remote peer.ID, req *pb.Messag return nil, fmt.Errorf("key doesn't match record key") } + // TODO: use putValueLocal? + // key is /$namespace/$binary_id ns, path, err := record.SplitKey(k) // get namespace (prefix of the key) if err != nil || len(path) == 0 { @@ -208,8 +210,17 @@ func (d *DHT) handleGetProviders(ctx context.Context, remote peer.ID, req *pb.Me return nil, fmt.Errorf("unsupported record type: %s", namespaceProviders) } + resp := &pb.Message{ + Type: pb.Message_GET_PROVIDERS, + Key: k, + CloserPeers: d.closerPeers(ctx, remote, kadt.NewKey(k)), + } + fetched, err := backend.Fetch(ctx, string(req.GetKey())) if err != nil { + if errors.Is(err, ds.ErrNotFound) { + return resp, nil + } return nil, fmt.Errorf("fetch providers from datastore: %w", err) } @@ -223,12 +234,7 @@ func (d *DHT) handleGetProviders(ctx context.Context, remote peer.ID, req *pb.Me pbProviders[i] = pb.FromAddrInfo(p) } - resp := &pb.Message{ - Type: pb.Message_GET_PROVIDERS, - Key: k, - CloserPeers: d.closerPeers(ctx, remote, kadt.NewKey(k)), - ProviderPeers: pbProviders, - } + resp.ProviderPeers = pbProviders return resp, nil } diff --git a/v2/handlers_test.go b/v2/handlers_test.go index 6910c347..f838aaf8 100644 --- a/v2/handlers_test.go +++ b/v2/handlers_test.go @@ -436,6 +436,24 @@ func BenchmarkDHT_handlePing(b *testing.B) { func newPutIPNSRequest(t testing.TB, clk clock.Clock, priv crypto.PrivKey, seq uint64, ttl time.Duration) *pb.Message { t.Helper() + keyStr, value := makeIPNSKeyValue(t, clk, priv, seq, ttl) + + req := &pb.Message{ + Type: pb.Message_PUT_VALUE, + Key: []byte(keyStr), + Record: &recpb.Record{ + Key: []byte(keyStr), + Value: value, + TimeReceived: clk.Now().Format(time.RFC3339Nano), + }, + } + + return req +} + +func makeIPNSKeyValue(t testing.TB, clk clock.Clock, priv crypto.PrivKey, seq uint64, ttl time.Duration) (string, []byte) { + t.Helper() + testPath := path.Path("/ipfs/bafkqac3jobxhgidsn5rww4yk") rec, err := ipns.NewRecord(priv, testPath, seq, clk.Now().Add(ttl), ttl) @@ -447,18 +465,7 @@ func newPutIPNSRequest(t testing.TB, clk clock.Clock, priv crypto.PrivKey, seq u data, err := ipns.MarshalRecord(rec) require.NoError(t, err) - key := ipns.NameFromPeer(remote).RoutingKey() - req := &pb.Message{ - Type: pb.Message_PUT_VALUE, - Key: key, - Record: &recpb.Record{ - Key: key, - Value: data, - TimeReceived: time.Now().Format(time.RFC3339Nano), - }, - } - - return req + return string(ipns.NameFromPeer(remote).RoutingKey()), data } func BenchmarkDHT_handlePutValue_unique_peers(b *testing.B) { diff --git a/v2/internal/coord/brdcst/brdcst.go b/v2/internal/coord/brdcst/brdcst.go index 5d16b973..8711a1b3 100644 --- a/v2/internal/coord/brdcst/brdcst.go +++ b/v2/internal/coord/brdcst/brdcst.go @@ -75,7 +75,7 @@ func (*StateBroadcastIdle) broadcastState() {} // implement this interface. An "Event" is the opposite of a "State." An "Event" // flows into the state machine and a "State" flows out of it. // -// Currently, there are the [FollowUp] and [Optimistic] state machines. +// Currently, there are the [FollowUp] and [Static] state machines. type BroadcastEvent interface { broadcastEvent() } diff --git a/v2/internal/coord/brdcst/config.go b/v2/internal/coord/brdcst/config.go index 4d6d425b..70ca0c85 100644 --- a/v2/internal/coord/brdcst/config.go +++ b/v2/internal/coord/brdcst/config.go @@ -31,16 +31,17 @@ func DefaultConfigPool() *ConfigPool { // Config is an interface that all broadcast configurations must implement. // Because we have multiple ways of broadcasting records to the network, like -// [FollowUp] or [Optimistic], the [EventPoolStartBroadcast] has a configuration +// [FollowUp] or [Static], the [EventPoolStartBroadcast] has a configuration // field that depending on the concrete type of [Config] initializes the -// respective state machine. Then the broadcast operation will performed based -// on the encoded rules in that state machine. +// respective state machine. Then the broadcast operation will be performed +// based on the encoded rules in that state machine. type Config interface { broadcastConfig() } func (c *ConfigFollowUp) broadcastConfig() {} func (c *ConfigOptimistic) broadcastConfig() {} +func (c *ConfigStatic) broadcastConfig() {} // ConfigFollowUp specifies the configuration for the [FollowUp] state machine. type ConfigFollowUp struct{} @@ -72,3 +73,19 @@ func (c *ConfigOptimistic) Validate() error { func DefaultConfigOptimistic() *ConfigOptimistic { return &ConfigOptimistic{} } + +// ConfigStatic specifies the configuration for the [Static] state +// machine. +type ConfigStatic struct{} + +// Validate checks the configuration options and returns an error if any have +// invalid values. +func (c *ConfigStatic) Validate() error { + return nil +} + +// DefaultConfigStatic returns the default configuration options for the +// [Static] state machine. +func DefaultConfigStatic() *ConfigStatic { + return &ConfigStatic{} +} diff --git a/v2/internal/coord/brdcst/config_test.go b/v2/internal/coord/brdcst/config_test.go index 68447a1f..43779523 100644 --- a/v2/internal/coord/brdcst/config_test.go +++ b/v2/internal/coord/brdcst/config_test.go @@ -37,6 +37,7 @@ func TestConfig_interface_conformance(t *testing.T) { configs := []Config{ &ConfigFollowUp{}, &ConfigOptimistic{}, + &ConfigStatic{}, } for _, c := range configs { c.broadcastConfig() // drives test coverage diff --git a/v2/internal/coord/brdcst/pool.go b/v2/internal/coord/brdcst/pool.go index bba83dad..71d4e936 100644 --- a/v2/internal/coord/brdcst/pool.go +++ b/v2/internal/coord/brdcst/pool.go @@ -14,7 +14,7 @@ import ( // Broadcast is a type alias for a specific kind of state machine that any // kind of broadcast strategy state machine must implement. Currently, there -// are the [FollowUp] and [Optimistic] state machines. +// are the [FollowUp] and [Static] state machines. type Broadcast = coordt.StateMachine[BroadcastEvent, BroadcastState] // Pool is a [coordt.StateMachine] that manages all running broadcast @@ -26,7 +26,7 @@ type Broadcast = coordt.StateMachine[BroadcastEvent, BroadcastState] // // Conceptually, a broadcast consists of finding the closest nodes to a certain // key and then storing the record with them. There are a few different -// strategies that can be applied. For now, these are the [FollowUp] and the [Optimistic] +// strategies that can be applied. For now, these are the [FollowUp] and the [Static] // strategies. In the future, we also want to support [Reprovide Sweep]. // However, this requires a different type of query as we are not looking for // the closest nodes but rather enumerating the keyspace. In any case, this @@ -104,7 +104,7 @@ func (p *Pool[K, N, M]) Advance(ctx context.Context, ev PoolEvent) (out PoolStat } // handleEvent receives a broadcast [PoolEvent] and returns the corresponding -// broadcast state machine [FollowUp] or [Optimistic] plus the event for that +// broadcast state machine [FollowUp] or [Static] plus the event for that // state machine. If any return parameter is nil, either the pool event was for // an unknown query or the event doesn't need to be forwarded to the state // machine. @@ -120,7 +120,9 @@ func (p *Pool[K, N, M]) handleEvent(ctx context.Context, ev PoolEvent) (sm Broad // first initialize the state machine for the broadcast desired strategy switch cfg := ev.Config.(type) { case *ConfigFollowUp: - p.bcs[ev.QueryID] = NewFollowUp(ev.QueryID, p.qp, ev.Message, cfg) + p.bcs[ev.QueryID] = NewFollowUp[K, N, M](ev.QueryID, p.qp, ev.Message, cfg) + case *ConfigStatic: + p.bcs[ev.QueryID] = NewStatic[K, N, M](ev.QueryID, ev.Message, cfg) case *ConfigOptimistic: panic("implement me") } @@ -171,7 +173,7 @@ func (p *Pool[K, N, M]) handleEvent(ctx context.Context, ev PoolEvent) (sm Broad } // advanceBroadcast advances the given broadcast state machine ([FollowUp] or -// [Optimistic]) and returns the new [Pool] state ([PoolState]). The additional +// [Static]) and returns the new [Pool] state ([PoolState]). The additional // boolean value indicates whether the returned [PoolState] should be ignored. func (p *Pool[K, N, M]) advanceBroadcast(ctx context.Context, sm Broadcast, bev BroadcastEvent) (PoolState, bool) { ctx, span := tele.StartSpan(ctx, "Pool.advanceBroadcast", trace.WithAttributes(tele.AttrInEvent(bev))) @@ -284,7 +286,7 @@ type EventPoolStartBroadcast[K kad.Key[K], N kad.NodeID[K], M coordt.Message] st Target K // the key we want to store the record for Message M // the message that we want to send to the closest peers (this encapsulates the payload we want to store) Seed []N // the closest nodes we know so far and from where we start the operation - Config Config // the configuration for this operation. Most importantly, this defines the broadcast strategy ([FollowUp] or [Optimistic]) + Config Config // the configuration for this operation. Most importantly, this defines the broadcast strategy ([FollowUp] or [Static]) } // EventPoolStopBroadcast notifies broadcast [Pool] to stop a broadcast diff --git a/v2/internal/coord/brdcst/pool_test.go b/v2/internal/coord/brdcst/pool_test.go index e4e7409c..aaad7dc9 100644 --- a/v2/internal/coord/brdcst/pool_test.go +++ b/v2/internal/coord/brdcst/pool_test.go @@ -272,7 +272,7 @@ func TestPool_FollowUp_stop_during_followup_phase(t *testing.T) { require.Len(t, st.Errors, 2) } -func TestPool_FollowUp_empty_seed(t *testing.T) { +func TestPool_empty_seed(t *testing.T) { ctx := context.Background() cfg := DefaultConfigPool() @@ -286,17 +286,121 @@ func TestPool_FollowUp_empty_seed(t *testing.T) { queryID := coordt.QueryID("test") - state := p.Advance(ctx, &EventPoolStartBroadcast[tiny.Key, tiny.Node, tiny.Message]{ + startEvt := &EventPoolStartBroadcast[tiny.Key, tiny.Node, tiny.Message]{ QueryID: queryID, Target: target, Message: msg, Seed: []tiny.Node{}, - Config: DefaultConfigFollowUp(), + } + + t.Run("follow up", func(t *testing.T) { + startEvt.Config = DefaultConfigFollowUp() + + state := p.Advance(ctx, startEvt) + require.IsType(t, &StatePoolBroadcastFinished[tiny.Key, tiny.Node]{}, state) + + state = p.Advance(ctx, &EventPoolPoll{}) + require.IsType(t, &StatePoolIdle{}, state) + }) + + t.Run("static", func(t *testing.T) { + startEvt.Config = DefaultConfigStatic() + state := p.Advance(ctx, startEvt) + require.IsType(t, &StatePoolBroadcastFinished[tiny.Key, tiny.Node]{}, state) + + state = p.Advance(ctx, &EventPoolPoll{}) + require.IsType(t, &StatePoolIdle{}, state) + }) +} + +func TestPool_Static_happy_path(t *testing.T) { + ctx := context.Background() + cfg := DefaultConfigPool() + + self := tiny.NewNode(0) + + p, err := NewPool[tiny.Key, tiny.Node, tiny.Message](self, cfg) + require.NoError(t, err) + + msg := tiny.Message{Content: "store this"} + target := tiny.Key(0b00000001) + a := tiny.NewNode(0b00000100) // 4 + b := tiny.NewNode(0b00000011) // 3 + c := tiny.NewNode(0b00000010) // 2 + + queryID := coordt.QueryID("test") + + state := p.Advance(ctx, &EventPoolStartBroadcast[tiny.Key, tiny.Node, tiny.Message]{ + QueryID: queryID, + Target: target, + Message: msg, + Seed: []tiny.Node{a, b, c}, + Config: DefaultConfigStatic(), + }) + spsr, ok := state.(*StatePoolStoreRecord[tiny.Key, tiny.Node, tiny.Message]) + require.True(t, ok, "state is %T", state) + first := spsr.NodeID + + state = p.Advance(ctx, &EventPoolPoll{}) + spsr, ok = state.(*StatePoolStoreRecord[tiny.Key, tiny.Node, tiny.Message]) + require.True(t, ok, "state is %T", state) + second := spsr.NodeID + + state = p.Advance(ctx, &EventPoolStoreRecordSuccess[tiny.Key, tiny.Node, tiny.Message]{ + QueryID: queryID, + NodeID: first, + Request: msg, + }) + spsr, ok = state.(*StatePoolStoreRecord[tiny.Key, tiny.Node, tiny.Message]) + require.True(t, ok, "state is %T", state) + third := spsr.NodeID + + state = p.Advance(ctx, &EventPoolStoreRecordFailure[tiny.Key, tiny.Node, tiny.Message]{ + QueryID: queryID, + NodeID: second, + Request: msg, + }) + require.IsType(t, &StatePoolWaiting{}, state) + + state = p.Advance(ctx, &EventPoolStoreRecordSuccess[tiny.Key, tiny.Node, tiny.Message]{ + QueryID: queryID, + NodeID: third, + Request: msg, }) require.IsType(t, &StatePoolBroadcastFinished[tiny.Key, tiny.Node]{}, state) +} + +func TestPool_Static_stop_mid_flight(t *testing.T) { + ctx := context.Background() + cfg := DefaultConfigPool() + + self := tiny.NewNode(0) + + p, err := NewPool[tiny.Key, tiny.Node, tiny.Message](self, cfg) + require.NoError(t, err) + + msg := tiny.Message{Content: "store this"} + target := tiny.Key(0b00000001) + a := tiny.NewNode(0b00000100) // 4 + b := tiny.NewNode(0b00000011) // 3 + c := tiny.NewNode(0b00000010) // 2 + + queryID := coordt.QueryID("test") + + state := p.Advance(ctx, &EventPoolStartBroadcast[tiny.Key, tiny.Node, tiny.Message]{ + QueryID: queryID, + Target: target, + Message: msg, + Seed: []tiny.Node{a, b, c}, + Config: DefaultConfigStatic(), + }) + require.IsType(t, &StatePoolStoreRecord[tiny.Key, tiny.Node, tiny.Message]{}, state) state = p.Advance(ctx, &EventPoolPoll{}) - require.IsType(t, &StatePoolIdle{}, state) + require.IsType(t, &StatePoolStoreRecord[tiny.Key, tiny.Node, tiny.Message]{}, state) + + state = p.Advance(ctx, &EventPoolStopBroadcast{QueryID: queryID}) + require.IsType(t, &StatePoolBroadcastFinished[tiny.Key, tiny.Node]{}, state) } func TestPoolState_interface_conformance(t *testing.T) { diff --git a/v2/internal/coord/brdcst/static.go b/v2/internal/coord/brdcst/static.go new file mode 100644 index 00000000..0a36721b --- /dev/null +++ b/v2/internal/coord/brdcst/static.go @@ -0,0 +1,143 @@ +package brdcst + +import ( + "context" + "fmt" + + "github.com/plprobelab/go-kademlia/kad" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" + + "github.com/libp2p/go-libp2p-kad-dht/v2/internal/coord/coordt" + "github.com/libp2p/go-libp2p-kad-dht/v2/tele" +) + +// Static is a [Broadcast] state machine and encapsulates the logic around +// doing a put operation to a static set of nodes. That static set of nodes +// is given by the list of seed nodes in the [EventBroadcastStart] event. +type Static[K kad.Key[K], N kad.NodeID[K], M coordt.Message] struct { + // the unique ID for this broadcast operation + queryID coordt.QueryID + + // a struct holding configuration options + cfg *ConfigStatic + + // the message that we will send to the closest nodes in the follow-up phase + msg M + + // nodes we still need to store records with. This map will be filled with + // all the closest nodes after the query has finished. + todo map[string]N + + // nodes we have contacted to store the record but haven't heard a response yet + waiting map[string]N + + // nodes that successfully hold the record for us + success map[string]N + + // nodes that failed to hold the record for us + failed map[string]struct { + Node N + Err error + } +} + +// NewStatic initializes a new [Static] struct. +func NewStatic[K kad.Key[K], N kad.NodeID[K], M coordt.Message](qid coordt.QueryID, msg M, cfg *ConfigStatic) *Static[K, N, M] { + return &Static[K, N, M]{ + queryID: qid, + cfg: cfg, + msg: msg, + todo: map[string]N{}, + waiting: map[string]N{}, + success: map[string]N{}, + failed: map[string]struct { + Node N + Err error + }{}, + } +} + +// Advance advances the state of the [Static] [Broadcast] state machine. +func (f *Static[K, N, M]) Advance(ctx context.Context, ev BroadcastEvent) (out BroadcastState) { + _, span := tele.StartSpan(ctx, "Static.Advance", trace.WithAttributes(tele.AttrInEvent(ev))) + defer func() { + span.SetAttributes( + tele.AttrOutEvent(out), + attribute.Int("todo", len(f.todo)), + attribute.Int("waiting", len(f.waiting)), + attribute.Int("success", len(f.success)), + attribute.Int("failed", len(f.failed)), + ) + span.End() + }() + + switch ev := ev.(type) { + case *EventBroadcastStart[K, N]: + span.SetAttributes(attribute.Int("seed", len(ev.Seed))) + for _, seed := range ev.Seed { + f.todo[seed.String()] = seed + } + case *EventBroadcastStop: + for _, n := range f.todo { + delete(f.todo, n.String()) + f.failed[n.String()] = struct { + Node N + Err error + }{Node: n, Err: fmt.Errorf("cancelled")} + } + + for _, n := range f.waiting { + delete(f.waiting, n.String()) + f.failed[n.String()] = struct { + Node N + Err error + }{Node: n, Err: fmt.Errorf("cancelled")} + } + case *EventBroadcastStoreRecordSuccess[K, N, M]: + delete(f.waiting, ev.NodeID.String()) + f.success[ev.NodeID.String()] = ev.NodeID + case *EventBroadcastStoreRecordFailure[K, N, M]: + delete(f.waiting, ev.NodeID.String()) + f.failed[ev.NodeID.String()] = struct { + Node N + Err error + }{Node: ev.NodeID, Err: ev.Error} + case *EventBroadcastPoll: + // ignore, nothing to do + default: + panic(fmt.Sprintf("unexpected event: %T", ev)) + } + + for k, n := range f.todo { + delete(f.todo, k) + f.waiting[k] = n + return &StateBroadcastStoreRecord[K, N, M]{ + QueryID: f.queryID, + NodeID: n, + Message: f.msg, + } + } + + if len(f.waiting) > 0 { + return &StateBroadcastWaiting{} + } + + if len(f.todo) == 0 { + contacted := make([]N, 0, len(f.success)+len(f.failed)) + for _, n := range f.success { + contacted = append(contacted, n) + } + for _, n := range f.failed { + contacted = append(contacted, n.Node) + } + + return &StateBroadcastFinished[K, N]{ + QueryID: f.queryID, + Contacted: contacted, + Errors: f.failed, + } + } + + return &StateBroadcastIdle{} +} diff --git a/v2/internal/coord/coordinator.go b/v2/internal/coord/coordinator.go index e8e1428b..cbaba85f 100644 --- a/v2/internal/coord/coordinator.go +++ b/v2/internal/coord/coordinator.go @@ -378,7 +378,7 @@ func (c *Coordinator) QueryClosest(ctx context.Context, target kadt.Key, fn coor // numResults specifies the minimum number of nodes to successfully contact before considering iteration complete. // The query is considered to be exhausted when it has received responses from at least this number of nodes // and there are no closer nodes remaining to be contacted. A default of 20 is used if this value is less than 1. -func (c *Coordinator) QueryMessage(ctx context.Context, msg *pb.Message, fn coordt.QueryFunc, numResults int) (coordt.QueryStats, error) { +func (c *Coordinator) QueryMessage(ctx context.Context, msg *pb.Message, fn coordt.QueryFunc, numResults int) ([]kadt.PeerID, coordt.QueryStats, error) { ctx, span := c.tele.Tracer.Start(ctx, "Coordinator.QueryMessage") defer span.End() @@ -391,7 +391,7 @@ func (c *Coordinator) QueryMessage(ctx context.Context, msg *pb.Message, fn coor seeds, err := c.GetClosestNodes(ctx, msg.Target(), numResults) if err != nil { - return coordt.QueryStats{}, err + return nil, coordt.QueryStats{}, err } seedIDs := make([]kadt.PeerID, 0, len(seeds)) @@ -414,8 +414,8 @@ func (c *Coordinator) QueryMessage(ctx context.Context, msg *pb.Message, fn coor // queue the start of the query c.queryBehaviour.Notify(ctx, cmd) - _, stats, err := c.waitForQuery(ctx, queryID, waiter, fn) - return stats, err + closest, stats, err := c.waitForQuery(ctx, queryID, waiter, fn) + return closest, stats, err } func (c *Coordinator) BroadcastRecord(ctx context.Context, msg *pb.Message) error { @@ -425,15 +425,30 @@ func (c *Coordinator) BroadcastRecord(ctx context.Context, msg *pb.Message) erro ctx, cancel := context.WithCancel(ctx) defer cancel() - seeds, err := c.GetClosestNodes(ctx, msg.Target(), 20) // TODO: parameterize + seedNodes, err := c.GetClosestNodes(ctx, msg.Target(), 20) // TODO: parameterize if err != nil { return err } - seedIDs := make([]kadt.PeerID, 0, len(seeds)) - for _, s := range seeds { - seedIDs = append(seedIDs, s.ID()) + seeds := make([]kadt.PeerID, 0, len(seedNodes)) + for _, s := range seedNodes { + seeds = append(seeds, s.ID()) } + return c.broadcast(ctx, msg, seeds, brdcst.DefaultConfigFollowUp()) +} + +func (c *Coordinator) BroadcastStatic(ctx context.Context, msg *pb.Message, seeds []kadt.PeerID) error { + ctx, span := c.tele.Tracer.Start(ctx, "Coordinator.BroadcastStatic") + defer span.End() + return c.broadcast(ctx, msg, seeds, brdcst.DefaultConfigStatic()) +} + +func (c *Coordinator) broadcast(ctx context.Context, msg *pb.Message, seeds []kadt.PeerID, cfg brdcst.Config) error { + ctx, span := c.tele.Tracer.Start(ctx, "Coordinator.broadcast") + defer span.End() + + ctx, cancel := context.WithCancel(ctx) + defer cancel() waiter := NewWaiter[BehaviourEvent]() queryID := c.newOperationID() @@ -442,9 +457,9 @@ func (c *Coordinator) BroadcastRecord(ctx context.Context, msg *pb.Message) erro QueryID: queryID, Target: msg.Target(), Message: msg, - Seed: seedIDs, + Seed: seeds, Notify: waiter, - Config: brdcst.DefaultConfigFollowUp(), + Config: cfg, } // queue the start of the query @@ -470,7 +485,10 @@ func (c *Coordinator) waitForQuery(ctx context.Context, queryID coordt.QueryID, select { case <-ctx.Done(): return nil, lastStats, ctx.Err() - case wev := <-waiter.Chan(): + case wev, more := <-waiter.Chan(): + if !more { + return nil, lastStats, ctx.Err() + } ctx, ev := wev.Ctx, wev.Event switch ev := ev.(type) { case *EventQueryProgressed: @@ -519,7 +537,11 @@ func (c *Coordinator) waitForBroadcast(ctx context.Context, waiter *Waiter[Behav select { case <-ctx.Done(): return nil, nil, ctx.Err() - case wev := <-waiter.Chan(): + case wev, more := <-waiter.Chan(): + if !more { + return nil, nil, ctx.Err() + } + switch ev := wev.Event.(type) { case *EventQueryProgressed: case *EventBroadcastFinished: @@ -567,28 +589,24 @@ func (c *Coordinator) Bootstrap(ctx context.Context, seeds []kadt.PeerID) error // NotifyConnectivity notifies the coordinator that a peer has passed a connectivity check // which means it is connected and supports finding closer nodes -func (c *Coordinator) NotifyConnectivity(ctx context.Context, id kadt.PeerID) error { +func (c *Coordinator) NotifyConnectivity(ctx context.Context, id kadt.PeerID) { ctx, span := c.tele.Tracer.Start(ctx, "Coordinator.NotifyConnectivity") defer span.End() c.routingBehaviour.Notify(ctx, &EventNotifyConnectivity{ NodeID: id, }) - - return nil } // NotifyNonConnectivity notifies the coordinator that a peer has failed a connectivity check // which means it is not connected and/or it doesn't support finding closer nodes -func (c *Coordinator) NotifyNonConnectivity(ctx context.Context, id kadt.PeerID) error { +func (c *Coordinator) NotifyNonConnectivity(ctx context.Context, id kadt.PeerID) { ctx, span := c.tele.Tracer.Start(ctx, "Coordinator.NotifyNonConnectivity") defer span.End() c.routingBehaviour.Notify(ctx, &EventNotifyNonConnectivity{ NodeID: id, }) - - return nil } func (c *Coordinator) newOperationID() coordt.QueryID { diff --git a/v2/router.go b/v2/router.go index bc586a39..999be273 100644 --- a/v2/router.go +++ b/v2/router.go @@ -2,6 +2,7 @@ package dht import ( "context" + "encoding/base64" "fmt" "time" @@ -11,25 +12,46 @@ import ( "github.com/libp2p/go-libp2p/core/protocol" "github.com/libp2p/go-msgio" "github.com/libp2p/go-msgio/pbio" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" "google.golang.org/protobuf/proto" "github.com/libp2p/go-libp2p-kad-dht/v2/internal/coord/coordt" "github.com/libp2p/go-libp2p-kad-dht/v2/kadt" "github.com/libp2p/go-libp2p-kad-dht/v2/pb" + "github.com/libp2p/go-libp2p-kad-dht/v2/tele" ) type router struct { + // the libp2p host to use for sending messages host host.Host - // ProtocolID represents the DHT [protocol] we can query with and respond to. + + // protocolID represents the DHT [protocol] we can query with and respond to. // // [protocol]: https://docs.libp2p.io/concepts/fundamentals/protocols/ - ProtocolID protocol.ID + protocolID protocol.ID + + // an open telemetry tacer instance + tracer trace.Tracer } var _ coordt.Router[kadt.Key, kadt.PeerID, *pb.Message] = (*router)(nil) -func (r *router) SendMessage(ctx context.Context, to kadt.PeerID, req *pb.Message) (*pb.Message, error) { - // TODO: what to do with addresses in peer.AddrInfo? +func (r *router) SendMessage(ctx context.Context, to kadt.PeerID, req *pb.Message) (resp *pb.Message, err error) { + spanOpts := []trace.SpanStartOption{ + trace.WithAttributes(tele.AttrMessageType(req.GetType().String())), + trace.WithAttributes(tele.AttrPeerID(to.String())), + trace.WithAttributes(tele.AttrKey(base64.RawStdEncoding.EncodeToString(req.GetKey()))), + } + ctx, span := r.tracer.Start(ctx, "router.SendMessage", spanOpts...) + defer func() { + if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) + } + span.End() + }() + if len(r.host.Peerstore().Addrs(peer.ID(to))) == 0 { return nil, fmt.Errorf("no address for peer %s", to) } @@ -38,10 +60,8 @@ func (r *router) SendMessage(ctx context.Context, to kadt.PeerID, req *pb.Messag ctx, cancel = context.WithCancel(ctx) defer cancel() - var err error - var s network.Stream - s, err = r.host.NewStream(ctx, peer.ID(to), r.ProtocolID) + s, err = r.host.NewStream(ctx, peer.ID(to), r.protocolID) if err != nil { return nil, fmt.Errorf("stream creation: %w", err) } @@ -59,10 +79,14 @@ func (r *router) SendMessage(ctx context.Context, to kadt.PeerID, req *pb.Messag return nil, nil } + span.End() + ctx, span = r.tracer.Start(ctx, "router.ReadMessage", spanOpts...) + data, err := reader.ReadMsg() if err != nil { return nil, fmt.Errorf("read message: %w", err) } + protoResp := pb.Message{} if err = proto.Unmarshal(data, &protoResp); err != nil { return nil, err diff --git a/v2/routing.go b/v2/routing.go index 756a3d48..8227fbbb 100644 --- a/v2/routing.go +++ b/v2/routing.go @@ -1,6 +1,7 @@ package dht import ( + "bytes" "context" "errors" "fmt" @@ -39,11 +40,9 @@ func (d *DHT) FindPeer(ctx context.Context, id peer.ID) (peer.AddrInfo, error) { return addrInfo, nil } default: - // we're + // we're not connected or were recently connected } - target := kadt.PeerID(id) - var foundPeer peer.ID fn := func(ctx context.Context, visited kadt.PeerID, msg *pb.Message, stats coordt.QueryStats) error { if peer.ID(visited) == id { @@ -53,7 +52,7 @@ func (d *DHT) FindPeer(ctx context.Context, id peer.ID) (peer.AddrInfo, error) { return nil } - _, _, err := d.kad.QueryClosest(ctx, target.Key(), fn, 20) + _, _, err := d.kad.QueryClosest(ctx, kadt.PeerID(id).Key(), fn, 20) if err != nil { return peer.AddrInfo{}, fmt.Errorf("failed to run query: %w", err) } @@ -116,7 +115,7 @@ func (d *DHT) FindProvidersAsync(ctx context.Context, c cid.Cid, count int) <-ch return peerOut } -func (d *DHT) findProvidersAsyncRoutine(ctx context.Context, c cid.Cid, count int, out chan peer.AddrInfo) { +func (d *DHT) findProvidersAsyncRoutine(ctx context.Context, c cid.Cid, count int, out chan<- peer.AddrInfo) { _, span := d.tele.Tracer.Start(ctx, "DHT.findProvidersAsyncRoutine", otel.WithAttributes(attribute.String("cid", c.String()), attribute.Int("count", count))) defer span.End() @@ -130,12 +129,20 @@ func (d *DHT) findProvidersAsyncRoutine(ctx context.Context, c cid.Cid, count in return } + // send all providers onto the out channel until the desired count + // was reached. If no count was specified, continue with network lookup. + providers := map[peer.ID]struct{}{} + // first fetch the record locally stored, err := b.Fetch(ctx, string(c.Hash())) if err != nil { - span.RecordError(err) - d.log.Warn("Fetching value from provider store", slog.String("cid", c.String()), slog.String("err", err.Error())) - return + if !errors.Is(err, ds.ErrNotFound) { + span.RecordError(err) + d.log.Warn("Fetching value from provider store", slog.String("cid", c.String()), slog.String("err", err.Error())) + return + } + + stored = &providerSet{} } ps, ok := stored.(*providerSet) @@ -145,9 +152,6 @@ func (d *DHT) findProvidersAsyncRoutine(ctx context.Context, c cid.Cid, count in return } - // send all providers onto the out channel until the desired count - // was reached. If no count was specified, continue with network lookup. - providers := map[peer.ID]struct{}{} for _, provider := range ps.providers { providers[provider.ID] = struct{}{} @@ -199,7 +203,7 @@ func (d *DHT) findProvidersAsyncRoutine(ctx context.Context, c cid.Cid, count in return nil } - _, err = d.kad.QueryMessage(ctx, msg, fn, 20) // TODO: parameterize + _, _, err = d.kad.QueryMessage(ctx, msg, fn, d.cfg.BucketSize) if err != nil { span.RecordError(err) d.log.Warn("Failed querying", slog.String("cid", c.String()), slog.String("err", err.Error())) @@ -249,7 +253,8 @@ func (d *DHT) PutValue(ctx context.Context, keyStr string, value []byte, opts .. return nil } -// putValueLocal stores a value in the local datastore without querying the network. +// putValueLocal stores a value in the local datastore without reaching out to +// the network. func (d *DHT) putValueLocal(ctx context.Context, key string, value []byte) error { ctx, span := d.tele.Tracer.Start(ctx, "DHT.PutValueLocal") defer span.End() @@ -265,7 +270,7 @@ func (d *DHT) putValueLocal(ctx context.Context, key string, value []byte) error } rec := record.MakePutRecord(key, value) - rec.TimeReceived = time.Now().UTC().Format(time.RFC3339Nano) + rec.TimeReceived = d.cfg.Clock.Now().UTC().Format(time.RFC3339Nano) _, err = b.Store(ctx, path, rec) if err != nil { @@ -275,57 +280,44 @@ func (d *DHT) putValueLocal(ctx context.Context, key string, value []byte) error return nil } -func (d *DHT) GetValue(ctx context.Context, key string, option ...routing.Option) ([]byte, error) { +func (d *DHT) GetValue(ctx context.Context, key string, opts ...routing.Option) ([]byte, error) { ctx, span := d.tele.Tracer.Start(ctx, "DHT.GetValue") defer span.End() - v, err := d.getValueLocal(ctx, key) - if err == nil { - return v, nil - } - if !errors.Is(err, ds.ErrNotFound) { - return nil, fmt.Errorf("put value locally: %w", err) + valueChan, err := d.SearchValue(ctx, key, opts...) + if err != nil { + return nil, err } - req := &pb.Message{ - Type: pb.Message_GET_VALUE, - Key: []byte(key), + var best []byte + for val := range valueChan { + best = val } - // TODO: quorum - var value []byte - fn := func(ctx context.Context, id kadt.PeerID, resp *pb.Message, stats coordt.QueryStats) error { - if resp == nil { - return nil - } - - if resp.GetType() != pb.Message_GET_VALUE { - return nil - } - - if string(resp.GetKey()) != key { - return nil - } - - value = resp.GetRecord().GetValue() - - return coordt.ErrSkipRemaining + if ctx.Err() != nil { + return best, ctx.Err() } - _, err = d.kad.QueryMessage(ctx, req, fn, d.cfg.BucketSize) - if err != nil { - return nil, fmt.Errorf("failed to run query: %w", err) + if best == nil { + return nil, routing.ErrNotFound } - return value, nil + return best, nil } -// getValueLocal retrieves a value from the local datastore without querying the network. -func (d *DHT) getValueLocal(ctx context.Context, key string) ([]byte, error) { - ctx, span := d.tele.Tracer.Start(ctx, "DHT.GetValueLocal") +// SearchValue will search in the DHT for keyStr. keyStr must have the form +// `/$namespace/$binary_id` +func (d *DHT) SearchValue(ctx context.Context, keyStr string, options ...routing.Option) (<-chan []byte, error) { + _, span := d.tele.Tracer.Start(ctx, "DHT.SearchValue") defer span.End() - ns, path, err := record.SplitKey(key) + // first parse the routing options + rOpt := &routing.Options{} // routing config + if err := rOpt.Apply(options...); err != nil { + return nil, fmt.Errorf("apply routing options: %w", err) + } + + ns, path, err := record.SplitKey(keyStr) if err != nil { return nil, fmt.Errorf("splitting key: %w", err) } @@ -337,7 +329,17 @@ func (d *DHT) getValueLocal(ctx context.Context, key string) ([]byte, error) { val, err := b.Fetch(ctx, path) if err != nil { - return nil, fmt.Errorf("fetch from backend: %w", err) + if !errors.Is(err, ds.ErrNotFound) { + return nil, fmt.Errorf("fetch from backend: %w", err) + } + + if rOpt.Offline { + return nil, routing.ErrNotFound + } + + out := make(chan []byte) + go d.searchValueRoutine(ctx, b, ns, path, rOpt, out) + return out, nil } rec, ok := val.(*recpb.Record) @@ -345,14 +347,153 @@ func (d *DHT) getValueLocal(ctx context.Context, key string) ([]byte, error) { return nil, fmt.Errorf("expected *recpb.Record from backend, got: %T", val) } - return rec.GetValue(), nil + if rOpt.Offline { + out := make(chan []byte, 1) + defer close(out) + out <- rec.GetValue() + return out, nil + } + + out := make(chan []byte) + go func() { + out <- rec.GetValue() + d.searchValueRoutine(ctx, b, ns, path, rOpt, out) + }() + + return out, nil } -func (d *DHT) SearchValue(ctx context.Context, s string, option ...routing.Option) (<-chan []byte, error) { - _, span := d.tele.Tracer.Start(ctx, "DHT.SearchValue") +func (d *DHT) searchValueRoutine(ctx context.Context, backend Backend, ns string, path string, ropt *routing.Options, out chan<- []byte) { + _, span := d.tele.Tracer.Start(ctx, "DHT.searchValueRoutine") defer span.End() + defer close(out) + + routingKey := []byte(newRoutingKey(ns, path)) + + req := &pb.Message{ + Type: pb.Message_GET_VALUE, + Key: routingKey, + } + + // The currently known best value for /$ns/$path + var best []byte + + // Peers that we identified to hold stale records + var fixupPeers []kadt.PeerID + + // The peers that returned the best value + quorumPeers := map[kadt.PeerID]struct{}{} + + // The quorum that we require for terminating the query. This number tells + // us how many peers must have responded with the "best" value before we + // cancel the query. + quorum := d.getQuorum(ropt) + + fn := func(ctx context.Context, id kadt.PeerID, resp *pb.Message, stats coordt.QueryStats) error { + rec := resp.GetRecord() + if rec == nil { + return nil + } + + if !bytes.Equal(routingKey, rec.GetKey()) { + return nil + } + + idx, _ := backend.Validate(ctx, path, best, rec.GetValue()) + switch idx { + case 0: // "best" is still the best value + if bytes.Equal(best, rec.GetValue()) { + quorumPeers[id] = struct{}{} + } + + case 1: // rec.GetValue() is better than our current "best" + + // We have identified a better record. All peers that were currently + // in our set of quorum peers need to be updated wit this new record + for p := range quorumPeers { + fixupPeers = append(fixupPeers, p) + } + + // re-initialize the quorum peers set for this new record + quorumPeers = map[kadt.PeerID]struct{}{} + quorumPeers[id] = struct{}{} + + // submit the new value to the user + best = rec.GetValue() + out <- best + case -1: // "best" and rec.GetValue() are both invalid + return nil + + default: + d.log.Warn("unexpected validate index", slog.Int("idx", idx)) + } + + // Check if we have reached the quorum + if len(quorumPeers) == quorum { + return coordt.ErrSkipRemaining + } + + return nil + } + + _, _, err := d.kad.QueryMessage(ctx, req, fn, d.cfg.BucketSize) + if err != nil { + d.logErr(err, "Search value query failed") + return + } + + // check if we have peers that we found to hold stale records. If so, + // update them asynchronously. + if len(fixupPeers) == 0 { + return + } + + go func() { + msg := &pb.Message{ + Type: pb.Message_PUT_VALUE, + Key: routingKey, + Record: record.MakePutRecord(string(routingKey), best), + } + + if err := d.kad.BroadcastStatic(ctx, msg, fixupPeers); err != nil { + d.log.Warn("Failed updating peer") + } + }() +} + +// quorumOptionKey is a struct that is used as a routing options key to pass +// the desired quorum value into, e.g., SearchValue or GetValue. +type quorumOptionKey struct{} + +// RoutingQuorum accepts the desired quorum that is required to terminate the +// search query. The quorum value must not be negative but can be 0 in which +// case we continue the query until we have exhausted the keyspace. If no +// quorum is specified, the [Config.DefaultQuorum] value will be used. +func RoutingQuorum(n int) routing.Option { + return func(opts *routing.Options) error { + if n < 0 { + return fmt.Errorf("quorum must not be negative") + } + + if opts.Other == nil { + opts.Other = make(map[interface{}]interface{}, 1) + } + + opts.Other[quorumOptionKey{}] = n + + return nil + } +} + +// getQuorum extracts the quorum value from the given routing options and +// returns [Config.DefaultQuorum] if no quorum value is present. +func (d *DHT) getQuorum(opts *routing.Options) int { + quorum, ok := opts.Other[quorumOptionKey{}].(int) + if !ok { + quorum = d.cfg.Query.DefaultQuorum + } - panic("implement me") + return quorum } func (d *DHT) Bootstrap(ctx context.Context) error { diff --git a/v2/routing_test.go b/v2/routing_test.go index 19209fd4..50d77895 100644 --- a/v2/routing_test.go +++ b/v2/routing_test.go @@ -6,7 +6,9 @@ import ( "crypto/sha256" "fmt" "testing" + "time" + "github.com/benbjohnson/clock" "github.com/ipfs/go-cid" "github.com/ipfs/go-datastore/failstore" "github.com/libp2p/go-libp2p/core/crypto" @@ -15,8 +17,10 @@ import ( mh "github.com/multiformats/go-multihash" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" "github.com/libp2p/go-libp2p-kad-dht/v2/internal/kadtest" + "github.com/libp2p/go-libp2p-kad-dht/v2/kadt" ) // newRandomContent reads 1024 bytes from crypto/rand and builds a content struct. @@ -34,7 +38,7 @@ func newRandomContent(t testing.TB) cid.Cid { return cid.NewCidV0(mhash) } -func makePkKeyValue(t *testing.T) (string, []byte) { +func makePkKeyValue(t testing.TB) (string, []byte) { t.Helper() _, pub, _ := crypto.GenerateEd25519Key(rng) @@ -47,6 +51,87 @@ func makePkKeyValue(t *testing.T) (string, []byte) { return routing.KeyForPublicKey(id), v } +func TestDHT_FindPeer_happy_path(t *testing.T) { + ctx := kadtest.CtxShort(t) + + top := NewTopology(t) + d1 := top.AddServer(nil) + d2 := top.AddServer(nil) + d3 := top.AddServer(nil) + d4 := top.AddServer(nil) + top.ConnectChain(ctx, d1, d2, d3, d4) + + addrInfo, err := d1.FindPeer(ctx, d4.host.ID()) + require.NoError(t, err) + assert.Equal(t, d4.host.ID(), addrInfo.ID) +} + +func TestDHT_FindPeer_not_found(t *testing.T) { + ctx := kadtest.CtxShort(t) + + top := NewTopology(t) + d1 := top.AddServer(nil) + d2 := top.AddServer(nil) + d3 := top.AddServer(nil) + d4 := top.AddServer(nil) + top.ConnectChain(ctx, d1, d2, d3) + + _, err := d1.FindPeer(ctx, d4.host.ID()) + assert.Error(t, err) +} + +func TestDHT_FindPeer_already_connected(t *testing.T) { + ctx := kadtest.CtxShort(t) + + top := NewTopology(t) + d1 := top.AddServer(nil) + d2 := top.AddServer(nil) + d3 := top.AddServer(nil) + d4 := top.AddServer(nil) + top.ConnectChain(ctx, d1, d2, d3) + + err := d1.host.Connect(ctx, peer.AddrInfo{ + ID: d4.host.ID(), + Addrs: d4.host.Addrs(), + }) + require.NoError(t, err) + + _, err = d1.FindPeer(ctx, d4.host.ID()) + assert.NoError(t, err) +} + +func TestDHT_PutValue_happy_path(t *testing.T) { + // TIMING: this test is based on timeouts - so might become flaky! + ctx := kadtest.CtxShort(t) + + top := NewTopology(t) + d1 := top.AddServer(nil) + d2 := top.AddServer(nil) + + top.ConnectChain(ctx, d1, d2) + + k, v := makePkKeyValue(t) + + err := d1.PutValue(ctx, k, v) + require.NoError(t, err) + + deadline, hasDeadline := ctx.Deadline() + if !hasDeadline { + deadline = time.Now().Add(5 * time.Second) + } + + // putting data to a remote peer is an asynchronous operation. Even after + // PutValue returns, and although we have closed the stream on our end, an + // acknowledgement that the other peer has received the data is not + // guaranteed. The data will be flushed at this point, but the remote might + // not have handled it yet. Therefore, we use "EventuallyWithT" here. + assert.EventuallyWithT(t, func(t *assert.CollectT) { + val, err := d2.GetValue(ctx, k, routing.Offline) + assert.NoError(t, err) + assert.Equal(t, v, val) + }, time.Until(deadline), 10*time.Millisecond) +} + func TestDHT_PutValue_local_only(t *testing.T) { ctx := kadtest.CtxShort(t) @@ -78,21 +163,16 @@ func TestDHT_PutValue_invalid_key(t *testing.T) { }) } -func TestGetSetValueLocal(t *testing.T) { +func TestDHT_PutValue_routing_option_returns_error(t *testing.T) { ctx := kadtest.CtxShort(t) + d := newTestDHT(t) - top := NewTopology(t) - d := top.AddServer(nil) - - key, v := makePkKeyValue(t) - - err := d.putValueLocal(ctx, key, v) - require.NoError(t, err) - - val, err := d.getValueLocal(ctx, key) - require.NoError(t, err) + errOption := func(opts *routing.Options) error { + return fmt.Errorf("some error") + } - require.Equal(t, v, val) + err := d.PutValue(ctx, "/ipns/some-key", []byte("some value"), errOption) + assert.ErrorContains(t, err, "routing options") } func TestGetValueOnePeer(t *testing.T) { @@ -190,12 +270,7 @@ func TestDHT_FindProvidersAsync_empty_routing_table(t *testing.T) { c := newRandomContent(t) out := d.FindProvidersAsync(ctx, c, 1) - select { - case _, more := <-out: - require.False(t, more) - case <-ctx.Done(): - t.Fatal("timeout") - } + assertClosed(t, ctx, out) } func TestDHT_FindProvidersAsync_dht_does_not_support_providers(t *testing.T) { @@ -206,12 +281,7 @@ func TestDHT_FindProvidersAsync_dht_does_not_support_providers(t *testing.T) { delete(d.backends, namespaceProviders) out := d.FindProvidersAsync(ctx, newRandomContent(t), 1) - select { - case _, more := <-out: - require.False(t, more) - case <-ctx.Done(): - t.Fatal("timeout") - } + assertClosed(t, ctx, out) } func TestDHT_FindProvidersAsync_providers_stored_locally(t *testing.T) { @@ -225,17 +295,11 @@ func TestDHT_FindProvidersAsync_providers_stored_locally(t *testing.T) { require.NoError(t, err) out := d.FindProvidersAsync(ctx, c, 1) - for { - select { - case p, more := <-out: - if !more { - return - } - assert.Equal(t, provider.ID, p.ID) - case <-ctx.Done(): - t.Fatal("timeout") - } - } + + val := readItem(t, ctx, out) + assert.Equal(t, provider.ID, val.ID) + + assertClosed(t, ctx, out) } func TestDHT_FindProvidersAsync_returns_only_count_from_local_store(t *testing.T) { @@ -292,20 +356,11 @@ func TestDHT_FindProvidersAsync_queries_other_peers(t *testing.T) { require.NoError(t, err) out := d1.FindProvidersAsync(ctx, c, 1) - select { - case p, more := <-out: - require.True(t, more) - assert.Equal(t, provider.ID, p.ID) - case <-ctx.Done(): - t.Fatal("timeout") - } - select { - case _, more := <-out: - assert.False(t, more) - case <-ctx.Done(): - t.Fatal("timeout") - } + val := readItem(t, ctx, out) + assert.Equal(t, provider.ID, val.ID) + + assertClosed(t, ctx, out) } func TestDHT_FindProvidersAsync_respects_cancelled_context_for_local_query(t *testing.T) { @@ -419,12 +474,7 @@ func TestDHT_FindProvidersAsync_datastore_error(t *testing.T) { be.datastore = dstore out := d.FindProvidersAsync(ctx, newRandomContent(t), 0) - select { - case _, more := <-out: - assert.False(t, more) - case <-ctx.Done(): - t.Fatal("timeout") - } + assertClosed(t, ctx, out) } func TestDHT_FindProvidersAsync_invalid_key(t *testing.T) { @@ -432,10 +482,480 @@ func TestDHT_FindProvidersAsync_invalid_key(t *testing.T) { d := newTestDHT(t) out := d.FindProvidersAsync(ctx, cid.Cid{}, 0) + assertClosed(t, ctx, out) +} + +func TestDHT_GetValue_happy_path(t *testing.T) { + ctx := kadtest.CtxShort(t) + + clk := clock.New() + + cfg := DefaultConfig() + cfg.Clock = clk + + // generate new identity for the peer that issues the request + priv, _, err := crypto.GenerateEd25519Key(rng) + require.NoError(t, err) + + _, validValue := makeIPNSKeyValue(t, clk, priv, 1, time.Hour) + _, worseValue := makeIPNSKeyValue(t, clk, priv, 0, time.Hour) + key, betterValue := makeIPNSKeyValue(t, clk, priv, 2, time.Hour) // higher sequence number means better value + + top := NewTopology(t) + d1 := top.AddServer(cfg) + d2 := top.AddServer(cfg) + d3 := top.AddServer(cfg) + d4 := top.AddServer(cfg) + d5 := top.AddServer(cfg) + + top.ConnectChain(ctx, d1, d2, d3, d4, d5) + + err = d3.putValueLocal(ctx, key, validValue) + require.NoError(t, err) + + err = d4.putValueLocal(ctx, key, worseValue) + require.NoError(t, err) + + err = d5.putValueLocal(ctx, key, betterValue) + require.NoError(t, err) + + val, err := d1.GetValue(ctx, key) + assert.NoError(t, err) + assert.Equal(t, betterValue, val) +} + +func TestDHT_GetValue_returns_context_error(t *testing.T) { + ctx := kadtest.CtxShort(t) + d := newTestDHT(t) + + cancelledCtx, cancel := context.WithCancel(ctx) + cancel() + + _, err := d.GetValue(cancelledCtx, "/"+namespaceIPNS+"/some-key") + assert.ErrorIs(t, err, context.Canceled) +} + +func TestDHT_GetValue_returns_not_found_error(t *testing.T) { + ctx := kadtest.CtxShort(t) + d := newTestDHT(t) + + valueChan, err := d.GetValue(ctx, "/"+namespaceIPNS+"/some-key") + assert.ErrorIs(t, err, routing.ErrNotFound) + assert.Nil(t, valueChan) +} + +// assertClosed triggers a test failure if the given channel was not closed but +// carried more values or a timeout occurs (given by the context). +func assertClosed[T any](t testing.TB, ctx context.Context, c <-chan T) { + t.Helper() + select { - case _, more := <-out: + case _, more := <-c: assert.False(t, more) case <-ctx.Done(): - t.Fatal("timeout") + t.Fatal("timeout closing channel") } } + +func readItem[T any](t testing.TB, ctx context.Context, c <-chan T) T { + t.Helper() + + select { + case val, more := <-c: + require.True(t, more, "channel closed unexpectedly") + return val + case <-ctx.Done(): + t.Fatal("timeout reading item") + return *new(T) + } +} + +func TestDHT_SearchValue_simple(t *testing.T) { + // Test setup: + // There is just one other server that returns a valid value. + ctx := kadtest.CtxShort(t) + + key, v := makePkKeyValue(t) + + top := NewTopology(t) + d1 := top.AddServer(nil) + d2 := top.AddServer(nil) + + top.Connect(ctx, d1, d2) + + err := d2.putValueLocal(ctx, key, v) + require.NoError(t, err) + + valChan, err := d1.SearchValue(ctx, key) + require.NoError(t, err) + + val := readItem(t, ctx, valChan) + assert.Equal(t, v, val) + + assertClosed(t, ctx, valChan) +} + +func TestDHT_SearchValue_returns_best_values(t *testing.T) { + // Test setup: + // d2 returns no value + // d3 returns valid value + // d4 returns worse value (will get rejected because we already have a valid value) + // d5 returns better value + // all peers are connected in a chain from d1 to d5 (d1 initiates the query) + // assert that we receive two values on the channel (valid + better) + ctx := kadtest.CtxShort(t) + clk := clock.New() + + cfg := DefaultConfig() + cfg.Clock = clk + + // generate new identity for the peer that issues the request + priv, _, err := crypto.GenerateEd25519Key(rng) + require.NoError(t, err) + + _, validValue := makeIPNSKeyValue(t, clk, priv, 1, time.Hour) + _, worseValue := makeIPNSKeyValue(t, clk, priv, 0, time.Hour) + key, betterValue := makeIPNSKeyValue(t, clk, priv, 2, time.Hour) // higher sequence number means better value + + top := NewTopology(t) + d1 := top.AddServer(cfg) + d2 := top.AddServer(cfg) + d3 := top.AddServer(cfg) + d4 := top.AddServer(cfg) + d5 := top.AddServer(cfg) + + top.ConnectChain(ctx, d1, d2, d3, d4, d5) + + err = d3.putValueLocal(ctx, key, validValue) + require.NoError(t, err) + + err = d4.putValueLocal(ctx, key, worseValue) + require.NoError(t, err) + + err = d5.putValueLocal(ctx, key, betterValue) + require.NoError(t, err) + + valChan, err := d1.SearchValue(ctx, key) + require.NoError(t, err) + + val := readItem(t, ctx, valChan) + assert.Equal(t, validValue, val) + + val = readItem(t, ctx, valChan) + assert.Equal(t, betterValue, val) + + assertClosed(t, ctx, valChan) +} + +// In order for 'go test' to run this suite, we need to create +// a normal test function and pass our suite to suite.Run +func TestDHT_SearchValue_quorum_test_suite(t *testing.T) { + suite.Run(t, new(SearchValueQuorumTestSuite)) +} + +type SearchValueQuorumTestSuite struct { + suite.Suite + + d *DHT + servers []*DHT + + key string + validValue []byte + betterValue []byte +} + +// Make sure that VariableThatShouldStartAtFive is set to five +// before each test +func (suite *SearchValueQuorumTestSuite) SetupTest() { + // Test setup: + // we create 1 DHT server that searches for values + // we create 10 additional DHT servers and connect all of them in a chain + // the first server holds an invalid record + // the next three servers of the 10 DHT servers hold a valid record + // the remaining 6 servers of the 10 DHT servers hold a better record + // first test assertion: with quorum of 3 we expect the valid but old record + // second test assertion: with a quorum of 5 we expect to receive the valid but also better record. + + t := suite.T() + ctx := kadtest.CtxShort(t) + clk := clock.New() + + cfg := DefaultConfig() + cfg.Clock = clk + top := NewTopology(t) + + // init privileged DHT server + suite.d = top.AddServer(cfg) + + // init remaining ones + suite.servers = make([]*DHT, 10) + for i := 0; i < 10; i++ { + suite.servers[i] = top.AddServer(cfg) + } + + // connect all together + top.ConnectChain(ctx, append([]*DHT{suite.d}, suite.servers...)...) + + // generate records + remote, priv := newIdentity(t) + invalidPutReq := newPutIPNSRequest(t, clk, priv, 3, -time.Hour) + suite.key, suite.validValue = makeIPNSKeyValue(t, clk, priv, 1, time.Hour) + suite.key, suite.betterValue = makeIPNSKeyValue(t, clk, priv, 2, time.Hour) // higher sequence number means better value + + // store invalid (expired) record directly in the datastore of + // the respective DHT server (bypassing any validation). + invalidRec, err := invalidPutReq.Record.Marshal() + require.NoError(t, err) + + rbe, err := typedBackend[*RecordBackend](suite.servers[0], namespaceIPNS) + require.NoError(t, err) + + dsKey := newDatastoreKey(namespaceIPNS, string(remote)) + err = rbe.datastore.Put(ctx, dsKey, invalidRec) + require.NoError(t, err) + + // The first four DHT servers hold a valid but old value + for i := 1; i < 4; i++ { + err = suite.servers[i].putValueLocal(ctx, suite.key, suite.validValue) + require.NoError(t, err) + } + + // The remaining six DHT servers hold a valid and newer record + for i := 4; i < 10; i++ { + err = suite.servers[i].putValueLocal(ctx, suite.key, suite.betterValue) + require.NoError(t, err) + } + + // one of the remaining returns and old record again + err = suite.servers[8].putValueLocal(ctx, suite.key, suite.betterValue) + require.NoError(t, err) +} + +func (suite *SearchValueQuorumTestSuite) TestQuorumReachedPrematurely() { + t := suite.T() + ctx := kadtest.CtxShort(t) + out, err := suite.d.SearchValue(ctx, suite.key, RoutingQuorum(3)) + require.NoError(t, err) + + val := readItem(t, ctx, out) + assert.Equal(t, suite.validValue, val) + + assertClosed(t, ctx, out) +} + +func (suite *SearchValueQuorumTestSuite) TestQuorumReachedAfterDiscoveryOfBetter() { + t := suite.T() + ctx := kadtest.CtxShort(t) + out, err := suite.d.SearchValue(ctx, suite.key, RoutingQuorum(5)) + require.NoError(t, err) + + val := readItem(t, ctx, out) + assert.Equal(t, suite.validValue, val) + + val = readItem(t, ctx, out) + assert.Equal(t, suite.betterValue, val) + + assertClosed(t, ctx, out) +} + +func (suite *SearchValueQuorumTestSuite) TestQuorumZero() { + t := suite.T() + ctx := kadtest.CtxShort(t) + + // search until query exhausted + out, err := suite.d.SearchValue(ctx, suite.key, RoutingQuorum(0)) + require.NoError(t, err) + + val := readItem(t, ctx, out) + assert.Equal(t, suite.validValue, val) + + val = readItem(t, ctx, out) + assert.Equal(t, suite.betterValue, val) + + assertClosed(t, ctx, out) +} + +func (suite *SearchValueQuorumTestSuite) TestQuorumUnspecified() { + t := suite.T() + ctx := kadtest.CtxShort(t) + + // search until query exhausted + out, err := suite.d.SearchValue(ctx, suite.key) + require.NoError(t, err) + + val := readItem(t, ctx, out) + assert.Equal(t, suite.validValue, val) + + val = readItem(t, ctx, out) + assert.Equal(t, suite.betterValue, val) + + assertClosed(t, ctx, out) +} + +func TestDHT_SearchValue_routing_option_returns_error(t *testing.T) { + ctx := kadtest.CtxShort(t) + d := newTestDHT(t) + + errOption := func(opts *routing.Options) error { + return fmt.Errorf("some error") + } + + valueChan, err := d.SearchValue(ctx, "/ipns/some-key", errOption) + assert.ErrorContains(t, err, "routing options") + assert.Nil(t, valueChan) +} + +func TestDHT_SearchValue_quorum_negative(t *testing.T) { + ctx := kadtest.CtxShort(t) + d := newTestDHT(t) + + out, err := d.SearchValue(ctx, "/"+namespaceIPNS+"/some-key", RoutingQuorum(-1)) + assert.ErrorContains(t, err, "quorum must not be negative") + assert.Nil(t, out) +} + +func TestDHT_SearchValue_invalid_key(t *testing.T) { + ctx := kadtest.CtxShort(t) + d := newTestDHT(t) + + valueChan, err := d.SearchValue(ctx, "invalid-key") + assert.ErrorContains(t, err, "splitting key") + assert.Nil(t, valueChan) +} + +func TestDHT_SearchValue_key_for_unsupported_namespace(t *testing.T) { + ctx := kadtest.CtxShort(t) + d := newTestDHT(t) + + valueChan, err := d.SearchValue(ctx, "/unsupported/key") + assert.ErrorIs(t, err, routing.ErrNotSupported) + assert.Nil(t, valueChan) +} + +func TestDHT_SearchValue_stops_with_cancelled_context(t *testing.T) { + ctx := kadtest.CtxShort(t) + cancelledCtx, cancel := context.WithCancel(ctx) + cancel() + + // make sure we don't just stop because we don't know any other DHT server + top := NewTopology(t) + d1 := top.AddServer(nil) + d2 := top.AddServer(nil) + top.Connect(ctx, d1, d2) + + valueChan, err := d1.SearchValue(cancelledCtx, "/"+namespaceIPNS+"/some-key") + assert.NoError(t, err) + assertClosed(t, ctx, valueChan) +} + +func TestDHT_SearchValue_has_record_locally(t *testing.T) { + // Test setup: + // There is just one other server that returns a valid value. + ctx := kadtest.CtxShort(t) + clk := clock.New() + + _, priv := newIdentity(t) + _, validValue := makeIPNSKeyValue(t, clk, priv, 1, time.Hour) + key, betterValue := makeIPNSKeyValue(t, clk, priv, 2, time.Hour) + + top := NewTopology(t) + d1 := top.AddServer(nil) + d2 := top.AddServer(nil) + + top.Connect(ctx, d1, d2) + + err := d1.putValueLocal(ctx, key, validValue) + require.NoError(t, err) + + err = d2.putValueLocal(ctx, key, betterValue) + require.NoError(t, err) + + valChan, err := d1.SearchValue(ctx, key) + require.NoError(t, err) + + val := readItem(t, ctx, valChan) // from local store + assert.Equal(t, validValue, val) + + val = readItem(t, ctx, valChan) + assert.Equal(t, betterValue, val) + + assertClosed(t, ctx, valChan) +} + +func TestDHT_SearchValue_offline(t *testing.T) { + // Test setup: + // There is just one other server that returns a valid value. + ctx := kadtest.CtxShort(t) + d := newTestDHT(t) + + key, v := makePkKeyValue(t) + err := d.putValueLocal(ctx, key, v) + require.NoError(t, err) + + valChan, err := d.SearchValue(ctx, key, routing.Offline) + require.NoError(t, err) + + val := readItem(t, ctx, valChan) + assert.Equal(t, v, val) + + assertClosed(t, ctx, valChan) +} + +func TestDHT_SearchValue_offline_not_found_locally(t *testing.T) { + // Test setup: + // We are connected to a peer that holds the record but require an offline + // lookup. Assert that we don't receive the record + ctx := kadtest.CtxShort(t) + + key, v := makePkKeyValue(t) + + top := NewTopology(t) + d1 := top.AddServer(nil) + d2 := top.AddServer(nil) + + top.Connect(ctx, d1, d2) + + err := d2.putValueLocal(ctx, key, v) + require.NoError(t, err) + + valChan, err := d1.SearchValue(ctx, key, routing.Offline) + assert.ErrorIs(t, err, routing.ErrNotFound) + assert.Nil(t, valChan) +} + +func TestDHT_Bootstrap_no_peers_configured(t *testing.T) { + // TIMING: this test is based on timeouts - so might become flaky! + ctx := kadtest.CtxShort(t) + + top := NewTopology(t) + d1 := top.AddServer(nil) + d2 := top.AddServer(nil) + d3 := top.AddServer(nil) + + d1.cfg.BootstrapPeers = []peer.AddrInfo{ + {ID: d2.host.ID(), Addrs: d2.host.Addrs()}, + {ID: d3.host.ID(), Addrs: d3.host.Addrs()}, + } + + err := d1.Bootstrap(ctx) + assert.NoError(t, err) + + deadline, hasDeadline := ctx.Deadline() + if !hasDeadline { + deadline = time.Now().Add(5 * time.Second) + } + + // bootstrapping is an asynchronous process, so we periodically check + // if the peers have each other in their routing tables + assert.EventuallyWithT(t, func(collect *assert.CollectT) { + _, found := d1.rt.GetNode(kadt.PeerID(d2.host.ID()).Key()) + assert.True(collect, found) + _, found = d1.rt.GetNode(kadt.PeerID(d3.host.ID()).Key()) + assert.True(collect, found) + + _, found = d2.rt.GetNode(kadt.PeerID(d1.host.ID()).Key()) + assert.True(collect, found) + _, found = d3.rt.GetNode(kadt.PeerID(d1.host.ID()).Key()) + assert.True(collect, found) + }, time.Until(deadline), 10*time.Millisecond) +} diff --git a/v2/tele/tele.go b/v2/tele/tele.go index 0cd4fac7..9c22b26a 100644 --- a/v2/tele/tele.go +++ b/v2/tele/tele.go @@ -27,7 +27,7 @@ func NoopTracer() trace.Tracer { return trace.NewNoopTracerProvider().Tracer("") } -// NoopMeterProvider returns a meter provider that does not record or emit metrics. +// NoopMeter returns a meter provider that does not record or emit metrics. func NoopMeter() metric.Meter { return noop.NewMeterProvider().Meter("") } diff --git a/v2/topology_test.go b/v2/topology_test.go index 189b494e..b6be05be 100644 --- a/v2/topology_test.go +++ b/v2/topology_test.go @@ -61,10 +61,6 @@ func (t *Topology) AddServer(cfg *Config) *DHT { rn := coord.NewBufferedRoutingNotifier() d.kad.SetRoutingNotifier(rn) - // add at least 1 entry in the routing table so the server will pass connectivity checks - fillRoutingTable(t.tb, d, 1) - require.NotEmpty(t.tb, d.rt.NearestNodes(kadt.PeerID(d.host.ID()).Key(), 1)) - t.tb.Cleanup(func() { if err = d.Close(); err != nil { t.tb.Logf("unexpected error when closing dht: %s", err)