diff --git a/integration-tests/pkg/mock_sensor/expect_conn.go b/integration-tests/pkg/mock_sensor/expect_conn.go index bf9b597183..7811db0cb5 100644 --- a/integration-tests/pkg/mock_sensor/expect_conn.go +++ b/integration-tests/pkg/mock_sensor/expect_conn.go @@ -33,7 +33,7 @@ loop: case <-timer: // we know they don't match at this point, but by using // ElementsMatch we get much better logging about the differences - return assert.ElementsMatch(t, expected, s.Connections(containerID), "timed out waiting for networks") + return assert.ElementsMatch(t, expected, s.Connections(containerID), "timed out waiting for network connections") case network := <-s.LiveConnections(): if network.GetContainerId() != containerID { continue loop @@ -72,7 +72,6 @@ loop: if conn.GetContainerId() != containerID { continue loop } - if len(s.Connections(containerID)) == n { return s.Connections(containerID) } @@ -80,6 +79,83 @@ loop: } } +// checkIfConnectionsMatchExpected compares a list of expected and observed connection match exactly. +func (s *MockSensor) checkIfConnectionsMatchExpected(t *testing.T, connections []types.NetworkInfo, expected []types.NetworkInfo) bool { + if len(connections) > len(expected) { + return assert.ElementsMatch(t, expected, connections, "networking connections do not match") + } + + if len(connections) == len(expected) { + types.SortConnections(connections) + for i := range expected { + if !expected[i].Equal(connections[i]) { + return assert.ElementsMatch(t, expected, connections, "networking connections do not match") + } + } + return true + } + return false +} + +// getConnectionsAndCompare gets the connections for a container, sorts them if order is set to true, and compares with a set of expected +// connections. If asssertMismatch is true and the set of observed connections does not match the set of expected connections an assert is +// triggered. If assertMismatch is not set then just return if the observed and expected connections match. +func (s *MockSensor) getConnectionsAndCompare(t *testing.T, containerID string, order bool, assertMismatch bool, expected ...types.NetworkInfo) bool { + connections := s.Connections(containerID) + if order { + types.SortConnections(connections) + } + success := s.checkIfConnectionsMatchExpected(t, connections, expected) + if assertMismatch && !success { + return assert.ElementsMatch(t, expected, connections, "networking connections do not match") + } + return success +} + +// CompareConnections compares a list of expected connections to the observed connections. This comparison is done at the beginning, when a new +// connection arrives, and after a timeout period. The number of connections must match and it can be specified if the order of the connections +// must match or not. The difference between this function and ExpectConnections is that ExpectConnections tolerates extra observed connections +// that are not expected. +func (s *MockSensor) CompareConnections(t *testing.T, containerID string, timeout time.Duration, order bool, expected ...types.NetworkInfo) bool { + if order { + types.SortConnections(expected) + } + + success := s.getConnectionsAndCompare(t, containerID, order, false, expected...) + if success { + return true + } + + timer := time.After(timeout) + + for { + select { + case <-timer: + return s.getConnectionsAndCompare(t, containerID, order, true, expected...) + case conn := <-s.LiveConnections(): + if conn.GetContainerId() != containerID { + continue + } + success := s.getConnectionsAndCompare(t, containerID, order, false, expected...) + if success { + return true + } + } + } +} + +// ExpectExactConnections requires that within a timeout period the networking connections involving containerID match a list of expected +// netwoking connections. +func (s *MockSensor) ExpectExactConnections(t *testing.T, containerID string, timeout time.Duration, expected ...types.NetworkInfo) bool { + return s.CompareConnections(t, containerID, timeout, false, expected...) +} + +// ExpectSameElementsConnections requires that within a timeout period the networking connections involving containerID match a list of expected +// netwoking connections, but the order of those connections do not have to match. +func (s *MockSensor) ExpectSameElementsConnections(t *testing.T, containerID string, timeout time.Duration, expected ...types.NetworkInfo) bool { + return s.CompareConnections(t, containerID, timeout, true, expected...) +} + // ExpectEndpoints waits up to the timeout for the gRPC server to receive // the list of expected Endpoints. It will first check to see if the endpoints // have been received already, and then monitor the live feed of endpoints diff --git a/integration-tests/pkg/mock_sensor/server.go b/integration-tests/pkg/mock_sensor/server.go index 5b99e5fcda..44291d6e1f 100644 --- a/integration-tests/pkg/mock_sensor/server.go +++ b/integration-tests/pkg/mock_sensor/server.go @@ -32,7 +32,6 @@ const ( // us to use any comparable type as the key) type ProcessMap map[types.ProcessInfo]interface{} type LineageMap map[types.ProcessLineage]interface{} -type ConnMap map[types.NetworkInfo]interface{} type EndpointMap map[types.EndpointInfo]interface{} type MockSensor struct { @@ -47,7 +46,7 @@ type MockSensor struct { processLineages map[string]LineageMap processMutex sync.Mutex - connections map[string]ConnMap + connections map[string][]types.NetworkInfo endpoints map[string]EndpointMap networkMutex sync.Mutex @@ -65,7 +64,7 @@ func NewMockSensor(test string) *MockSensor { testName: test, processes: make(map[string]ProcessMap), processLineages: make(map[string]LineageMap), - connections: make(map[string]ConnMap), + connections: make(map[string][]types.NetworkInfo), endpoints: make(map[string]EndpointMap), } } @@ -155,11 +154,7 @@ func (m *MockSensor) Connections(containerID string) []types.NetworkInfo { defer m.networkMutex.Unlock() if connections, ok := m.connections[containerID]; ok { - keys := make([]types.NetworkInfo, 0, len(connections)) - for k := range connections { - keys = append(keys, k) - } - return keys + return connections } return make([]types.NetworkInfo, 0) } @@ -171,8 +166,11 @@ func (m *MockSensor) HasConnection(containerID string, conn types.NetworkInfo) b defer m.networkMutex.Unlock() if conns, ok := m.connections[containerID]; ok { - _, exists := conns[conn] - return exists + for _, connection := range conns { + if connection.Equal(conn) { + return true + } + } } return false @@ -271,7 +269,7 @@ func (m *MockSensor) Stop() { m.processes = make(map[string]ProcessMap) m.processLineages = make(map[string]LineageMap) - m.connections = make(map[string]ConnMap) + m.connections = make(map[string][]types.NetworkInfo) m.endpoints = make(map[string]EndpointMap) m.processChannel.Stop() @@ -432,11 +430,10 @@ func (m *MockSensor) pushConnection(containerID string, connection *sensorAPI.Ne CloseTimestamp: connection.GetCloseTimestamp().String(), } - if connections, ok := m.connections[containerID]; ok { - connections[conn] = true + if _, ok := m.connections[containerID]; ok { + m.connections[containerID] = append(m.connections[containerID], conn) } else { - connections := ConnMap{conn: true} - m.connections[containerID] = connections + m.connections[containerID] = []types.NetworkInfo{conn} } } diff --git a/integration-tests/pkg/types/network.go b/integration-tests/pkg/types/network.go index dbe260ca4d..1f688e24cc 100644 --- a/integration-tests/pkg/types/network.go +++ b/integration-tests/pkg/types/network.go @@ -1,5 +1,9 @@ package types +import ( + "sort" +) + const ( NilTimestamp = "" ) @@ -16,3 +20,39 @@ func (n *NetworkInfo) IsActive() bool { // no close timestamp means the connection is open, and active return n.CloseTimestamp == NilTimestamp } + +func (n *NetworkInfo) Equal(other NetworkInfo) bool { + return n.LocalAddress == other.LocalAddress && + n.RemoteAddress == other.RemoteAddress && + n.Role == other.Role && + n.SocketFamily == other.SocketFamily && + n.IsActive() == other.IsActive() +} + +func (n *NetworkInfo) Less(other NetworkInfo) bool { + if n.LocalAddress != other.LocalAddress { + return n.LocalAddress < other.LocalAddress + } + + if n.RemoteAddress != other.RemoteAddress { + return n.RemoteAddress < other.RemoteAddress + } + + if n.Role != other.Role { + return n.Role < other.Role + } + + if n.SocketFamily != other.SocketFamily { + return n.SocketFamily < other.SocketFamily + } + + if n.IsActive() != other.IsActive() { + return n.IsActive() + } + + return false +} + +func SortConnections(connections []NetworkInfo) { + sort.Slice(connections, func(i, j int) bool { return connections[i].Less(connections[j]) }) +}