diff --git a/internal/index/builder/builder_test.go b/internal/index/builder/builder_test.go new file mode 100644 index 0000000..88a1fd0 --- /dev/null +++ b/internal/index/builder/builder_test.go @@ -0,0 +1,87 @@ +package builder + +import ( + "path" + "reflect" + "testing" + "time" +) + +var ( + t1 time.Time +) + +func init() { + var err error + t1, err = time.Parse(time.RFC3339, "2020-01-01T00:00:00Z") + if err != nil { + panic(err) + } +} + +func TestSnapshots(t *testing.T) { + for _, tc := range []struct { + name string + snapshot []*snapshot + }{ + { + name: "empty", + snapshot: []*snapshot{}, + }, + { + name: "nil", + snapshot: nil, + }, + { + name: "single", + snapshot: []*snapshot{ + { + timestamp: t1, + referencedPackets: map[string][]uint64{"a": {1, 2, 3}}, + chunkCount: 42, + }, + }, + }, + { + name: "multiple", + snapshot: []*snapshot{ + { + timestamp: t1, + referencedPackets: map[string][]uint64{"a": {1, 2, 3}}, + chunkCount: 42, + }, + { + timestamp: t1.Add(time.Hour), + referencedPackets: map[string][]uint64{"b": {4, 5, 6}}, + chunkCount: 43, + }, + }, + }, + } { + t.Run(tc.name, func(t *testing.T) { + fn := path.Join(t.TempDir(), "test.snap") + if err := saveSnapshots(fn, tc.snapshot); err != nil { + t.Fatalf("saveSnapshots failed: %v", err) + } + got, err := loadSnapshots(fn) + if err != nil { + t.Fatalf("loadSnapshots failed: %v", err) + } + if len(got) != len(tc.snapshot) { + t.Fatalf("len(got)=%d, want %d", len(got), len(tc.snapshot)) + } + for i, want := range tc.snapshot { + got := *got[i] + if got.timestamp.UTC() != want.timestamp.UTC() { + t.Errorf("got=%v, want %v", got.timestamp, want.timestamp) + } + if got.chunkCount != want.chunkCount { + t.Errorf("got=%v, want %v", got.chunkCount, want.chunkCount) + } + if !reflect.DeepEqual(got.referencedPackets, want.referencedPackets) { + t.Errorf("got=%v, want %v", got.referencedPackets, want.referencedPackets) + } + } + }) + } +} diff --git a/internal/index/manager/manager_test.go b/internal/index/manager/manager_test.go new file mode 100644 index 0000000..d40e4a5 --- /dev/null +++ b/internal/index/manager/manager_test.go @@ -0,0 +1,607 @@ +package manager + +import ( + "context" + _ "embed" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "net/http/httptest" + "net/netip" + "os" + "path" + "reflect" + "slices" + "strings" + "testing" + "time" + + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + "github.com/google/gopacket/pcapgo" + "github.com/spq/pkappa2/internal/index/converters" + "github.com/spq/pkappa2/internal/query" +) + +type ( + dirs struct { + base, pcap, index, snapshot, state, converter string + } +) + +var ( + t1 time.Time + + //go:embed testdata/test_converter.py + converterScript []byte +) + +func init() { + var err error + t1, err = time.Parse(time.RFC3339, "2020-01-01T12:00:00Z") + if err != nil { + panic(err) + } +} + +func makeTempdirs(t *testing.T) dirs { + dirs := dirs{ + base: t.TempDir(), + } + dirs.pcap = path.Join(dirs.base, "pcap") + "/" + dirs.index = path.Join(dirs.base, "index") + "/" + dirs.state = path.Join(dirs.base, "state") + "/" + dirs.snapshot = path.Join(dirs.base, "snapshot") + "/" + dirs.converter = path.Join(dirs.base, "converter") + "/" + for _, p := range []string{dirs.pcap, dirs.index, dirs.snapshot, dirs.state, dirs.converter} { + if err := os.Mkdir(p, 0755); err != nil { + t.Fatalf("Mkdir(%q) failed with error: %v", p, err) + } + } + return dirs +} + +func addConverter(dirs dirs, name string) { + if err := os.WriteFile(path.Join(dirs.converter, name), []byte(converterScript), 0775); err != nil { + panic(err) + } +} + +func makeManager(t *testing.T, dirs dirs) *Manager { + mgr, err := New(dirs.pcap, dirs.index, dirs.snapshot, dirs.state, dirs.converter) + if err != nil { + t.Fatalf("manager.New failed with error: %v", err) + } + return mgr +} + +func TestEmptyManager(t *testing.T) { + dirs := makeTempdirs(t) + mgr := makeManager(t, dirs) + if got := mgr.Status(); !reflect.DeepEqual(got, Statistics{}) { + t.Fatalf("Status() = %v, want {}", got) + } + mgr.Close() +} + +func makeUDPPacket(client, server string, t time.Time, payload string) pcapOverIPPacket { + clientAddrPort := netip.MustParseAddrPort(client) + serverAddrPort := netip.MustParseAddrPort(server) + if !(clientAddrPort.Addr().Is4() && serverAddrPort.Addr().Is4()) { + panic("only support v4 for now") + } + ip := layers.IPv4{ + Version: 4, + TTL: 64, + SrcIP: clientAddrPort.Addr().AsSlice(), + DstIP: serverAddrPort.Addr().AsSlice(), + Protocol: layers.IPProtocolUDP, + } + udp := layers.UDP{ + SrcPort: layers.UDPPort(clientAddrPort.Port()), + DstPort: layers.UDPPort(serverAddrPort.Port()), + } + if err := udp.SetNetworkLayerForChecksum(&ip); err != nil { + panic(err) + } + + options := gopacket.SerializeOptions{ + ComputeChecksums: true, + FixLengths: true, + } + + buffer := gopacket.NewSerializeBuffer() + + err := gopacket.SerializeLayers(buffer, options, + &ip, + &udp, + gopacket.Payload([]byte(payload)), + ) + if err != nil { + panic(err) + } + data := buffer.Bytes() + return pcapOverIPPacket{ + linkType: layers.LinkTypeIPv4, + ci: gopacket.CaptureInfo{ + Timestamp: t, + CaptureLength: len(data), + Length: 0xffff, + }, + data: data, + } +} + +func TestTags(t *testing.T) { + dirs := makeTempdirs(t) + mgr := makeManager(t, dirs) + defer mgr.Close() + if got, want := mgr.ListTags(), []TagInfo{}; !reflect.DeepEqual(got, want) { + t.Fatalf("Manager.ListTags() = %v, want %v", got, want) + } + testcases := []struct { + tag string + query string + expectError bool + }{ + { + tag: "foo", + query: "id:1", + expectError: true, + }, + { + tag: "tag/foo", + query: "id:1", + }, + { + tag: "service/foo", + query: "id:1", + }, + { + tag: "mark/foo", + query: "id:1", + }, + { + tag: "mark/foo", + query: "port:1", + expectError: true, + }, + { + tag: "generated/foo", + query: "id:1", + }, + { + tag: "tag/foo", + query: "foo", + expectError: true, + }, + } + for _, tc := range testcases { + t.Run(tc.tag, func(t *testing.T) { + err := mgr.AddTag(tc.tag, "blue", tc.query) + if tc.expectError { + if err == nil { + t.Fatalf("Manager.AddTag succeeded, want error") + } + return + } + if err != nil { + t.Fatalf("Manager.AddTag failed with error: %v", err) + } + if got, want := mgr.ListTags(), []TagInfo{{ + Name: tc.tag, + Color: "blue", + Definition: tc.query, + MatchingCount: 0, + UncertainCount: 0, + Referenced: false, + Converters: []string{}, + }}; !reflect.DeepEqual(got, want) { + t.Fatalf("Manager.ListTags() = %v, want %v", got, want) + } + if err := mgr.DelTag(tc.tag); err != nil { + t.Fatalf("Manager.DelTag failed with error: %v", err) + } + }) + } + if err := mgr.AddTag("service/foo", "red", "cport:2,3"); err != nil { + t.Fatalf("Manager.AddTag failed with error: %v", err) + } + importSomePackets(t, mgr, t1, "tagEvaluated") + if got := mgr.ListTags()[0]; got.MatchingCount != 2 || got.UncertainCount != 0 { + t.Fatalf("Manager.ListTags()[0] = %+v, want {MatchingCount: 2, UncertainCount: 0}", got) + } + if err := mgr.DelTag("service/foo"); err != nil { + t.Fatalf("Manager.DelTag failed with error: %v", err) + } + if err := mgr.AddTag("mark/foo", "blue", "id:0"); err != nil { + t.Fatalf("Manager.AddTag failed with error: %v", err) + } + if err := mgr.UpdateTag("mark/foo", UpdateTagOperationMarkAddStream([]uint64{2, 3})); err != nil { + t.Fatalf("Manager.UpdateTag failed with error: %v", err) + } + if got, want := mgr.ListTags()[0].Definition, "id:0,2,3"; got != want { + t.Fatalf("Manager.ListTags()[0].Definition = %v, want %v", got, want) + } + if err := mgr.UpdateTag("mark/foo", UpdateTagOperationMarkDelStream([]uint64{2})); err != nil { + t.Fatalf("Manager.UpdateTag failed with error: %v", err) + } + if got, want := mgr.ListTags()[0].Definition, "id:0,3"; got != want { + t.Fatalf("Manager.ListTags()[0].Definition = %v, want %v", got, want) + } + if err := mgr.UpdateTag("mark/foo", UpdateTagOperationUpdateName("mark/bar")); err != nil { + t.Fatalf("Manager.UpdateTag failed with error: %v", err) + } + if got, want := mgr.ListTags()[0].Name, "mark/bar"; got != want { + t.Fatalf("Manager.ListTags()[0].Name = %v, want %v", got, want) + } + if err := mgr.UpdateTag("mark/bar", UpdateTagOperationUpdateColor("red")); err != nil { + t.Fatalf("Manager.UpdateTag failed with error: %v", err) + } + if got, want := mgr.ListTags()[0].Color, "red"; got != want { + t.Fatalf("Manager.ListTags()[0].Color = %v, want %v", got, want) + } + if err := mgr.AddTag("tag/foo", "blue", "port:123"); err != nil { + t.Fatalf("Manager.AddTag failed with error: %v", err) + } + if err := mgr.AddTag("tag/bar", "blue", "tag:foo"); err != nil { + t.Fatalf("Manager.AddTag failed with error: %v", err) + } + if err := mgr.DelTag("tag/foo"); err == nil { + t.Fatalf("Manager.DelTag succeeded, want error") + } + if err := mgr.UpdateTag("tag/bar", UpdateTagOperationUpdateQuery("port:123")); err != nil { + t.Fatalf("Manager.UpdateTag failed with error: %v", err) + } + if err := mgr.DelTag("tag/foo"); err != nil { + t.Fatalf("Manager.DelTag failed with error: %v", err) + } + if err := mgr.DelTag("tag/bar"); err != nil { + t.Fatalf("Manager.DelTag failed with error: %v", err) + } + if err := mgr.AddTag("tag/foo", "blue", "tag:foo"); err == nil { + t.Fatalf("Manager.DelTag succeeded, want error") + } +} + +func TestManagerRestartKeepsState(t *testing.T) { + dirs := makeTempdirs(t) + mgr := makeManager(t, dirs) + if err := mgr.AddTag("tag/foo", "red", "port:123"); err != nil { + mgr.Close() + t.Fatalf("Manager.AddTag failed with error: %v", err) + } + if err := mgr.AddTag("mark/foo", "red", "id:-1"); err != nil { + mgr.Close() + t.Fatalf("Manager.AddTag failed with error: %v", err) + } + mgr.Close() + mgr = makeManager(t, dirs) + if got, want := mgr.ListTags(), []TagInfo{ + { + Name: "mark/foo", + Color: "red", + Definition: "id:-1", + MatchingCount: 0, + UncertainCount: 0, + Referenced: false, + Converters: []string{}, + }, + { + Name: "tag/foo", + Color: "red", + Definition: "port:123", + MatchingCount: 0, + UncertainCount: 0, + Referenced: false, + Converters: []string{}, + }, + }; !reflect.DeepEqual(got, want) { + t.Fatalf("Manager.ListTags() = %v, want %v", got, want) + } + defer mgr.Close() +} + +func TestManagerPcapOverIP(t *testing.T) { + dirs := makeTempdirs(t) + mgr := makeManager(t, dirs) + defer mgr.Close() + if err := mgr.AddPcapOverIPEndpoint("foo"); err == nil { + t.Fatalf("Manager.AddPcapOverIPEndpoint succeeded, want error") + } + if err := mgr.DelPcapOverIPEndpoint("foo"); err == nil { + t.Fatalf("Manager.DelPcapOverIPEndpoint succeeded, want error") + } + listener, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}) + if err != nil { + t.Fatalf("net.ListenTCP failed with error: %v", err) + } + events, eventsCloser := mgr.Listen() + if err := mgr.AddPcapOverIPEndpoint(listener.Addr().String()); err != nil { + t.Fatalf("Manager.AddPcapOverIPEndpoint failed: %v", err) + } + if err := mgr.AddPcapOverIPEndpoint(listener.Addr().String()); err == nil { + t.Fatalf("Manager.AddPcapOverIPEndpoint succeeded, want error") + } + conn, err := listener.AcceptTCP() + if err != nil { + t.Fatalf("listener.AcceptTCP failed with error: %v", err) + } + defer conn.Close() + pkt := makeUDPPacket("1.2.3.4:1234", "5.6.7.8:5678", time.Now(), "foo9001") + wr := pcapgo.NewWriter(conn) + if err := wr.WriteFileHeader(0xffff, pkt.linkType); err != nil { + t.Fatalf("pcapgo.NewWriter failed with error: %v", err) + } + if err := wr.WritePacket(pkt.ci, pkt.data); err != nil { + t.Fatalf("pcapgo.Writer.WritePacket failed with error: %v", err) + } + waitForEvent(t, events, eventsCloser, "pcapProcessed") + if got := mgr.ListPcapOverIPEndpoints(); len(got) != 1 || got[0].ReceivedPackets != 1 || got[0].LastConnected == 0 { + t.Fatalf("Manager.ListPcapOverIPEndpoints() = %v, want [{ReceivedPackets:1, LastConnected:non-zero}]", got) + } + v := mgr.GetView() + c, err := v.Stream(0) + if err != nil { + t.Fatalf("View.Stream failed with error: %v", err) + } + d, err := c.Data("") + if err != nil { + t.Fatalf("Stream.Data failed with error: %v", err) + } + if len(d) != 1 || string(d[0].Content) != "foo9001" { + t.Fatalf("Stream.Data = %v, want [{Content:foo}]", d) + } + if got := mgr.KnownPcaps(); len(got) != 1 || got[0].PacketCount != 1 { + t.Fatalf("Manager.KnownPcaps() = %+v, want [{PacketCount:1}]", got) + } + defer v.Release() + if err := mgr.DelPcapOverIPEndpoint(listener.Addr().String()); err != nil { + t.Fatalf("Manager.DelPcapOverIPEndpoint failed: %v", err) + } +} + +func importSomePackets(t *testing.T, mgr *Manager, t1 time.Time, eventType string) { + pcaps, err := writePcaps(mgr.PcapDir, []pcapOverIPPacket{ + makeUDPPacket("1.2.3.4:1", "4.3.2.1:4321", t1.Add(time.Second*0), "foo"), + makeUDPPacket("1.2.3.4:2", "4.3.2.1:4321", t1.Add(time.Second*1), "bar"), + makeUDPPacket("1.2.3.4:3", "4.3.2.1:4321", t1.Add(time.Second*2), "baz"), + makeUDPPacket("1.2.3.4:4", "4.3.2.1:4321", t1.Add(time.Second*3), "qux"), + }) + if err != nil { + t.Fatalf("writePcaps failed with error: %v", err) + } + if eventType != "" { + events, eventCloser := mgr.Listen() + mgr.ImportPcaps(pcaps) + waitForEvent(t, events, eventCloser, eventType) + } else { + mgr.ImportPcaps(pcaps) + } +} + +func TestWebhooks(t *testing.T) { + dirs := makeTempdirs(t) + mgr := makeManager(t, dirs) + defer mgr.Close() + if got := mgr.ListPcapProcessorWebhooks(); len(got) != 0 { + t.Fatalf("Manager.ListPcapProcessorWebhooks() = %v, want []", got) + } + receivedPcaps := make(chan []string, 1) + server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + data, err := io.ReadAll(req.Body) + if err != nil { + rw.WriteHeader(http.StatusInternalServerError) + if _, err := rw.Write([]byte(err.Error())); err != nil { + panic(err) + } + return + } + res := []string(nil) + if err := json.Unmarshal(data, &res); err != nil { + rw.WriteHeader(http.StatusInternalServerError) + if _, err := rw.Write([]byte(err.Error())); err != nil { + panic(err) + } + return + } + if _, err := rw.Write([]byte("ok")); err != nil { + panic(err) + } + receivedPcaps <- res + close(receivedPcaps) + })) + defer server.Close() + + if err := mgr.AddPcapProcessorWebhook(server.URL); err != nil { + t.Fatalf("Manager.AddPcapProcessorWebhook failed with error: %v", err) + } + if err := mgr.AddPcapProcessorWebhook(server.URL); err == nil { + t.Fatalf("Manager.AddPcapProcessorWebhook succeeded, want error") + } + if err := mgr.DelPcapProcessorWebhook("foo"); err == nil { + t.Fatalf("Manager.DelPcapProcessorWebhook succeeded, want error") + } + if got, want := mgr.ListPcapProcessorWebhooks(), []string{server.URL}; !reflect.DeepEqual(got, want) { + t.Fatalf("Manager.ListPcapProcessorWebhooks() = %v, want %v", got, want) + } + importSomePackets(t, mgr, t1, "") + if got, want := <-receivedPcaps, []string{path.Join(mgr.PcapDir, mgr.KnownPcaps()[0].Filename)}; !reflect.DeepEqual(got, want) { + t.Fatalf("receivedPcaps = %v, want %v", got, want) + } + if got, want := mgr.ListPcapProcessorWebhooks(), []string{server.URL}; !slices.Equal(got, want) { + t.Fatalf("Manager.ListPcapProcessorWebhooks() = %v, want %v", got, want) + } + if err := mgr.DelPcapProcessorWebhook(server.URL); err != nil { + t.Fatalf("Manager.DelPcapProcessorWebhook failed with error: %v", err) + } + if got := mgr.ListPcapProcessorWebhooks(); len(got) != 0 { + t.Fatalf("Manager.ListPcapProcessorWebhooks() = %v, want []", got) + } +} + +func TestManagerMerging(t *testing.T) { + dirs := makeTempdirs(t) + mgr := makeManager(t, dirs) + defer mgr.Close() + events, eventsCloser := mgr.Listen() + defer eventsCloser() + for i := 0; i < 10; i++ { + pcaps, err := writePcaps(mgr.PcapDir, []pcapOverIPPacket{ + makeUDPPacket(fmt.Sprintf("9.0.0.%d:123", i), "2.3.4.5:9001", t1.Add(time.Second*time.Duration(i)), "foo"), + }) + if err != nil { + t.Fatalf("writePcaps failed with error: %v", err) + } + mgr.ImportPcaps(pcaps) + waitForEvent(t, events, func() {}, "pcapProcessed") + } + waitForEvent(t, events, eventsCloser, "indexesMerged") +} + +func TestManagerView(t *testing.T) { + dirs := makeTempdirs(t) + mgr := makeManager(t, dirs) + defer mgr.Close() + importSomePackets(t, mgr, t1, "pcapProcessed") + if err := mgr.AddTag("tag/foo", "red", ""); err != nil { + t.Fatalf("Manager.AddTag failed with error: %v", err) + } + view := mgr.GetView() + defer view.Release() + if err := view.AllStreams(context.Background(), func(sc StreamContext) error { + return nil + }, PrefetchAllTags()); err != nil { + t.Fatalf("View.AllStreams failed with error: %v", err) + } + rt, err := view.ReferenceTime() + if err != nil { + t.Fatalf("View.ReferenceTime failed with error: %v", err) + } + if rt.UTC() != t1.UTC() { + t.Fatalf("View.ReferenceTime = %v, want %v", rt, t1) + } + sc, err := view.Stream(0) + if err != nil { + t.Fatalf("View.Stream failed with error: %v", err) + } + if got := sc.Stream().ClientPort; got != 1 { + t.Fatalf("StreamContext.Stream().ClientPort = %v, want 1", got) + } + if got, err := sc.HasTag("foo"); err != nil || got { + t.Fatalf("StreamContext.HasTag(\"foo\") = %v, %v, want false, nil", got, err) + } + if got, err := sc.AllTags(); err != nil || len(got) != 1 || got[0] != "tag/foo" { + t.Fatalf("StreamContext.AllTags() = %v, %v, want [], nil", got, err) + } + if _, err := sc.Data("bar"); err == nil { + t.Fatalf("StreamContext.Data(\"bar\") succeeded, want error") + } + if got, err := sc.Data(""); err != nil || len(got) != 1 || string(got[0].Content) != "foo" { + t.Fatalf("StreamContext.Data(\"\") = %+v, %v, want [{Content:foo}], nil", got, err) + } + if got, err := sc.AllConverters(); err != nil || len(got) != 0 { + t.Fatalf("StreamContext.AllConverters() = %v, %v, want [], nil", got, err) + } + q, err := query.Parse("") + if err != nil { + t.Fatalf("query.Parse failed: %v", err) + } + if err := mgr.AddTag("tag/bar", "red", ""); err != nil { + t.Fatalf("Manager.AddTag failed with error: %v", err) + } + if m, n, err := view.SearchStreams(context.Background(), q, func(StreamContext) error { + return nil + }, Limit(1, 1), PrefetchAllTags()); err != nil || n != 1 || !m { + t.Fatalf("View.SearchStreams() = %v, %v, %v, want true, 1, nil", m, n, err) + } +} + +func waitForEvent(t *testing.T, listener <-chan Event, listenerCloser func(), eventType string) { + for e := range listener { + t.Logf("event: %+v\n", e) + if e.Type == eventType { + break + } + } + if listenerCloser != nil { + listenerCloser() + } +} + +func TestConverters(t *testing.T) { + dirs := makeTempdirs(t) + addConverter(dirs, "foo") + mgr := makeManager(t, dirs) + defer mgr.Close() + if got := mgr.ListConverters(); len(got) != 1 || got[0].Name != "foo" { + gotReadable := []converters.Statistics(nil) + for _, s := range got { + gotReadable = append(gotReadable, *s) + } + t.Fatalf("Manager.ListConverters() = %v, want [{Name:foo}]", gotReadable) + } + listener, listenerCloser := mgr.Listen() + addConverter(dirs, "bar") + waitForEvent(t, listener, listenerCloser, "converterAdded") + if got := mgr.ListConverters(); len(got) != 2 || got[0].Name != "bar" || got[1].Name != "foo" { + gotReadable := []converters.Statistics(nil) + for _, s := range got { + gotReadable = append(gotReadable, *s) + } + t.Fatalf("Manager.ListConverters() = %v, want [{Name:bar}, {Name:foo}]", gotReadable) + } + if err := mgr.ResetConverter("foo"); err != nil { + t.Fatalf("Manager.ResetConverter failed with error: %v", err) + } + if err := mgr.ResetConverter("baz"); err == nil { + t.Fatalf("Manager.ResetConverter succeeded, want error") + } + listener, listenerCloser = mgr.Listen() + defer listenerCloser() + if err := os.Remove(path.Join(dirs.converter, "bar")); err != nil { + t.Fatalf("os.Remove failed with error: %v", err) + } + waitForEvent(t, listener, nil, "converterDeleted") + if got := mgr.ListConverters(); len(got) != 1 || got[0].Name != "foo" { + gotReadable := []converters.Statistics(nil) + for _, s := range got { + gotReadable = append(gotReadable, *s) + } + t.Fatalf("Manager.ListConverters() = %v, want [{Name:foo}]", gotReadable) + } + importSomePackets(t, mgr, t1, "pcapProcessed") + if err := mgr.AddTag("tag/foo", "red", ""); err != nil { + t.Fatalf("Manager.AddTag failed with error: %v", err) + } + if err := mgr.UpdateTag("tag/foo", UpdateTagOperationSetConverter([]string{"foo"})); err != nil { + t.Fatalf("Manager.UpdateTag failed with error: %v", err) + } + waitForEvent(t, listener, nil, "converterCompleted") + view := mgr.GetView() + defer view.Release() + if err := view.AllStreams(context.Background(), func(sc StreamContext) error { + data, err := sc.Data("foo") + if err != nil { + return err + } + if len(data) != 1 || !strings.Contains(string(data[0].Content), fmt.Sprintf("\"StreamID\": %d", sc.Stream().ID())) { + t.Log(string(data[0].Content)) + return fmt.Errorf("StreamContext.Data(\"foo\") = %v, want [{Content:foo}]", data) + } + if got, err := sc.AllConverters(); err != nil || len(got) != 1 || got[0] != "foo" { + t.Fatalf("StreamContext.AllConverters returned: %v, %v, want [foo], nil", got, err) + } + return nil + }, PrefetchTags([]string{"tag/foo"})); err != nil { + t.Fatalf("View.AllStreams failed with error: %v", err) + } + if err := mgr.UpdateTag("tag/foo", UpdateTagOperationSetConverter(nil)); err != nil { + t.Fatalf("Manager.UpdateTag failed with error: %v", err) + } + if got := mgr.ListTags(); len(got) != 1 || len(got[0].Converters) != 0 { + t.Fatalf("ListTags returned %v, want [{Converters: []}]", got) + } +} diff --git a/internal/index/manager/testdata/test_converter.py b/internal/index/manager/testdata/test_converter.py new file mode 100755 index 0000000..d244665 --- /dev/null +++ b/internal/index/manager/testdata/test_converter.py @@ -0,0 +1,22 @@ +#!/usr/bin/python3 +import base64 +import json +import sys + +lines = [] +while 1: + line = sys.stdin.readline().strip() + if line != "": + lines.append(json.loads(line)) + continue + print(json.dumps({ + "Direction": "client-to-server", + "Content": base64.b64encode(json.dumps({ + "converter": sys.argv[0], + "info": lines[0], + "data": lines[1:] + }).encode()).decode() + })) + print() + print("{}", flush=True) + lines = [] diff --git a/internal/index/merger_test.go b/internal/index/merger_test.go new file mode 100644 index 0000000..dcdb077 --- /dev/null +++ b/internal/index/merger_test.go @@ -0,0 +1,120 @@ +package index + +import ( + "encoding/json" + "reflect" + "testing" + "time" +) + +func TestMerge(t *testing.T) { + tmpDir := t.TempDir() + t1, err := time.Parse(time.RFC3339, "2020-01-01T12:00:00Z") + if err != nil { + t.Fatalf("time.Parse failed with error: %v", err) + } + + inputs := []map[uint64]streamInfo{ + { + + 0: makeStream("1.2.3.40:1", "105.6.7.8:9", t1.Add(time.Hour*1), []string{"Lorem", "ipsum", "dolor", "sit", "amet,"}), + 1: makeStream("1.2.30.4:2", "5.106.7.8:8", t1.Add(time.Hour*2), []string{"", "sed", "do", "eiusmod", "tempor"}), + 2: makeStream("[12::34]:3", "[::1234]:7", t1.Add(time.Hour*3), []string{"magna", "aliqua.", "Ut", "enim", "ad"}), + 10: makeStream("1.20.3.4:4", "5.6.107.8:6", t1.Add(time.Hour*4), []string{"", "exercitation", "ullamco", "laboris"}), + 11: makeStream("10.2.3.4:5", "5.6.7.108:5", t1.Add(time.Hour*5), []string{"commodo", "consequat.", "Duis", "aute"}), + 12: makeStream("[0::34:12]:6", "[0::12:34]:4", t1.Add(time.Hour*6), []string{"", "in", "voluptate", "velit", "esse"}), + }, + { + 1: makeStream("1.2.30.4:2", "5.106.7.8:8", t1.Add(time.Hour*2), []string{"", "sed", "do", "eiusmod", "tempor", "incididunt", "ut", "labore", "et", "dolore"}), + 2: makeStream("[12::34]:3", "[::1234]:7", t1.Add(time.Hour*3), []string{"magna", "aliqua.", "Ut", "enim", "ad", "minim", "veniam,", "quis", "nostrud"}), + 3: makeStream("1.2.3.40:1", "105.6.7.8:9", t1.Add(time.Hour*4), []string{"Lorem", "ipsum", "dolor", "sit", "amet,", "consectetur", "adipiscing", "elit,"}), + 11: makeStream("10.2.3.4:5", "5.6.7.108:5", t1.Add(time.Hour*5), []string{"commodo", "consequat.", "Duis", "aute", "irure", "dolor", "in", "reprehenderit"}), + 12: makeStream("[0::34:12]:6", "[0::12:34]:4", t1.Add(time.Hour*6), []string{"", "in", "voluptate", "velit", "esse", "cillum", "dolore", "eu", "fugiat"}), + 13: makeStream("1.20.3.4:4", "5.6.107.8:6", t1.Add(time.Hour*7), []string{"", "exercitation", "ullamco", "laboris", "nisi", "ut", "aliquip", "ex", "ea"}), + }, + } + + indexes := []*Reader(nil) + for _, streams := range inputs { + index, err := makeIndex(tmpDir, streams, nil) + if err != nil { + t.Errorf("makeIndex failed with error: %v", err) + } + indexes = append(indexes, index) + } + + wantStreams := map[uint64]int{ + 0: 0, + 1: 1, + 2: 1, + 3: 1, + 10: 0, + 11: 1, + 12: 1, + 13: 1, + } + + merged, err := Merge(tmpDir, indexes) + if err != nil { + t.Errorf("Merge failed with error: %v", err) + } + + if len(merged) != 1 { + t.Fatalf("Expected 1 merged index, but got %d", len(merged)) + } + + gotJson := map[uint64][]byte{} + gotData := map[uint64][]Data{} + err = merged[0].AllStreams(func(s *Stream) error { + json, err := s.MarshalJSON() + if err != nil { + return err + } + data, err := s.Data() + if err != nil { + return err + } + gotJson[s.StreamID] = json + gotData[s.StreamID] = data + return nil + }) + if err != nil { + t.Errorf("AllStreams failed with error: %v", err) + } + if len(gotJson) != len(wantStreams) { + t.Fatalf("Expected %d streams, but got %d", len(wantStreams), len(gotJson)) + } + for streamID, i := range wantStreams { + wantStream, err := indexes[i].StreamByID(streamID) + if err != nil { + t.Errorf("StreamByID failed with error: %v", err) + } + wantJson, err := wantStream.MarshalJSON() + if err != nil { + t.Errorf("MarshalJSON failed with error: %v", err) + } + var got, want map[string]interface{} + if err := json.Unmarshal(wantJson, &want); err != nil { + t.Errorf("json.Unmarshal failed with error: %v", err) + } + if err := json.Unmarshal(gotJson[streamID], &got); err != nil { + t.Errorf("json.Unmarshal failed with error: %v", err) + } + delete(got, "Index") + delete(want, "Index") + if !reflect.DeepEqual(got, want) { + t.Errorf("Stream %d mismatch:\nGot: %v\nWant: %v", streamID, got, want) + } + wantData, err := wantStream.Data() + if err != nil { + t.Errorf("Data failed with error: %v", err) + } + if !reflect.DeepEqual(gotData[streamID], wantData) { + t.Errorf("Stream %d data mismatch:\nGot: %v\nWant: %v", streamID, gotData[streamID], wantData) + } + + } + if err := merged[0].Close(); err != nil { + t.Errorf("Close failed with error: %v", err) + } +} diff --git a/internal/index/reader_test.go b/internal/index/reader_test.go new file mode 100644 index 0000000..6b6afaf --- /dev/null +++ b/internal/index/reader_test.go @@ -0,0 +1,68 @@ +package index + +import ( + "fmt" + "testing" + "time" + + pcapmetadata "github.com/spq/pkappa2/internal/tools/pcapMetadata" +) + +func TestReader(t *testing.T) { + tmpDir := t.TempDir() + t1, err := time.Parse(time.RFC3339, "2020-01-01T12:00:00Z") + if err != nil { + t.Fatalf("time.Parse failed with error: %v", err) + } + streams := map[uint64]streamInfo{} + for i := 0; i < 10; i++ { + streams[uint64(i*100)] = makeStream("1.2.3.4:1234", "4.3.2.1:4321", t1.Add(time.Hour*time.Duration(i)), []string{fmt.Sprintf("foo%d", i)}) + } + idx, err := makeIndex(tmpDir, streams, nil) + if err != nil { + t.Fatalf("makeIndex failed: %v", err) + } + if got, want := idx.PacketCount(), 30; got != want { + t.Errorf("Reader.PacketCount() = %v, want %v", got, want) + } + gotStreams := idx.StreamIDs() + if len(gotStreams) != len(streams) { + t.Fatalf("len(Reader.StreamIDs()) = %v, want %v", len(gotStreams), len(streams)) + } + for streamID, streamIndex := range gotStreams { + s1, err := idx.StreamByID(streamID) + if err != nil { + t.Fatalf("Reader.StreamByID failed with error: %v", err) + } + if s1.index != streamIndex { + t.Errorf("streamIndex mismatch: %v != %v", s1.index, streamIndex) + } + s2, err := idx.streamByIndex(streamIndex) + if err != nil { + t.Fatalf("Reader.streamByIndex failed with error: %v", err) + } + if s2.StreamID != streamID { + t.Errorf("streamID mismatch: %v != %v", s2.StreamID, streamID) + } + s3, err := idx.StreamByFirstPacketSource(streams[streamID].s.Packets[0].AncillaryData[0].(*pcapmetadata.PcapMetadata).PcapInfo.Filename, 0) + if err != nil { + t.Fatalf("Reader.StreamByFirstPacketSource failed with error: %v", err) + } + if s3.index != streamIndex { + t.Errorf("streamIndex mismatch: %v != %v", s3.index, streamIndex) + } + packets, err := s1.Packets() + if err != nil { + t.Fatalf("Stream.Packets failed with error: %v", err) + } + if len(packets) != 3 { + t.Errorf("len(Stream.Packets()) = %v, want 3", len(packets)) + } + if got, want := s1.FirstPacket().UTC(), t1.Add(time.Hour*time.Duration(streamID/100)).UTC(); !got.Equal(want) { + t.Errorf("Stream[%d].FirstPacket() = %v, want %v", streamID, got, want) + } + if got, want := s1.LastPacket().UTC(), t1.Add(time.Hour*time.Duration(streamID/100)+time.Second*3).UTC(); !got.Equal(want) { + t.Errorf("Stream[%d].LastPacket() = %v, want %v", streamID, got, want) + } + } +} diff --git a/internal/index/search_test.go b/internal/index/search_test.go index 52ca761..abe5239 100644 --- a/internal/index/search_test.go +++ b/internal/index/search_test.go @@ -26,7 +26,19 @@ type ( } ) -func makeStream(client, server string, t1 time.Time, data []string, converterData ...[]string) streamInfo { +var ( + t1 time.Time +) + +func init() { + var err error + t1, err = time.Parse(time.RFC3339, "2020-01-01T12:00:00Z") + if err != nil { + panic(err) + } +} + +func makeStream(client, server string, t time.Time, data []string, converterData ...[]string) streamInfo { first := reassembly.TCPDirClientToServer if len(data) != 0 && data[0] == "" { data = data[1:] @@ -34,14 +46,14 @@ func makeStream(client, server string, t1 time.Time, data []string, converterDat } clientAddrPort := netip.MustParseAddrPort(client) serverAddrPort := netip.MustParseAddrPort(server) - t2 := t1.Add(time.Second * time.Duration(2+len(data))) + t2 := t.Add(time.Second * time.Duration(2+len(data))) t3 := t2.Add(time.Minute) pcapinfo := &pcapmetadata.PcapInfo{ - Filename: "test.pcap", + Filename: fmt.Sprintf("%s_%s_%d.pcap", client, server, t.UnixNano()), Filesize: 123, - PacketTimestampMin: t1, + PacketTimestampMin: t, PacketTimestampMax: t2, ParseTime: t3, @@ -50,7 +62,7 @@ func makeStream(client, server string, t1 time.Time, data []string, converterDat packets := []gopacket.CaptureInfo(nil) packetDirections := []reassembly.TCPFlowDirection(nil) packets = append(packets, gopacket.CaptureInfo{ - Timestamp: t1, + Timestamp: t, CaptureLength: 123, Length: 123, }) @@ -58,7 +70,7 @@ func makeStream(client, server string, t1 time.Time, data []string, converterDat streamData := []streams.StreamData(nil) for i, d := range data { packets = append(packets, gopacket.CaptureInfo{ - Timestamp: t1.Add(time.Second * time.Duration(i+1)), + Timestamp: t.Add(time.Second * time.Duration(i+1)), CaptureLength: 123, Length: 123, }) @@ -95,6 +107,39 @@ func makeStream(client, server string, t1 time.Time, data []string, converterDat } } +func makeIndex(tmpDir string, streams map[uint64]streamInfo, converters *map[string]ConverterAccess) (*Reader, error) { + w, err := NewWriter(tools.MakeFilename(tmpDir, "idx")) + if err != nil { + return nil, err + } + for streamID, si := range streams { + ok, err := w.AddStream(&si.s, streamID) + if err != nil { + return nil, err + } + if !ok { + return nil, fmt.Errorf("Stream couldn't be added to index") + } + for i, d := range streams[streamID].c { + if d == nil { + continue + } + c := fmt.Sprintf("c%d", i) + if _, ok := (*converters)[c]; !ok { + (*converters)[c] = &fakeConverter{ + data: make(map[uint64][]string), + } + } + (*converters)[c].(*fakeConverter).data[streamID] = d + } + } + r, err := w.Finalize() + if err != nil { + return nil, err + } + return r, nil +} + func (c *fakeConverter) Data(stream *Stream, moreDetails bool) (data []Data, clientBytes, serverBytes uint64, wasCached bool, err error) { return nil, 0, 0, false, nil } @@ -122,7 +167,6 @@ func (c *fakeConverter) DataForSearch(streamID uint64) ([2][]byte, [][2]int, uin func TestSearchStreams(t *testing.T) { tmpDir := t.TempDir() - t1, _ := time.Parse(time.RFC3339, "2020-01-01T12:00:00Z00:00") testCases := []struct { name string streams []streamInfo @@ -275,41 +319,165 @@ func TestSearchStreams(t *testing.T) { "@sub:id:0,1 @sub:cdata:\"(?Pneedle[0-9])\" cdata:@sub:var@ id:2", []uint64{2}, }, + { + "test protocol:tcp query", + []streamInfo{ + makeStream("192.168.0.100:123", "192.168.0.1:80", t1.Add(time.Hour*1), []string{"needle0"}), + }, + "protocol:tcp", + []uint64{0}, + }, + { + "test protocol:udp query", + []streamInfo{ + makeStream("192.168.0.100:123", "192.168.0.1:80", t1.Add(time.Hour*1), []string{"needle0"}), + }, + "protocol:udp", + []uint64{}, + }, + { + "test ftime query", + []streamInfo{ + makeStream("192.168.0.100:123", "192.168.0.1:80", t1.Add(time.Hour*1), []string{"needle0"}), + makeStream("192.168.0.101:123", "192.168.0.1:80", t1.Add(time.Hour*2), []string{"needle1"}), + makeStream("192.168.0.102:123", "192.168.0.1:80", t1.Add(time.Hour*3), []string{"needle2"}), + }, + fmt.Sprintf(`ftime:"%s"`, t1.Add(time.Hour*2).Local().Format("2006-01-02 1504")), + []uint64{1}, + }, + { + "test ltime query", + []streamInfo{ + makeStream("192.168.0.100:123", "192.168.0.1:80", t1.Add(time.Hour*1), []string{"needle0"}), + makeStream("192.168.0.101:123", "192.168.0.1:80", t1.Add(time.Hour*2), []string{"needle1"}), + makeStream("192.168.0.102:123", "192.168.0.1:80", t1.Add(time.Hour*3), []string{"needle2"}), + }, + fmt.Sprintf(`ltime:":%s"`, t1.Add(time.Hour*2).Local().Format("2006-01-02 1504")), + []uint64{0}, + }, + { + "sort by id", + []streamInfo{ + makeStream("192.168.0.100:123", "192.168.0.1:80", t1.Add(time.Hour*1), []string{"needle0"}), + makeStream("192.168.0.100:123", "192.168.0.1:80", t1.Add(time.Hour*2), []string{"needle1"}), + makeStream("192.168.0.100:123", "192.168.0.1:80", t1.Add(time.Hour*3), []string{"needle2"}), + }, + "sort:id", + []uint64{0, 1, 2}, + }, + { + "sort by -id", + []streamInfo{ + makeStream("192.168.0.100:123", "192.168.0.1:80", t1.Add(time.Hour*1), []string{"needle0"}), + makeStream("192.168.0.100:123", "192.168.0.1:80", t1.Add(time.Hour*2), []string{"needle1"}), + makeStream("192.168.0.100:123", "192.168.0.1:80", t1.Add(time.Hour*3), []string{"needle2"}), + }, + "sort:-id", + []uint64{2, 1, 0}, + }, + { + "sort by ftime", + []streamInfo{ + makeStream("192.168.0.100:123", "192.168.0.1:80", t1.Add(time.Hour*2), []string{"needle0"}), + makeStream("192.168.0.100:123", "192.168.0.1:80", t1.Add(time.Hour*1), []string{"needle1"}), + makeStream("192.168.0.100:123", "192.168.0.1:80", t1.Add(time.Hour*3), []string{"needle2"}), + }, + "sort:ftime", + []uint64{1, 0, 2}, + }, + { + "sort by ltime", + []streamInfo{ + makeStream("192.168.0.100:123", "192.168.0.1:80", t1.Add(time.Hour*2), []string{"needle0"}), + makeStream("192.168.0.100:123", "192.168.0.1:80", t1.Add(time.Hour*1), []string{"needle1", "foo"}), + makeStream("192.168.0.100:123", "192.168.0.1:80", t1.Add(time.Hour*1), []string{"needle2"}), + }, + "sort:ltime", + []uint64{2, 1, 0}, + }, + { + "sort by cbytes", + []streamInfo{ + makeStream("192.168.0.100:123", "192.168.0.1:80", t1.Add(time.Hour*1), []string{"AA"}), + makeStream("192.168.0.100:123", "192.168.0.1:80", t1.Add(time.Hour*2), []string{"AAA"}), + makeStream("192.168.0.100:123", "192.168.0.1:80", t1.Add(time.Hour*3), []string{"A"}), + }, + "sort:cbytes", + []uint64{2, 0, 1}, + }, + { + "sort by sbytes", + []streamInfo{ + makeStream("192.168.0.100:123", "192.168.0.1:80", t1.Add(time.Hour*1), []string{"foo", "A"}), + makeStream("192.168.0.100:123", "192.168.0.1:80", t1.Add(time.Hour*2), []string{"foo", "AAA"}), + makeStream("192.168.0.100:123", "192.168.0.1:80", t1.Add(time.Hour*3), []string{"foo", "AA"}), + }, + "sort:sbytes", + []uint64{0, 2, 1}, + }, + { + "sort by cport", + []streamInfo{ + makeStream("192.168.0.100:2", "192.168.0.1:80", t1.Add(time.Hour*1), []string{"foo"}), + makeStream("192.168.0.100:1", "192.168.0.1:80", t1.Add(time.Hour*2), []string{"foo"}), + makeStream("192.168.0.100:3", "192.168.0.1:80", t1.Add(time.Hour*3), []string{"foo"}), + }, + "sort:cport", + []uint64{1, 0, 2}, + }, + { + "sort by sport", + []streamInfo{ + makeStream("192.168.0.100:123", "192.168.0.1:3", t1.Add(time.Hour*1), []string{"foo"}), + makeStream("192.168.0.100:123", "192.168.0.1:1", t1.Add(time.Hour*2), []string{"foo"}), + makeStream("192.168.0.100:123", "192.168.0.1:2", t1.Add(time.Hour*3), []string{"foo"}), + }, + "sort:sport", + []uint64{1, 2, 0}, + }, + { + "sort by chost", + []streamInfo{ + makeStream("192.168.0.102:123", "192.168.0.1:80", t1.Add(time.Hour*1), []string{"foo"}), + makeStream("192.168.0.101:123", "192.168.0.1:80", t1.Add(time.Hour*2), []string{"foo"}), + makeStream("192.168.0.100:123", "192.168.0.1:80", t1.Add(time.Hour*3), []string{"foo"}), + }, + "sort:chost", + []uint64{2, 1, 0}, + }, + { + "sort by shost", + []streamInfo{ + makeStream("192.168.0.100:123", "192.168.0.1:80", t1.Add(time.Hour*1), []string{"foo"}), + makeStream("192.168.0.100:123", "192.168.0.2:80", t1.Add(time.Hour*2), []string{"foo"}), + makeStream("192.168.0.100:123", "192.168.0.3:80", t1.Add(time.Hour*3), []string{"foo"}), + }, + "sort:shost", + []uint64{0, 1, 2}, + }, + { + "sort by multiple", + []streamInfo{ + makeStream("192.168.0.100:2", "192.168.0.1:1", t1.Add(time.Hour*1), []string{"foo"}), + makeStream("192.168.0.100:1", "192.168.0.1:1", t1.Add(time.Hour*2), []string{"foo"}), + makeStream("192.168.0.100:1", "192.168.0.1:2", t1.Add(time.Hour*3), []string{"foo"}), + }, + "sort:cport,sport", + []uint64{1, 2, 0}, + }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { converters := map[string]ConverterAccess{ "dummy": &fakeConverter{}, } - w, err := NewWriter(tools.MakeFilename(tmpDir, "idx")) - if err != nil { - t.Fatalf("Error creating index writer: %v", err) - } - for streamID := range tc.streams { - streamID := uint64(streamID) - ok, err := w.AddStream(&tc.streams[streamID].s, streamID) - if err != nil { - t.Fatalf("Error adding stream to index: %v", err) - } - if !ok { - t.Fatalf("Stream couldn't be added to index") - } - for i, d := range tc.streams[streamID].c { - if d == nil { - continue - } - c := fmt.Sprintf("c%d", i) - if _, ok := converters[c]; !ok { - converters[c] = &fakeConverter{ - data: make(map[uint64][]string), - } - } - converters[c].(*fakeConverter).data[streamID] = d - } + streamsMap := make(map[uint64]streamInfo) + for i, s := range tc.streams { + streamsMap[uint64(i)] = s } - r, err := w.Finalize() + r, err := makeIndex(tmpDir, streamsMap, &converters) if err != nil { - t.Fatalf("Error finalizing index writer: %v", err) + t.Fatalf("Error creating index: %v", err) } t.Logf("using query %q", tc.query) q, err := query.Parse(tc.query) diff --git a/internal/tools/bitmask/bitmasks_test.go b/internal/tools/bitmask/bitmasks_test.go new file mode 100644 index 0000000..c0ec7c0 --- /dev/null +++ b/internal/tools/bitmask/bitmasks_test.go @@ -0,0 +1,488 @@ +package bitmask + +import ( + "testing" +) + +type ( + bitmask interface { + Set(uint) + Unset(uint) + Flip(uint) + Inject(uint, bool) + // TODO: Extract(uint) bool + } +) + +func compareMasks(t *testing.T, c ConnectedBitmask, s ShortBitmask, l LongBitmask) { + for i := uint(0); i < 100; i++ { + if cv, sv, lv := c.IsSet(i), s.IsSet(i), l.IsSet(i); cv != sv || cv != lv { + t.Fatalf("IsSet(%d): %v %v %v", i, cv, sv, lv) + } + } + if c1, s1, l1 := c.OnesCount(), s.OnesCount(), l.OnesCount(); c1 != s1 || c1 != l1 { + t.Errorf("OnesCount(): %d %d %d", c1, s1, l1) + } + if cl, sl, ll := c.Len(), s.Len(), l.Len(); cl != sl || cl != ll { + t.Errorf("Len(): %d %d %d", cl, sl, ll) + } + if cz, sz, lz := c.IsZero(), s.IsZero(), l.IsZero(); cz != sz || cz != lz { + t.Errorf("IsZero(): %v %v %v", cz, sz, lz) + } +} + +func TestBitmasks(t *testing.T) { + testcasesBase := []struct { + name string + f func(b bitmask) + }{ + { + name: "no-op", + f: func(b bitmask) { + }, + }, + { + name: "set 0", + f: func(b bitmask) { + b.Set(0) + }, + }, + { + name: "set 10-20", + f: func(b bitmask) { + for i := uint(10); i <= 20; i++ { + b.Set(i) + } + }, + }, + { + name: "set 10-20 every 2", + f: func(b bitmask) { + for i := uint(10); i <= 20; i += 2 { + b.Set(i) + } + }, + }, + { + name: "set 10-20 backwards", + f: func(b bitmask) { + for i := uint(20); i >= 10; i-- { + b.Set(i) + } + }, + }, + { + name: "set 10-20 backwards every 2", + f: func(b bitmask) { + for i := uint(20); i >= 10; i -= 2 { + b.Set(i) + } + }, + }, + { + name: "set 99", + f: func(b bitmask) { + b.Set(99) + }, + }, + { + name: "set 0 and 99", + f: func(b bitmask) { + b.Set(0) + b.Set(99) + }, + }, + { + name: "set 0 to 63 unordered", + f: func(b bitmask) { + for i := uint(0); i < 64; i++ { + b.Set(i ^ 18) + } + }, + }, + } + testcasesMod := []struct { + name string + f func(b bitmask) + }{ + { + name: "no-op", + f: func(b bitmask) { + }, + }, + { + name: "unset 0", + f: func(b bitmask) { + b.Unset(0) + }, + }, + { + name: "unset 10-20", + f: func(b bitmask) { + for i := uint(10); i <= 20; i++ { + b.Unset(i) + } + }, + }, + { + name: "unset 10-20 every 2", + f: func(b bitmask) { + for i := uint(10); i <= 20; i += 2 { + b.Unset(i) + } + }, + }, + { + name: "unset 10-20 backwards", + f: func(b bitmask) { + for i := uint(20); i >= 10; i-- { + b.Unset(i) + } + }, + }, + { + name: "unset 10-20 backwards every 2", + f: func(b bitmask) { + for i := uint(20); i >= 10; i -= 2 { + b.Unset(i) + } + }, + }, + { + name: "unset 99", + f: func(b bitmask) { + b.Unset(99) + }, + }, + { + name: "unset 0 and 99", + f: func(b bitmask) { + b.Unset(0) + b.Unset(99) + }, + }, + { + name: "unset 0 to 63 unordered", + f: func(b bitmask) { + for i := uint(0); i < 64; i++ { + b.Unset(i ^ 18) + } + }, + }, + { + name: "flip 0", + f: func(b bitmask) { + b.Flip(0) + }, + }, + { + name: "flip 10-20", + f: func(b bitmask) { + for i := uint(10); i <= 20; i++ { + b.Flip(i) + } + }, + }, + { + name: "flip 10-20 every 2", + f: func(b bitmask) { + for i := uint(10); i <= 20; i += 2 { + b.Flip(i) + } + }, + }, + { + name: "flip 10-20 backwards", + f: func(b bitmask) { + for i := uint(20); i >= 10; i-- { + b.Flip(i) + } + }, + }, + { + name: "flip 10-20 backwards every 2", + f: func(b bitmask) { + for i := uint(20); i >= 10; i -= 2 { + b.Flip(i) + } + }, + }, + { + name: "flip 99", + f: func(b bitmask) { + b.Flip(99) + }, + }, + { + name: "flip 0 and 99", + f: func(b bitmask) { + b.Flip(0) + b.Flip(99) + }, + }, + { + name: "flip 0 to 63 unordered", + f: func(b bitmask) { + for i := uint(0); i < 64; i++ { + b.Flip(i ^ 18) + } + }, + }, + { + name: "inject false 0", + f: func(b bitmask) { + b.Inject(0, false) + }, + }, + { + name: "inject false 10-20", + f: func(b bitmask) { + for i := uint(10); i <= 20; i++ { + b.Inject(i, false) + } + }, + }, + { + name: "inject false 10-20 every 2", + f: func(b bitmask) { + for i := uint(10); i <= 20; i += 2 { + b.Inject(i, false) + } + }, + }, + { + name: "inject false 10-20 backwards", + f: func(b bitmask) { + for i := uint(20); i >= 10; i-- { + b.Inject(i, false) + } + }, + }, + { + name: "inject false 10-20 backwards every 2", + f: func(b bitmask) { + for i := uint(20); i >= 10; i -= 2 { + b.Inject(i, false) + } + }, + }, + { + name: "inject false 99", + f: func(b bitmask) { + b.Inject(99, false) + }, + }, + { + name: "inject false 0 and 99", + f: func(b bitmask) { + b.Inject(0, false) + b.Inject(99, false) + }, + }, + { + name: "inject false 0 to 63 unordered", + f: func(b bitmask) { + for i := uint(0); i < 64; i++ { + b.Inject(i^18, false) + } + }, + }, + { + name: "inject true 0", + f: func(b bitmask) { + b.Inject(0, true) + }, + }, + { + name: "inject true 10-20", + f: func(b bitmask) { + for i := uint(10); i <= 20; i++ { + b.Inject(i, true) + } + }, + }, + { + name: "inject true 10-20 every 2", + f: func(b bitmask) { + for i := uint(10); i <= 20; i += 2 { + b.Inject(i, true) + } + }, + }, + { + name: "inject true 10-20 backwards", + f: func(b bitmask) { + for i := uint(20); i >= 10; i-- { + b.Inject(i, true) + } + }, + }, + { + name: "inject true 10-20 backwards every 2", + f: func(b bitmask) { + for i := uint(20); i >= 10; i -= 2 { + b.Inject(i, true) + } + }, + }, + { + name: "inject true 99", + f: func(b bitmask) { + b.Inject(99, true) + }, + }, + { + name: "inject true 0 and 99", + f: func(b bitmask) { + b.Inject(0, true) + b.Inject(99, true) + }, + }, + { + name: "inject true 0 to 63 unordered", + f: func(b bitmask) { + for i := uint(0); i < 64; i++ { + b.Inject(i^18, true) + } + }, + }, + } + testcasesCombine := []struct { + name string + f func(c *ConnectedBitmask, s *ShortBitmask, l *LongBitmask, c2 *ConnectedBitmask, s2 *ShortBitmask, l2 *LongBitmask) (*ConnectedBitmask, *ShortBitmask, *LongBitmask) + }{ + { + name: "And", + f: func(c *ConnectedBitmask, s *ShortBitmask, l *LongBitmask, c2 *ConnectedBitmask, s2 *ShortBitmask, l2 *LongBitmask) (*ConnectedBitmask, *ShortBitmask, *LongBitmask) { + c.And(*c2) + s.And(*s2) + l.And(*l2) + return c, s, l + }, + }, + { + name: "Or", + f: func(c *ConnectedBitmask, s *ShortBitmask, l *LongBitmask, c2 *ConnectedBitmask, s2 *ShortBitmask, l2 *LongBitmask) (*ConnectedBitmask, *ShortBitmask, *LongBitmask) { + c.Or(*c2) + s.Or(*s2) + l.Or(*l2) + return c, s, l + }, + }, + { + name: "Xor", + f: func(c *ConnectedBitmask, s *ShortBitmask, l *LongBitmask, c2 *ConnectedBitmask, s2 *ShortBitmask, l2 *LongBitmask) (*ConnectedBitmask, *ShortBitmask, *LongBitmask) { + c.Xor(*c2) + s.Xor(*s2) + l.Xor(*l2) + return c, s, l + }, + }, + { + name: "Sub", + f: func(c *ConnectedBitmask, s *ShortBitmask, l *LongBitmask, c2 *ConnectedBitmask, s2 *ShortBitmask, l2 *LongBitmask) (*ConnectedBitmask, *ShortBitmask, *LongBitmask) { + c.Sub(*c2) + s.Sub(*s2) + l.Sub(*l2) + return c, s, l + }, + }, + { + name: "AndCopy", + f: func(c *ConnectedBitmask, s *ShortBitmask, l *LongBitmask, c2 *ConnectedBitmask, s2 *ShortBitmask, l2 *LongBitmask) (*ConnectedBitmask, *ShortBitmask, *LongBitmask) { + c.AndCopy(*c2) + s.AndCopy(*s2) + l.AndCopy(*l2) + return c, s, l + }, + }, + { + name: "OrCopy", + f: func(c *ConnectedBitmask, s *ShortBitmask, l *LongBitmask, c2 *ConnectedBitmask, s2 *ShortBitmask, l2 *LongBitmask) (*ConnectedBitmask, *ShortBitmask, *LongBitmask) { + c.OrCopy(*c2) + s.OrCopy(*s2) + l.OrCopy(*l2) + return c, s, l + }, + }, + { + name: "XorCopy", + f: func(c *ConnectedBitmask, s *ShortBitmask, l *LongBitmask, c2 *ConnectedBitmask, s2 *ShortBitmask, l2 *LongBitmask) (*ConnectedBitmask, *ShortBitmask, *LongBitmask) { + c.XorCopy(*c2) + s.XorCopy(*s2) + l.XorCopy(*l2) + return c, s, l + }, + }, + { + name: "SubCopy", + f: func(c *ConnectedBitmask, s *ShortBitmask, l *LongBitmask, c2 *ConnectedBitmask, s2 *ShortBitmask, l2 *LongBitmask) (*ConnectedBitmask, *ShortBitmask, *LongBitmask) { + c.SubCopy(*c2) + s.SubCopy(*s2) + l.SubCopy(*l2) + return c, s, l + }, + }, + } + cList := []*ConnectedBitmask{} + sList := []*ShortBitmask{} + lList := []*LongBitmask{} + for _, tcBase := range testcasesBase { + for _, tcMod := range testcasesMod { + c := &ConnectedBitmask{} + s := &ShortBitmask{} + l := &LongBitmask{} + tcBase.f(c) + tcMod.f(c) + tcBase.f(s) + tcMod.f(s) + tcBase.f(l) + tcMod.f(l) + isNew := true + for i := range cList { + c2 := *cList[i] + s2 := *sList[i] + l2 := *lList[i] + ce, se, le := c.Equal(c2), s.Equal(s2), l.Equal(l2) + if ce != se || ce != le { + t.Fatalf("Equal(): %v %v %v", ce, se, le) + } + if ce { + isNew = false + } + } + if isNew { + cList = append(cList, c) + sList = append(sList, s) + lList = append(lList, l) + } + t.Run(tcBase.name, func(t *testing.T) { + t.Run(tcMod.name, func(t *testing.T) { + compareMasks(t, *c, *s, *l) + cc := c.Copy() + sc := s.Copy() + lc := l.Copy() + compareMasks(t, cc, sc, lc) + sc.Shrink() + lc.Shrink() + compareMasks(t, cc, sc, lc) + }) + }) + } + } + for _, tcCombine := range testcasesCombine { + t.Run(tcCombine.name, func(t *testing.T) { + for i := range cList { + for j := range cList { + c, c2 := cList[i].Copy(), cList[j].Copy() + s, s2 := sList[i].Copy(), sList[j].Copy() + l, l2 := lList[i].Copy(), lList[j].Copy() + cc, sc, lc := tcCombine.f(&c2, &s2, &l2, &c, &s, &l) + compareMasks(t, c2, s2, l2) + compareMasks(t, c, s, l) + compareMasks(t, *cc, *sc, *lc) + } + } + }) + } +}