From b83a8d5306c0d4568b6d9f84a304743475c3bcad Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Wed, 12 Jul 2023 23:27:37 -0700 Subject: [PATCH 01/16] network: add internal/network package Signed-off-by: Achille Roussel --- internal/network/host.go | 47 ++ internal/network/host_darwin.go | 58 ++ internal/network/host_linux.go | 28 + internal/network/host_unix.go | 246 +++++++++ internal/network/network.go | 105 ++++ internal/network/network_unix.go | 43 ++ internal/network/virtual.go | 874 +++++++++++++++++++++++++++++++ internal/network/virtual_test.go | 238 +++++++++ 8 files changed, 1639 insertions(+) create mode 100644 internal/network/host.go create mode 100644 internal/network/host_darwin.go create mode 100644 internal/network/host_linux.go create mode 100644 internal/network/host_unix.go create mode 100644 internal/network/network.go create mode 100644 internal/network/network_unix.go create mode 100644 internal/network/virtual.go create mode 100644 internal/network/virtual_test.go diff --git a/internal/network/host.go b/internal/network/host.go new file mode 100644 index 00000000..1a90c3cc --- /dev/null +++ b/internal/network/host.go @@ -0,0 +1,47 @@ +package network + +import "net" + +func Host() Namespace { return hostNamespace{} } + +type hostNamespace struct{} + +func (hostNamespace) InterfaceByIndex(index int) (Interface, error) { + i, err := net.InterfaceByIndex(index) + if err != nil { + return nil, err + } + return hostInterface{i}, nil +} + +func (hostNamespace) InterfaceByName(name string) (Interface, error) { + i, err := net.InterfaceByName(name) + if err != nil { + return nil, err + } + return hostInterface{i}, nil +} + +func (hostNamespace) Interfaces() ([]Interface, error) { + interfaces, err := net.Interfaces() + if err != nil { + return nil, err + } + hostInterfaces := make([]Interface, len(interfaces)) + for i := range interfaces { + hostInterfaces[i] = hostInterface{&interfaces[i]} + } + return hostInterfaces, nil +} + +type hostInterface struct{ *net.Interface } + +func (i hostInterface) Index() int { return i.Interface.Index } + +func (i hostInterface) MTU() int { return i.Interface.MTU } + +func (i hostInterface) Name() string { return i.Interface.Name } + +func (i hostInterface) HardwareAddr() net.HardwareAddr { return i.Interface.HardwareAddr } + +func (i hostInterface) Flags() net.Flags { return i.Interface.Flags } diff --git a/internal/network/host_darwin.go b/internal/network/host_darwin.go new file mode 100644 index 00000000..4a6acfbd --- /dev/null +++ b/internal/network/host_darwin.go @@ -0,0 +1,58 @@ +package network + +import ( + "syscall" + + "golang.org/x/sys/unix" +) + +func (hostNamespace) Socket(family Family, socktype Socktype, protocol Protocol) (Socket, error) { + syscall.ForkLock.RLock() + defer syscall.ForkLock.RUnlock() + fd, err := ignoreEINTR2(func() (int, error) { + return unix.Socket(int(family), int(socktype), int(protocol)) + }) + if err != nil { + return nil, err + } + if err := setCloseOnExecAndNonBlocking(fd); err != nil { + unix.Close(fd) + return nil, err + } + return newHostSocket(fd, family, socktype), nil +} + +func (s *hostSocket) Accept() (Socket, Sockaddr, error) { + fd := s.acquire() + if fd < 0 { + return nil, nil, EBADF + } + defer s.release(fd) + syscall.ForkLock.RLock() + defer syscall.ForkLock.RUnlock() + conn, addr, err := ignoreEINTR3(func() (int, Sockaddr, error) { + return unix.Accept(fd) + }) + if err != nil { + return nil, nil, err + } + if err := setCloseOnExecAndNonBlocking(conn); err != nil { + unix.Close(conn) + return nil, nil, err + } + return newHostSocket(conn, s.family, s.socktype), addr, nil +} + +func setCloseOnExecAndNonBlocking(fd int) error { + if _, err := ignoreEINTR2(func() (int, error) { + return unix.FcntlInt(uintptr(fd), unix.F_SETFD, unix.O_CLOEXEC) + }); err != nil { + return err + } + if err := ignoreEINTR(func() error { + return unix.SetNonblock(fd, true) + }); err != nil { + return err + } + return nil +} diff --git a/internal/network/host_linux.go b/internal/network/host_linux.go new file mode 100644 index 00000000..7047b755 --- /dev/null +++ b/internal/network/host_linux.go @@ -0,0 +1,28 @@ +package network + +import "golang.org/x/sys/unix" + +func (hostNamespace) Socket(family Family, socktype Socktype, protocol Protocol) (Socket, error) { + fd, err := ignoreEINTR2(func() (int, error) { + return unix.Socket(int(family), int(socktype)|unix.SOCK_CLOEXEC|unix.SOCK_NONBLOCK, int(protocol)) + }) + if err != nil { + return nil, err + } + return newHostSocket(fd, family, socktype), nil +} + +func (s *hostSocket) Accept() (Socket, Sockaddr, error) { + fd := s.acquire() + if fd < 0 { + return nil, nil, EBADF + } + defer s.release(fd) + conn, addr, err := ignoreEINTR3(func() (int, unix.Sockaddr, error) { + return unix.Accept4(fd, unix.SOCK_CLOEXEC|unix.SOCK_NONBLOCK) + }) + if err != nil { + return nil, nil, err + } + return newHostSocket(conn, s.family, s.socktype), addr, nil +} diff --git a/internal/network/host_unix.go b/internal/network/host_unix.go new file mode 100644 index 00000000..2f38f1d3 --- /dev/null +++ b/internal/network/host_unix.go @@ -0,0 +1,246 @@ +package network + +import ( + "sync/atomic" + "time" + + "golang.org/x/sys/unix" +) + +type hostSocket struct { + state atomic.Uint64 // upper 32 bits: refCount, lower 32 bits: fd + family Family + socktype Socktype +} + +func newHostSocket(fd int, family Family, socktype Socktype) *hostSocket { + s := &hostSocket{family: family, socktype: socktype} + s.state.Store(uint64(fd)) + return s +} + +func (s *hostSocket) acquire() int { + for { + oldState := s.state.Load() + refCount := (oldState >> 32) + 1 + newState := (refCount << 32) | (oldState & 0xFFFFFFFF) + + if int32(oldState) < 0 { + return -1 + } + if s.state.CompareAndSwap(oldState, newState) { + return int(int32(oldState)) // int32->int for sign extension + } + } +} + +func (s *hostSocket) release(fd int) { + for { + oldState := s.state.Load() + refCount := (oldState >> 32) - 1 + newState := (oldState << 32) | (oldState & 0xFFFFFFFF) + + if s.state.CompareAndSwap(oldState, newState) { + if int32(oldState) < 0 && refCount == 0 { + unix.Close(fd) + } + break + } + } +} + +func (s *hostSocket) Family() Family { + return s.family +} + +func (s *hostSocket) Type() Socktype { + return s.socktype +} + +func (s *hostSocket) Fd() int { + return int(int32(s.state.Load())) +} + +func (s *hostSocket) Close() error { + for { + oldState := s.state.Load() + refCount := oldState >> 32 + newState := oldState | 0xFFFFFFFF + + if s.state.CompareAndSwap(oldState, newState) { + fd := int32(oldState) + if fd < 0 { + return EBADF + } + if refCount == 0 { + return unix.Close(int(fd)) + } + return nil + } + } +} + +func (s *hostSocket) Bind(addr Sockaddr) error { + fd := s.acquire() + if fd < 0 { + return EBADF + } + defer s.release(fd) + return ignoreEINTR(func() error { return unix.Bind(fd, addr) }) +} + +func (s *hostSocket) Listen(backlog int) error { + fd := s.acquire() + if fd < 0 { + return EBADF + } + defer s.release(fd) + return ignoreEINTR(func() error { return unix.Listen(fd, backlog) }) +} + +func (s *hostSocket) Connect(addr Sockaddr) error { + fd := s.acquire() + if fd < 0 { + return EBADF + } + defer s.release(fd) + return ignoreEINTR(func() error { return unix.Connect(fd, addr) }) +} + +func (s *hostSocket) Name() (Sockaddr, error) { + fd := s.acquire() + if fd < 0 { + return nil, EBADF + } + defer s.release(fd) + return ignoreEINTR2(func() (Sockaddr, error) { return unix.Getsockname(fd) }) +} + +func (s *hostSocket) Peer() (Sockaddr, error) { + fd := s.acquire() + if fd < 0 { + return nil, EBADF + } + defer s.release(fd) + return ignoreEINTR2(func() (Sockaddr, error) { return unix.Getpeername(fd) }) +} + +func (s *hostSocket) RecvFrom(iovs [][]byte, oob []byte, flags int) (int, int, int, Sockaddr, error) { + fd := s.acquire() + if fd < 0 { + return -1, 0, 0, nil, EBADF + } + defer s.release(fd) + // TODO: remove the heap allocation that happens for the socket address by + // implementing recvfrom(2) and using a cached socket address for connected + // sockets. + for { + n, oobn, rflags, addr, err := unix.RecvmsgBuffers(fd, iovs, oob, flags) + if err == EINTR { + if n == 0 { + continue + } + err = nil + } + return n, oobn, rflags, addr, err + } +} + +func (s *hostSocket) SendTo(iovs [][]byte, oob []byte, addr Sockaddr, flags int) (int, error) { + fd := s.acquire() + if fd < 0 { + return -1, EBADF + } + defer s.release(fd) + for { + n, err := unix.SendmsgBuffers(fd, iovs, oob, addr, flags) + if err == EINTR { + if n == 0 { + continue + } + err = nil + } + return n, err + } +} + +func (s *hostSocket) Shutdown(how int) error { + fd := s.acquire() + if fd < 0 { + return EBADF + } + defer s.release(fd) + return ignoreEINTR(func() error { + return unix.Shutdown(fd, how) + }) +} + +func (s *hostSocket) SetOption(level, name, value int) error { + fd := s.acquire() + if fd < 0 { + return EBADF + } + defer s.release(fd) + return ignoreEINTR(func() error { return unix.SetsockoptInt(fd, level, name, value) }) +} + +func (s *hostSocket) GetOption(level, name int) (int, error) { + fd := s.acquire() + if fd < 0 { + return -1, EBADF + } + defer s.release(fd) + return ignoreEINTR2(func() (int, error) { return unix.GetsockoptInt(fd, level, name) }) +} + +// This function is used to automtically retry syscalls when they return EINTR +// due to having handled a signal instead of executing. Despite defininig a +// EINTR constant and having proc_raise to trigger signals from the guest, WASI +// does not provide any mechanism for handling signals so masking those errors +// seems like a safer approach to ensure that guest applications will work the +// same regardless of the compiler being used. +func ignoreEINTR(f func() error) error { + for { + if err := f(); err != EINTR { + return err + } + } +} + +func ignoreEINTR2[F func() (R, error), R any](f F) (R, error) { + for { + v, err := f() + if err != EINTR { + return v, err + } + } +} + +func ignoreEINTR3[F func() (R1, R2, error), R1, R2 any](f F) (R1, R2, error) { + for { + v1, v2, err := f() + if err != EINTR { + return v1, v2, err + } + } +} + +func WaitReadyRead(socket Socket, timeout time.Duration) error { + return wait(socket, unix.POLLIN, timeout) +} + +func WaitReadyWrite(socket Socket, timeout time.Duration) error { + return wait(socket, unix.POLLOUT, timeout) +} + +func wait(socket Socket, events int16, timeout time.Duration) error { + tms := int(timeout / time.Millisecond) + pfd := []unix.PollFd{{ + Fd: int32(socket.Fd()), + Events: events, + }} + return ignoreEINTR(func() error { + _, err := unix.Poll(pfd, tms) + return err + }) +} diff --git a/internal/network/network.go b/internal/network/network.go new file mode 100644 index 00000000..4762a3d3 --- /dev/null +++ b/internal/network/network.go @@ -0,0 +1,105 @@ +package network + +import ( + "net" +) + +type Socket interface { + Family() Family + + Type() Socktype + + Fd() int + + Close() error + + Bind(addr Sockaddr) error + + Listen(backlog int) error + + Connect(addr Sockaddr) error + + Accept() (Socket, Sockaddr, error) + + Name() (Sockaddr, error) + + Peer() (Sockaddr, error) + + RecvFrom(iovs [][]byte, oob []byte, flags int) (n, oobn, rflags int, addr Sockaddr, err error) + + SendTo(iovs [][]byte, oob []byte, addr Sockaddr, flags int) (int, error) + + Shutdown(how int) error + + SetOption(level, name, value int) error + + GetOption(level, name int) (int, error) +} + +type Socktype uint8 + +type Family uint8 + +type Protocol uint16 + +const ( + UNSPEC Protocol = 0 + TCP Protocol = 6 + UDP Protocol = 17 +) + +type Namespace interface { + InterfaceByIndex(index int) (Interface, error) + + InterfaceByName(name string) (Interface, error) + + Interfaces() ([]Interface, error) + + Socket(family Family, socktype Socktype, protocol Protocol) (Socket, error) +} + +type Interface interface { + Index() int + + MTU() int + + Name() string + + HardwareAddr() net.HardwareAddr + + Flags() net.Flags + + Addrs() ([]net.Addr, error) + + MulticastAddrs() ([]net.Addr, error) +} + +func SockaddrFamily(sa Sockaddr) Family { + switch sa.(type) { + case *SockaddrInet4: + return INET + case *SockaddrInet6: + return INET6 + default: + return UNIX + } +} + +func isUnspecified(sa Sockaddr) bool { + switch a := sa.(type) { + case *SockaddrInet4: + return isUnspecifiedInet4(a) + case *SockaddrInet6: + return isUnspecifiedInet6(a) + default: + return true + } +} + +func isUnspecifiedInet4(sa *SockaddrInet4) bool { + return sa.Addr == [4]byte{} +} + +func isUnspecifiedInet6(sa *SockaddrInet6) bool { + return sa.Addr == [16]byte{} +} diff --git a/internal/network/network_unix.go b/internal/network/network_unix.go new file mode 100644 index 00000000..94d1c527 --- /dev/null +++ b/internal/network/network_unix.go @@ -0,0 +1,43 @@ +package network + +import "golang.org/x/sys/unix" + +const ( + EADDRNOTAVAIL = unix.EADDRNOTAVAIL + EAFNOSUPPORT = unix.EAFNOSUPPORT + EBADF = unix.EBADF + ECONNREFUSED = unix.ECONNREFUSED + ECONNRESET = unix.ECONNRESET + EHOSTUNREACH = unix.EHOSTUNREACH + EINVAL = unix.EINVAL + EINTR = unix.EINTR + EINPROGRESS = unix.EINPROGRESS + ENETUNREACH = unix.ENETUNREACH + ENOTCONN = unix.ENOTCONN +) + +const ( + UNIX Family = unix.AF_UNIX + INET Family = unix.AF_INET + INET6 Family = unix.AF_INET6 +) + +const ( + STREAM Socktype = unix.SOCK_STREAM + DGRAM Socktype = unix.SOCK_DGRAM +) + +const ( + TRUNC = unix.MSG_TRUNC + PEEK = unix.MSG_PEEK + WAITALL = unix.MSG_WAITALL +) + +const ( + SHUTRD = unix.SHUT_RD + SHUTWR = unix.SHUT_WR +) + +type Sockaddr = unix.Sockaddr +type SockaddrInet4 = unix.SockaddrInet4 +type SockaddrInet6 = unix.SockaddrInet6 diff --git a/internal/network/virtual.go b/internal/network/virtual.go new file mode 100644 index 00000000..fec3504e --- /dev/null +++ b/internal/network/virtual.go @@ -0,0 +1,874 @@ +package network + +import ( + "errors" + "fmt" + "net" + "sync" + "sync/atomic" + + "github.com/stealthrocket/timecraft/internal/ipam" +) + +var ( + errIPv4Exhausted = errors.New("ipv4 pool exhausted") + errIPv6Exhausted = errors.New("ipv6 pool exhausted") +) + +type VirtualNetwork struct { + mutex sync.RWMutex + host Namespace + ipnet4 *net.IPNet + ipnet6 *net.IPNet + ipam4 ipam.IPv4Pool + ipam6 ipam.IPv6Pool + iface4 map[ipam.IPv4]*virtualInterface + iface6 map[ipam.IPv6]*virtualInterface +} + +func NewVirtualNetwork(host Namespace, ipnet4, ipnet6 *net.IPNet) *VirtualNetwork { + n := &VirtualNetwork{ + host: host, + ipnet4: ipnet4, + ipnet6: ipnet6, + iface4: make(map[ipam.IPv4]*virtualInterface), + iface6: make(map[ipam.IPv6]*virtualInterface), + } + if ip := ipnet4.IP.To4(); ip != nil { + ones, _ := ipnet4.Mask.Size() + n.ipam4.Reset((ipam.IPv4)(ip), ones) + } + if ip := ipnet6.IP.To16(); ip != nil { + ones, _ := ipnet6.Mask.Size() + n.ipam6.Reset((ipam.IPv6)(ip), ones) + } + n.ipam4.Get() + n.ipam6.Get() + return n +} + +func (n *VirtualNetwork) CreateNamespace() (*VirtualNamespace, error) { + loIPv4 := ipam.IPv4{127, 0, 0, 1} + loIPv6 := ipam.IPv6{15: 1} + + ns := &VirtualNamespace{ + lo0: virtualInterface{ + index: 0, + name: "lo0", + ipv4: loIPv4, + ipv6: loIPv6, + flags: net.FlagUp | net.FlagLoopback, + }, + en0: virtualInterface{ + index: 1, + name: "en0", + flags: net.FlagUp, + }, + } + + ns.net.Store(n) + + n.mutex.Lock() + defer n.mutex.Unlock() + + var ok bool + ns.en0.ipv4, ok = n.ipam4.Get() + if !ok { + return nil, errIPv4Exhausted + } + ns.en0.ipv6, ok = n.ipam6.Get() + if !ok { + n.ipam4.Put(ns.en0.ipv4) + return nil, errIPv6Exhausted + } + + n.iface4[ns.en0.ipv4] = &ns.en0 + n.iface6[ns.en0.ipv6] = &ns.en0 + return ns, nil +} + +func (n *VirtualNetwork) Contains(ip net.IP) bool { + return n.ipnet4.Contains(ip) || n.ipnet6.Contains(ip) +} + +func (n *VirtualNetwork) containsIPv4(addr ipam.IPv4) bool { + return n.ipnet4.Contains(addr[:]) +} + +func (n *VirtualNetwork) containsIPv6(addr ipam.IPv6) bool { + return n.ipnet6.Contains(addr[:]) +} + +func (n *VirtualNetwork) lookupIPv4Interface(addr ipam.IPv4) *virtualInterface { + n.mutex.RLock() + defer n.mutex.RUnlock() + return n.iface4[addr] +} + +func (n *VirtualNetwork) lookupIPv6Interface(addr ipam.IPv6) *virtualInterface { + n.mutex.RLock() + defer n.mutex.RUnlock() + return n.iface6[addr] +} + +func (n *VirtualNetwork) detachInterface(vi *virtualInterface) { + n.mutex.Lock() + defer n.mutex.Unlock() + if n.iface4[vi.ipv4] == vi { + delete(n.iface4, vi.ipv4) + n.ipam4.Put(vi.ipv4) + } + if n.iface6[vi.ipv6] == vi { + delete(n.iface6, vi.ipv6) + n.ipam6.Put(vi.ipv6) + } +} + +type VirtualNamespace struct { + net atomic.Pointer[VirtualNetwork] + lo0 virtualInterface + en0 virtualInterface +} + +func (ns *VirtualNamespace) Detach() { + if n := ns.net.Swap(nil); n != nil { + n.detachInterface(&ns.en0) + } +} + +func (ns *VirtualNamespace) InterfaceByIndex(index int) (Interface, error) { + switch index { + case ns.lo0.index: + return &ns.lo0, nil + case ns.en0.index: + return &ns.en0, nil + default: + return nil, fmt.Errorf("virtual network interface index out of bounds: %d", index) + } +} + +func (ns *VirtualNamespace) InterfaceByName(name string) (Interface, error) { + switch name { + case ns.lo0.name: + return &ns.lo0, nil + case ns.en0.name: + return &ns.en0, nil + default: + return nil, fmt.Errorf("virtual network interface not found: %s", name) + } +} + +func (ns *VirtualNamespace) Interfaces() ([]Interface, error) { + return []Interface{&ns.lo0, &ns.en0}, nil +} + +func (ns *VirtualNamespace) Socket(family Family, socktype Socktype, protocol Protocol) (Socket, error) { + n := ns.net.Load() + if n == nil { + return nil, EAFNOSUPPORT + } + socket, err := n.host.Socket(family, socktype, protocol) + if err != nil { + return nil, err + } + switch family { + case INET, INET6: + socket = &virtualSocket{ + ns: ns, + base: socket, + proto: protocol, + } + } + return socket, nil +} + +func (ns *VirtualNamespace) bindInet4(socket *virtualSocket, host, addr *SockaddrInet4) error { + switch addr.Addr { + case [4]byte{}: // unspecified + if err := ns.lo0.bindInet4(socket, host, addr); err != nil { + return err + } + if err := ns.en0.bindInet4(socket, host, addr); err != nil { + return err + } + return nil + case ns.lo0.ipv4: + return ns.lo0.bindInet4(socket, host, addr) + case ns.en0.ipv4: + return ns.en0.bindInet4(socket, host, addr) + default: + return EADDRNOTAVAIL + } +} + +func (ns *VirtualNamespace) bindInet6(socket *virtualSocket, host, addr *SockaddrInet6) error { + switch addr.Addr { + case [16]byte{}: // unspecified + if err := ns.lo0.bindInet6(socket, host, addr); err != nil { + return err + } + if err := ns.en0.bindInet6(socket, host, addr); err != nil { + return err + } + return nil + case ns.lo0.ipv6: + return ns.lo0.bindInet6(socket, host, addr) + case ns.en0.ipv6: + return ns.en0.bindInet6(socket, host, addr) + default: + return EADDRNOTAVAIL + } +} + +func (ns *VirtualNamespace) lookupByHostInet4(socket *virtualSocket, host *SockaddrInet4) *virtualSocket { + if peer := ns.lo0.lookupByHostInet4(socket, host); peer != nil { + return peer + } + if peer := ns.en0.lookupByHostInet4(socket, host); peer != nil { + return peer + } + if n := ns.net.Load(); n != nil { + n.mutex.RLock() + defer n.mutex.RUnlock() + + for _, i := range n.iface4 { + if i == &ns.en0 { + continue + } + if peer := i.lookupByHostInet4(socket, host); peer != nil { + return peer + } + } + } + return nil +} + +func (ns *VirtualNamespace) lookupByHostInet6(socket *virtualSocket, host *SockaddrInet6) *virtualSocket { + if peer := ns.lo0.lookupByHostInet6(socket, host); peer != nil { + return peer + } + if peer := ns.en0.lookupByHostInet6(socket, host); peer != nil { + return peer + } + if n := ns.net.Load(); n != nil { + n.mutex.RLock() + defer n.mutex.RUnlock() + + for _, i := range n.iface6 { + if i == &ns.en0 { + continue + } + if peer := i.lookupByHostInet6(socket, host); peer != nil { + return peer + } + } + } + return nil +} + +func (ns *VirtualNamespace) lookupByAddrInet4(socket *virtualSocket, addr *SockaddrInet4) (*virtualSocket, error) { + if isUnspecifiedInet4(addr) { + return ns.lookupByPortInet4(socket, addr) + } + var peer *virtualSocket + switch addr.Addr { + case ns.lo0.ipv4: + peer = ns.lo0.lookupByAddrInet4(socket, addr) + case ns.en0.ipv4: + peer = ns.en0.lookupByAddrInet4(socket, addr) + default: + n := ns.net.Load() + if n == nil { + return nil, ENETUNREACH + } + if !n.containsIPv4(addr.Addr) { + return nil, nil + } + iface := n.lookupIPv4Interface(addr.Addr) + if iface == nil { + return nil, EHOSTUNREACH + } + peer = iface.lookupByAddrInet4(socket, addr) + } + if peer != nil { + return peer, nil + } + return nil, EHOSTUNREACH +} + +func (ns *VirtualNamespace) lookupByAddrInet6(socket *virtualSocket, addr *SockaddrInet6) (*virtualSocket, error) { + if isUnspecifiedInet6(addr) { + return ns.lookupByPortInet6(socket, addr) + } + var peer *virtualSocket + switch addr.Addr { + case ns.lo0.ipv6: + peer = ns.lo0.lookupByAddrInet6(socket, addr) + case ns.en0.ipv6: + peer = ns.en0.lookupByAddrInet6(socket, addr) + default: + n := ns.net.Load() + if n == nil { + return nil, ENETUNREACH + } + if !n.containsIPv6(addr.Addr) { + return nil, nil + } + iface := n.lookupIPv6Interface(addr.Addr) + if iface == nil { + return nil, EHOSTUNREACH + } + peer = iface.lookupByAddrInet6(socket, addr) + } + if peer != nil { + return peer, nil + } + return nil, EHOSTUNREACH +} + +func (ns *VirtualNamespace) lookupByPortInet4(socket *virtualSocket, addr *SockaddrInet4) (*virtualSocket, error) { + if peer := ns.lo0.lookupByAddrInet4(socket, addr); peer != nil { + return peer, nil + } + if peer := ns.en0.lookupByAddrInet4(socket, addr); peer != nil { + return peer, nil + } + if n := ns.net.Load(); n != nil { + n.mutex.RLock() + defer n.mutex.RUnlock() + + for _, iface := range n.iface4 { + if peer := iface.lookupByAddrInet4(socket, addr); peer != nil { + return peer, nil + } + } + } + return nil, EHOSTUNREACH +} + +func (ns *VirtualNamespace) lookupByPortInet6(socket *virtualSocket, addr *SockaddrInet6) (*virtualSocket, error) { + if peer := ns.lo0.lookupByAddrInet6(socket, addr); peer != nil { + return peer, nil + } + if peer := ns.en0.lookupByAddrInet6(socket, addr); peer != nil { + return peer, nil + } + if n := ns.net.Load(); n != nil { + n.mutex.RLock() + defer n.mutex.RUnlock() + + for _, iface := range n.iface6 { + if peer := iface.lookupByAddrInet6(socket, addr); peer != nil { + return peer, nil + } + } + } + return nil, EHOSTUNREACH +} + +func (ns *VirtualNamespace) unlinkInet4(socket *virtualSocket, host, addr *SockaddrInet4) { + switch addr.Addr { + case [4]byte{}: + ns.lo0.unlinkInet4(socket, host, addr) + ns.en0.unlinkInet4(socket, host, addr) + case ns.lo0.ipv4: + ns.lo0.unlinkInet4(socket, host, addr) + case ns.en0.ipv4: + ns.en0.unlinkInet4(socket, host, addr) + } +} + +func (ns *VirtualNamespace) unlinkInet6(socket *virtualSocket, host, addr *SockaddrInet6) { + switch addr.Addr { + case [16]byte{}: + ns.lo0.unlinkInet6(socket, host, addr) + ns.en0.unlinkInet6(socket, host, addr) + case ns.lo0.ipv6: + ns.lo0.unlinkInet6(socket, host, addr) + case ns.en0.ipv6: + ns.en0.unlinkInet6(socket, host, addr) + } +} + +type virtualAddress struct { + proto Protocol + port uint16 +} + +func virtualSockaddrInet4(proto Protocol, addr *SockaddrInet4) virtualAddress { + return virtualAddress{ + proto: proto, + port: uint16(addr.Port), + } +} + +func virtualSockaddrInet6(proto Protocol, addr *SockaddrInet6) virtualAddress { + return virtualAddress{ + proto: proto, + port: uint16(addr.Port), + } +} + +type virtualAddressTable struct { + host map[virtualAddress]*virtualSocket + sock map[virtualAddress]*virtualSocket +} + +func (t *virtualAddressTable) bind(socket *virtualSocket, ha, sa virtualAddress) (int, error) { + if sa.port != 0 { + if _, exist := t.sock[sa]; exist { + // TODO: + // - SO_REUSEADDR + // - SO_REUSEPORT + return -1, EADDRNOTAVAIL + } + } else { + var port int + for port = 49152; port <= 65535; port++ { + sa.port = uint16(port) + if _, exist := t.sock[sa]; !exist { + break + } + } + if port == 65535 { + return -1, EADDRNOTAVAIL + } + } + if t.host == nil { + t.host = make(map[virtualAddress]*virtualSocket) + } + if t.sock == nil { + t.sock = make(map[virtualAddress]*virtualSocket) + } + t.host[ha] = socket + t.sock[sa] = socket + return int(sa.port), nil +} + +func (t *virtualAddressTable) unlink(ha, sa virtualAddress) { + delete(t.host, ha) + delete(t.sock, sa) +} + +type virtualInterface struct { + index int + name string + ipv4 ipam.IPv4 + ipv6 ipam.IPv6 + haddr net.HardwareAddr + flags net.Flags + + mutex sync.Mutex + inet4 virtualAddressTable + inet6 virtualAddressTable +} + +func (i *virtualInterface) Index() int { + return i.index +} + +func (i *virtualInterface) MTU() int { + return 1500 +} + +func (i *virtualInterface) Name() string { + return i.name +} + +func (i *virtualInterface) Addrs() ([]net.Addr, error) { + ipv4 := &net.IPAddr{IP: net.IP(i.ipv4[:])} + ipv6 := &net.IPAddr{IP: net.IP(i.ipv6[:])} + return []net.Addr{ipv4, ipv6}, nil +} + +func (i *virtualInterface) MulticastAddrs() ([]net.Addr, error) { + return nil, nil +} + +func (i *virtualInterface) HardwareAddr() net.HardwareAddr { + return i.haddr +} + +func (i *virtualInterface) Flags() net.Flags { + return i.flags +} + +func (i *virtualInterface) bindInet4(socket *virtualSocket, host, addr *SockaddrInet4) error { + hostAddr := virtualSockaddrInet4(socket.proto, host) + bindAddr := virtualSockaddrInet4(socket.proto, addr) + name := &SockaddrInet4{Addr: addr.Addr} + + i.mutex.Lock() + defer i.mutex.Unlock() + + port, err := i.inet4.bind(socket, hostAddr, bindAddr) + if err != nil { + return err + } + + name.Port = port + socket.host.Store(host) + socket.name.Store(name) + socket.bound = true + return nil +} + +func (i *virtualInterface) bindInet6(socket *virtualSocket, host, addr *SockaddrInet6) error { + hostAddr := virtualSockaddrInet6(socket.proto, host) + bindAddr := virtualSockaddrInet6(socket.proto, addr) + name := &SockaddrInet6{Addr: addr.Addr} + + i.mutex.Lock() + defer i.mutex.Unlock() + + port, err := i.inet6.bind(socket, hostAddr, bindAddr) + if err != nil { + return err + } + + name.Port = port + socket.host.Store(host) + socket.name.Store(name) + socket.bound = true + return nil +} + +func (i *virtualInterface) lookupByHostInet4(socket *virtualSocket, host *SockaddrInet4) *virtualSocket { + va := virtualSockaddrInet4(socket.proto, host) + i.mutex.Lock() + defer i.mutex.Unlock() + return i.inet4.host[va] +} + +func (i *virtualInterface) lookupByHostInet6(socket *virtualSocket, host *SockaddrInet6) *virtualSocket { + va := virtualSockaddrInet6(socket.proto, host) + i.mutex.Lock() + defer i.mutex.Unlock() + return i.inet6.host[va] +} + +func (i *virtualInterface) lookupByAddrInet4(socket *virtualSocket, addr *SockaddrInet4) *virtualSocket { + va := virtualSockaddrInet4(socket.proto, addr) + i.mutex.Lock() + defer i.mutex.Unlock() + return i.inet4.sock[va] +} + +func (i *virtualInterface) lookupByAddrInet6(socket *virtualSocket, addr *SockaddrInet6) *virtualSocket { + va := virtualSockaddrInet6(socket.proto, addr) + i.mutex.Lock() + defer i.mutex.Unlock() + return i.inet6.sock[va] +} + +func (i *virtualInterface) unlinkInet4(socket *virtualSocket, host, addr *SockaddrInet4) { + hostAddr := virtualSockaddrInet4(socket.proto, host) + sockAddr := virtualSockaddrInet4(socket.proto, addr) + + i.mutex.Lock() + defer i.mutex.Unlock() + + i.inet4.unlink(hostAddr, sockAddr) + socket.host.Store(&sockaddrInet4Any) + socket.name.Store(&sockaddrInet4Any) + socket.bound = false +} + +func (i *virtualInterface) unlinkInet6(socket *virtualSocket, host, addr *SockaddrInet6) { + hostAddr := virtualSockaddrInet6(socket.proto, host) + sockAddr := virtualSockaddrInet6(socket.proto, addr) + + i.mutex.Lock() + defer i.mutex.Unlock() + + i.inet6.unlink(hostAddr, sockAddr) + socket.host.Store(&sockaddrInet6Any) + socket.name.Store(&sockaddrInet6Any) + socket.bound = false +} + +var ( + sockaddrInet4Any SockaddrInet4 + sockaddrInet6Any SockaddrInet6 +) + +type virtualSocket struct { + ns *VirtualNamespace + base Socket + host atomic.Value + name atomic.Value + peer atomic.Value + proto Protocol + bound bool +} + +func (s *virtualSocket) Family() Family { + return s.base.Family() +} + +func (s *virtualSocket) Type() Socktype { + return s.base.Type() +} + +func (s *virtualSocket) Fd() int { + return s.base.Fd() +} + +func (s *virtualSocket) Close() error { + if err := s.base.Close(); err != nil { + return err + } + if s.bound { + host := s.host.Load() + name := s.name.Load() + switch a := name.(type) { + case *SockaddrInet4: + s.ns.unlinkInet4(s, host.(*SockaddrInet4), a) + case *SockaddrInet6: + s.ns.unlinkInet6(s, host.(*SockaddrInet6), a) + } + } + return nil +} + +func (s *virtualSocket) Bind(addr Sockaddr) error { + if s.name.Load() != nil { + return EINVAL + } + + switch addr.(type) { + case *SockaddrInet4: + if s.Family() != INET { + return EAFNOSUPPORT + } + _ = s.base.Bind(&sockaddrInet4Any) + case *SockaddrInet6: + if s.Family() != INET6 { + return EAFNOSUPPORT + } + _ = s.base.Bind(&sockaddrInet6Any) + default: + return EINVAL + } + + // The host socket was bound to a random port, we retrieve the address that + // it got associated with. + a, err := s.base.Name() + if err != nil { + return err + } + + switch host := a.(type) { + case *SockaddrInet4: + return s.ns.bindInet4(s, host, addr.(*SockaddrInet4)) + case *SockaddrInet6: + return s.ns.bindInet6(s, host, addr.(*SockaddrInet6)) + default: + return EAFNOSUPPORT + } +} + +func (s *virtualSocket) bindAny() error { + a, err := s.base.Name() + if err != nil { + return err + } + switch host := a.(type) { + case *SockaddrInet4: + return s.ns.bindInet4(s, host, &sockaddrInet4Any) + case *SockaddrInet6: + return s.ns.bindInet6(s, host, &sockaddrInet6Any) + default: + return EAFNOSUPPORT + } +} + +func (s *virtualSocket) Listen(backlog int) error { + if err := s.base.Listen(backlog); err != nil { + return err + } + if s.name.Load() == nil { + return s.bindAny() + } + return nil +} + +func (s *virtualSocket) Connect(addr Sockaddr) error { + var peer *virtualSocket + var err error + switch a := addr.(type) { + case *SockaddrInet4: + peer, err = s.ns.lookupByAddrInet4(s, a) + case *SockaddrInet6: + peer, err = s.ns.lookupByAddrInet6(s, a) + default: + return EAFNOSUPPORT + } + if err != nil { + return ECONNREFUSED + } + + connectAddr := addr + if peer != nil { + switch a := peer.host.Load().(type) { + case *SockaddrInet4: + connectAddr = a + case *SockaddrInet6: + connectAddr = a + default: + return ECONNREFUSED + } + } + + err = s.base.Connect(connectAddr) + if err != nil && err != EINPROGRESS { + return err + } + + s.peer.Store(addr) + + if s.name.Load() == nil { + if err := s.bindAny(); err != nil { + return err + } + } + return err +} + +func (s *virtualSocket) Accept() (Socket, Sockaddr, error) { + base, addr, err := s.base.Accept() + if err != nil { + return nil, nil, err + } + conn := &virtualSocket{ + ns: s.ns, + base: base, + } + var peer *virtualSocket + switch a := addr.(type) { + case *SockaddrInet4: + peer = s.ns.lookupByHostInet4(s, a) + case *SockaddrInet6: + peer = s.ns.lookupByHostInet6(s, a) + } + conn.host.Store(addr) + if peer != nil { + addr, _ = peer.name.Load().(Sockaddr) + } + conn.name.Store(s.name.Load()) + conn.peer.Store(addr) + return conn, addr, nil +} + +func (s *virtualSocket) Name() (Sockaddr, error) { + switch name := s.name.Load().(type) { + case *SockaddrInet4: + return name, nil + case *SockaddrInet6: + return name, nil + } + switch s.Family() { + case INET: + return &sockaddrInet4Any, nil + default: + return &sockaddrInet6Any, nil + } +} + +func (s *virtualSocket) Peer() (Sockaddr, error) { + switch peer := s.peer.Load().(type) { + case *SockaddrInet4: + return peer, nil + case *SockaddrInet6: + return peer, nil + } + return nil, ENOTCONN +} + +func (s *virtualSocket) RecvFrom(iovs [][]byte, oob []byte, flags int) (int, int, int, Sockaddr, error) { + n, oobn, flags, addr, err := s.base.RecvFrom(iovs, oob, flags) + if err != nil { + return -1, -1, 0, nil, err + } + switch conn := s.peer.Load().(type) { + case *SockaddrInet4: + addr = conn + case *SockaddrInet6: + addr = conn + default: + addr, err = s.toVirtualAddr(addr) + } + return n, oobn, flags, addr, err +} + +func (s *virtualSocket) SendTo(iovs [][]byte, oob []byte, addr Sockaddr, flags int) (int, error) { + if addr != nil { + a, err := s.toHostAddr(addr) + if err != nil { + return -1, err + } + addr = a + } + n, err := s.base.SendTo(iovs, oob, addr, flags) + if s.name.Load() == nil { + if err := s.bindAny(); err != nil { + return n, err + } + } + return n, err +} + +func (s *virtualSocket) Shutdown(how int) error { + return s.base.Shutdown(how) +} + +func (s *virtualSocket) SetOption(level, name, value int) error { + return s.base.SetOption(level, name, value) +} + +func (s *virtualSocket) GetOption(level, name int) (int, error) { + return s.base.GetOption(level, name) +} + +func (s *virtualSocket) toHostAddr(addr Sockaddr) (Sockaddr, error) { + var peer *virtualSocket + var err error + switch a := addr.(type) { + case *SockaddrInet4: + peer, err = s.ns.lookupByAddrInet4(s, a) + case *SockaddrInet6: + peer, err = s.ns.lookupByAddrInet6(s, a) + default: + return nil, EAFNOSUPPORT + } + if peer == nil { + return addr, err + } + peerAddr, err := peer.Name() + switch err { + case nil: + return peerAddr, nil + case EBADF: + return nil, ECONNRESET + default: + return nil, err + } +} + +func (s *virtualSocket) toVirtualAddr(addr Sockaddr) (Sockaddr, error) { + var peer *virtualSocket + switch a := addr.(type) { + case *SockaddrInet4: + peer = s.ns.lookupByHostInet4(s, a) + case *SockaddrInet6: + peer = s.ns.lookupByHostInet6(s, a) + default: + return nil, EAFNOSUPPORT + } + if peer != nil { + return peer.name.Load().(Sockaddr), nil + } + // TODO: this races if the virtual peer was closed after sending a + // datagram but before we looked it up on the interfaces. + return addr, nil +} diff --git a/internal/network/virtual_test.go b/internal/network/virtual_test.go new file mode 100644 index 00000000..efe95c38 --- /dev/null +++ b/internal/network/virtual_test.go @@ -0,0 +1,238 @@ +package network_test + +import ( + "net" + "testing" + "time" + + "github.com/stealthrocket/timecraft/internal/assert" + "github.com/stealthrocket/timecraft/internal/network" +) + +func TestVirtualNetwork(t *testing.T) { + tests := []struct { + scenario string + function func(*testing.T, *network.VirtualNetwork) + }{ + { + scenario: "a virtual network namespace has two interfaces", + function: testVirtualNetworkInterfaces, + }, + + { + scenario: "ipv4 sockets can connect to one another on a loopback interface", + function: testVirtualNetworkConnectLoopbackIPv4, + }, + + { + scenario: "ipv6 sockets can connect to one another on a loopback interface", + function: testVirtualNetworkConnectLoopbackIPv6, + }, + + { + scenario: "ipv4 sockets can connect to one another on a network interface", + function: testVirtualNetworkConnectInterfaceIPv4, + }, + + { + scenario: "ipv6 sockets can connect to one another on a network interface", + function: testVirtualNetworkConnectInterfaceIPv6, + }, + + { + scenario: "ipv4 sockets in different namespaces can connet to one another", + function: testVirtualNetworkConnectNamespacesIPv4, + }, + + { + scenario: "ipv6 sockets in different namespaces can connet to one another", + function: testVirtualNetworkConnectNamespacesIPv6, + }, + } + + _, ipnet4, _ := net.ParseCIDR("192.168.0.0/24") + _, ipnet6, _ := net.ParseCIDR("fe80::/64") + + for _, test := range tests { + t.Run(test.scenario, func(t *testing.T) { + test.function(t, + network.NewVirtualNetwork(network.Host(), ipnet4, ipnet6), + ) + }) + } +} + +func testVirtualNetworkInterfaces(t *testing.T, n *network.VirtualNetwork) { + ns, err := n.CreateNamespace() + assert.OK(t, err) + + ifaces, err := ns.Interfaces() + assert.OK(t, err) + assert.Equal(t, len(ifaces), 2) + + lo0 := ifaces[0] + assert.Equal(t, lo0.Index(), 0) + assert.Equal(t, lo0.MTU(), 1500) + assert.Equal(t, lo0.Name(), "lo0") + assert.Equal(t, lo0.Flags(), net.FlagUp|net.FlagLoopback) + + lo0Addrs, err := lo0.Addrs() + assert.OK(t, err) + assert.Equal(t, len(lo0Addrs), 2) + assert.Equal(t, lo0Addrs[0].String(), "127.0.0.1") + assert.Equal(t, lo0Addrs[1].String(), "::1") + + en0 := ifaces[1] + assert.Equal(t, en0.Index(), 1) + assert.Equal(t, en0.MTU(), 1500) + assert.Equal(t, en0.Name(), "en0") + assert.Equal(t, en0.Flags(), net.FlagUp) + + en0Addrs, err := en0.Addrs() + assert.OK(t, err) + assert.Equal(t, len(en0Addrs), 2) + assert.Equal(t, en0Addrs[0].String(), "192.168.0.1") + assert.Equal(t, en0Addrs[1].String(), "fe80::1") +} + +func testVirtualNetworkConnectLoopbackIPv4(t *testing.T, n *network.VirtualNetwork) { + testVirtualNetworkConnect(t, n, &network.SockaddrInet4{ + Addr: [4]byte{127, 0, 0, 1}, + Port: 80, + }) +} + +func testVirtualNetworkConnectLoopbackIPv6(t *testing.T, n *network.VirtualNetwork) { + testVirtualNetworkConnect(t, n, &network.SockaddrInet6{ + Addr: [16]byte{15: 1}, + Port: 80, + }) +} + +func testVirtualNetworkConnectInterfaceIPv4(t *testing.T, n *network.VirtualNetwork) { + testVirtualNetworkConnect(t, n, &network.SockaddrInet4{ + Addr: [4]byte{192, 168, 0, 1}, + Port: 80, + }) +} + +func testVirtualNetworkConnectInterfaceIPv6(t *testing.T, n *network.VirtualNetwork) { + testVirtualNetworkConnect(t, n, &network.SockaddrInet6{ + Addr: [16]byte{0: 0xfe, 1: 0x80, 15: 1}, + Port: 80, + }) +} + +func testVirtualNetworkConnect(t *testing.T, n *network.VirtualNetwork, bind network.Sockaddr) { + family := network.SockaddrFamily(bind) + + ns, err := n.CreateNamespace() + assert.OK(t, err) + + server, err := ns.Socket(family, network.STREAM, network.TCP) + assert.OK(t, err) + defer server.Close() + + assert.OK(t, server.Bind(bind)) + assert.OK(t, server.Listen(1)) + serverAddr, err := server.Name() + assert.OK(t, err) + + client, err := ns.Socket(family, network.STREAM, network.TCP) + assert.OK(t, err) + defer client.Close() + + assert.Error(t, client.Connect(serverAddr), network.EINPROGRESS) + assert.OK(t, waitReadyRead(server)) + conn, addr, err := server.Accept() + assert.OK(t, err) + defer conn.Close() + + assert.OK(t, waitReadyWrite(client)) + peer, err := client.Peer() + assert.OK(t, err) + assert.DeepEqual(t, peer, serverAddr) + + name, err := client.Name() + assert.OK(t, err) + assert.DeepEqual(t, name, addr) + + wn, err := client.SendTo([][]byte{[]byte("Hello, World!")}, nil, nil, 0) + assert.OK(t, err) + assert.Equal(t, wn, 13) + + assert.OK(t, waitReadyRead(conn)) + buf := make([]byte, 32) + rn, oobn, rflags, peer, err := conn.RecvFrom([][]byte{buf}, nil, 0) + assert.OK(t, err) + assert.Equal(t, rn, 13) + assert.Equal(t, oobn, 0) + assert.Equal(t, rflags, 0) + assert.Equal(t, string(buf[:13]), "Hello, World!") + assert.Equal(t, peer, addr) +} + +func testVirtualNetworkConnectNamespacesIPv4(t *testing.T, n *network.VirtualNetwork) { + testVirtualNetworkConnectNamespaces(t, n, network.INET) +} + +func testVirtualNetworkConnectNamespacesIPv6(t *testing.T, n *network.VirtualNetwork) { + testVirtualNetworkConnectNamespaces(t, n, network.INET6) +} + +func testVirtualNetworkConnectNamespaces(t *testing.T, n *network.VirtualNetwork, family network.Family) { + ns1, err := n.CreateNamespace() + assert.OK(t, err) + + ns2, err := n.CreateNamespace() + assert.OK(t, err) + + server, err := ns1.Socket(family, network.STREAM, network.TCP) + assert.OK(t, err) + defer server.Close() + + assert.OK(t, server.Listen(1)) + serverAddr, err := server.Name() + assert.OK(t, err) + + client, err := ns2.Socket(family, network.STREAM, network.TCP) + assert.OK(t, err) + defer client.Close() + + assert.Error(t, client.Connect(serverAddr), network.EINPROGRESS) + assert.OK(t, waitReadyRead(server)) + conn, addr, err := server.Accept() + assert.OK(t, err) + defer conn.Close() + + assert.OK(t, waitReadyWrite(client)) + peer, err := client.Peer() + assert.OK(t, err) + assert.DeepEqual(t, peer, serverAddr) + + name, err := client.Name() + assert.OK(t, err) + assert.DeepEqual(t, name, addr) + + wn, err := client.SendTo([][]byte{[]byte("Hello, World!")}, nil, nil, 0) + assert.OK(t, err) + assert.Equal(t, wn, 13) + + assert.OK(t, waitReadyRead(conn)) + buf := make([]byte, 32) + rn, oobn, rflags, peer, err := conn.RecvFrom([][]byte{buf}, nil, 0) + assert.OK(t, err) + assert.Equal(t, rn, 13) + assert.Equal(t, oobn, 0) + assert.Equal(t, rflags, 0) + assert.Equal(t, string(buf[:13]), "Hello, World!") + assert.Equal(t, peer, addr) +} + +func waitReadyRead(socket network.Socket) error { + return network.WaitReadyRead(socket, time.Second) +} + +func waitReadyWrite(socket network.Socket) error { + return network.WaitReadyWrite(socket, time.Second) +} From a9f34672eb08386b73ac237f7027968310861f3b Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Thu, 13 Jul 2023 18:20:58 -0700 Subject: [PATCH 02/16] network: add local and virtual networks Signed-off-by: Achille Roussel --- internal/network/host_darwin.go | 18 +- internal/network/host_unix.go | 110 ++---- internal/network/local.go | 280 ++++++++++++++++ internal/network/local_darwin.go | 53 +++ internal/network/local_linux.go | 47 +++ internal/network/local_test.go | 67 ++++ internal/network/local_unix.go | 515 +++++++++++++++++++++++++++++ internal/network/network.go | 49 ++- internal/network/network_darwin.go | 17 + internal/network/network_test.go | 125 +++++++ internal/network/network_unix.go | 70 +++- internal/network/virtual.go | 67 ++-- internal/network/virtual_test.go | 57 +--- 13 files changed, 1280 insertions(+), 195 deletions(-) create mode 100644 internal/network/local.go create mode 100644 internal/network/local_darwin.go create mode 100644 internal/network/local_linux.go create mode 100644 internal/network/local_test.go create mode 100644 internal/network/local_unix.go create mode 100644 internal/network/network_darwin.go create mode 100644 internal/network/network_test.go diff --git a/internal/network/host_darwin.go b/internal/network/host_darwin.go index 4a6acfbd..590ac2cd 100644 --- a/internal/network/host_darwin.go +++ b/internal/network/host_darwin.go @@ -23,11 +23,11 @@ func (hostNamespace) Socket(family Family, socktype Socktype, protocol Protocol) } func (s *hostSocket) Accept() (Socket, Sockaddr, error) { - fd := s.acquire() + fd := s.fd.acquire() if fd < 0 { return nil, nil, EBADF } - defer s.release(fd) + defer s.fd.release(fd) syscall.ForkLock.RLock() defer syscall.ForkLock.RUnlock() conn, addr, err := ignoreEINTR3(func() (int, Sockaddr, error) { @@ -42,17 +42,3 @@ func (s *hostSocket) Accept() (Socket, Sockaddr, error) { } return newHostSocket(conn, s.family, s.socktype), addr, nil } - -func setCloseOnExecAndNonBlocking(fd int) error { - if _, err := ignoreEINTR2(func() (int, error) { - return unix.FcntlInt(uintptr(fd), unix.F_SETFD, unix.O_CLOEXEC) - }); err != nil { - return err - } - if err := ignoreEINTR(func() error { - return unix.SetNonblock(fd, true) - }); err != nil { - return err - } - return nil -} diff --git a/internal/network/host_unix.go b/internal/network/host_unix.go index 2f38f1d3..8e136596 100644 --- a/internal/network/host_unix.go +++ b/internal/network/host_unix.go @@ -1,54 +1,23 @@ package network import ( - "sync/atomic" "time" "golang.org/x/sys/unix" ) type hostSocket struct { - state atomic.Uint64 // upper 32 bits: refCount, lower 32 bits: fd + fd socketFD family Family socktype Socktype } func newHostSocket(fd int, family Family, socktype Socktype) *hostSocket { s := &hostSocket{family: family, socktype: socktype} - s.state.Store(uint64(fd)) + s.fd.init(fd) return s } -func (s *hostSocket) acquire() int { - for { - oldState := s.state.Load() - refCount := (oldState >> 32) + 1 - newState := (refCount << 32) | (oldState & 0xFFFFFFFF) - - if int32(oldState) < 0 { - return -1 - } - if s.state.CompareAndSwap(oldState, newState) { - return int(int32(oldState)) // int32->int for sign extension - } - } -} - -func (s *hostSocket) release(fd int) { - for { - oldState := s.state.Load() - refCount := (oldState >> 32) - 1 - newState := (oldState << 32) | (oldState & 0xFFFFFFFF) - - if s.state.CompareAndSwap(oldState, newState) { - if int32(oldState) < 0 && refCount == 0 { - unix.Close(fd) - } - break - } - } -} - func (s *hostSocket) Family() Family { return s.family } @@ -58,102 +27,87 @@ func (s *hostSocket) Type() Socktype { } func (s *hostSocket) Fd() int { - return int(int32(s.state.Load())) + return s.fd.load() } func (s *hostSocket) Close() error { - for { - oldState := s.state.Load() - refCount := oldState >> 32 - newState := oldState | 0xFFFFFFFF - - if s.state.CompareAndSwap(oldState, newState) { - fd := int32(oldState) - if fd < 0 { - return EBADF - } - if refCount == 0 { - return unix.Close(int(fd)) - } - return nil - } - } + return s.fd.close() } func (s *hostSocket) Bind(addr Sockaddr) error { - fd := s.acquire() + fd := s.fd.acquire() if fd < 0 { return EBADF } - defer s.release(fd) + defer s.fd.release(fd) return ignoreEINTR(func() error { return unix.Bind(fd, addr) }) } func (s *hostSocket) Listen(backlog int) error { - fd := s.acquire() + fd := s.fd.acquire() if fd < 0 { return EBADF } - defer s.release(fd) + defer s.fd.release(fd) return ignoreEINTR(func() error { return unix.Listen(fd, backlog) }) } func (s *hostSocket) Connect(addr Sockaddr) error { - fd := s.acquire() + fd := s.fd.acquire() if fd < 0 { return EBADF } - defer s.release(fd) + defer s.fd.release(fd) return ignoreEINTR(func() error { return unix.Connect(fd, addr) }) } func (s *hostSocket) Name() (Sockaddr, error) { - fd := s.acquire() + fd := s.fd.acquire() if fd < 0 { return nil, EBADF } - defer s.release(fd) + defer s.fd.release(fd) return ignoreEINTR2(func() (Sockaddr, error) { return unix.Getsockname(fd) }) } func (s *hostSocket) Peer() (Sockaddr, error) { - fd := s.acquire() + fd := s.fd.acquire() if fd < 0 { return nil, EBADF } - defer s.release(fd) + defer s.fd.release(fd) return ignoreEINTR2(func() (Sockaddr, error) { return unix.Getpeername(fd) }) } -func (s *hostSocket) RecvFrom(iovs [][]byte, oob []byte, flags int) (int, int, int, Sockaddr, error) { - fd := s.acquire() +func (s *hostSocket) RecvFrom(iovs [][]byte, flags int) (int, int, Sockaddr, error) { + fd := s.fd.acquire() if fd < 0 { - return -1, 0, 0, nil, EBADF + return -1, 0, nil, EBADF } - defer s.release(fd) + defer s.fd.release(fd) // TODO: remove the heap allocation that happens for the socket address by // implementing recvfrom(2) and using a cached socket address for connected // sockets. for { - n, oobn, rflags, addr, err := unix.RecvmsgBuffers(fd, iovs, oob, flags) + n, _, rflags, addr, err := unix.RecvmsgBuffers(fd, iovs, nil, flags) if err == EINTR { if n == 0 { continue } err = nil } - return n, oobn, rflags, addr, err + return n, rflags, addr, err } } -func (s *hostSocket) SendTo(iovs [][]byte, oob []byte, addr Sockaddr, flags int) (int, error) { - fd := s.acquire() +func (s *hostSocket) SendTo(iovs [][]byte, addr Sockaddr, flags int) (int, error) { + fd := s.fd.acquire() if fd < 0 { return -1, EBADF } - defer s.release(fd) + defer s.fd.release(fd) for { - n, err := unix.SendmsgBuffers(fd, iovs, oob, addr, flags) + n, err := unix.SendmsgBuffers(fd, iovs, nil, addr, flags) if err == EINTR { if n == 0 { continue @@ -165,31 +119,31 @@ func (s *hostSocket) SendTo(iovs [][]byte, oob []byte, addr Sockaddr, flags int) } func (s *hostSocket) Shutdown(how int) error { - fd := s.acquire() + fd := s.fd.acquire() if fd < 0 { return EBADF } - defer s.release(fd) + defer s.fd.release(fd) return ignoreEINTR(func() error { return unix.Shutdown(fd, how) }) } -func (s *hostSocket) SetOption(level, name, value int) error { - fd := s.acquire() +func (s *hostSocket) SetOptInt(level, name, value int) error { + fd := s.fd.acquire() if fd < 0 { return EBADF } - defer s.release(fd) + defer s.fd.release(fd) return ignoreEINTR(func() error { return unix.SetsockoptInt(fd, level, name, value) }) } -func (s *hostSocket) GetOption(level, name int) (int, error) { - fd := s.acquire() +func (s *hostSocket) GetOptInt(level, name int) (int, error) { + fd := s.fd.acquire() if fd < 0 { return -1, EBADF } - defer s.release(fd) + defer s.fd.release(fd) return ignoreEINTR2(func() (int, error) { return unix.GetsockoptInt(fd, level, name) }) } diff --git a/internal/network/local.go b/internal/network/local.go new file mode 100644 index 00000000..fd8472cd --- /dev/null +++ b/internal/network/local.go @@ -0,0 +1,280 @@ +package network + +import ( + "context" + "net" + "sync" +) + +var ( + localAddrs = [2]net.Addr{ + &net.IPAddr{IP: net.IP{127, 0, 0, 1}}, + &net.IPAddr{IP: net.IP{15: 1}}, + } +) + +type LocalOption func(*localNamespace) + +func DialFunc(dial func(context.Context, string, string) (net.Conn, error)) LocalOption { + return func(ns *localNamespace) { ns.dial = dial } +} + +func ListenFunc(listen func(context.Context, string, string) (net.Listener, error)) LocalOption { + return func(ns *localNamespace) { ns.listen = listen } +} + +func ListenPacketFunc(listenPacket func(context.Context, string, string) (net.PacketConn, error)) LocalOption { + return func(ns *localNamespace) { ns.listenPacket = listenPacket } +} + +func NewLocalNamespace(host Namespace, opts ...LocalOption) Namespace { + ns := &localNamespace{host: host} + for _, opt := range opts { + opt(ns) + } + return ns +} + +type localNamespace struct { + host Namespace + dial func(context.Context, string, string) (net.Conn, error) + listen func(context.Context, string, string) (net.Listener, error) + listenPacket func(context.Context, string, string) (net.PacketConn, error) + lo0 localInterface +} + +func (ns *localNamespace) InterfaceByIndex(index int) (Interface, error) { + if index != 0 { + return nil, errInterfaceIndexNotFound(index) + } + return &ns.lo0, nil +} + +func (ns *localNamespace) InterfaceByName(name string) (Interface, error) { + if name != "lo0" { + return nil, errInterfaceNameNotFound(name) + } + return &ns.lo0, nil +} + +func (ns *localNamespace) Interfaces() ([]Interface, error) { + return []Interface{&ns.lo0}, nil +} + +func (ns *localNamespace) bindInet4(sock *localSocket, addr *SockaddrInet4) error { + switch { + case isUnspecifiedInet4(addr), isLoopbackInet4(addr): + return ns.lo0.bindInet4(sock, addr) + default: + return EADDRNOTAVAIL + } +} + +func (ns *localNamespace) bindInet6(sock *localSocket, addr *SockaddrInet6) error { + switch { + case isUnspecifiedInet6(addr), isLoopbackInet6(addr): + return ns.lo0.bindInet6(sock, addr) + default: + return EADDRNOTAVAIL + } +} + +func (ns *localNamespace) lookupInet4(sock *localSocket, addr *SockaddrInet4) (*localSocket, error) { + switch { + case isUnspecifiedInet4(addr), isLoopbackInet4(addr): + return ns.lo0.lookupInet4(sock, addr) + default: + return nil, ENETUNREACH + } +} + +func (ns *localNamespace) lookupInet6(sock *localSocket, addr *SockaddrInet6) (*localSocket, error) { + switch { + case isUnspecifiedInet6(addr), isLoopbackInet6(addr): + return ns.lo0.lookupInet6(sock, addr) + default: + return nil, ENETUNREACH + } +} + +func (ns *localNamespace) unlinkInet4(sock *localSocket, addr *SockaddrInet4) { + ns.lo0.unlinkInet4(sock, addr) +} + +func (ns *localNamespace) unlinkInet6(sock *localSocket, addr *SockaddrInet6) { + ns.lo0.unlinkInet6(sock, addr) +} + +type localAddress struct { + proto Protocol + port uint16 +} + +func localSockaddrInet4(proto Protocol, addr *SockaddrInet4) localAddress { + return localAddress{ + proto: proto, + port: uint16(addr.Port), + } +} + +func localSockaddrInet6(proto Protocol, addr *SockaddrInet6) localAddress { + return localAddress{ + proto: proto, + port: uint16(addr.Port), + } +} + +type localSocketTable struct { + sockets map[localAddress]*localSocket +} + +func (t *localSocketTable) bind(sock *localSocket, addr localAddress) (int, error) { + if addr.port != 0 { + if _, exist := t.sockets[addr]; exist { + // TODO: + // - SO_REUSEADDR + // - SO_REUSEPORT + return -1, EADDRNOTAVAIL + } + } else { + var port int + for port = 49152; port <= 65535; port++ { + addr.port = uint16(port) + if _, exist := t.sockets[addr]; !exist { + break + } + } + if port == 65535 { + return -1, EADDRNOTAVAIL + } + } + if t.sockets == nil { + t.sockets = make(map[localAddress]*localSocket) + } + t.sockets[addr] = sock + return int(addr.port), nil +} + +func (t *localSocketTable) unlink(sock *localSocket, addr localAddress) { + if t.sockets[addr] == sock { + delete(t.sockets, addr) + } +} + +type localInterface struct { + mutex sync.RWMutex + ipv4 localSocketTable + ipv6 localSocketTable +} + +func (i *localInterface) Index() int { + return 0 +} + +func (i *localInterface) MTU() int { + return 1500 +} + +func (i *localInterface) Name() string { + return "lo0" +} + +func (i *localInterface) HardwareAddr() net.HardwareAddr { + return net.HardwareAddr{} +} + +func (i *localInterface) Flags() net.Flags { + return net.FlagUp | net.FlagLoopback +} + +func (i *localInterface) Addrs() ([]net.Addr, error) { + return localAddrs[:], nil +} + +func (i *localInterface) MulticastAddrs() ([]net.Addr, error) { + return nil, nil +} + +func (i *localInterface) bindInet4(sock *localSocket, addr *SockaddrInet4) error { + link := localSockaddrInet4(sock.protocol, addr) + name := &SockaddrInet4{Addr: addr.Addr} + + i.mutex.Lock() + defer i.mutex.Unlock() + + port, err := i.ipv4.bind(sock, link) + if err != nil { + return err + } + + name.Port = port + sock.name.Store(name) + sock.state.set(bound) + return nil +} + +func (i *localInterface) bindInet6(sock *localSocket, addr *SockaddrInet6) error { + link := localSockaddrInet6(sock.protocol, addr) + name := &SockaddrInet6{Addr: addr.Addr} + + i.mutex.Lock() + defer i.mutex.Unlock() + + port, err := i.ipv6.bind(sock, link) + if err != nil { + return err + } + + name.Port = port + sock.name.Store(name) + sock.state.set(bound) + return nil +} + +func (i *localInterface) lookupInet4(sock *localSocket, addr *SockaddrInet4) (*localSocket, error) { + link := localSockaddrInet4(sock.protocol, addr) + + i.mutex.RLock() + defer i.mutex.RUnlock() + + if peer := i.ipv4.sockets[link]; peer != nil { + return peer, nil + } + return nil, EHOSTUNREACH +} + +func (i *localInterface) lookupInet6(sock *localSocket, addr *SockaddrInet6) (*localSocket, error) { + link := localSockaddrInet6(sock.protocol, addr) + + i.mutex.RLock() + defer i.mutex.RUnlock() + + if peer := i.ipv6.sockets[link]; peer != nil { + return peer, nil + } + return nil, EHOSTUNREACH +} + +func (i *localInterface) unlinkInet4(sock *localSocket, addr *SockaddrInet4) { + link := localSockaddrInet4(sock.protocol, addr) + + i.mutex.Lock() + defer i.mutex.Unlock() + + i.ipv4.unlink(sock, link) + + sock.name.Store(&sockaddrInet4Any) + sock.state.unset(bound) +} + +func (i *localInterface) unlinkInet6(sock *localSocket, addr *SockaddrInet6) { + link := localSockaddrInet6(sock.protocol, addr) + + i.mutex.Lock() + defer i.mutex.Unlock() + + i.ipv6.unlink(sock, link) + + sock.name.Store(&sockaddrInet6Any) + sock.state.unset(bound) +} diff --git a/internal/network/local_darwin.go b/internal/network/local_darwin.go new file mode 100644 index 00000000..de8108df --- /dev/null +++ b/internal/network/local_darwin.go @@ -0,0 +1,53 @@ +package network + +import ( + "syscall" + + "golang.org/x/sys/unix" +) + +func socketpair(family, socktype, protocol int) ([2]int, error) { + syscall.ForkLock.RLock() + defer syscall.ForkLock.RUnlock() + + fds, err := ignoreEINTR2(func() ([2]int, error) { + return unix.Socketpair(family, socktype, protocol) + }) + if err != nil { + return fds, err + } + if err := setCloseOnExecAndNonBlocking(fds[0]); err != nil { + closePair(&fds) + return fds, err + } + if err := setCloseOnExecAndNonBlocking(fds[1]); err != nil { + closePair(&fds) + return fds, err + } + return fds, nil +} + +func closePair(fds *[2]int) { + unix.Close(fds[0]) + unix.Close(fds[1]) + fds[0] = -1 + fds[1] = -1 +} + +func (s *localSocket) SetOptInt(level, name, value int) error { + fd := s.fd0.acquire() + if fd < 0 { + return EBADF + } + defer s.fd0.release(fd) + return ignoreEINTR(func() error { return unix.SetsockoptInt(fd, level, name, value) }) +} + +func (s *localSocket) GetOptInt(level, name int) (int, error) { + fd := s.fd0.acquire() + if fd < 0 { + return 0, EBADF + } + defer s.fd0.release(fd) + return ignoreEINTR2(func() (int, error) { return unix.GetsockoptInt(fd, level, name) }) +} diff --git a/internal/network/local_linux.go b/internal/network/local_linux.go new file mode 100644 index 00000000..214823f5 --- /dev/null +++ b/internal/network/local_linux.go @@ -0,0 +1,47 @@ +package network + +import "golang.org/x/sys/unix" + +func socketpair(family, socktype, protocol int) ([2]int, error) { + return ignoreEINTR2(func() ([2]int, error) { + return unix.Socketpair(family, socktype|unix.SOCK_CLOEXEC|unix.SOCK_NONBLOCK, protocol) + }) +} + +func (s *localSocket) SetOptInt(level, name, value int) error { + fd := s.fd0.acquire() + if fd < 0 { + return EBADF + } + defer s.fd0.release(fd) + + switch level { + case unix.SOL_SOCKET: + switch name { + case unix.SO_BINDTODEVICE: + return 0, ENOPROTOOPT + } + } + + return ignoreEINTR(func() error { return unix.SetsockoptInt(fd, level, name, value) }) +} + +func (s *localSocket) GetOptInt(level, name int) (int, error) { + fd := s.fd0.acquire() + if fd < 0 { + return 0, EBADF + } + defer s.fd0.release(fd) + + switch level { + case unix.SOL_SOCKET: + switch name { + case unix.SO_DOMAIN: + return int(s.family), nil + case unix.SO_BINDTODEVICE: + return 0, ENOPROTOOPT + } + } + + return ignoreEINTR2(func() (int, error) { return unix.GetsockoptInt(fd, level, name) }) +} diff --git a/internal/network/local_test.go b/internal/network/local_test.go new file mode 100644 index 00000000..401d99ff --- /dev/null +++ b/internal/network/local_test.go @@ -0,0 +1,67 @@ +package network_test + +import ( + "net" + "testing" + + "github.com/stealthrocket/timecraft/internal/assert" + "github.com/stealthrocket/timecraft/internal/network" +) + +func TestLocalNetwork(t *testing.T) { + tests := []struct { + scenario string + function func(*testing.T, network.Namespace) + }{ + { + scenario: "a local network namespace has one interface", + function: testLocalNetworkInterface, + }, + + { + scenario: "ipv4 sockets can connect to one another on the loopback interface", + function: testNamespaceConnectLoopbackIPv4, + }, + + { + scenario: "ipv6 sockets can connect to one another on the loopback interface", + function: testNamespaceConnectLoopbackIPv6, + }, + + { + scenario: "ipv4 sockets can exchange datagrams on the loopback interface", + function: testNamespaceExchangeDatagramLoopbackIPv4, + }, + + { + scenario: "ipv6 sockets can exchange datagrams on the loopback interface", + function: testNamespaceExchangeDatagramLoopbackIPv6, + }, + } + + for _, test := range tests { + t.Run(test.scenario, func(t *testing.T) { + test.function(t, + network.NewLocalNamespace(nil), + ) + }) + } +} + +func testLocalNetworkInterface(t *testing.T, ns network.Namespace) { + ifaces, err := ns.Interfaces() + assert.OK(t, err) + assert.Equal(t, len(ifaces), 1) + + lo0 := ifaces[0] + assert.Equal(t, lo0.Index(), 0) + assert.Equal(t, lo0.MTU(), 1500) + assert.Equal(t, lo0.Name(), "lo0") + assert.Equal(t, lo0.Flags(), net.FlagUp|net.FlagLoopback) + + lo0Addrs, err := lo0.Addrs() + assert.OK(t, err) + assert.Equal(t, len(lo0Addrs), 2) + assert.Equal(t, lo0Addrs[0].String(), "127.0.0.1") + assert.Equal(t, lo0Addrs[1].String(), "::1") +} diff --git a/internal/network/local_unix.go b/internal/network/local_unix.go new file mode 100644 index 00000000..e5da3358 --- /dev/null +++ b/internal/network/local_unix.go @@ -0,0 +1,515 @@ +package network + +import ( + "encoding/binary" + "sync/atomic" + + "golang.org/x/sys/unix" +) + +func (ns *localNamespace) Socket(family Family, socktype Socktype, protocol Protocol) (Socket, error) { + switch family { + case INET, INET6: + default: + return ns.host.Socket(family, socktype, protocol) + } + socket := &localSocket{ + ns: ns, + family: family, + socktype: socktype, + protocol: protocol, + } + fds, err := socketpair(int(UNIX), int(socktype), 0) + if err != nil { + return nil, err + } + socket.fd0.init(fds[0]) + socket.fd1.init(fds[1]) + return socket, nil +} + +type localSocketState uint8 + +const ( + bound localSocketState = 1 << iota + connected + listening +) + +func (state localSocketState) is(s localSocketState) bool { + return (state & s) != 0 +} + +func (state *localSocketState) set(s localSocketState) { + *state |= s +} + +func (state *localSocketState) unset(s localSocketState) { + *state &= ^s +} + +const ( + addrBufSize = 20 +) + +type localSocket struct { + ns *localNamespace + fd0 socketFD + fd1 socketFD + family Family + socktype Socktype + protocol Protocol + state localSocketState + name atomic.Value + peer atomic.Value + iovs [][]byte + addrBuf [addrBufSize]byte +} + +func (s *localSocket) Family() Family { + return s.family +} + +func (s *localSocket) Type() Socktype { + return s.socktype +} + +func (s *localSocket) Fd() int { + return s.fd0.load() +} + +func (s *localSocket) Close() error { + if s.state.is(bound) { + switch addr := s.name.Load().(type) { + case *SockaddrInet4: + s.ns.unlinkInet4(s, addr) + case *SockaddrInet6: + s.ns.unlinkInet6(s, addr) + } + } + s.fd0.close() + s.fd1.close() + return nil +} + +func (s *localSocket) Bind(addr Sockaddr) error { + fd := s.fd0.acquire() + if fd < 0 { + return EBADF + } + defer s.fd0.release(fd) + + if s.state.is(bound) { + return EINVAL + } + + switch bind := addr.(type) { + case *SockaddrInet4: + if s.family == INET { + return s.ns.bindInet4(s, bind) + } + case *SockaddrInet6: + if s.family == INET6 { + return s.ns.bindInet6(s, bind) + } + } + return EAFNOSUPPORT +} + +func (s *localSocket) bindAny() error { + switch s.family { + case INET: + return s.ns.bindInet4(s, &sockaddrInet4Any) + default: + return s.ns.bindInet6(s, &sockaddrInet6Any) + } +} + +func (s *localSocket) bindLoopback() error { + switch s.family { + case INET: + return s.ns.bindInet4(s, &sockaddrInet4Loopback) + default: + return s.ns.bindInet6(s, &sockaddrInet6Loopback) + } +} + +func (s *localSocket) Listen(backlog int) error { + fd := s.fd0.acquire() + if fd < 0 { + return EBADF + } + defer s.fd0.release(fd) + + if s.state.is(listening) { + return nil + } + if s.state.is(connected) { + return EINVAL + } + if s.socktype != STREAM { + return EINVAL + } + if !s.state.is(bound) { + if err := s.bindAny(); err != nil { + return err + } + } + + s.state.set(listening) + return nil +} + +func (s *localSocket) Connect(addr Sockaddr) error { + fd := s.fd0.acquire() + if fd < 0 { + return EBADF + } + defer s.fd0.release(fd) + + if s.state.is(listening) { + return EINVAL + } + if s.state.is(connected) { + return EISCONN + } + if !s.state.is(bound) { + if err := s.bindLoopback(); err != nil { + return err + } + } + + var peer *localSocket + var err error + switch a := addr.(type) { + case *SockaddrInet4: + peer, err = s.ns.lookupInet4(s, a) + case *SockaddrInet6: + peer, err = s.ns.lookupInet6(s, a) + default: + return EAFNOSUPPORT + } + if err != nil { + return ECONNREFUSED + } + if peer.family != s.family { + return ECONNREFUSED + } + if peer.socktype != s.socktype { + return ECONNREFUSED + } + + if s.socktype == DGRAM { + s.peer.Store(addr) + s.state.set(connected) + return nil + } + + peerFd := peer.fd1.acquire() + if peerFd < 0 { + return ECONNREFUSED + } + defer peer.fd1.release(peerFd) + + fd1 := s.fd1.acquire() + if fd1 < 0 { + return EBADF + } + defer s.fd1.release(fd1) + + addrBuf := encodeSockaddrAny(s.name.Load()) + // TODO: remove the heap allocation by implementing UnixRights to output to + // a stack buffer. + rights := unix.UnixRights(fd1) + if err := unix.Sendmsg(peerFd, addrBuf[:], rights, nil, 0); err != nil { + return ECONNREFUSED + } + + s.fd1.close() + s.peer.Store(addr) + s.state.set(connected) + return EINPROGRESS +} + +func (s *localSocket) Accept() (Socket, Sockaddr, error) { + fd := s.fd0.acquire() + if fd < 0 { + return nil, nil, EBADF + } + defer s.fd0.release(fd) + + if !s.state.is(listening) { + return nil, nil, EINVAL + } + + socket := &localSocket{ + ns: s.ns, + family: s.family, + socktype: s.socktype, + protocol: s.protocol, + state: bound | connected, + } + + var oobn int + var oobBuf [16]byte + var addrBuf [addrBufSize]byte + for { + var err error + // TOOD: remove the heap allocation for the receive address by + // implementing recvmsg and using the stack-allocated socket address + // buffer. + _, oobn, _, _, err = unix.Recvmsg(fd, addrBuf[:], oobBuf[:unix.CmsgSpace(1)], 0) + if err == nil { + break + } + if err != EINTR { + return nil, nil, err + } + if oobn > 0 { + break + } + } + + // TOOD: remove the heap allocation for the return value by implementing + // ParseSocketControlMessage; we know that we will receive at most most one + // message since we sized the buffer accordingly. + msgs, err := unix.ParseSocketControlMessage(oobBuf[:oobn]) + if err != nil { + return nil, nil, err + } + + // TODO: remove the heap allocation for the return fd slice by implementing + // ParseUnixRights and decoding the single file descriptor we received in a + // stack-allocated variabled. + fds, err := unix.ParseUnixRights(&msgs[0]) + if err != nil { + return nil, nil, err + } + + addr := decodeSockaddr(addrBuf) + socket.fd0.init(fds[0]) + socket.fd1.init(-1) + socket.name.Store(s.name.Load()) + socket.peer.Store(addr) + return socket, addr, nil +} + +func (s *localSocket) Name() (Sockaddr, error) { + fd := s.fd0.acquire() + if fd < 0 { + return nil, EBADF + } + defer s.fd0.release(fd) + switch name := s.name.Load().(type) { + case *SockaddrInet4: + return name, nil + case *SockaddrInet6: + return name, nil + } + switch s.family { + case INET: + return &sockaddrInet4Any, nil + default: + return &sockaddrInet6Any, nil + } +} + +func (s *localSocket) Peer() (Sockaddr, error) { + fd := s.fd0.acquire() + if fd < 0 { + return nil, EBADF + } + defer s.fd0.release(fd) + switch peer := s.peer.Load().(type) { + case *SockaddrInet4: + return peer, nil + case *SockaddrInet6: + return peer, nil + } + return nil, ENOTCONN +} + +func (s *localSocket) RecvFrom(iovs [][]byte, flags int) (int, int, Sockaddr, error) { + fd := s.fd0.acquire() + if fd < 0 { + return -1, 0, nil, EBADF + } + defer s.fd0.release(fd) + + if s.state.is(listening) { + return -1, 0, nil, EINVAL + } + if !s.state.is(bound) { + if err := s.bindLoopback(); err != nil { + return -1, 0, nil, err + } + } + + if s.socktype == DGRAM { + s.iovs = s.iovs[:0] + s.iovs = append(s.iovs, s.addrBuf[:]) + s.iovs = append(s.iovs, iovs...) + iovs = s.iovs + defer clearIOVecs(iovs) + } + + // TODO: remove the heap allocation that happens for the socket address by + // implementing recvfrom(2) and using a cached socket address for connected + // sockets. + for { + n, _, rflags, addr, err := unix.RecvmsgBuffers(fd, iovs, nil, flags) + if err == EINTR { + if n == 0 { + continue + } + err = nil + } + if s.socktype == DGRAM { + addr = decodeSockaddr(s.addrBuf) + n -= addrBufSize + } else { + switch a := s.peer.Load().(type) { + case *SockaddrInet4: + addr = a + case *SockaddrInet6: + addr = a + default: + addr = nil + } + } + return n, rflags, addr, err + } +} + +func (s *localSocket) SendTo(iovs [][]byte, addr Sockaddr, flags int) (int, error) { + fd := s.fd0.acquire() + if fd < 0 { + return -1, EBADF + } + defer s.fd0.release(fd) + + if s.state.is(listening) { + return -1, EINVAL + } + if s.state.is(connected) && addr != nil { + return -1, EISCONN + } + if !s.state.is(connected) && addr == nil { + return -1, ENOTCONN + } + if !s.state.is(bound) { + if err := s.bindLoopback(); err != nil { + return -1, err + } + } + + sendSocketFd := fd + if addr != nil { + var peer *localSocket + var err error + switch a := addr.(type) { + case *SockaddrInet4: + peer, err = s.ns.lookupInet4(s, a) + case *SockaddrInet6: + peer, err = s.ns.lookupInet6(s, a) + default: + return -1, EAFNOSUPPORT + } + if err != nil { + return -1, err + } + if peer.socktype != DGRAM { + return iovecLen(iovs), nil + } + + peerFd := peer.fd1.acquire() + if peerFd < 0 { + return -1, EHOSTUNREACH + } + defer peer.fd1.release(peerFd) + sendSocketFd = peerFd + + s.addrBuf = encodeSockaddrAny(s.name.Load()) + s.iovs = s.iovs[:0] + s.iovs = append(s.iovs, s.addrBuf[:]) + s.iovs = append(s.iovs, iovs...) + iovs = s.iovs + defer clearIOVecs(iovs) + } + + for { + n, err := unix.SendmsgBuffers(sendSocketFd, iovs, nil, nil, flags) + if err == EINTR { + if n == 0 { + continue + } + err = nil + } + if n > 0 && s.socktype == DGRAM { + n -= addrBufSize + } + return n, err + } +} + +func (s *localSocket) Shutdown(how int) error { + fd := s.fd0.acquire() + if fd < 0 { + return EBADF + } + defer s.fd0.release(fd) + return ignoreEINTR(func() error { return unix.Shutdown(fd, how) }) +} + +func clearIOVecs(iovs [][]byte) { + for i := range iovs { + iovs[i] = nil + } +} + +func encodeSockaddrAny(addr any) (buf [addrBufSize]byte) { + switch a := addr.(type) { + case *SockaddrInet4: + return encodeSockaddrInet4(a) + case *SockaddrInet6: + return encodeSockaddrInet6(a) + default: + return + } +} + +func encodeSockaddrInet4(addr *SockaddrInet4) (buf [addrBufSize]byte) { + binary.LittleEndian.PutUint16(buf[0:2], uint16(INET)) + binary.LittleEndian.PutUint16(buf[2:4], uint16(addr.Port)) + *(*[4]byte)(buf[4:]) = addr.Addr + return +} + +func encodeSockaddrInet6(addr *SockaddrInet6) (buf [addrBufSize]byte) { + binary.LittleEndian.PutUint16(buf[0:2], uint16(INET6)) + binary.LittleEndian.PutUint16(buf[2:4], uint16(addr.Port)) + *(*[16]byte)(buf[4:]) = addr.Addr + return +} + +func decodeSockaddr(buf [addrBufSize]byte) Sockaddr { + switch Family(binary.LittleEndian.Uint16(buf[0:2])) { + case INET: + return &SockaddrInet4{ + Port: int(binary.LittleEndian.Uint16(buf[2:4])), + Addr: ([4]byte)(buf[4:]), + } + default: + return &SockaddrInet6{ + Port: int(binary.LittleEndian.Uint16(buf[2:4])), + Addr: ([16]byte)(buf[4:]), + } + } +} + +func iovecLen(iovs [][]byte) (n int) { + for _, iov := range iovs { + n += len(iov) + } + return n +} diff --git a/internal/network/network.go b/internal/network/network.go index 4762a3d3..e424339c 100644 --- a/internal/network/network.go +++ b/internal/network/network.go @@ -1,9 +1,15 @@ package network import ( + "errors" + "fmt" "net" ) +var ( + ErrInterfaceNotFound = errors.New("network interface not found") +) + type Socket interface { Family() Family @@ -25,15 +31,15 @@ type Socket interface { Peer() (Sockaddr, error) - RecvFrom(iovs [][]byte, oob []byte, flags int) (n, oobn, rflags int, addr Sockaddr, err error) + RecvFrom(iovs [][]byte, flags int) (n, rflags int, addr Sockaddr, err error) - SendTo(iovs [][]byte, oob []byte, addr Sockaddr, flags int) (int, error) + SendTo(iovs [][]byte, addr Sockaddr, flags int) (int, error) Shutdown(how int) error - SetOption(level, name, value int) error + SetOptInt(level, name, value int) error - GetOption(level, name int) (int, error) + GetOptInt(level, name int) (int, error) } type Socktype uint8 @@ -85,17 +91,6 @@ func SockaddrFamily(sa Sockaddr) Family { } } -func isUnspecified(sa Sockaddr) bool { - switch a := sa.(type) { - case *SockaddrInet4: - return isUnspecifiedInet4(a) - case *SockaddrInet6: - return isUnspecifiedInet6(a) - default: - return true - } -} - func isUnspecifiedInet4(sa *SockaddrInet4) bool { return sa.Addr == [4]byte{} } @@ -103,3 +98,27 @@ func isUnspecifiedInet4(sa *SockaddrInet4) bool { func isUnspecifiedInet6(sa *SockaddrInet6) bool { return sa.Addr == [16]byte{} } + +func isLoopbackInet4(sa *SockaddrInet4) bool { + return sa.Addr == [4]byte{127, 0, 0, 1} +} + +func isLoopbackInet6(sa *SockaddrInet6) bool { + return sa.Addr == [16]byte{15: 1} +} + +func errInterfaceIndexNotFound(index int) error { + return fmt.Errorf("index=%d: %w", index, ErrInterfaceNotFound) +} + +func errInterfaceNameNotFound(name string) error { + return fmt.Errorf("name=%q: %w", name, ErrInterfaceNotFound) +} + +var ( + sockaddrInet4Any SockaddrInet4 + sockaddrInet6Any SockaddrInet6 + + sockaddrInet4Loopback = SockaddrInet4{Addr: [4]byte{127, 0, 0, 1}} + sockaddrInet6Loopback = SockaddrInet6{Addr: [16]byte{15: 1}} +) diff --git a/internal/network/network_darwin.go b/internal/network/network_darwin.go new file mode 100644 index 00000000..6796d4fc --- /dev/null +++ b/internal/network/network_darwin.go @@ -0,0 +1,17 @@ +package network + +import "golang.org/x/sys/unix" + +func setCloseOnExecAndNonBlocking(fd int) error { + if _, err := ignoreEINTR2(func() (int, error) { + return unix.FcntlInt(uintptr(fd), unix.F_SETFD, unix.O_CLOEXEC) + }); err != nil { + return err + } + if err := ignoreEINTR(func() error { + return unix.SetNonblock(fd, true) + }); err != nil { + return err + } + return nil +} diff --git a/internal/network/network_test.go b/internal/network/network_test.go new file mode 100644 index 00000000..19ae875d --- /dev/null +++ b/internal/network/network_test.go @@ -0,0 +1,125 @@ +package network_test + +import ( + "testing" + + "github.com/stealthrocket/timecraft/internal/assert" + "github.com/stealthrocket/timecraft/internal/network" +) + +func testNamespaceConnectLoopbackIPv4(t *testing.T, ns network.Namespace) { + testNamespaceConnect(t, ns, &network.SockaddrInet4{ + Addr: [4]byte{127, 0, 0, 1}, + Port: 80, + }) +} + +func testNamespaceConnectLoopbackIPv6(t *testing.T, ns network.Namespace) { + testNamespaceConnect(t, ns, &network.SockaddrInet6{ + Addr: [16]byte{15: 1}, + Port: 80, + }) +} + +func testNamespaceConnect(t *testing.T, ns network.Namespace, bind network.Sockaddr) { + family := network.SockaddrFamily(bind) + + server, err := ns.Socket(family, network.STREAM, network.TCP) + assert.OK(t, err) + defer server.Close() + + assert.OK(t, server.Bind(bind)) + assert.OK(t, server.Listen(1)) + serverAddr, err := server.Name() + assert.OK(t, err) + + client, err := ns.Socket(family, network.STREAM, network.TCP) + assert.OK(t, err) + defer client.Close() + + assert.Error(t, client.Connect(serverAddr), network.EINPROGRESS) + assert.OK(t, waitReadyRead(server)) + conn, addr, err := server.Accept() + assert.OK(t, err) + defer conn.Close() + + assert.OK(t, waitReadyWrite(client)) + peer, err := client.Peer() + assert.OK(t, err) + assert.DeepEqual(t, peer, serverAddr) + + name, err := client.Name() + assert.OK(t, err) + assert.DeepEqual(t, name, addr) + + wn, err := client.SendTo([][]byte{[]byte("Hello, World!")}, nil, 0) + assert.OK(t, err) + assert.Equal(t, wn, 13) + + assert.OK(t, waitReadyRead(conn)) + + buf := make([]byte, 32) + rn, rflags, peer, err := conn.RecvFrom([][]byte{buf}, 0) + assert.OK(t, err) + assert.Equal(t, rn, 13) + assert.Equal(t, rflags, 0) + assert.Equal(t, string(buf[:13]), "Hello, World!") + assert.DeepEqual(t, peer, addr) +} + +func testNamespaceExchangeDatagramLoopbackIPv4(t *testing.T, ns network.Namespace) { + testNamespaceExchangeDatagram(t, ns, &network.SockaddrInet4{ + Addr: [4]byte{127, 0, 0, 1}, + Port: 80, + }) +} + +func testNamespaceExchangeDatagramLoopbackIPv6(t *testing.T, ns network.Namespace) { + testNamespaceExchangeDatagram(t, ns, &network.SockaddrInet6{ + Addr: [16]byte{15: 1}, + Port: 80, + }) +} + +func testNamespaceExchangeDatagram(t *testing.T, ns network.Namespace, bind network.Sockaddr) { + family := network.SockaddrFamily(bind) + + socket1, err := ns.Socket(family, network.DGRAM, network.TCP) + assert.OK(t, err) + defer socket1.Close() + + socket2, err := ns.Socket(family, network.DGRAM, network.TCP) + assert.OK(t, err) + defer socket2.Close() + + assert.OK(t, socket2.Bind(bind)) + addr2, err := socket2.Name() + assert.OK(t, err) + + wn, err := socket1.SendTo([][]byte{[]byte("Hello, World!")}, addr2, 0) + assert.OK(t, err) + assert.Equal(t, wn, 13) + addr1, err := socket1.Name() + assert.OK(t, err) + + assert.OK(t, waitReadyRead(socket2)) + buf := make([]byte, 32) + + rn, rflags, addr, err := socket2.RecvFrom([][]byte{buf}, 0) + assert.OK(t, err) + assert.Equal(t, rn, 13) + assert.Equal(t, rflags, 0) + assert.Equal(t, string(buf[:13]), "Hello, World!") + assert.DeepEqual(t, addr, addr1) + + wn, err = socket2.SendTo([][]byte{[]byte("How are you?")}, addr1, 0) + assert.OK(t, err) + assert.Equal(t, wn, 12) + + rn, rflags, addr, err = socket1.RecvFrom([][]byte{buf[:11]}, 0) + assert.OK(t, err) + assert.Equal(t, rn, 11) + assert.Equal(t, rflags, network.TRUNC) + assert.Equal(t, string(buf[:11]), "How are you") + assert.DeepEqual(t, addr, addr2) +} diff --git a/internal/network/network_unix.go b/internal/network/network_unix.go index 94d1c527..552f2b0a 100644 --- a/internal/network/network_unix.go +++ b/internal/network/network_unix.go @@ -1,6 +1,10 @@ package network -import "golang.org/x/sys/unix" +import ( + "sync/atomic" + + "golang.org/x/sys/unix" +) const ( EADDRNOTAVAIL = unix.EADDRNOTAVAIL @@ -12,7 +16,10 @@ const ( EINVAL = unix.EINVAL EINTR = unix.EINTR EINPROGRESS = unix.EINPROGRESS + EISCONN = unix.EISCONN ENETUNREACH = unix.ENETUNREACH + ENOPROTOOPT = unix.ENOPROTOOPT + ENOSYS = unix.ENOSYS ENOTCONN = unix.ENOTCONN ) @@ -41,3 +48,64 @@ const ( type Sockaddr = unix.Sockaddr type SockaddrInet4 = unix.SockaddrInet4 type SockaddrInet6 = unix.SockaddrInet6 + +type socketFD struct { + state atomic.Uint64 // upper 32 bits: refCount, lower 32 bits: fd +} + +func (s *socketFD) init(fd int) { + s.state.Store(uint64(fd)) +} + +func (s *socketFD) load() int { + return int(int32(s.state.Load())) +} + +func (s *socketFD) acquire() int { + for { + oldState := s.state.Load() + refCount := (oldState >> 32) + 1 + newState := (refCount << 32) | (oldState & 0xFFFFFFFF) + + if int32(oldState) < 0 { + return -1 + } + if s.state.CompareAndSwap(oldState, newState) { + return int(int32(oldState)) // int32->int for sign extension + } + } +} + +func (s *socketFD) release(fd int) { + for { + oldState := s.state.Load() + refCount := (oldState >> 32) - 1 + newState := (oldState << 32) | (oldState & 0xFFFFFFFF) + + if s.state.CompareAndSwap(oldState, newState) { + if int32(oldState) < 0 && refCount == 0 { + unix.Close(fd) + } + break + } + } +} + +func (s *socketFD) close() error { + for { + oldState := s.state.Load() + refCount := oldState >> 32 + newState := oldState | 0xFFFFFFFF + + if s.state.CompareAndSwap(oldState, newState) { + fd := int32(oldState) + if fd < 0 { + return EBADF + } + if refCount == 0 { + return unix.Close(int(fd)) + } + return nil + } + } +} diff --git a/internal/network/virtual.go b/internal/network/virtual.go index fec3504e..49f73e49 100644 --- a/internal/network/virtual.go +++ b/internal/network/virtual.go @@ -2,7 +2,6 @@ package network import ( "errors" - "fmt" "net" "sync" "sync/atomic" @@ -143,7 +142,7 @@ func (ns *VirtualNamespace) InterfaceByIndex(index int) (Interface, error) { case ns.en0.index: return &ns.en0, nil default: - return nil, fmt.Errorf("virtual network interface index out of bounds: %d", index) + return nil, errInterfaceIndexNotFound(index) } } @@ -154,7 +153,7 @@ func (ns *VirtualNamespace) InterfaceByName(name string) (Interface, error) { case ns.en0.name: return &ns.en0, nil default: - return nil, fmt.Errorf("virtual network interface not found: %s", name) + return nil, errInterfaceNameNotFound(name) } } @@ -445,9 +444,13 @@ func (t *virtualAddressTable) bind(socket *virtualSocket, ha, sa virtualAddress) return int(sa.port), nil } -func (t *virtualAddressTable) unlink(ha, sa virtualAddress) { - delete(t.host, ha) - delete(t.sock, sa) +func (t *virtualAddressTable) unlink(socket *virtualSocket, ha, sa virtualAddress) { + if t.host[ha] == socket { + delete(t.host, ha) + } + if t.sock[sa] == socket { + delete(t.sock, sa) + } } type virtualInterface struct { @@ -458,7 +461,7 @@ type virtualInterface struct { haddr net.HardwareAddr flags net.Flags - mutex sync.Mutex + mutex sync.RWMutex inet4 virtualAddressTable inet6 virtualAddressTable } @@ -535,29 +538,29 @@ func (i *virtualInterface) bindInet6(socket *virtualSocket, host, addr *Sockaddr func (i *virtualInterface) lookupByHostInet4(socket *virtualSocket, host *SockaddrInet4) *virtualSocket { va := virtualSockaddrInet4(socket.proto, host) - i.mutex.Lock() - defer i.mutex.Unlock() + i.mutex.RLock() + defer i.mutex.RUnlock() return i.inet4.host[va] } func (i *virtualInterface) lookupByHostInet6(socket *virtualSocket, host *SockaddrInet6) *virtualSocket { va := virtualSockaddrInet6(socket.proto, host) - i.mutex.Lock() - defer i.mutex.Unlock() + i.mutex.RLock() + defer i.mutex.RUnlock() return i.inet6.host[va] } func (i *virtualInterface) lookupByAddrInet4(socket *virtualSocket, addr *SockaddrInet4) *virtualSocket { va := virtualSockaddrInet4(socket.proto, addr) - i.mutex.Lock() - defer i.mutex.Unlock() + i.mutex.RLock() + defer i.mutex.RUnlock() return i.inet4.sock[va] } func (i *virtualInterface) lookupByAddrInet6(socket *virtualSocket, addr *SockaddrInet6) *virtualSocket { va := virtualSockaddrInet6(socket.proto, addr) - i.mutex.Lock() - defer i.mutex.Unlock() + i.mutex.RLock() + defer i.mutex.RUnlock() return i.inet6.sock[va] } @@ -568,7 +571,7 @@ func (i *virtualInterface) unlinkInet4(socket *virtualSocket, host, addr *Sockad i.mutex.Lock() defer i.mutex.Unlock() - i.inet4.unlink(hostAddr, sockAddr) + i.inet4.unlink(socket, hostAddr, sockAddr) socket.host.Store(&sockaddrInet4Any) socket.name.Store(&sockaddrInet4Any) socket.bound = false @@ -581,17 +584,12 @@ func (i *virtualInterface) unlinkInet6(socket *virtualSocket, host, addr *Sockad i.mutex.Lock() defer i.mutex.Unlock() - i.inet6.unlink(hostAddr, sockAddr) + i.inet6.unlink(socket, hostAddr, sockAddr) socket.host.Store(&sockaddrInet6Any) socket.name.Store(&sockaddrInet6Any) socket.bound = false } -var ( - sockaddrInet4Any SockaddrInet4 - sockaddrInet6Any SockaddrInet6 -) - type virtualSocket struct { ns *VirtualNamespace base Socket @@ -615,9 +613,6 @@ func (s *virtualSocket) Fd() int { } func (s *virtualSocket) Close() error { - if err := s.base.Close(); err != nil { - return err - } if s.bound { host := s.host.Load() name := s.name.Load() @@ -628,7 +623,7 @@ func (s *virtualSocket) Close() error { s.ns.unlinkInet6(s, host.(*SockaddrInet6), a) } } - return nil + return s.base.Close() } func (s *virtualSocket) Bind(addr Sockaddr) error { @@ -785,10 +780,10 @@ func (s *virtualSocket) Peer() (Sockaddr, error) { return nil, ENOTCONN } -func (s *virtualSocket) RecvFrom(iovs [][]byte, oob []byte, flags int) (int, int, int, Sockaddr, error) { - n, oobn, flags, addr, err := s.base.RecvFrom(iovs, oob, flags) +func (s *virtualSocket) RecvFrom(iovs [][]byte, flags int) (int, int, Sockaddr, error) { + n, flags, addr, err := s.base.RecvFrom(iovs, flags) if err != nil { - return -1, -1, 0, nil, err + return -1, 0, nil, err } switch conn := s.peer.Load().(type) { case *SockaddrInet4: @@ -798,10 +793,10 @@ func (s *virtualSocket) RecvFrom(iovs [][]byte, oob []byte, flags int) (int, int default: addr, err = s.toVirtualAddr(addr) } - return n, oobn, flags, addr, err + return n, flags, addr, err } -func (s *virtualSocket) SendTo(iovs [][]byte, oob []byte, addr Sockaddr, flags int) (int, error) { +func (s *virtualSocket) SendTo(iovs [][]byte, addr Sockaddr, flags int) (int, error) { if addr != nil { a, err := s.toHostAddr(addr) if err != nil { @@ -809,7 +804,7 @@ func (s *virtualSocket) SendTo(iovs [][]byte, oob []byte, addr Sockaddr, flags i } addr = a } - n, err := s.base.SendTo(iovs, oob, addr, flags) + n, err := s.base.SendTo(iovs, addr, flags) if s.name.Load() == nil { if err := s.bindAny(); err != nil { return n, err @@ -822,12 +817,12 @@ func (s *virtualSocket) Shutdown(how int) error { return s.base.Shutdown(how) } -func (s *virtualSocket) SetOption(level, name, value int) error { - return s.base.SetOption(level, name, value) +func (s *virtualSocket) SetOptInt(level, name, value int) error { + return s.base.SetOptInt(level, name, value) } -func (s *virtualSocket) GetOption(level, name int) (int, error) { - return s.base.GetOption(level, name) +func (s *virtualSocket) GetOptInt(level, name int) (int, error) { + return s.base.GetOptInt(level, name) } func (s *virtualSocket) toHostAddr(addr Sockaddr) (Sockaddr, error) { diff --git a/internal/network/virtual_test.go b/internal/network/virtual_test.go index efe95c38..a7f5c079 100644 --- a/internal/network/virtual_test.go +++ b/internal/network/virtual_test.go @@ -50,8 +50,11 @@ func TestVirtualNetwork(t *testing.T) { }, } - _, ipnet4, _ := net.ParseCIDR("192.168.0.0/24") - _, ipnet6, _ := net.ParseCIDR("fe80::/64") + _, ipnet4, err := net.ParseCIDR("192.168.0.0/24") + assert.OK(t, err) + + _, ipnet6, err := net.ParseCIDR("fe80::/64") + assert.OK(t, err) for _, test := range tests { t.Run(test.scenario, func(t *testing.T) { @@ -124,52 +127,9 @@ func testVirtualNetworkConnectInterfaceIPv6(t *testing.T, n *network.VirtualNetw } func testVirtualNetworkConnect(t *testing.T, n *network.VirtualNetwork, bind network.Sockaddr) { - family := network.SockaddrFamily(bind) - ns, err := n.CreateNamespace() assert.OK(t, err) - - server, err := ns.Socket(family, network.STREAM, network.TCP) - assert.OK(t, err) - defer server.Close() - - assert.OK(t, server.Bind(bind)) - assert.OK(t, server.Listen(1)) - serverAddr, err := server.Name() - assert.OK(t, err) - - client, err := ns.Socket(family, network.STREAM, network.TCP) - assert.OK(t, err) - defer client.Close() - - assert.Error(t, client.Connect(serverAddr), network.EINPROGRESS) - assert.OK(t, waitReadyRead(server)) - conn, addr, err := server.Accept() - assert.OK(t, err) - defer conn.Close() - - assert.OK(t, waitReadyWrite(client)) - peer, err := client.Peer() - assert.OK(t, err) - assert.DeepEqual(t, peer, serverAddr) - - name, err := client.Name() - assert.OK(t, err) - assert.DeepEqual(t, name, addr) - - wn, err := client.SendTo([][]byte{[]byte("Hello, World!")}, nil, nil, 0) - assert.OK(t, err) - assert.Equal(t, wn, 13) - - assert.OK(t, waitReadyRead(conn)) - buf := make([]byte, 32) - rn, oobn, rflags, peer, err := conn.RecvFrom([][]byte{buf}, nil, 0) - assert.OK(t, err) - assert.Equal(t, rn, 13) - assert.Equal(t, oobn, 0) - assert.Equal(t, rflags, 0) - assert.Equal(t, string(buf[:13]), "Hello, World!") - assert.Equal(t, peer, addr) + testNamespaceConnect(t, ns, bind) } func testVirtualNetworkConnectNamespacesIPv4(t *testing.T, n *network.VirtualNetwork) { @@ -214,16 +174,15 @@ func testVirtualNetworkConnectNamespaces(t *testing.T, n *network.VirtualNetwork assert.OK(t, err) assert.DeepEqual(t, name, addr) - wn, err := client.SendTo([][]byte{[]byte("Hello, World!")}, nil, nil, 0) + wn, err := client.SendTo([][]byte{[]byte("Hello, World!")}, nil, 0) assert.OK(t, err) assert.Equal(t, wn, 13) assert.OK(t, waitReadyRead(conn)) buf := make([]byte, 32) - rn, oobn, rflags, peer, err := conn.RecvFrom([][]byte{buf}, nil, 0) + rn, rflags, peer, err := conn.RecvFrom([][]byte{buf}, 0) assert.OK(t, err) assert.Equal(t, rn, 13) - assert.Equal(t, oobn, 0) assert.Equal(t, rflags, 0) assert.Equal(t, string(buf[:13]), "Hello, World!") assert.Equal(t, peer, addr) From 1e7f20ce6d558477aabbba7047879d47d0a40cef Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Thu, 13 Jul 2023 19:20:54 -0700 Subject: [PATCH 03/16] network: run tests against host network Signed-off-by: Achille Roussel --- internal/network/host_test.go | 82 ++++++++++++++++++++++++++++++++ internal/network/local_unix.go | 9 +--- internal/network/network.go | 25 ++++++++++ internal/network/network_test.go | 24 +++++----- internal/network/virtual.go | 7 +-- internal/network/virtual_test.go | 6 +-- 6 files changed, 124 insertions(+), 29 deletions(-) create mode 100644 internal/network/host_test.go diff --git a/internal/network/host_test.go b/internal/network/host_test.go new file mode 100644 index 00000000..a03e491d --- /dev/null +++ b/internal/network/host_test.go @@ -0,0 +1,82 @@ +package network_test + +import ( + "net" + "testing" + + "github.com/stealthrocket/timecraft/internal/assert" + "github.com/stealthrocket/timecraft/internal/network" +) + +func TestHostNetwork(t *testing.T) { + tests := []struct { + scenario string + function func(*testing.T, network.Namespace) + }{ + { + scenario: "a host network namespace has at least one loopback interface", + function: testHostNetworkInterface, + }, + + { + scenario: "ipv4 sockets can connect to one another on the loopback interface", + function: testNamespaceConnectLoopbackIPv4, + }, + + { + scenario: "ipv6 sockets can connect to one another on the loopback interface", + function: testNamespaceConnectLoopbackIPv6, + }, + + { + scenario: "ipv4 sockets can exchange datagrams on the loopback interface", + function: testNamespaceExchangeDatagramLoopbackIPv4, + }, + + { + scenario: "ipv6 sockets can exchange datagrams on the loopback interface", + function: testNamespaceExchangeDatagramLoopbackIPv6, + }, + } + + for _, test := range tests { + t.Run(test.scenario, func(t *testing.T) { + test.function(t, network.Host()) + }) + } +} + +func testHostNetworkInterface(t *testing.T, ns network.Namespace) { + ifaces, err := ns.Interfaces() + assert.OK(t, err) + + for _, iface := range ifaces { + if (iface.Flags() & net.FlagLoopback) == 0 { + continue + } + if (iface.Flags() & net.FlagUp) == 0 { + continue + } + + lo0 := iface + assert.NotEqual(t, lo0.Name(), "") + + lo0Addrs, err := lo0.Addrs() + assert.OK(t, err) + + ipv4 := false + ipv6 := false + for _, addr := range lo0Addrs { + switch addr.String() { + case "127.0.0.1/8": + ipv4 = true + case "::1/128": + ipv6 = true + } + } + assert.True(t, ipv4 && ipv6) + return + } + + t.Fatal("host network has not loopback interface") +} diff --git a/internal/network/local_unix.go b/internal/network/local_unix.go index e5da3358..eb9a92dd 100644 --- a/internal/network/local_unix.go +++ b/internal/network/local_unix.go @@ -368,14 +368,7 @@ func (s *localSocket) RecvFrom(iovs [][]byte, flags int) (int, int, Sockaddr, er addr = decodeSockaddr(s.addrBuf) n -= addrBufSize } else { - switch a := s.peer.Load().(type) { - case *SockaddrInet4: - addr = a - case *SockaddrInet6: - addr = a - default: - addr = nil - } + addr = nil } return n, rflags, addr, err } diff --git a/internal/network/network.go b/internal/network/network.go index e424339c..b8fe8f21 100644 --- a/internal/network/network.go +++ b/internal/network/network.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "net" + "net/netip" ) var ( @@ -46,6 +47,19 @@ type Socktype uint8 type Family uint8 +func (f Family) String() string { + switch f { + case UNIX: + return "UNIX" + case INET: + return "INET" + case INET6: + return "INET6" + default: + return "UNSPEC" + } +} + type Protocol uint16 const ( @@ -91,6 +105,17 @@ func SockaddrFamily(sa Sockaddr) Family { } } +func SockaddrAddrPort(sa Sockaddr) netip.AddrPort { + switch a := sa.(type) { + case *SockaddrInet4: + return netip.AddrPortFrom(netip.AddrFrom4(a.Addr), uint16(a.Port)) + case *SockaddrInet6: + return netip.AddrPortFrom(netip.AddrFrom16(a.Addr), uint16(a.Port)) + default: + return netip.AddrPort{} + } +} + func isUnspecifiedInet4(sa *SockaddrInet4) bool { return sa.Addr == [4]byte{} } diff --git a/internal/network/network_test.go b/internal/network/network_test.go index 19ae875d..9944afb9 100644 --- a/internal/network/network_test.go +++ b/internal/network/network_test.go @@ -10,14 +10,12 @@ import ( func testNamespaceConnectLoopbackIPv4(t *testing.T, ns network.Namespace) { testNamespaceConnect(t, ns, &network.SockaddrInet4{ Addr: [4]byte{127, 0, 0, 1}, - Port: 80, }) } func testNamespaceConnectLoopbackIPv6(t *testing.T, ns network.Namespace) { testNamespaceConnect(t, ns, &network.SockaddrInet6{ Addr: [16]byte{15: 1}, - Port: 80, }) } @@ -46,7 +44,7 @@ func testNamespaceConnect(t *testing.T, ns network.Namespace, bind network.Socka assert.OK(t, waitReadyWrite(client)) peer, err := client.Peer() assert.OK(t, err) - assert.DeepEqual(t, peer, serverAddr) + assert.Equal(t, network.SockaddrAddrPort(peer), network.SockaddrAddrPort(serverAddr)) name, err := client.Name() assert.OK(t, err) @@ -64,31 +62,33 @@ func testNamespaceConnect(t *testing.T, ns network.Namespace, bind network.Socka assert.Equal(t, rn, 13) assert.Equal(t, rflags, 0) assert.Equal(t, string(buf[:13]), "Hello, World!") - assert.DeepEqual(t, peer, addr) + assert.Equal(t, peer, nil) } func testNamespaceExchangeDatagramLoopbackIPv4(t *testing.T, ns network.Namespace) { testNamespaceExchangeDatagram(t, ns, &network.SockaddrInet4{ Addr: [4]byte{127, 0, 0, 1}, - Port: 80, }) } func testNamespaceExchangeDatagramLoopbackIPv6(t *testing.T, ns network.Namespace) { testNamespaceExchangeDatagram(t, ns, &network.SockaddrInet6{ Addr: [16]byte{15: 1}, - Port: 80, }) } func testNamespaceExchangeDatagram(t *testing.T, ns network.Namespace, bind network.Sockaddr) { family := network.SockaddrFamily(bind) - socket1, err := ns.Socket(family, network.DGRAM, network.TCP) + socket1, err := ns.Socket(family, network.DGRAM, network.UDP) assert.OK(t, err) defer socket1.Close() - socket2, err := ns.Socket(family, network.DGRAM, network.TCP) + assert.OK(t, socket1.Bind(bind)) + addr1, err := socket1.Name() + assert.OK(t, err) + + socket2, err := ns.Socket(family, network.DGRAM, network.UDP) assert.OK(t, err) defer socket2.Close() @@ -99,8 +99,6 @@ func testNamespaceExchangeDatagram(t *testing.T, ns network.Namespace, bind netw wn, err := socket1.SendTo([][]byte{[]byte("Hello, World!")}, addr2, 0) assert.OK(t, err) assert.Equal(t, wn, 13) - addr1, err := socket1.Name() - assert.OK(t, err) assert.OK(t, waitReadyRead(socket2)) buf := make([]byte, 32) @@ -110,16 +108,18 @@ func testNamespaceExchangeDatagram(t *testing.T, ns network.Namespace, bind netw assert.Equal(t, rn, 13) assert.Equal(t, rflags, 0) assert.Equal(t, string(buf[:13]), "Hello, World!") - assert.DeepEqual(t, addr, addr1) + assert.Equal(t, network.SockaddrAddrPort(addr), network.SockaddrAddrPort(addr1)) wn, err = socket2.SendTo([][]byte{[]byte("How are you?")}, addr1, 0) assert.OK(t, err) assert.Equal(t, wn, 12) + assert.OK(t, waitReadyRead(socket1)) + rn, rflags, addr, err = socket1.RecvFrom([][]byte{buf[:11]}, 0) assert.OK(t, err) assert.Equal(t, rn, 11) assert.Equal(t, rflags, network.TRUNC) assert.Equal(t, string(buf[:11]), "How are you") - assert.DeepEqual(t, addr, addr2) + assert.Equal(t, network.SockaddrAddrPort(addr), network.SockaddrAddrPort(addr2)) } diff --git a/internal/network/virtual.go b/internal/network/virtual.go index 49f73e49..d0ed5c47 100644 --- a/internal/network/virtual.go +++ b/internal/network/virtual.go @@ -785,12 +785,7 @@ func (s *virtualSocket) RecvFrom(iovs [][]byte, flags int) (int, int, Sockaddr, if err != nil { return -1, 0, nil, err } - switch conn := s.peer.Load().(type) { - case *SockaddrInet4: - addr = conn - case *SockaddrInet6: - addr = conn - default: + if addr != nil { addr, err = s.toVirtualAddr(addr) } return n, flags, addr, err diff --git a/internal/network/virtual_test.go b/internal/network/virtual_test.go index a7f5c079..9117e95b 100644 --- a/internal/network/virtual_test.go +++ b/internal/network/virtual_test.go @@ -168,11 +168,11 @@ func testVirtualNetworkConnectNamespaces(t *testing.T, n *network.VirtualNetwork assert.OK(t, waitReadyWrite(client)) peer, err := client.Peer() assert.OK(t, err) - assert.DeepEqual(t, peer, serverAddr) + assert.Equal(t, network.SockaddrAddrPort(peer), network.SockaddrAddrPort(serverAddr)) name, err := client.Name() assert.OK(t, err) - assert.DeepEqual(t, name, addr) + assert.Equal(t, network.SockaddrAddrPort(name), network.SockaddrAddrPort(addr)) wn, err := client.SendTo([][]byte{[]byte("Hello, World!")}, nil, 0) assert.OK(t, err) @@ -185,7 +185,7 @@ func testVirtualNetworkConnectNamespaces(t *testing.T, n *network.VirtualNetwork assert.Equal(t, rn, 13) assert.Equal(t, rflags, 0) assert.Equal(t, string(buf[:13]), "Hello, World!") - assert.Equal(t, peer, addr) + assert.Equal(t, peer, nil) } func waitReadyRead(socket network.Socket) error { From 063cbae8064e195cea93e2e5aff65865dfa03577 Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Fri, 14 Jul 2023 00:17:28 -0700 Subject: [PATCH 04/16] network: add local network Signed-off-by: Achille Roussel --- internal/ipam/ipam.go | 22 ++ internal/ipam/ipv4.go | 13 + internal/ipam/ipv6.go | 13 + internal/network/local.go | 457 ++++++++++++++++++++----------- internal/network/local_test.go | 179 ++++++++++-- internal/network/local_unix.go | 52 ++-- internal/network/network.go | 11 - internal/network/network_test.go | 9 + internal/network/virtual_test.go | 9 - 9 files changed, 544 insertions(+), 221 deletions(-) diff --git a/internal/ipam/ipam.go b/internal/ipam/ipam.go index 9989cd0d..1e13a85f 100644 --- a/internal/ipam/ipam.go +++ b/internal/ipam/ipam.go @@ -1,3 +1,25 @@ // Package ipam contains types to implement IP address management on IPv4 and // IPv6 networks. package ipam + +import "net" + +// Pool is an interface implemented by the IPv4Pool and IPv6Pool types to +// abstract the type of IP addresses that are managed by the pool. +type Pool interface { + // Obtains the next IP address, or returns nil if the pool was exhausted. + GetIP() net.IP + // Returns an IP address to the pool. The ip address must have been obtained + // by a previous call to GetIP or the method panics. + PutIP(net.IP) +} + +// NewPool constructs a pool of IP addresses for the network passed as argument. +func NewPool(ipnet *net.IPNet) Pool { + ones, _ := ipnet.Mask.Size() + if ipv4 := ipnet.IP.To4(); ipv4 != nil { + return NewIPv4Pool((IPv4)(ipv4), ones) + } else { + return NewIPv6Pool((IPv6)(ipnet.IP), ones) + } +} diff --git a/internal/ipam/ipv4.go b/internal/ipam/ipv4.go index 27ec9138..e7bc6ecd 100644 --- a/internal/ipam/ipv4.go +++ b/internal/ipam/ipv4.go @@ -4,6 +4,7 @@ import ( "encoding/binary" "fmt" "math/bits" + "net" "net/netip" ) @@ -69,6 +70,14 @@ func (p *IPv4Pool) Reset(ip IPv4, nbits int) { p.bits.clear() } +func (p *IPv4Pool) GetIP() net.IP { + ip, ok := p.Get() + if !ok { + return nil + } + return ip[:] +} + func (p *IPv4Pool) Get() (IPv4, bool) { i := p.bits.findFirstZeroBit() a := p.base.add(i) @@ -82,6 +91,10 @@ func (p *IPv4Pool) Get() (IPv4, bool) { return p.base.add(i), true } +func (p *IPv4Pool) PutIP(ip net.IP) { + p.Put((IPv4)(ip)) +} + func (p *IPv4Pool) Put(ip IPv4) { i := ip.sub(p.base) if !p.bits.has(i) { diff --git a/internal/ipam/ipv6.go b/internal/ipam/ipv6.go index e16ca2e0..2597a651 100644 --- a/internal/ipam/ipv6.go +++ b/internal/ipam/ipv6.go @@ -4,6 +4,7 @@ import ( "encoding/binary" "fmt" "math/bits" + "net" "net/netip" ) @@ -95,6 +96,14 @@ func (p *IPv6Pool) Reset(ip IPv6, nbits int) { p.bits.clear() } +func (p *IPv6Pool) GetIP() net.IP { + ip, ok := p.Get() + if !ok { + return nil + } + return ip[:] +} + func (p *IPv6Pool) Get() (IPv6, bool) { i := p.bits.findFirstZeroBit() a := p.base.add(i) @@ -108,6 +117,10 @@ func (p *IPv6Pool) Get() (IPv6, bool) { return p.base.add(i), true } +func (p *IPv6Pool) PutIP(ip net.IP) { + p.Put((IPv6)(ip)) +} + func (p *IPv6Pool) Put(ip IPv6) { i := ip.sub(p.base) if !p.bits.has(i) { diff --git a/internal/network/local.go b/internal/network/local.go index fd8472cd..944ecc4f 100644 --- a/internal/network/local.go +++ b/internal/network/local.go @@ -2,173 +2,278 @@ package network import ( "context" + "errors" + "fmt" "net" + "net/netip" "sync" -) + "sync/atomic" -var ( - localAddrs = [2]net.Addr{ - &net.IPAddr{IP: net.IP{127, 0, 0, 1}}, - &net.IPAddr{IP: net.IP{15: 1}}, - } + "github.com/stealthrocket/timecraft/internal/ipam" ) -type LocalOption func(*localNamespace) +var ErrIPAM = errors.New("IP pool exhausted") + +var localAddrs = [2]net.IPNet{ + net.IPNet{ + IP: net.IP{ + 127, 0, 0, 1, + }, + Mask: net.IPMask{ + 255, 0, 0, 0, + }, + }, + net.IPNet{ + IP: net.IP{ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + }, + Mask: net.IPMask{ + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + }, + }, +} + +type LocalOption func(*LocalNamespace) func DialFunc(dial func(context.Context, string, string) (net.Conn, error)) LocalOption { - return func(ns *localNamespace) { ns.dial = dial } + return func(ns *LocalNamespace) { ns.dial = dial } } func ListenFunc(listen func(context.Context, string, string) (net.Listener, error)) LocalOption { - return func(ns *localNamespace) { ns.listen = listen } + return func(ns *LocalNamespace) { ns.listen = listen } } func ListenPacketFunc(listenPacket func(context.Context, string, string) (net.PacketConn, error)) LocalOption { - return func(ns *localNamespace) { ns.listenPacket = listenPacket } + return func(ns *LocalNamespace) { ns.listenPacket = listenPacket } } -func NewLocalNamespace(host Namespace, opts ...LocalOption) Namespace { - ns := &localNamespace{host: host} +type LocalNetwork struct { + addrs []net.IPNet + ipams []ipam.Pool + + mutex sync.RWMutex + routes map[netip.Addr]*localInterface +} + +func NewLocalNetwork(addrs ...*net.IPNet) *LocalNetwork { + n := &LocalNetwork{ + addrs: make([]net.IPNet, len(addrs)), + ipams: make([]ipam.Pool, len(addrs)), + routes: make(map[netip.Addr]*localInterface), + } + for i, addr := range addrs { + n.addrs[i] = *addr + n.ipams[i] = ipam.NewPool(&n.addrs[i]) + } + return n +} + +func (n *LocalNetwork) CreateNamespace(host Namespace, opts ...LocalOption) (*LocalNamespace, error) { + ns := &LocalNamespace{ + host: host, + lo0: localInterface{ + index: 0, + name: "lo0", + flags: net.FlagUp | net.FlagLoopback, + addrs: localAddrs[:], + }, + en0: localInterface{ + index: 1, + name: "en0", + flags: net.FlagUp, + addrs: make([]net.IPNet, 0, len(n.addrs)), + }, + } + for _, opt := range opts { opt(ns) } - return ns -} -type localNamespace struct { - host Namespace - dial func(context.Context, string, string) (net.Conn, error) - listen func(context.Context, string, string) (net.Listener, error) - listenPacket func(context.Context, string, string) (net.PacketConn, error) - lo0 localInterface + n.mutex.Lock() + defer n.mutex.Unlock() + + for i, ipam := range n.ipams { + ip := ipam.GetIP() + if ip == nil { + n.detach(&ns.en0) + return nil, fmt.Errorf("%s: %w", ipam, ErrIPAM) + } + ns.en0.addrs = append(ns.en0.addrs, net.IPNet{ + IP: ip, + Mask: n.addrs[i].Mask, + }) + } + + ns.lo0.ports = make([]map[localPort]*localSocket, len(ns.lo0.addrs)) + ns.en0.ports = make([]map[localPort]*localSocket, len(ns.en0.addrs)) + n.attach(&ns.en0) + ns.network.Store(n) + return ns, nil } -func (ns *localNamespace) InterfaceByIndex(index int) (Interface, error) { - if index != 0 { - return nil, errInterfaceIndexNotFound(index) +func (n *LocalNetwork) attach(iface *localInterface) { + for _, ipnet := range iface.addrs { + n.routes[addrFromIP(ipnet.IP)] = iface } - return &ns.lo0, nil } -func (ns *localNamespace) InterfaceByName(name string) (Interface, error) { - if name != "lo0" { - return nil, errInterfaceNameNotFound(name) +func (n *LocalNetwork) detach(iface *localInterface) { + for i, ipnet := range iface.addrs { + delete(n.routes, addrFromIP(ipnet.IP)) + n.ipams[i].PutIP(ipnet.IP) } - return &ns.lo0, nil } -func (ns *localNamespace) Interfaces() ([]Interface, error) { - return []Interface{&ns.lo0}, nil +func (n *LocalNetwork) lookup(ip net.IP) *localInterface { + return n.routes[addrFromIP(ip)] } -func (ns *localNamespace) bindInet4(sock *localSocket, addr *SockaddrInet4) error { - switch { - case isUnspecifiedInet4(addr), isLoopbackInet4(addr): - return ns.lo0.bindInet4(sock, addr) - default: - return EADDRNOTAVAIL +func addrFromIP(ip net.IP) netip.Addr { + if ipv4 := ip.To4(); ipv4 != nil { + return netip.AddrFrom4(([4]byte)(ipv4)) + } else { + return netip.AddrFrom16(([16]byte)(ip)) } } -func (ns *localNamespace) bindInet6(sock *localSocket, addr *SockaddrInet6) error { - switch { - case isUnspecifiedInet6(addr), isLoopbackInet6(addr): - return ns.lo0.bindInet6(sock, addr) - default: - return EADDRNOTAVAIL +type LocalNamespace struct { + network atomic.Pointer[LocalNetwork] + host Namespace + + dial func(context.Context, string, string) (net.Conn, error) + listen func(context.Context, string, string) (net.Listener, error) + listenPacket func(context.Context, string, string) (net.PacketConn, error) + + lo0 localInterface + en0 localInterface +} + +func (ns *LocalNamespace) Detach() { + if n := ns.network.Swap(nil); n != nil { + n.mutex.Lock() + n.detach(&ns.en0) + n.mutex.Unlock() } } -func (ns *localNamespace) lookupInet4(sock *localSocket, addr *SockaddrInet4) (*localSocket, error) { - switch { - case isUnspecifiedInet4(addr), isLoopbackInet4(addr): - return ns.lo0.lookupInet4(sock, addr) +func (ns *LocalNamespace) InterfaceByIndex(index int) (Interface, error) { + switch index { + case ns.lo0.index: + return &ns.lo0, nil + case ns.en0.index: + return &ns.en0, nil default: - return nil, ENETUNREACH + return nil, errInterfaceIndexNotFound(index) } } -func (ns *localNamespace) lookupInet6(sock *localSocket, addr *SockaddrInet6) (*localSocket, error) { - switch { - case isUnspecifiedInet6(addr), isLoopbackInet6(addr): - return ns.lo0.lookupInet6(sock, addr) +func (ns *LocalNamespace) InterfaceByName(name string) (Interface, error) { + switch name { + case ns.lo0.name: + return &ns.lo0, nil + case ns.en0.name: + return &ns.en0, nil default: - return nil, ENETUNREACH + return nil, errInterfaceNameNotFound(name) } } -func (ns *localNamespace) unlinkInet4(sock *localSocket, addr *SockaddrInet4) { - ns.lo0.unlinkInet4(sock, addr) +func (ns *LocalNamespace) Interfaces() ([]Interface, error) { + return []Interface{&ns.lo0, &ns.en0}, nil } -func (ns *localNamespace) unlinkInet6(sock *localSocket, addr *SockaddrInet6) { - ns.lo0.unlinkInet6(sock, addr) +func (ns *LocalNamespace) interfaces() []*localInterface { + return []*localInterface{&ns.lo0, &ns.en0} } -type localAddress struct { - proto Protocol - port uint16 +func (ns *LocalNamespace) bindInet4(sock *localSocket, addr *SockaddrInet4) error { + return ns.bind(sock, addr.Addr[:], addr.Port) } -func localSockaddrInet4(proto Protocol, addr *SockaddrInet4) localAddress { - return localAddress{ - proto: proto, - port: uint16(addr.Port), - } +func (ns *LocalNamespace) bindInet6(sock *localSocket, addr *SockaddrInet6) error { + return ns.bind(sock, addr.Addr[:], addr.Port) } -func localSockaddrInet6(proto Protocol, addr *SockaddrInet6) localAddress { - return localAddress{ - proto: proto, - port: uint16(addr.Port), +func (ns *LocalNamespace) bind(sock *localSocket, addr net.IP, port int) error { + for _, iface := range ns.interfaces() { + if err := iface.bind(sock, addr, port); err != nil { + return err + } } + return nil } -type localSocketTable struct { - sockets map[localAddress]*localSocket +func (ns *LocalNamespace) lookupInet4(sock *localSocket, addr *SockaddrInet4) (*localSocket, error) { + return ns.lookup(sock, addr.Addr[:], addr.Port) } -func (t *localSocketTable) bind(sock *localSocket, addr localAddress) (int, error) { - if addr.port != 0 { - if _, exist := t.sockets[addr]; exist { - // TODO: - // - SO_REUSEADDR - // - SO_REUSEPORT - return -1, EADDRNOTAVAIL - } - } else { - var port int - for port = 49152; port <= 65535; port++ { - addr.port = uint16(port) - if _, exist := t.sockets[addr]; !exist { - break - } - } - if port == 65535 { - return -1, EADDRNOTAVAIL +func (ns *LocalNamespace) lookupInet6(sock *localSocket, addr *SockaddrInet6) (*localSocket, error) { + return ns.lookup(sock, addr.Addr[:], addr.Port) +} + +func (ns *LocalNamespace) lookup(sock *localSocket, addr net.IP, port int) (*localSocket, error) { + for _, iface := range ns.interfaces() { + if peer := iface.lookup(sock, addr, port); peer != nil { + return peer, nil } } - if t.sockets == nil { - t.sockets = make(map[localAddress]*localSocket) + + if addr.IsUnspecified() { + return nil, EHOSTUNREACH + } + + n := ns.network.Load() + if n == nil { + return nil, ENETUNREACH } - t.sockets[addr] = sock - return int(addr.port), nil + + n.mutex.RLock() + iface := n.lookup(addr) + n.mutex.RUnlock() + + if iface != nil { + if peer := iface.lookup(sock, addr, port); peer != nil { + return peer, nil + } + return nil, EHOSTUNREACH + } + return nil, ENETUNREACH } -func (t *localSocketTable) unlink(sock *localSocket, addr localAddress) { - if t.sockets[addr] == sock { - delete(t.sockets, addr) +func (ns *LocalNamespace) unlinkInet4(sock *localSocket, addr *SockaddrInet4) { + ns.unlink(sock, addr.Addr[:], addr.Port) +} + +func (ns *LocalNamespace) unlinkInet6(sock *localSocket, addr *SockaddrInet6) { + ns.unlink(sock, addr.Addr[:], addr.Port) +} + +func (ns *LocalNamespace) unlink(sock *localSocket, addr net.IP, port int) { + for _, iface := range ns.interfaces() { + iface.unlink(sock, addr, port) } } +type localPort struct { + proto Protocol + port uint16 +} + type localInterface struct { + index int + name string + haddr net.HardwareAddr + flags net.Flags + addrs []net.IPNet + mutex sync.RWMutex - ipv4 localSocketTable - ipv6 localSocketTable + ports []map[localPort]*localSocket } func (i *localInterface) Index() int { - return 0 + return i.index } func (i *localInterface) MTU() int { @@ -176,105 +281,145 @@ func (i *localInterface) MTU() int { } func (i *localInterface) Name() string { - return "lo0" + return i.name } func (i *localInterface) HardwareAddr() net.HardwareAddr { - return net.HardwareAddr{} + return i.haddr } func (i *localInterface) Flags() net.Flags { - return net.FlagUp | net.FlagLoopback + return i.flags } func (i *localInterface) Addrs() ([]net.Addr, error) { - return localAddrs[:], nil + addrs := make([]net.Addr, len(i.addrs)) + for j := range addrs { + addrs[j] = &i.addrs[j] + } + return addrs, nil } func (i *localInterface) MulticastAddrs() ([]net.Addr, error) { return nil, nil } -func (i *localInterface) bindInet4(sock *localSocket, addr *SockaddrInet4) error { - link := localSockaddrInet4(sock.protocol, addr) - name := &SockaddrInet4{Addr: addr.Addr} +func (i *localInterface) bind(sock *localSocket, addr net.IP, port int) error { + link := localPort{sock.protocol, uint16(port)} + ipv4 := addr.To4() + name := (Sockaddr)(nil) + + if ipv4 != nil { + name = &SockaddrInet4{Addr: ([4]byte)(ipv4), Port: port} + } else { + name = &SockaddrInet6{Addr: ([16]byte)(addr), Port: port} + } i.mutex.Lock() defer i.mutex.Unlock() - port, err := i.ipv4.bind(sock, link) - if err != nil { - return err + if link.port != 0 { + for j, a := range i.addrs { + if !socketAndInterfaceMatch(addr, a.IP) { + continue + } + if _, used := i.ports[j][link]; used { + // TODO: + // - SO_REUSEADDR + // - SO_REUSEPORT + return EADDRNOTAVAIL + } + } + } else { + var port int + searchFreePort: + for port = 49152; port <= 65535; port++ { + link.port = uint16(port) + for j, a := range i.addrs { + if !socketAndInterfaceMatch(addr, a.IP) { + continue + } + if _, used := i.ports[j][link]; used { + continue searchFreePort + } + } + break + } + if port == 65535 { + return EADDRNOTAVAIL + } + switch a := name.(type) { + case *SockaddrInet4: + a.Port = port + case *SockaddrInet6: + a.Port = port + } } - name.Port = port - sock.name.Store(name) - sock.state.set(bound) - return nil -} - -func (i *localInterface) bindInet6(sock *localSocket, addr *SockaddrInet6) error { - link := localSockaddrInet6(sock.protocol, addr) - name := &SockaddrInet6{Addr: addr.Addr} + for j := range i.ports { + a := i.addrs[j] + p := i.ports[j] - i.mutex.Lock() - defer i.mutex.Unlock() - - port, err := i.ipv6.bind(sock, link) - if err != nil { - return err + if socketAndInterfaceMatch(addr, a.IP) { + if p == nil { + p = make(map[localPort]*localSocket) + i.ports[j] = p + } + p[link] = sock + } } - name.Port = port sock.name.Store(name) sock.state.set(bound) return nil } -func (i *localInterface) lookupInet4(sock *localSocket, addr *SockaddrInet4) (*localSocket, error) { - link := localSockaddrInet4(sock.protocol, addr) - +func (i *localInterface) lookup(sock *localSocket, addr net.IP, port int) *localSocket { i.mutex.RLock() defer i.mutex.RUnlock() - if peer := i.ipv4.sockets[link]; peer != nil { - return peer, nil - } - return nil, EHOSTUNREACH -} - -func (i *localInterface) lookupInet6(sock *localSocket, addr *SockaddrInet6) (*localSocket, error) { - link := localSockaddrInet6(sock.protocol, addr) + for j := range i.addrs { + a := i.addrs[j] + p := i.ports[j] - i.mutex.RLock() - defer i.mutex.RUnlock() - - if peer := i.ipv6.sockets[link]; peer != nil { - return peer, nil + if socketAndInterfaceMatch(addr, a.IP) { + link := localPort{sock.protocol, uint16(port)} + peer := p[link] + if peer != nil { + return peer + } + } } - return nil, EHOSTUNREACH -} -func (i *localInterface) unlinkInet4(sock *localSocket, addr *SockaddrInet4) { - link := localSockaddrInet4(sock.protocol, addr) + return nil +} +func (i *localInterface) unlink(sock *localSocket, addr net.IP, port int) { i.mutex.Lock() defer i.mutex.Unlock() - i.ipv4.unlink(sock, link) + for j := range i.addrs { + a := i.addrs[j] + p := i.ports[j] - sock.name.Store(&sockaddrInet4Any) - sock.state.unset(bound) + if socketAndInterfaceMatch(addr, a.IP) { + link := localPort{sock.protocol, uint16(port)} + peer := p[link] + if sock == peer { + delete(p, link) + } + } + } } -func (i *localInterface) unlinkInet6(sock *localSocket, addr *SockaddrInet6) { - link := localSockaddrInet6(sock.protocol, addr) - - i.mutex.Lock() - defer i.mutex.Unlock() - - i.ipv6.unlink(sock, link) +func socketAndInterfaceMatch(sock, iface net.IP) bool { + return (sock.IsUnspecified() && family(sock) == family(iface)) || sock.Equal(iface) +} - sock.name.Store(&sockaddrInet6Any) - sock.state.unset(bound) +func family(ip net.IP) Family { + if ip.To4() != nil { + return INET + } else { + return INET6 + } } diff --git a/internal/network/local_test.go b/internal/network/local_test.go index 401d99ff..8b52976e 100644 --- a/internal/network/local_test.go +++ b/internal/network/local_test.go @@ -11,47 +11,69 @@ import ( func TestLocalNetwork(t *testing.T) { tests := []struct { scenario string - function func(*testing.T, network.Namespace) + function func(*testing.T, *network.LocalNetwork) }{ { - scenario: "a local network namespace has one interface", - function: testLocalNetworkInterface, + scenario: "a local network namespace has two interfaces", + function: testLocalNetworkInterfaces, }, { - scenario: "ipv4 sockets can connect to one another on the loopback interface", - function: testNamespaceConnectLoopbackIPv4, + scenario: "ipv4 sockets can connect to one another on a loopback interface", + function: testLocalNetworkConnectLoopbackIPv4, }, { - scenario: "ipv6 sockets can connect to one another on the loopback interface", - function: testNamespaceConnectLoopbackIPv6, + scenario: "ipv6 sockets can connect to one another on a loopback interface", + function: testLocalNetworkConnectLoopbackIPv6, }, { - scenario: "ipv4 sockets can exchange datagrams on the loopback interface", - function: testNamespaceExchangeDatagramLoopbackIPv4, + scenario: "ipv4 sockets can connect to one another on a network interface", + function: testLocalNetworkConnectInterfaceIPv4, }, { - scenario: "ipv6 sockets can exchange datagrams on the loopback interface", - function: testNamespaceExchangeDatagramLoopbackIPv6, + scenario: "ipv6 sockets can connect to one another on a network interface", + function: testLocalNetworkConnectInterfaceIPv6, + }, + + { + scenario: "ipv4 sockets in different namespaces can connect to one another", + function: testLocalNetworkConnectNamespacesIPv4, + }, + + { + scenario: "ipv6 sockets in different namespaces can connect to one another", + function: testLocalNetworkConnectNamespacesIPv6, }, } + ipv4, ipnet4, err := net.ParseCIDR("192.168.0.1/24") + assert.OK(t, err) + + ipv6, ipnet6, err := net.ParseCIDR("fe80::1/64") + assert.OK(t, err) + + ipnet4.IP = ipv4 + ipnet6.IP = ipv6 + for _, test := range tests { t.Run(test.scenario, func(t *testing.T) { test.function(t, - network.NewLocalNamespace(nil), + network.NewLocalNetwork(ipnet4, ipnet6), ) }) } } -func testLocalNetworkInterface(t *testing.T, ns network.Namespace) { +func testLocalNetworkInterfaces(t *testing.T, n *network.LocalNetwork) { + ns, err := n.CreateNamespace(nil) + assert.OK(t, err) + ifaces, err := ns.Interfaces() assert.OK(t, err) - assert.Equal(t, len(ifaces), 1) + assert.Equal(t, len(ifaces), 2) lo0 := ifaces[0] assert.Equal(t, lo0.Index(), 0) @@ -62,6 +84,131 @@ func testLocalNetworkInterface(t *testing.T, ns network.Namespace) { lo0Addrs, err := lo0.Addrs() assert.OK(t, err) assert.Equal(t, len(lo0Addrs), 2) - assert.Equal(t, lo0Addrs[0].String(), "127.0.0.1") - assert.Equal(t, lo0Addrs[1].String(), "::1") + assert.Equal(t, lo0Addrs[0].String(), "127.0.0.1/8") + assert.Equal(t, lo0Addrs[1].String(), "::1/128") + + en0 := ifaces[1] + assert.Equal(t, en0.Index(), 1) + assert.Equal(t, en0.MTU(), 1500) + assert.Equal(t, en0.Name(), "en0") + assert.Equal(t, en0.Flags(), net.FlagUp) + + en0Addrs, err := en0.Addrs() + assert.OK(t, err) + assert.Equal(t, len(en0Addrs), 2) + assert.Equal(t, en0Addrs[0].String(), "192.168.0.1/24") + assert.Equal(t, en0Addrs[1].String(), "fe80::1/64") +} + +func testLocalNetworkConnectLoopbackIPv4(t *testing.T, n *network.LocalNetwork) { + testLocalNetworkConnect(t, n, &network.SockaddrInet4{ + Addr: [4]byte{127, 0, 0, 1}, + Port: 80, + }) +} + +func testLocalNetworkConnectLoopbackIPv6(t *testing.T, n *network.LocalNetwork) { + testLocalNetworkConnect(t, n, &network.SockaddrInet6{ + Addr: [16]byte{15: 1}, + Port: 80, + }) +} + +func testLocalNetworkConnectInterfaceIPv4(t *testing.T, n *network.LocalNetwork) { + testLocalNetworkConnect(t, n, &network.SockaddrInet4{ + Addr: [4]byte{192, 168, 0, 1}, + Port: 80, + }) +} + +func testLocalNetworkConnectInterfaceIPv6(t *testing.T, n *network.LocalNetwork) { + testLocalNetworkConnect(t, n, &network.SockaddrInet6{ + Addr: [16]byte{0: 0xfe, 1: 0x80, 15: 1}, + Port: 80, + }) +} + +func testLocalNetworkConnect(t *testing.T, n *network.LocalNetwork, bind network.Sockaddr) { + ns, err := n.CreateNamespace(nil) + assert.OK(t, err) + testNamespaceConnect(t, ns, bind) +} + +func testLocalNetworkConnectNamespacesIPv4(t *testing.T, n *network.LocalNetwork) { + testLocalNetworkConnectNamespaces(t, n, network.INET) +} + +func testLocalNetworkConnectNamespacesIPv6(t *testing.T, n *network.LocalNetwork) { + testLocalNetworkConnectNamespaces(t, n, network.INET6) +} + +func testLocalNetworkConnectNamespaces(t *testing.T, n *network.LocalNetwork, family network.Family) { + ns1, err := n.CreateNamespace(nil) + assert.OK(t, err) + + ns2, err := n.CreateNamespace(nil) + assert.OK(t, err) + + ifaces1, err := ns1.Interfaces() + assert.OK(t, err) + assert.Equal(t, len(ifaces1), 2) + + addrs1, err := ifaces1[1].Addrs() + assert.OK(t, err) + + server, err := ns1.Socket(family, network.STREAM, network.TCP) + assert.OK(t, err) + defer server.Close() + + assert.OK(t, server.Listen(1)) + serverAddr, err := server.Name() + assert.OK(t, err) + + switch a := serverAddr.(type) { + case *network.SockaddrInet4: + for _, addr := range addrs1 { + if ipnet := addr.(*net.IPNet); ipnet.IP.To4() != nil { + copy(a.Addr[:], ipnet.IP.To4()) + } + } + case *network.SockaddrInet6: + for _, addr := range addrs1 { + if ipnet := addr.(*net.IPNet); ipnet.IP.To4() == nil { + copy(a.Addr[:], ipnet.IP) + } + } + } + + client, err := ns2.Socket(family, network.STREAM, network.TCP) + assert.OK(t, err) + defer client.Close() + + assert.Error(t, client.Connect(serverAddr), network.EINPROGRESS) + assert.OK(t, waitReadyRead(server)) + + conn, addr, err := server.Accept() + assert.OK(t, err) + defer conn.Close() + + assert.OK(t, waitReadyWrite(client)) + peer, err := client.Peer() + assert.OK(t, err) + assert.Equal(t, network.SockaddrAddrPort(peer), network.SockaddrAddrPort(serverAddr)) + + name, err := client.Name() + assert.OK(t, err) + assert.Equal(t, network.SockaddrAddrPort(name), network.SockaddrAddrPort(addr)) + + wn, err := client.SendTo([][]byte{[]byte("Hello, World!")}, nil, 0) + assert.OK(t, err) + assert.Equal(t, wn, 13) + + assert.OK(t, waitReadyRead(conn)) + buf := make([]byte, 32) + rn, rflags, peer, err := conn.RecvFrom([][]byte{buf}, 0) + assert.OK(t, err) + assert.Equal(t, rn, 13) + assert.Equal(t, rflags, 0) + assert.Equal(t, string(buf[:13]), "Hello, World!") + assert.Equal(t, peer, nil) } diff --git a/internal/network/local_unix.go b/internal/network/local_unix.go index eb9a92dd..54200e6b 100644 --- a/internal/network/local_unix.go +++ b/internal/network/local_unix.go @@ -7,7 +7,7 @@ import ( "golang.org/x/sys/unix" ) -func (ns *localNamespace) Socket(family Family, socktype Socktype, protocol Protocol) (Socket, error) { +func (ns *LocalNamespace) Socket(family Family, socktype Socktype, protocol Protocol) (Socket, error) { switch family { case INET, INET6: default: @@ -44,16 +44,12 @@ func (state *localSocketState) set(s localSocketState) { *state |= s } -func (state *localSocketState) unset(s localSocketState) { - *state &= ^s -} - const ( addrBufSize = 20 ) type localSocket struct { - ns *localNamespace + ns *LocalNamespace fd0 socketFD fd1 socketFD family Family @@ -125,15 +121,6 @@ func (s *localSocket) bindAny() error { } } -func (s *localSocket) bindLoopback() error { - switch s.family { - case INET: - return s.ns.bindInet4(s, &sockaddrInet4Loopback) - default: - return s.ns.bindInet6(s, &sockaddrInet6Loopback) - } -} - func (s *localSocket) Listen(backlog int) error { fd := s.fd0.acquire() if fd < 0 { @@ -174,7 +161,7 @@ func (s *localSocket) Connect(addr Sockaddr) error { return EISCONN } if !s.state.is(bound) { - if err := s.bindLoopback(); err != nil { + if err := s.bindAny(); err != nil { return err } } @@ -340,7 +327,7 @@ func (s *localSocket) RecvFrom(iovs [][]byte, flags int) (int, int, Sockaddr, er return -1, 0, nil, EINVAL } if !s.state.is(bound) { - if err := s.bindLoopback(); err != nil { + if err := s.bindAny(); err != nil { return -1, 0, nil, err } } @@ -357,18 +344,17 @@ func (s *localSocket) RecvFrom(iovs [][]byte, flags int) (int, int, Sockaddr, er // implementing recvfrom(2) and using a cached socket address for connected // sockets. for { - n, _, rflags, addr, err := unix.RecvmsgBuffers(fd, iovs, nil, flags) + n, _, rflags, _, err := unix.RecvmsgBuffers(fd, iovs, nil, flags) if err == EINTR { if n == 0 { continue } err = nil } + var addr Sockaddr if s.socktype == DGRAM { addr = decodeSockaddr(s.addrBuf) n -= addrBufSize - } else { - addr = nil } return n, rflags, addr, err } @@ -391,7 +377,7 @@ func (s *localSocket) SendTo(iovs [][]byte, addr Sockaddr, flags int) (int, erro return -1, ENOTCONN } if !s.state.is(bound) { - if err := s.bindLoopback(); err != nil { + if err := s.bindAny(); err != nil { return -1, err } } @@ -488,15 +474,23 @@ func encodeSockaddrInet6(addr *SockaddrInet6) (buf [addrBufSize]byte) { func decodeSockaddr(buf [addrBufSize]byte) Sockaddr { switch Family(binary.LittleEndian.Uint16(buf[0:2])) { case INET: - return &SockaddrInet4{ - Port: int(binary.LittleEndian.Uint16(buf[2:4])), - Addr: ([4]byte)(buf[4:]), - } + return decodeSockaddrInet4(buf) default: - return &SockaddrInet6{ - Port: int(binary.LittleEndian.Uint16(buf[2:4])), - Addr: ([16]byte)(buf[4:]), - } + return decodeSockaddrInet6(buf) + } +} + +func decodeSockaddrInet4(buf [addrBufSize]byte) *SockaddrInet4 { + return &SockaddrInet4{ + Port: int(binary.LittleEndian.Uint16(buf[2:4])), + Addr: ([4]byte)(buf[4:]), + } +} + +func decodeSockaddrInet6(buf [addrBufSize]byte) *SockaddrInet6 { + return &SockaddrInet6{ + Port: int(binary.LittleEndian.Uint16(buf[2:4])), + Addr: ([16]byte)(buf[4:]), } } diff --git a/internal/network/network.go b/internal/network/network.go index b8fe8f21..7c5ddfe2 100644 --- a/internal/network/network.go +++ b/internal/network/network.go @@ -124,14 +124,6 @@ func isUnspecifiedInet6(sa *SockaddrInet6) bool { return sa.Addr == [16]byte{} } -func isLoopbackInet4(sa *SockaddrInet4) bool { - return sa.Addr == [4]byte{127, 0, 0, 1} -} - -func isLoopbackInet6(sa *SockaddrInet6) bool { - return sa.Addr == [16]byte{15: 1} -} - func errInterfaceIndexNotFound(index int) error { return fmt.Errorf("index=%d: %w", index, ErrInterfaceNotFound) } @@ -143,7 +135,4 @@ func errInterfaceNameNotFound(name string) error { var ( sockaddrInet4Any SockaddrInet4 sockaddrInet6Any SockaddrInet6 - - sockaddrInet4Loopback = SockaddrInet4{Addr: [4]byte{127, 0, 0, 1}} - sockaddrInet6Loopback = SockaddrInet6{Addr: [16]byte{15: 1}} ) diff --git a/internal/network/network_test.go b/internal/network/network_test.go index 9944afb9..6690dc4f 100644 --- a/internal/network/network_test.go +++ b/internal/network/network_test.go @@ -2,6 +2,7 @@ package network_test import ( "testing" + "time" "github.com/stealthrocket/timecraft/internal/assert" "github.com/stealthrocket/timecraft/internal/network" @@ -123,3 +124,11 @@ func testNamespaceExchangeDatagram(t *testing.T, ns network.Namespace, bind netw assert.Equal(t, string(buf[:11]), "How are you") assert.Equal(t, network.SockaddrAddrPort(addr), network.SockaddrAddrPort(addr2)) } + +func waitReadyRead(socket network.Socket) error { + return network.WaitReadyRead(socket, time.Second) +} + +func waitReadyWrite(socket network.Socket) error { + return network.WaitReadyWrite(socket, time.Second) +} diff --git a/internal/network/virtual_test.go b/internal/network/virtual_test.go index 9117e95b..ff245f3d 100644 --- a/internal/network/virtual_test.go +++ b/internal/network/virtual_test.go @@ -3,7 +3,6 @@ package network_test import ( "net" "testing" - "time" "github.com/stealthrocket/timecraft/internal/assert" "github.com/stealthrocket/timecraft/internal/network" @@ -187,11 +186,3 @@ func testVirtualNetworkConnectNamespaces(t *testing.T, n *network.VirtualNetwork assert.Equal(t, string(buf[:13]), "Hello, World!") assert.Equal(t, peer, nil) } - -func waitReadyRead(socket network.Socket) error { - return network.WaitReadyRead(socket, time.Second) -} - -func waitReadyWrite(socket network.Socket) error { - return network.WaitReadyWrite(socket, time.Second) -} From 8f1ae296d68151a6323378b0ce2562b712ba2609 Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Fri, 14 Jul 2023 00:18:04 -0700 Subject: [PATCH 05/16] network: remove virtual network Signed-off-by: Achille Roussel --- internal/network/virtual.go | 864 ------------------------------- internal/network/virtual_test.go | 188 ------- 2 files changed, 1052 deletions(-) delete mode 100644 internal/network/virtual.go delete mode 100644 internal/network/virtual_test.go diff --git a/internal/network/virtual.go b/internal/network/virtual.go deleted file mode 100644 index d0ed5c47..00000000 --- a/internal/network/virtual.go +++ /dev/null @@ -1,864 +0,0 @@ -package network - -import ( - "errors" - "net" - "sync" - "sync/atomic" - - "github.com/stealthrocket/timecraft/internal/ipam" -) - -var ( - errIPv4Exhausted = errors.New("ipv4 pool exhausted") - errIPv6Exhausted = errors.New("ipv6 pool exhausted") -) - -type VirtualNetwork struct { - mutex sync.RWMutex - host Namespace - ipnet4 *net.IPNet - ipnet6 *net.IPNet - ipam4 ipam.IPv4Pool - ipam6 ipam.IPv6Pool - iface4 map[ipam.IPv4]*virtualInterface - iface6 map[ipam.IPv6]*virtualInterface -} - -func NewVirtualNetwork(host Namespace, ipnet4, ipnet6 *net.IPNet) *VirtualNetwork { - n := &VirtualNetwork{ - host: host, - ipnet4: ipnet4, - ipnet6: ipnet6, - iface4: make(map[ipam.IPv4]*virtualInterface), - iface6: make(map[ipam.IPv6]*virtualInterface), - } - if ip := ipnet4.IP.To4(); ip != nil { - ones, _ := ipnet4.Mask.Size() - n.ipam4.Reset((ipam.IPv4)(ip), ones) - } - if ip := ipnet6.IP.To16(); ip != nil { - ones, _ := ipnet6.Mask.Size() - n.ipam6.Reset((ipam.IPv6)(ip), ones) - } - n.ipam4.Get() - n.ipam6.Get() - return n -} - -func (n *VirtualNetwork) CreateNamespace() (*VirtualNamespace, error) { - loIPv4 := ipam.IPv4{127, 0, 0, 1} - loIPv6 := ipam.IPv6{15: 1} - - ns := &VirtualNamespace{ - lo0: virtualInterface{ - index: 0, - name: "lo0", - ipv4: loIPv4, - ipv6: loIPv6, - flags: net.FlagUp | net.FlagLoopback, - }, - en0: virtualInterface{ - index: 1, - name: "en0", - flags: net.FlagUp, - }, - } - - ns.net.Store(n) - - n.mutex.Lock() - defer n.mutex.Unlock() - - var ok bool - ns.en0.ipv4, ok = n.ipam4.Get() - if !ok { - return nil, errIPv4Exhausted - } - ns.en0.ipv6, ok = n.ipam6.Get() - if !ok { - n.ipam4.Put(ns.en0.ipv4) - return nil, errIPv6Exhausted - } - - n.iface4[ns.en0.ipv4] = &ns.en0 - n.iface6[ns.en0.ipv6] = &ns.en0 - return ns, nil -} - -func (n *VirtualNetwork) Contains(ip net.IP) bool { - return n.ipnet4.Contains(ip) || n.ipnet6.Contains(ip) -} - -func (n *VirtualNetwork) containsIPv4(addr ipam.IPv4) bool { - return n.ipnet4.Contains(addr[:]) -} - -func (n *VirtualNetwork) containsIPv6(addr ipam.IPv6) bool { - return n.ipnet6.Contains(addr[:]) -} - -func (n *VirtualNetwork) lookupIPv4Interface(addr ipam.IPv4) *virtualInterface { - n.mutex.RLock() - defer n.mutex.RUnlock() - return n.iface4[addr] -} - -func (n *VirtualNetwork) lookupIPv6Interface(addr ipam.IPv6) *virtualInterface { - n.mutex.RLock() - defer n.mutex.RUnlock() - return n.iface6[addr] -} - -func (n *VirtualNetwork) detachInterface(vi *virtualInterface) { - n.mutex.Lock() - defer n.mutex.Unlock() - if n.iface4[vi.ipv4] == vi { - delete(n.iface4, vi.ipv4) - n.ipam4.Put(vi.ipv4) - } - if n.iface6[vi.ipv6] == vi { - delete(n.iface6, vi.ipv6) - n.ipam6.Put(vi.ipv6) - } -} - -type VirtualNamespace struct { - net atomic.Pointer[VirtualNetwork] - lo0 virtualInterface - en0 virtualInterface -} - -func (ns *VirtualNamespace) Detach() { - if n := ns.net.Swap(nil); n != nil { - n.detachInterface(&ns.en0) - } -} - -func (ns *VirtualNamespace) InterfaceByIndex(index int) (Interface, error) { - switch index { - case ns.lo0.index: - return &ns.lo0, nil - case ns.en0.index: - return &ns.en0, nil - default: - return nil, errInterfaceIndexNotFound(index) - } -} - -func (ns *VirtualNamespace) InterfaceByName(name string) (Interface, error) { - switch name { - case ns.lo0.name: - return &ns.lo0, nil - case ns.en0.name: - return &ns.en0, nil - default: - return nil, errInterfaceNameNotFound(name) - } -} - -func (ns *VirtualNamespace) Interfaces() ([]Interface, error) { - return []Interface{&ns.lo0, &ns.en0}, nil -} - -func (ns *VirtualNamespace) Socket(family Family, socktype Socktype, protocol Protocol) (Socket, error) { - n := ns.net.Load() - if n == nil { - return nil, EAFNOSUPPORT - } - socket, err := n.host.Socket(family, socktype, protocol) - if err != nil { - return nil, err - } - switch family { - case INET, INET6: - socket = &virtualSocket{ - ns: ns, - base: socket, - proto: protocol, - } - } - return socket, nil -} - -func (ns *VirtualNamespace) bindInet4(socket *virtualSocket, host, addr *SockaddrInet4) error { - switch addr.Addr { - case [4]byte{}: // unspecified - if err := ns.lo0.bindInet4(socket, host, addr); err != nil { - return err - } - if err := ns.en0.bindInet4(socket, host, addr); err != nil { - return err - } - return nil - case ns.lo0.ipv4: - return ns.lo0.bindInet4(socket, host, addr) - case ns.en0.ipv4: - return ns.en0.bindInet4(socket, host, addr) - default: - return EADDRNOTAVAIL - } -} - -func (ns *VirtualNamespace) bindInet6(socket *virtualSocket, host, addr *SockaddrInet6) error { - switch addr.Addr { - case [16]byte{}: // unspecified - if err := ns.lo0.bindInet6(socket, host, addr); err != nil { - return err - } - if err := ns.en0.bindInet6(socket, host, addr); err != nil { - return err - } - return nil - case ns.lo0.ipv6: - return ns.lo0.bindInet6(socket, host, addr) - case ns.en0.ipv6: - return ns.en0.bindInet6(socket, host, addr) - default: - return EADDRNOTAVAIL - } -} - -func (ns *VirtualNamespace) lookupByHostInet4(socket *virtualSocket, host *SockaddrInet4) *virtualSocket { - if peer := ns.lo0.lookupByHostInet4(socket, host); peer != nil { - return peer - } - if peer := ns.en0.lookupByHostInet4(socket, host); peer != nil { - return peer - } - if n := ns.net.Load(); n != nil { - n.mutex.RLock() - defer n.mutex.RUnlock() - - for _, i := range n.iface4 { - if i == &ns.en0 { - continue - } - if peer := i.lookupByHostInet4(socket, host); peer != nil { - return peer - } - } - } - return nil -} - -func (ns *VirtualNamespace) lookupByHostInet6(socket *virtualSocket, host *SockaddrInet6) *virtualSocket { - if peer := ns.lo0.lookupByHostInet6(socket, host); peer != nil { - return peer - } - if peer := ns.en0.lookupByHostInet6(socket, host); peer != nil { - return peer - } - if n := ns.net.Load(); n != nil { - n.mutex.RLock() - defer n.mutex.RUnlock() - - for _, i := range n.iface6 { - if i == &ns.en0 { - continue - } - if peer := i.lookupByHostInet6(socket, host); peer != nil { - return peer - } - } - } - return nil -} - -func (ns *VirtualNamespace) lookupByAddrInet4(socket *virtualSocket, addr *SockaddrInet4) (*virtualSocket, error) { - if isUnspecifiedInet4(addr) { - return ns.lookupByPortInet4(socket, addr) - } - var peer *virtualSocket - switch addr.Addr { - case ns.lo0.ipv4: - peer = ns.lo0.lookupByAddrInet4(socket, addr) - case ns.en0.ipv4: - peer = ns.en0.lookupByAddrInet4(socket, addr) - default: - n := ns.net.Load() - if n == nil { - return nil, ENETUNREACH - } - if !n.containsIPv4(addr.Addr) { - return nil, nil - } - iface := n.lookupIPv4Interface(addr.Addr) - if iface == nil { - return nil, EHOSTUNREACH - } - peer = iface.lookupByAddrInet4(socket, addr) - } - if peer != nil { - return peer, nil - } - return nil, EHOSTUNREACH -} - -func (ns *VirtualNamespace) lookupByAddrInet6(socket *virtualSocket, addr *SockaddrInet6) (*virtualSocket, error) { - if isUnspecifiedInet6(addr) { - return ns.lookupByPortInet6(socket, addr) - } - var peer *virtualSocket - switch addr.Addr { - case ns.lo0.ipv6: - peer = ns.lo0.lookupByAddrInet6(socket, addr) - case ns.en0.ipv6: - peer = ns.en0.lookupByAddrInet6(socket, addr) - default: - n := ns.net.Load() - if n == nil { - return nil, ENETUNREACH - } - if !n.containsIPv6(addr.Addr) { - return nil, nil - } - iface := n.lookupIPv6Interface(addr.Addr) - if iface == nil { - return nil, EHOSTUNREACH - } - peer = iface.lookupByAddrInet6(socket, addr) - } - if peer != nil { - return peer, nil - } - return nil, EHOSTUNREACH -} - -func (ns *VirtualNamespace) lookupByPortInet4(socket *virtualSocket, addr *SockaddrInet4) (*virtualSocket, error) { - if peer := ns.lo0.lookupByAddrInet4(socket, addr); peer != nil { - return peer, nil - } - if peer := ns.en0.lookupByAddrInet4(socket, addr); peer != nil { - return peer, nil - } - if n := ns.net.Load(); n != nil { - n.mutex.RLock() - defer n.mutex.RUnlock() - - for _, iface := range n.iface4 { - if peer := iface.lookupByAddrInet4(socket, addr); peer != nil { - return peer, nil - } - } - } - return nil, EHOSTUNREACH -} - -func (ns *VirtualNamespace) lookupByPortInet6(socket *virtualSocket, addr *SockaddrInet6) (*virtualSocket, error) { - if peer := ns.lo0.lookupByAddrInet6(socket, addr); peer != nil { - return peer, nil - } - if peer := ns.en0.lookupByAddrInet6(socket, addr); peer != nil { - return peer, nil - } - if n := ns.net.Load(); n != nil { - n.mutex.RLock() - defer n.mutex.RUnlock() - - for _, iface := range n.iface6 { - if peer := iface.lookupByAddrInet6(socket, addr); peer != nil { - return peer, nil - } - } - } - return nil, EHOSTUNREACH -} - -func (ns *VirtualNamespace) unlinkInet4(socket *virtualSocket, host, addr *SockaddrInet4) { - switch addr.Addr { - case [4]byte{}: - ns.lo0.unlinkInet4(socket, host, addr) - ns.en0.unlinkInet4(socket, host, addr) - case ns.lo0.ipv4: - ns.lo0.unlinkInet4(socket, host, addr) - case ns.en0.ipv4: - ns.en0.unlinkInet4(socket, host, addr) - } -} - -func (ns *VirtualNamespace) unlinkInet6(socket *virtualSocket, host, addr *SockaddrInet6) { - switch addr.Addr { - case [16]byte{}: - ns.lo0.unlinkInet6(socket, host, addr) - ns.en0.unlinkInet6(socket, host, addr) - case ns.lo0.ipv6: - ns.lo0.unlinkInet6(socket, host, addr) - case ns.en0.ipv6: - ns.en0.unlinkInet6(socket, host, addr) - } -} - -type virtualAddress struct { - proto Protocol - port uint16 -} - -func virtualSockaddrInet4(proto Protocol, addr *SockaddrInet4) virtualAddress { - return virtualAddress{ - proto: proto, - port: uint16(addr.Port), - } -} - -func virtualSockaddrInet6(proto Protocol, addr *SockaddrInet6) virtualAddress { - return virtualAddress{ - proto: proto, - port: uint16(addr.Port), - } -} - -type virtualAddressTable struct { - host map[virtualAddress]*virtualSocket - sock map[virtualAddress]*virtualSocket -} - -func (t *virtualAddressTable) bind(socket *virtualSocket, ha, sa virtualAddress) (int, error) { - if sa.port != 0 { - if _, exist := t.sock[sa]; exist { - // TODO: - // - SO_REUSEADDR - // - SO_REUSEPORT - return -1, EADDRNOTAVAIL - } - } else { - var port int - for port = 49152; port <= 65535; port++ { - sa.port = uint16(port) - if _, exist := t.sock[sa]; !exist { - break - } - } - if port == 65535 { - return -1, EADDRNOTAVAIL - } - } - if t.host == nil { - t.host = make(map[virtualAddress]*virtualSocket) - } - if t.sock == nil { - t.sock = make(map[virtualAddress]*virtualSocket) - } - t.host[ha] = socket - t.sock[sa] = socket - return int(sa.port), nil -} - -func (t *virtualAddressTable) unlink(socket *virtualSocket, ha, sa virtualAddress) { - if t.host[ha] == socket { - delete(t.host, ha) - } - if t.sock[sa] == socket { - delete(t.sock, sa) - } -} - -type virtualInterface struct { - index int - name string - ipv4 ipam.IPv4 - ipv6 ipam.IPv6 - haddr net.HardwareAddr - flags net.Flags - - mutex sync.RWMutex - inet4 virtualAddressTable - inet6 virtualAddressTable -} - -func (i *virtualInterface) Index() int { - return i.index -} - -func (i *virtualInterface) MTU() int { - return 1500 -} - -func (i *virtualInterface) Name() string { - return i.name -} - -func (i *virtualInterface) Addrs() ([]net.Addr, error) { - ipv4 := &net.IPAddr{IP: net.IP(i.ipv4[:])} - ipv6 := &net.IPAddr{IP: net.IP(i.ipv6[:])} - return []net.Addr{ipv4, ipv6}, nil -} - -func (i *virtualInterface) MulticastAddrs() ([]net.Addr, error) { - return nil, nil -} - -func (i *virtualInterface) HardwareAddr() net.HardwareAddr { - return i.haddr -} - -func (i *virtualInterface) Flags() net.Flags { - return i.flags -} - -func (i *virtualInterface) bindInet4(socket *virtualSocket, host, addr *SockaddrInet4) error { - hostAddr := virtualSockaddrInet4(socket.proto, host) - bindAddr := virtualSockaddrInet4(socket.proto, addr) - name := &SockaddrInet4{Addr: addr.Addr} - - i.mutex.Lock() - defer i.mutex.Unlock() - - port, err := i.inet4.bind(socket, hostAddr, bindAddr) - if err != nil { - return err - } - - name.Port = port - socket.host.Store(host) - socket.name.Store(name) - socket.bound = true - return nil -} - -func (i *virtualInterface) bindInet6(socket *virtualSocket, host, addr *SockaddrInet6) error { - hostAddr := virtualSockaddrInet6(socket.proto, host) - bindAddr := virtualSockaddrInet6(socket.proto, addr) - name := &SockaddrInet6{Addr: addr.Addr} - - i.mutex.Lock() - defer i.mutex.Unlock() - - port, err := i.inet6.bind(socket, hostAddr, bindAddr) - if err != nil { - return err - } - - name.Port = port - socket.host.Store(host) - socket.name.Store(name) - socket.bound = true - return nil -} - -func (i *virtualInterface) lookupByHostInet4(socket *virtualSocket, host *SockaddrInet4) *virtualSocket { - va := virtualSockaddrInet4(socket.proto, host) - i.mutex.RLock() - defer i.mutex.RUnlock() - return i.inet4.host[va] -} - -func (i *virtualInterface) lookupByHostInet6(socket *virtualSocket, host *SockaddrInet6) *virtualSocket { - va := virtualSockaddrInet6(socket.proto, host) - i.mutex.RLock() - defer i.mutex.RUnlock() - return i.inet6.host[va] -} - -func (i *virtualInterface) lookupByAddrInet4(socket *virtualSocket, addr *SockaddrInet4) *virtualSocket { - va := virtualSockaddrInet4(socket.proto, addr) - i.mutex.RLock() - defer i.mutex.RUnlock() - return i.inet4.sock[va] -} - -func (i *virtualInterface) lookupByAddrInet6(socket *virtualSocket, addr *SockaddrInet6) *virtualSocket { - va := virtualSockaddrInet6(socket.proto, addr) - i.mutex.RLock() - defer i.mutex.RUnlock() - return i.inet6.sock[va] -} - -func (i *virtualInterface) unlinkInet4(socket *virtualSocket, host, addr *SockaddrInet4) { - hostAddr := virtualSockaddrInet4(socket.proto, host) - sockAddr := virtualSockaddrInet4(socket.proto, addr) - - i.mutex.Lock() - defer i.mutex.Unlock() - - i.inet4.unlink(socket, hostAddr, sockAddr) - socket.host.Store(&sockaddrInet4Any) - socket.name.Store(&sockaddrInet4Any) - socket.bound = false -} - -func (i *virtualInterface) unlinkInet6(socket *virtualSocket, host, addr *SockaddrInet6) { - hostAddr := virtualSockaddrInet6(socket.proto, host) - sockAddr := virtualSockaddrInet6(socket.proto, addr) - - i.mutex.Lock() - defer i.mutex.Unlock() - - i.inet6.unlink(socket, hostAddr, sockAddr) - socket.host.Store(&sockaddrInet6Any) - socket.name.Store(&sockaddrInet6Any) - socket.bound = false -} - -type virtualSocket struct { - ns *VirtualNamespace - base Socket - host atomic.Value - name atomic.Value - peer atomic.Value - proto Protocol - bound bool -} - -func (s *virtualSocket) Family() Family { - return s.base.Family() -} - -func (s *virtualSocket) Type() Socktype { - return s.base.Type() -} - -func (s *virtualSocket) Fd() int { - return s.base.Fd() -} - -func (s *virtualSocket) Close() error { - if s.bound { - host := s.host.Load() - name := s.name.Load() - switch a := name.(type) { - case *SockaddrInet4: - s.ns.unlinkInet4(s, host.(*SockaddrInet4), a) - case *SockaddrInet6: - s.ns.unlinkInet6(s, host.(*SockaddrInet6), a) - } - } - return s.base.Close() -} - -func (s *virtualSocket) Bind(addr Sockaddr) error { - if s.name.Load() != nil { - return EINVAL - } - - switch addr.(type) { - case *SockaddrInet4: - if s.Family() != INET { - return EAFNOSUPPORT - } - _ = s.base.Bind(&sockaddrInet4Any) - case *SockaddrInet6: - if s.Family() != INET6 { - return EAFNOSUPPORT - } - _ = s.base.Bind(&sockaddrInet6Any) - default: - return EINVAL - } - - // The host socket was bound to a random port, we retrieve the address that - // it got associated with. - a, err := s.base.Name() - if err != nil { - return err - } - - switch host := a.(type) { - case *SockaddrInet4: - return s.ns.bindInet4(s, host, addr.(*SockaddrInet4)) - case *SockaddrInet6: - return s.ns.bindInet6(s, host, addr.(*SockaddrInet6)) - default: - return EAFNOSUPPORT - } -} - -func (s *virtualSocket) bindAny() error { - a, err := s.base.Name() - if err != nil { - return err - } - switch host := a.(type) { - case *SockaddrInet4: - return s.ns.bindInet4(s, host, &sockaddrInet4Any) - case *SockaddrInet6: - return s.ns.bindInet6(s, host, &sockaddrInet6Any) - default: - return EAFNOSUPPORT - } -} - -func (s *virtualSocket) Listen(backlog int) error { - if err := s.base.Listen(backlog); err != nil { - return err - } - if s.name.Load() == nil { - return s.bindAny() - } - return nil -} - -func (s *virtualSocket) Connect(addr Sockaddr) error { - var peer *virtualSocket - var err error - switch a := addr.(type) { - case *SockaddrInet4: - peer, err = s.ns.lookupByAddrInet4(s, a) - case *SockaddrInet6: - peer, err = s.ns.lookupByAddrInet6(s, a) - default: - return EAFNOSUPPORT - } - if err != nil { - return ECONNREFUSED - } - - connectAddr := addr - if peer != nil { - switch a := peer.host.Load().(type) { - case *SockaddrInet4: - connectAddr = a - case *SockaddrInet6: - connectAddr = a - default: - return ECONNREFUSED - } - } - - err = s.base.Connect(connectAddr) - if err != nil && err != EINPROGRESS { - return err - } - - s.peer.Store(addr) - - if s.name.Load() == nil { - if err := s.bindAny(); err != nil { - return err - } - } - return err -} - -func (s *virtualSocket) Accept() (Socket, Sockaddr, error) { - base, addr, err := s.base.Accept() - if err != nil { - return nil, nil, err - } - conn := &virtualSocket{ - ns: s.ns, - base: base, - } - var peer *virtualSocket - switch a := addr.(type) { - case *SockaddrInet4: - peer = s.ns.lookupByHostInet4(s, a) - case *SockaddrInet6: - peer = s.ns.lookupByHostInet6(s, a) - } - conn.host.Store(addr) - if peer != nil { - addr, _ = peer.name.Load().(Sockaddr) - } - conn.name.Store(s.name.Load()) - conn.peer.Store(addr) - return conn, addr, nil -} - -func (s *virtualSocket) Name() (Sockaddr, error) { - switch name := s.name.Load().(type) { - case *SockaddrInet4: - return name, nil - case *SockaddrInet6: - return name, nil - } - switch s.Family() { - case INET: - return &sockaddrInet4Any, nil - default: - return &sockaddrInet6Any, nil - } -} - -func (s *virtualSocket) Peer() (Sockaddr, error) { - switch peer := s.peer.Load().(type) { - case *SockaddrInet4: - return peer, nil - case *SockaddrInet6: - return peer, nil - } - return nil, ENOTCONN -} - -func (s *virtualSocket) RecvFrom(iovs [][]byte, flags int) (int, int, Sockaddr, error) { - n, flags, addr, err := s.base.RecvFrom(iovs, flags) - if err != nil { - return -1, 0, nil, err - } - if addr != nil { - addr, err = s.toVirtualAddr(addr) - } - return n, flags, addr, err -} - -func (s *virtualSocket) SendTo(iovs [][]byte, addr Sockaddr, flags int) (int, error) { - if addr != nil { - a, err := s.toHostAddr(addr) - if err != nil { - return -1, err - } - addr = a - } - n, err := s.base.SendTo(iovs, addr, flags) - if s.name.Load() == nil { - if err := s.bindAny(); err != nil { - return n, err - } - } - return n, err -} - -func (s *virtualSocket) Shutdown(how int) error { - return s.base.Shutdown(how) -} - -func (s *virtualSocket) SetOptInt(level, name, value int) error { - return s.base.SetOptInt(level, name, value) -} - -func (s *virtualSocket) GetOptInt(level, name int) (int, error) { - return s.base.GetOptInt(level, name) -} - -func (s *virtualSocket) toHostAddr(addr Sockaddr) (Sockaddr, error) { - var peer *virtualSocket - var err error - switch a := addr.(type) { - case *SockaddrInet4: - peer, err = s.ns.lookupByAddrInet4(s, a) - case *SockaddrInet6: - peer, err = s.ns.lookupByAddrInet6(s, a) - default: - return nil, EAFNOSUPPORT - } - if peer == nil { - return addr, err - } - peerAddr, err := peer.Name() - switch err { - case nil: - return peerAddr, nil - case EBADF: - return nil, ECONNRESET - default: - return nil, err - } -} - -func (s *virtualSocket) toVirtualAddr(addr Sockaddr) (Sockaddr, error) { - var peer *virtualSocket - switch a := addr.(type) { - case *SockaddrInet4: - peer = s.ns.lookupByHostInet4(s, a) - case *SockaddrInet6: - peer = s.ns.lookupByHostInet6(s, a) - default: - return nil, EAFNOSUPPORT - } - if peer != nil { - return peer.name.Load().(Sockaddr), nil - } - // TODO: this races if the virtual peer was closed after sending a - // datagram but before we looked it up on the interfaces. - return addr, nil -} diff --git a/internal/network/virtual_test.go b/internal/network/virtual_test.go deleted file mode 100644 index ff245f3d..00000000 --- a/internal/network/virtual_test.go +++ /dev/null @@ -1,188 +0,0 @@ -package network_test - -import ( - "net" - "testing" - - "github.com/stealthrocket/timecraft/internal/assert" - "github.com/stealthrocket/timecraft/internal/network" -) - -func TestVirtualNetwork(t *testing.T) { - tests := []struct { - scenario string - function func(*testing.T, *network.VirtualNetwork) - }{ - { - scenario: "a virtual network namespace has two interfaces", - function: testVirtualNetworkInterfaces, - }, - - { - scenario: "ipv4 sockets can connect to one another on a loopback interface", - function: testVirtualNetworkConnectLoopbackIPv4, - }, - - { - scenario: "ipv6 sockets can connect to one another on a loopback interface", - function: testVirtualNetworkConnectLoopbackIPv6, - }, - - { - scenario: "ipv4 sockets can connect to one another on a network interface", - function: testVirtualNetworkConnectInterfaceIPv4, - }, - - { - scenario: "ipv6 sockets can connect to one another on a network interface", - function: testVirtualNetworkConnectInterfaceIPv6, - }, - - { - scenario: "ipv4 sockets in different namespaces can connet to one another", - function: testVirtualNetworkConnectNamespacesIPv4, - }, - - { - scenario: "ipv6 sockets in different namespaces can connet to one another", - function: testVirtualNetworkConnectNamespacesIPv6, - }, - } - - _, ipnet4, err := net.ParseCIDR("192.168.0.0/24") - assert.OK(t, err) - - _, ipnet6, err := net.ParseCIDR("fe80::/64") - assert.OK(t, err) - - for _, test := range tests { - t.Run(test.scenario, func(t *testing.T) { - test.function(t, - network.NewVirtualNetwork(network.Host(), ipnet4, ipnet6), - ) - }) - } -} - -func testVirtualNetworkInterfaces(t *testing.T, n *network.VirtualNetwork) { - ns, err := n.CreateNamespace() - assert.OK(t, err) - - ifaces, err := ns.Interfaces() - assert.OK(t, err) - assert.Equal(t, len(ifaces), 2) - - lo0 := ifaces[0] - assert.Equal(t, lo0.Index(), 0) - assert.Equal(t, lo0.MTU(), 1500) - assert.Equal(t, lo0.Name(), "lo0") - assert.Equal(t, lo0.Flags(), net.FlagUp|net.FlagLoopback) - - lo0Addrs, err := lo0.Addrs() - assert.OK(t, err) - assert.Equal(t, len(lo0Addrs), 2) - assert.Equal(t, lo0Addrs[0].String(), "127.0.0.1") - assert.Equal(t, lo0Addrs[1].String(), "::1") - - en0 := ifaces[1] - assert.Equal(t, en0.Index(), 1) - assert.Equal(t, en0.MTU(), 1500) - assert.Equal(t, en0.Name(), "en0") - assert.Equal(t, en0.Flags(), net.FlagUp) - - en0Addrs, err := en0.Addrs() - assert.OK(t, err) - assert.Equal(t, len(en0Addrs), 2) - assert.Equal(t, en0Addrs[0].String(), "192.168.0.1") - assert.Equal(t, en0Addrs[1].String(), "fe80::1") -} - -func testVirtualNetworkConnectLoopbackIPv4(t *testing.T, n *network.VirtualNetwork) { - testVirtualNetworkConnect(t, n, &network.SockaddrInet4{ - Addr: [4]byte{127, 0, 0, 1}, - Port: 80, - }) -} - -func testVirtualNetworkConnectLoopbackIPv6(t *testing.T, n *network.VirtualNetwork) { - testVirtualNetworkConnect(t, n, &network.SockaddrInet6{ - Addr: [16]byte{15: 1}, - Port: 80, - }) -} - -func testVirtualNetworkConnectInterfaceIPv4(t *testing.T, n *network.VirtualNetwork) { - testVirtualNetworkConnect(t, n, &network.SockaddrInet4{ - Addr: [4]byte{192, 168, 0, 1}, - Port: 80, - }) -} - -func testVirtualNetworkConnectInterfaceIPv6(t *testing.T, n *network.VirtualNetwork) { - testVirtualNetworkConnect(t, n, &network.SockaddrInet6{ - Addr: [16]byte{0: 0xfe, 1: 0x80, 15: 1}, - Port: 80, - }) -} - -func testVirtualNetworkConnect(t *testing.T, n *network.VirtualNetwork, bind network.Sockaddr) { - ns, err := n.CreateNamespace() - assert.OK(t, err) - testNamespaceConnect(t, ns, bind) -} - -func testVirtualNetworkConnectNamespacesIPv4(t *testing.T, n *network.VirtualNetwork) { - testVirtualNetworkConnectNamespaces(t, n, network.INET) -} - -func testVirtualNetworkConnectNamespacesIPv6(t *testing.T, n *network.VirtualNetwork) { - testVirtualNetworkConnectNamespaces(t, n, network.INET6) -} - -func testVirtualNetworkConnectNamespaces(t *testing.T, n *network.VirtualNetwork, family network.Family) { - ns1, err := n.CreateNamespace() - assert.OK(t, err) - - ns2, err := n.CreateNamespace() - assert.OK(t, err) - - server, err := ns1.Socket(family, network.STREAM, network.TCP) - assert.OK(t, err) - defer server.Close() - - assert.OK(t, server.Listen(1)) - serverAddr, err := server.Name() - assert.OK(t, err) - - client, err := ns2.Socket(family, network.STREAM, network.TCP) - assert.OK(t, err) - defer client.Close() - - assert.Error(t, client.Connect(serverAddr), network.EINPROGRESS) - assert.OK(t, waitReadyRead(server)) - conn, addr, err := server.Accept() - assert.OK(t, err) - defer conn.Close() - - assert.OK(t, waitReadyWrite(client)) - peer, err := client.Peer() - assert.OK(t, err) - assert.Equal(t, network.SockaddrAddrPort(peer), network.SockaddrAddrPort(serverAddr)) - - name, err := client.Name() - assert.OK(t, err) - assert.Equal(t, network.SockaddrAddrPort(name), network.SockaddrAddrPort(addr)) - - wn, err := client.SendTo([][]byte{[]byte("Hello, World!")}, nil, 0) - assert.OK(t, err) - assert.Equal(t, wn, 13) - - assert.OK(t, waitReadyRead(conn)) - buf := make([]byte, 32) - rn, rflags, peer, err := conn.RecvFrom([][]byte{buf}, 0) - assert.OK(t, err) - assert.Equal(t, rn, 13) - assert.Equal(t, rflags, 0) - assert.Equal(t, string(buf[:13]), "Hello, World!") - assert.Equal(t, peer, nil) -} From 2f40be52903aca6265d166aa4c3005bd1969ba67 Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Fri, 14 Jul 2023 01:57:35 -0700 Subject: [PATCH 06/16] network: WIP Dial, Listen, ListenPacket Signed-off-by: Achille Roussel --- internal/network/conn.go | 52 +++++++++++ internal/network/host_unix.go | 72 ++------------ internal/network/local_darwin.go | 12 +-- internal/network/local_linux.go | 16 ++-- internal/network/local_test.go | 84 +++++++++++++++++ internal/network/local_unix.go | 155 ++++++++++++++++++++++++++++--- internal/network/network.go | 53 ++++++++--- internal/network/network_unix.go | 85 +++++++++++++++++ 8 files changed, 426 insertions(+), 103 deletions(-) create mode 100644 internal/network/conn.go diff --git a/internal/network/conn.go b/internal/network/conn.go new file mode 100644 index 00000000..39bfa045 --- /dev/null +++ b/internal/network/conn.go @@ -0,0 +1,52 @@ +package network + +import ( + "io" + "net" + "sync" +) + +type closeReader interface { + CloseRead() error +} + +type closeWriter interface { + CloseWrite() error +} + +var ( + _ closeReader = (*net.UnixConn)(nil) + _ closeWriter = (*net.UnixConn)(nil) +) + +func closeRead(conn io.Closer) error { + switch c := conn.(type) { + case closeReader: + return c.CloseRead() + default: + return c.Close() + } +} + +func closeWrite(conn io.Closer) error { + switch c := conn.(type) { + case closeWriter: + return c.CloseWrite() + default: + return c.Close() + } +} + +func tunnel(w, r net.Conn, b []byte, errs chan<- error, wg *sync.WaitGroup) { + defer wg.Done() + defer closeWrite(w) //nolint:errcheck + _, err := io.CopyBuffer(w, r, b) + if err != nil { + errs <- err + } +} + +func isTemporary(err error) bool { + e, _ := err.(interface{ Temporary() bool }) + return e != nil && e.Temporary() +} diff --git a/internal/network/host_unix.go b/internal/network/host_unix.go index 8e136596..76c333a2 100644 --- a/internal/network/host_unix.go +++ b/internal/network/host_unix.go @@ -1,8 +1,6 @@ package network import ( - "time" - "golang.org/x/sys/unix" ) @@ -40,7 +38,7 @@ func (s *hostSocket) Bind(addr Sockaddr) error { return EBADF } defer s.fd.release(fd) - return ignoreEINTR(func() error { return unix.Bind(fd, addr) }) + return bind(fd, addr) } func (s *hostSocket) Listen(backlog int) error { @@ -49,7 +47,7 @@ func (s *hostSocket) Listen(backlog int) error { return EBADF } defer s.fd.release(fd) - return ignoreEINTR(func() error { return unix.Listen(fd, backlog) }) + return listen(fd, backlog) } func (s *hostSocket) Connect(addr Sockaddr) error { @@ -58,7 +56,7 @@ func (s *hostSocket) Connect(addr Sockaddr) error { return EBADF } defer s.fd.release(fd) - return ignoreEINTR(func() error { return unix.Connect(fd, addr) }) + return connect(fd, addr) } func (s *hostSocket) Name() (Sockaddr, error) { @@ -67,7 +65,7 @@ func (s *hostSocket) Name() (Sockaddr, error) { return nil, EBADF } defer s.fd.release(fd) - return ignoreEINTR2(func() (Sockaddr, error) { return unix.Getsockname(fd) }) + return getsockname(fd) } func (s *hostSocket) Peer() (Sockaddr, error) { @@ -76,7 +74,7 @@ func (s *hostSocket) Peer() (Sockaddr, error) { return nil, EBADF } defer s.fd.release(fd) - return ignoreEINTR2(func() (Sockaddr, error) { return unix.Getpeername(fd) }) + return getpeername(fd) } func (s *hostSocket) RecvFrom(iovs [][]byte, flags int) (int, int, Sockaddr, error) { @@ -124,9 +122,7 @@ func (s *hostSocket) Shutdown(how int) error { return EBADF } defer s.fd.release(fd) - return ignoreEINTR(func() error { - return unix.Shutdown(fd, how) - }) + return shutdown(fd, how) } func (s *hostSocket) SetOptInt(level, name, value int) error { @@ -135,7 +131,7 @@ func (s *hostSocket) SetOptInt(level, name, value int) error { return EBADF } defer s.fd.release(fd) - return ignoreEINTR(func() error { return unix.SetsockoptInt(fd, level, name, value) }) + return setsockoptInt(fd, level, name, value) } func (s *hostSocket) GetOptInt(level, name int) (int, error) { @@ -144,57 +140,5 @@ func (s *hostSocket) GetOptInt(level, name int) (int, error) { return -1, EBADF } defer s.fd.release(fd) - return ignoreEINTR2(func() (int, error) { return unix.GetsockoptInt(fd, level, name) }) -} - -// This function is used to automtically retry syscalls when they return EINTR -// due to having handled a signal instead of executing. Despite defininig a -// EINTR constant and having proc_raise to trigger signals from the guest, WASI -// does not provide any mechanism for handling signals so masking those errors -// seems like a safer approach to ensure that guest applications will work the -// same regardless of the compiler being used. -func ignoreEINTR(f func() error) error { - for { - if err := f(); err != EINTR { - return err - } - } -} - -func ignoreEINTR2[F func() (R, error), R any](f F) (R, error) { - for { - v, err := f() - if err != EINTR { - return v, err - } - } -} - -func ignoreEINTR3[F func() (R1, R2, error), R1, R2 any](f F) (R1, R2, error) { - for { - v1, v2, err := f() - if err != EINTR { - return v1, v2, err - } - } -} - -func WaitReadyRead(socket Socket, timeout time.Duration) error { - return wait(socket, unix.POLLIN, timeout) -} - -func WaitReadyWrite(socket Socket, timeout time.Duration) error { - return wait(socket, unix.POLLOUT, timeout) -} - -func wait(socket Socket, events int16, timeout time.Duration) error { - tms := int(timeout / time.Millisecond) - pfd := []unix.PollFd{{ - Fd: int32(socket.Fd()), - Events: events, - }} - return ignoreEINTR(func() error { - _, err := unix.Poll(pfd, tms) - return err - }) + return getsockoptInt(fd, level, name) } diff --git a/internal/network/local_darwin.go b/internal/network/local_darwin.go index de8108df..7c936efe 100644 --- a/internal/network/local_darwin.go +++ b/internal/network/local_darwin.go @@ -34,20 +34,20 @@ func closePair(fds *[2]int) { fds[1] = -1 } -func (s *localSocket) SetOptInt(level, name, value int) error { +func (s *localSocket) GetOptInt(level, name int) (int, error) { fd := s.fd0.acquire() if fd < 0 { - return EBADF + return 0, EBADF } defer s.fd0.release(fd) - return ignoreEINTR(func() error { return unix.SetsockoptInt(fd, level, name, value) }) + return getsockoptInt(fd, level, name) } -func (s *localSocket) GetOptInt(level, name int) (int, error) { +func (s *localSocket) SetOptInt(level, name, value int) error { fd := s.fd0.acquire() if fd < 0 { - return 0, EBADF + return EBADF } defer s.fd0.release(fd) - return ignoreEINTR2(func() (int, error) { return unix.GetsockoptInt(fd, level, name) }) + return setsockoptInt(fd, level, name, value) } diff --git a/internal/network/local_linux.go b/internal/network/local_linux.go index 214823f5..d9d14abe 100644 --- a/internal/network/local_linux.go +++ b/internal/network/local_linux.go @@ -8,40 +8,40 @@ func socketpair(family, socktype, protocol int) ([2]int, error) { }) } -func (s *localSocket) SetOptInt(level, name, value int) error { +func (s *localSocket) GetOptInt(level, name int) (int, error) { fd := s.fd0.acquire() if fd < 0 { - return EBADF + return 0, EBADF } defer s.fd0.release(fd) switch level { case unix.SOL_SOCKET: switch name { + case unix.SO_DOMAIN: + return int(s.family), nil case unix.SO_BINDTODEVICE: return 0, ENOPROTOOPT } } - return ignoreEINTR(func() error { return unix.SetsockoptInt(fd, level, name, value) }) + return getsockoptInt(fd, level, name) } -func (s *localSocket) GetOptInt(level, name int) (int, error) { +func (s *localSocket) SetOptInt(level, name, value int) error { fd := s.fd0.acquire() if fd < 0 { - return 0, EBADF + return EBADF } defer s.fd0.release(fd) switch level { case unix.SOL_SOCKET: switch name { - case unix.SO_DOMAIN: - return int(s.family), nil case unix.SO_BINDTODEVICE: return 0, ENOPROTOOPT } } - return ignoreEINTR2(func() (int, error) { return unix.GetsockoptInt(fd, level, name) }) + return setsockoptInt(fd, level, name, value) } diff --git a/internal/network/local_test.go b/internal/network/local_test.go index 8b52976e..41061e99 100644 --- a/internal/network/local_test.go +++ b/internal/network/local_test.go @@ -47,6 +47,11 @@ func TestLocalNetwork(t *testing.T) { scenario: "ipv6 sockets in different namespaces can connect to one another", function: testLocalNetworkConnectNamespacesIPv6, }, + + { + scenario: "local sockets can establish connections to foreign networks when a dial function is configured", + function: testLocalNetworkOutboundConnect, + }, } ipv4, ipnet4, err := net.ParseCIDR("192.168.0.1/24") @@ -212,3 +217,82 @@ func testLocalNetworkConnectNamespaces(t *testing.T, n *network.LocalNetwork, fa assert.Equal(t, string(buf[:13]), "Hello, World!") assert.Equal(t, peer, nil) } + +func testLocalNetworkOutboundConnect(t *testing.T, n *network.LocalNetwork) { + ns1 := network.Host() + + ifaces1, err := ns1.Interfaces() + assert.OK(t, err) + assert.NotEqual(t, len(ifaces1), 0) + + var hostAddr *network.SockaddrInet4 + for _, iface := range ifaces1 { + if hostAddr != nil { + break + } + if (iface.Flags() & net.FlagUp) == 0 { + continue + } + if (iface.Flags() & net.FlagLoopback) != 0 { + continue + } + + addrs, err := iface.Addrs() + assert.OK(t, err) + + for _, addr := range addrs { + if a, ok := addr.(*net.IPNet); ok { + if ipv4 := a.IP.To4(); ipv4 != nil { + hostAddr = &network.SockaddrInet4{Addr: ([4]byte)(ipv4)} + break + } + } + } + } + assert.NotEqual(t, hostAddr, nil) + + server, err := ns1.Socket(network.INET, network.STREAM, network.TCP) + assert.OK(t, err) + defer server.Close() + + assert.OK(t, server.Bind(hostAddr)) + assert.OK(t, server.Listen(1)) + serverAddr, err := server.Name() + assert.OK(t, err) + + var dialer net.Dialer + ns2, err := n.CreateNamespace(nil, network.DialFunc(dialer.DialContext)) + assert.OK(t, err) + + client, err := ns2.Socket(network.INET, network.STREAM, network.TCP) + assert.OK(t, err) + defer client.Close() + + assert.Error(t, client.Connect(serverAddr), network.EINPROGRESS) + assert.OK(t, waitReadyRead(server)) + + conn, addr, err := server.Accept() + assert.OK(t, err) + defer conn.Close() + assert.NotEqual(t, addr, nil) + + assert.OK(t, waitReadyWrite(client)) + peer, err := client.Peer() + assert.OK(t, err) + assert.Equal(t, + network.SockaddrAddrPort(peer), + network.SockaddrAddrPort(serverAddr)) + + wn, err := client.SendTo([][]byte{[]byte("Hello, World!")}, nil, 0) + assert.OK(t, err) + assert.Equal(t, wn, 13) + + assert.OK(t, waitReadyRead(conn)) + buf := make([]byte, 32) + rn, rflags, peer, err := conn.RecvFrom([][]byte{buf}, 0) + assert.OK(t, err) + assert.Equal(t, rn, 13) + assert.Equal(t, rflags, 0) + assert.Equal(t, string(buf[:13]), "Hello, World!") + assert.Equal(t, peer, nil) +} diff --git a/internal/network/local_unix.go b/internal/network/local_unix.go index 54200e6b..0dc72b20 100644 --- a/internal/network/local_unix.go +++ b/internal/network/local_unix.go @@ -1,7 +1,12 @@ package network import ( + "context" "encoding/binary" + "fmt" + "net" + "os" + "sync" "sync/atomic" "golang.org/x/sys/unix" @@ -50,16 +55,22 @@ const ( type localSocket struct { ns *LocalNamespace - fd0 socketFD - fd1 socketFD family Family socktype Socktype protocol Protocol - state localSocketState - name atomic.Value - peer atomic.Value - iovs [][]byte - addrBuf [addrBufSize]byte + + fd0 socketFD + fd1 socketFD + state localSocketState + + name atomic.Value + peer atomic.Value + + iovs [][]byte + addrBuf [addrBufSize]byte + + errs <-chan error + cancel context.CancelFunc } func (s *localSocket) Family() Family { @@ -142,11 +153,49 @@ func (s *localSocket) Listen(backlog int) error { return err } } - + if s.ns.listen != nil { + if err := s.bridge(); err != nil { + return err + } + } s.state.set(listening) return nil } +func (s *localSocket) bridge() error { + var address string + switch a := s.name.Load().(type) { + case *SockaddrInet4: + address = fmt.Sprintf(":%d", a.Port) + case *SockaddrInet6: + address = fmt.Sprintf("[::]:%d", a.Port) + } + + l, err := s.ns.listen(context.TODO(), "tcp", address) + if err != nil { + return err + } + + go func() { + for { + c, err := l.Accept() + if err != nil { + if !isTemporary(err) { + return + } + continue + } + + _ = c + // TOOD: + // - create local socket + // - send to parent server socket + // - start connection tunnel + } + }() + return nil +} + func (s *localSocket) Connect(addr Sockaddr) error { fd := s.fd0.acquire() if fd < 0 { @@ -177,9 +226,9 @@ func (s *localSocket) Connect(addr Sockaddr) error { return EAFNOSUPPORT } if err != nil { - return ECONNREFUSED - } - if peer.family != s.family { + if err == ENETUNREACH { + return s.dial(addr) + } return ECONNREFUSED } if peer.socktype != s.socktype { @@ -218,6 +267,73 @@ func (s *localSocket) Connect(addr Sockaddr) error { return EINPROGRESS } +func (s *localSocket) dial(addr Sockaddr) error { + dial := s.ns.dial + if dial == nil { + return ENETUNREACH + } + + fd1 := s.fd1.acquire() + if fd1 < 0 { + return EBADF + } + defer s.fd1.release(fd1) + + rbufsize, err := getsockoptInt(fd1, unix.SOL_SOCKET, unix.SO_RCVBUF) + if err != nil { + return err + } + wbufsize, err := getsockoptInt(fd1, unix.SOL_SOCKET, unix.SO_SNDBUF) + if err != nil { + return err + } + + f := os.NewFile(uintptr(fd1), "") + defer f.Close() + s.fd1.acquire() + s.fd1.close() // detach from the socket, f owns the fd now + + downstream, err := net.FileConn(f) + if err != nil { + return err + } + + dialNetwork := s.protocol.Network() + dialAddress := SockaddrAddrPort(addr).String() + ctx, cancel := context.WithCancel(context.Background()) + + errs := make(chan error, 2) + s.errs = errs + s.cancel = cancel + + go func() { + defer close(errs) + defer downstream.Close() + + upstream, err := dial(ctx, dialNetwork, dialAddress) + if err != nil { + errs <- err + return + } + defer upstream.Close() + buffer := make([]byte, rbufsize+wbufsize) + + wg := new(sync.WaitGroup) + wg.Add(2) + defer wg.Wait() + + go tunnel(downstream, upstream, buffer[:rbufsize], errs, wg) + go tunnel(upstream, downstream, buffer[rbufsize:], errs, wg) + + <-ctx.Done() + closeRead(upstream) //nolint:errcheck + }() + + s.peer.Store(addr) + s.state.set(connected) + return EINPROGRESS +} + func (s *localSocket) Accept() (Socket, Sockaddr, error) { fd := s.fd0.acquire() if fd < 0 { @@ -326,6 +442,9 @@ func (s *localSocket) RecvFrom(iovs [][]byte, flags int) (int, int, Sockaddr, er if s.state.is(listening) { return -1, 0, nil, EINVAL } + if err := s.getError(); err != nil { + return -1, 0, nil, err + } if !s.state.is(bound) { if err := s.bindAny(); err != nil { return -1, 0, nil, err @@ -376,6 +495,9 @@ func (s *localSocket) SendTo(iovs [][]byte, addr Sockaddr, flags int) (int, erro if !s.state.is(connected) && addr == nil { return -1, ENOTCONN } + if err := s.getError(); err != nil { + return -1, err + } if !s.state.is(bound) { if err := s.bindAny(); err != nil { return -1, err @@ -437,7 +559,16 @@ func (s *localSocket) Shutdown(how int) error { return EBADF } defer s.fd0.release(fd) - return ignoreEINTR(func() error { return unix.Shutdown(fd, how) }) + return shutdown(fd, how) +} + +func (s *localSocket) getError() error { + select { + case err := <-s.errs: + return err + default: + return nil + } } func clearIOVecs(iovs [][]byte) { diff --git a/internal/network/network.go b/internal/network/network.go index 7c5ddfe2..0ed54c30 100644 --- a/internal/network/network.go +++ b/internal/network/network.go @@ -38,9 +38,9 @@ type Socket interface { Shutdown(how int) error - SetOptInt(level, name, value int) error - GetOptInt(level, name int) (int, error) + + SetOptInt(level, name, value int) error } type Socktype uint8 @@ -63,11 +63,35 @@ func (f Family) String() string { type Protocol uint16 const ( - UNSPEC Protocol = 0 - TCP Protocol = 6 - UDP Protocol = 17 + NOPROTO Protocol = 0 + TCP Protocol = 6 + UDP Protocol = 17 ) +func (p Protocol) String() string { + switch p { + case NOPROTO: + return "NOPROTO" + case TCP: + return "TCP" + case UDP: + return "UDP" + default: + return "UNKNOWN" + } +} + +func (p Protocol) Network() string { + switch p { + case TCP: + return "tcp" + case UDP: + return "udp" + default: + return "ip" + } +} + type Namespace interface { InterfaceByIndex(index int) (Interface, error) @@ -105,6 +129,17 @@ func SockaddrFamily(sa Sockaddr) Family { } } +func SockaddrAddr(sa Sockaddr) netip.Addr { + switch a := sa.(type) { + case *SockaddrInet4: + return netip.AddrFrom4(a.Addr) + case *SockaddrInet6: + return netip.AddrFrom16(a.Addr) + default: + return netip.Addr{} + } +} + func SockaddrAddrPort(sa Sockaddr) netip.AddrPort { switch a := sa.(type) { case *SockaddrInet4: @@ -116,14 +151,6 @@ func SockaddrAddrPort(sa Sockaddr) netip.AddrPort { } } -func isUnspecifiedInet4(sa *SockaddrInet4) bool { - return sa.Addr == [4]byte{} -} - -func isUnspecifiedInet6(sa *SockaddrInet6) bool { - return sa.Addr == [16]byte{} -} - func errInterfaceIndexNotFound(index int) error { return fmt.Errorf("index=%d: %w", index, ErrInterfaceNotFound) } diff --git a/internal/network/network_unix.go b/internal/network/network_unix.go index 552f2b0a..87c5e403 100644 --- a/internal/network/network_unix.go +++ b/internal/network/network_unix.go @@ -2,6 +2,7 @@ package network import ( "sync/atomic" + "time" "golang.org/x/sys/unix" ) @@ -49,6 +50,90 @@ type Sockaddr = unix.Sockaddr type SockaddrInet4 = unix.SockaddrInet4 type SockaddrInet6 = unix.SockaddrInet6 +// This function is used to automtically retry syscalls when they return EINTR +// due to having handled a signal instead of executing. Despite defininig a +// EINTR constant and having proc_raise to trigger signals from the guest, WASI +// does not provide any mechanism for handling signals so masking those errors +// seems like a safer approach to ensure that guest applications will work the +// same regardless of the compiler being used. +func ignoreEINTR(f func() error) error { + for { + if err := f(); err != EINTR { + return err + } + } +} + +func ignoreEINTR2[F func() (R, error), R any](f F) (R, error) { + for { + v, err := f() + if err != EINTR { + return v, err + } + } +} + +func ignoreEINTR3[F func() (R1, R2, error), R1, R2 any](f F) (R1, R2, error) { + for { + v1, v2, err := f() + if err != EINTR { + return v1, v2, err + } + } +} + +func WaitReadyRead(socket Socket, timeout time.Duration) error { + return wait(socket, unix.POLLIN, timeout) +} + +func WaitReadyWrite(socket Socket, timeout time.Duration) error { + return wait(socket, unix.POLLOUT, timeout) +} + +func wait(socket Socket, events int16, timeout time.Duration) error { + tms := int(timeout / time.Millisecond) + pfd := []unix.PollFd{{ + Fd: int32(socket.Fd()), + Events: events, + }} + return ignoreEINTR(func() error { + _, err := unix.Poll(pfd, tms) + return err + }) +} + +func bind(fd int, addr Sockaddr) error { + return ignoreEINTR(func() error { return unix.Bind(fd, addr) }) +} + +func listen(fd, backlog int) error { + return ignoreEINTR(func() error { return unix.Listen(fd, backlog) }) +} + +func connect(fd int, addr Sockaddr) error { + return ignoreEINTR(func() error { return unix.Connect(fd, addr) }) +} + +func shutdown(fd, how int) error { + return ignoreEINTR(func() error { return unix.Shutdown(fd, how) }) +} + +func getsockname(fd int) (Sockaddr, error) { + return ignoreEINTR2(func() (Sockaddr, error) { return unix.Getsockname(fd) }) +} + +func getpeername(fd int) (Sockaddr, error) { + return ignoreEINTR2(func() (Sockaddr, error) { return unix.Getpeername(fd) }) +} + +func getsockoptInt(fd, level, name int) (int, error) { + return ignoreEINTR2(func() (int, error) { return unix.GetsockoptInt(fd, level, name) }) +} + +func setsockoptInt(fd, level, name, value int) error { + return ignoreEINTR(func() error { return unix.SetsockoptInt(fd, level, name, value) }) +} + type socketFD struct { state atomic.Uint64 // upper 32 bits: refCount, lower 32 bits: fd } From 8c890fd241fcbfbd684a681c5f42e33e5a40ea70 Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Fri, 14 Jul 2023 11:58:02 -0700 Subject: [PATCH 07/16] network: allow local network to receive inbound connections from external networks Signed-off-by: Achille Roussel --- internal/network/conn.go | 21 +++- internal/network/host_unix.go | 3 +- internal/network/local.go | 13 +++ internal/network/local_test.go | 71 ++++++++++++ internal/network/local_unix.go | 185 +++++++++++++++++++++++-------- internal/network/network_unix.go | 63 +---------- internal/network/socket.go | 66 +++++++++++ internal/network/socket_test.go | 60 ++++++++++ internal/network/socket_unix.go | 24 ++++ 9 files changed, 390 insertions(+), 116 deletions(-) create mode 100644 internal/network/socket.go create mode 100644 internal/network/socket_test.go create mode 100644 internal/network/socket_unix.go diff --git a/internal/network/conn.go b/internal/network/conn.go index 39bfa045..300c50d4 100644 --- a/internal/network/conn.go +++ b/internal/network/conn.go @@ -37,7 +37,21 @@ func closeWrite(conn io.Closer) error { } } -func tunnel(w, r net.Conn, b []byte, errs chan<- error, wg *sync.WaitGroup) { +func tunnel(downstream, upstream net.Conn, rbufsize, wbufsize int) error { + buffer := make([]byte, rbufsize+wbufsize) // TODO: pool this buffer? + errs := make(chan error, 2) + wg := new(sync.WaitGroup) + wg.Add(2) + + go copyAndClose(downstream, upstream, buffer[:rbufsize], errs, wg) + go copyAndClose(upstream, downstream, buffer[rbufsize:], errs, wg) + + wg.Wait() + close(errs) + return <-errs +} + +func copyAndClose(w, r net.Conn, b []byte, errs chan<- error, wg *sync.WaitGroup) { defer wg.Done() defer closeWrite(w) //nolint:errcheck _, err := io.CopyBuffer(w, r, b) @@ -45,8 +59,3 @@ func tunnel(w, r net.Conn, b []byte, errs chan<- error, wg *sync.WaitGroup) { errs <- err } } - -func isTemporary(err error) bool { - e, _ := err.(interface{ Temporary() bool }) - return e != nil && e.Temporary() -} diff --git a/internal/network/host_unix.go b/internal/network/host_unix.go index 76c333a2..2c64c793 100644 --- a/internal/network/host_unix.go +++ b/internal/network/host_unix.go @@ -29,7 +29,8 @@ func (s *hostSocket) Fd() int { } func (s *hostSocket) Close() error { - return s.fd.close() + s.fd.close() + return nil } func (s *hostSocket) Bind(addr Sockaddr) error { diff --git a/internal/network/local.go b/internal/network/local.go index 944ecc4f..4feebe5a 100644 --- a/internal/network/local.go +++ b/internal/network/local.go @@ -150,6 +150,19 @@ type LocalNamespace struct { en0 localInterface } +func (ns *LocalNamespace) Socket(family Family, socktype Socktype, protocol Protocol) (Socket, error) { + switch family { + case INET, INET6: + default: + return ns.host.Socket(family, socktype, protocol) + } + s, err := ns.socket(family, socktype, protocol) + if err != nil { + return nil, err + } + return s, nil +} + func (ns *LocalNamespace) Detach() { if n := ns.network.Swap(nil); n != nil { n.mutex.Lock() diff --git a/internal/network/local_test.go b/internal/network/local_test.go index 41061e99..681b2f81 100644 --- a/internal/network/local_test.go +++ b/internal/network/local_test.go @@ -1,7 +1,11 @@ package network_test import ( + "context" + "io" "net" + "net/netip" + "strconv" "testing" "github.com/stealthrocket/timecraft/internal/assert" @@ -52,6 +56,11 @@ func TestLocalNetwork(t *testing.T) { scenario: "local sockets can establish connections to foreign networks when a dial function is configured", function: testLocalNetworkOutboundConnect, }, + + { + scenario: "local sockets can receive inbound connections from foreign network when a listen function is configured", + function: testLocalNetworkInboundAccept, + }, } ipv4, ipnet4, err := net.ParseCIDR("192.168.0.1/24") @@ -296,3 +305,65 @@ func testLocalNetworkOutboundConnect(t *testing.T, n *network.LocalNetwork) { assert.Equal(t, string(buf[:13]), "Hello, World!") assert.Equal(t, peer, nil) } + +func testLocalNetworkInboundAccept(t *testing.T, n *network.LocalNetwork) { + ns, err := n.CreateNamespace(nil, + network.ListenFunc(func(ctx context.Context, network, address string) (net.Listener, error) { + _, port, err := net.SplitHostPort(address) + if err != nil { + return nil, err + } + return net.Listen(network, net.JoinHostPort("127.0.0.1", port)) + }), + ) + assert.OK(t, err) + + sock, err := ns.Socket(network.INET, network.STREAM, network.TCP) + assert.OK(t, err) + defer sock.Close() + + assert.OK(t, sock.Listen(0)) + addr, err := sock.Name() + assert.OK(t, err) + + addrPort := network.SockaddrAddrPort(addr) + connAddr := net.JoinHostPort("127.0.0.1", strconv.Itoa(int(addrPort.Port()))) + conn, err := net.Dial("tcp", connAddr) + assert.OK(t, err) + defer conn.Close() + + assert.OK(t, waitReadyRead(sock)) + peer, peerAddr, err := sock.Accept() + assert.OK(t, err) + + // verify that the address of the inbound connection matches the remote + // address of the peer socket + connLocalAddr := conn.LocalAddr().(*net.TCPAddr) + peerAddrPort := network.SockaddrAddrPort(peerAddr) + connAddrPort := netip.AddrPortFrom(netip.AddrFrom4(([4]byte)(connLocalAddr.IP)), uint16(connLocalAddr.Port)) + assert.Equal(t, peerAddrPort, connAddrPort) + + // verify that the inbound connection and the peer socket can exchange data + size, err := conn.Write([]byte("message")) + assert.OK(t, err) + assert.Equal(t, size, 7) + + assert.OK(t, waitReadyRead(peer)) + buf := make([]byte, 32) + size, _, _, err = peer.RecvFrom([][]byte{buf}, 0) + assert.OK(t, err) + assert.Equal(t, size, 7) + assert.Equal(t, string(buf[:7]), "message") + + // exercise shutting down the write end of the inbound connection + conn.(*net.TCPConn).CloseWrite() + waitReadyRead(peer) + size, _, _, err = peer.RecvFrom([][]byte{buf}, 0) + assert.OK(t, err) + assert.Equal(t, size, 0) + + // exercise shutting down the write end of the peer socket + assert.OK(t, peer.Shutdown(network.SHUTWR)) + _, err = conn.Read(buf) + assert.Equal(t, err, io.EOF) +} diff --git a/internal/network/local_unix.go b/internal/network/local_unix.go index 0dc72b20..05ff09d7 100644 --- a/internal/network/local_unix.go +++ b/internal/network/local_unix.go @@ -6,18 +6,12 @@ import ( "fmt" "net" "os" - "sync" "sync/atomic" "golang.org/x/sys/unix" ) -func (ns *LocalNamespace) Socket(family Family, socktype Socktype, protocol Protocol) (Socket, error) { - switch family { - case INET, INET6: - default: - return ns.host.Socket(family, socktype, protocol) - } +func (ns *LocalNamespace) socket(family Family, socktype Socktype, protocol Protocol) (*localSocket, error) { socket := &localSocket{ ns: ns, family: family, @@ -37,6 +31,7 @@ type localSocketState uint8 const ( bound localSocketState = 1 << iota + accepted connected listening ) @@ -69,10 +64,29 @@ type localSocket struct { iovs [][]byte addrBuf [addrBufSize]byte + lstn net.Listener errs <-chan error cancel context.CancelFunc } +func (s *localSocket) String() string { + name := s.name.Load() + peer := s.peer.Load() + if name == nil { + return "?" + } + if peer == nil { + return SockaddrAddrPort(name.(Sockaddr)).String() + } + nameString := SockaddrAddrPort(name.(Sockaddr)).String() + peerString := SockaddrAddrPort(peer.(Sockaddr)).String() + if s.state.is(accepted) { + return peerString + "->" + nameString + } else { + return nameString + "->" + peerString + } +} + func (s *localSocket) Family() Family { return s.family } @@ -94,8 +108,20 @@ func (s *localSocket) Close() error { s.ns.unlinkInet6(s, addr) } } + s.fd0.close() s.fd1.close() + + if s.lstn != nil { + s.lstn.Close() + } + if s.cancel != nil { + s.cancel() + } + if s.errs != nil { + for range s.errs { + } + } return nil } @@ -176,21 +202,77 @@ func (s *localSocket) bridge() error { return err } + errs := make(chan error, 1) + s.lstn = l + s.errs = errs + go func() { + defer close(errs) for { c, err := l.Accept() if err != nil { - if !isTemporary(err) { + e, _ := err.(interface{ Temporary() bool }) + if e == nil || !e.Temporary() { + errs <- err return } - continue + } else if s.serve(c) != nil { + c.Close() + select { + case errs <- ECONNABORTED: + default: + } } + } + }() + return nil +} + +func (s *localSocket) serve(upstream net.Conn) error { + serverFd := s.fd1.acquire() + if serverFd < 0 { + return EBADF + } + defer s.fd1.release(serverFd) + + rbufsize, err := getsockoptInt(serverFd, unix.SOL_SOCKET, unix.SO_RCVBUF) + if err != nil { + return err + } + wbufsize, err := getsockoptInt(serverFd, unix.SOL_SOCKET, unix.SO_SNDBUF) + if err != nil { + return err + } + socket, err := s.ns.socket(s.family, s.socktype, s.protocol) + if err != nil { + return err + } + defer socket.Close() - _ = c - // TOOD: - // - create local socket - // - send to parent server socket - // - start connection tunnel + if err := socket.connect(serverFd, upstream.RemoteAddr()); err != nil { + return err + } + + f := os.NewFile(uintptr(socket.fd0.acquire()), "") + defer f.Close() + socket.fd0.close() // detach from the socket, f owns the fd now + + downstream, err := net.FileConn(f) + if err != nil { + return err + } + + go func() { + defer upstream.Close() + defer downstream.Close() + + if err := tunnel(downstream, upstream, rbufsize, wbufsize); err != nil { + // TODO: figure out if this error needs to be reported: + // + // When the downstream file is closed, the other end of the + // connection will observe that the socket was shutdown, we + // lose the information of why but we currently have no way + // of locating the peer socket on the other side. } }() return nil @@ -215,13 +297,13 @@ func (s *localSocket) Connect(addr Sockaddr) error { } } - var peer *localSocket + var server *localSocket var err error switch a := addr.(type) { case *SockaddrInet4: - peer, err = s.ns.lookupInet4(s, a) + server, err = s.ns.lookupInet4(s, a) case *SockaddrInet6: - peer, err = s.ns.lookupInet6(s, a) + server, err = s.ns.lookupInet6(s, a) default: return EAFNOSUPPORT } @@ -231,7 +313,7 @@ func (s *localSocket) Connect(addr Sockaddr) error { } return ECONNREFUSED } - if peer.socktype != s.socktype { + if server.socktype != s.socktype { return ECONNREFUSED } @@ -241,30 +323,37 @@ func (s *localSocket) Connect(addr Sockaddr) error { return nil } - peerFd := peer.fd1.acquire() - if peerFd < 0 { + serverFd := server.fd1.acquire() + if serverFd < 0 { return ECONNREFUSED } - defer peer.fd1.release(peerFd) + defer server.fd1.release(serverFd) + if err := s.connect(serverFd, s.name.Load()); err != nil { + return err + } + + s.peer.Store(addr) + s.state.set(connected) + return EINPROGRESS +} + +func (s *localSocket) connect(serverFd int, addr any) error { fd1 := s.fd1.acquire() if fd1 < 0 { return EBADF } defer s.fd1.release(fd1) - addrBuf := encodeSockaddrAny(s.name.Load()) + addrBuf := encodeSockaddrAny(addr) // TODO: remove the heap allocation by implementing UnixRights to output to // a stack buffer. rights := unix.UnixRights(fd1) - if err := unix.Sendmsg(peerFd, addrBuf[:], rights, nil, 0); err != nil { + if err := unix.Sendmsg(serverFd, addrBuf[:], rights, nil, 0); err != nil { return ECONNREFUSED } - s.fd1.close() - s.peer.Store(addr) - s.state.set(connected) - return EINPROGRESS + return nil } func (s *localSocket) dial(addr Sockaddr) error { @@ -302,7 +391,7 @@ func (s *localSocket) dial(addr Sockaddr) error { dialAddress := SockaddrAddrPort(addr).String() ctx, cancel := context.WithCancel(context.Background()) - errs := make(chan error, 2) + errs := make(chan error, 1) s.errs = errs s.cancel = cancel @@ -316,17 +405,10 @@ func (s *localSocket) dial(addr Sockaddr) error { return } defer upstream.Close() - buffer := make([]byte, rbufsize+wbufsize) - wg := new(sync.WaitGroup) - wg.Add(2) - defer wg.Wait() - - go tunnel(downstream, upstream, buffer[:rbufsize], errs, wg) - go tunnel(upstream, downstream, buffer[rbufsize:], errs, wg) - - <-ctx.Done() - closeRead(upstream) //nolint:errcheck + if err := tunnel(downstream, upstream, rbufsize, wbufsize); err != nil { + errs <- err + } }() s.peer.Store(addr) @@ -344,13 +426,16 @@ func (s *localSocket) Accept() (Socket, Sockaddr, error) { if !s.state.is(listening) { return nil, nil, EINVAL } + if err := s.getError(); err != nil { + return nil, nil, err + } socket := &localSocket{ ns: s.ns, family: s.family, socktype: s.socktype, protocol: s.protocol, - state: bound | connected, + state: bound | accepted | connected, } var oobn int @@ -580,25 +665,31 @@ func clearIOVecs(iovs [][]byte) { func encodeSockaddrAny(addr any) (buf [addrBufSize]byte) { switch a := addr.(type) { case *SockaddrInet4: - return encodeSockaddrInet4(a) + return encodeAddrPortInet4(a.Addr, a.Port) case *SockaddrInet6: - return encodeSockaddrInet6(a) + return encodeAddrPortInet6(a.Addr, a.Port) + case *net.TCPAddr: + if ipv4 := a.IP.To4(); ipv4 != nil { + return encodeAddrPortInet4(([4]byte)(ipv4), a.Port) + } else { + return encodeAddrPortInet6(([16]byte)(a.IP), a.Port) + } default: return } } -func encodeSockaddrInet4(addr *SockaddrInet4) (buf [addrBufSize]byte) { +func encodeAddrPortInet4(addr [4]byte, port int) (buf [addrBufSize]byte) { binary.LittleEndian.PutUint16(buf[0:2], uint16(INET)) - binary.LittleEndian.PutUint16(buf[2:4], uint16(addr.Port)) - *(*[4]byte)(buf[4:]) = addr.Addr + binary.LittleEndian.PutUint16(buf[2:4], uint16(port)) + *(*[4]byte)(buf[4:]) = addr return } -func encodeSockaddrInet6(addr *SockaddrInet6) (buf [addrBufSize]byte) { +func encodeAddrPortInet6(addr [16]byte, port int) (buf [addrBufSize]byte) { binary.LittleEndian.PutUint16(buf[0:2], uint16(INET6)) - binary.LittleEndian.PutUint16(buf[2:4], uint16(addr.Port)) - *(*[16]byte)(buf[4:]) = addr.Addr + binary.LittleEndian.PutUint16(buf[2:4], uint16(port)) + *(*[16]byte)(buf[4:]) = addr return } diff --git a/internal/network/network_unix.go b/internal/network/network_unix.go index 87c5e403..fd22c6a2 100644 --- a/internal/network/network_unix.go +++ b/internal/network/network_unix.go @@ -1,7 +1,6 @@ package network import ( - "sync/atomic" "time" "golang.org/x/sys/unix" @@ -11,6 +10,7 @@ const ( EADDRNOTAVAIL = unix.EADDRNOTAVAIL EAFNOSUPPORT = unix.EAFNOSUPPORT EBADF = unix.EBADF + ECONNABORTED = unix.ECONNABORTED ECONNREFUSED = unix.ECONNREFUSED ECONNRESET = unix.ECONNRESET EHOSTUNREACH = unix.EHOSTUNREACH @@ -133,64 +133,3 @@ func getsockoptInt(fd, level, name int) (int, error) { func setsockoptInt(fd, level, name, value int) error { return ignoreEINTR(func() error { return unix.SetsockoptInt(fd, level, name, value) }) } - -type socketFD struct { - state atomic.Uint64 // upper 32 bits: refCount, lower 32 bits: fd -} - -func (s *socketFD) init(fd int) { - s.state.Store(uint64(fd)) -} - -func (s *socketFD) load() int { - return int(int32(s.state.Load())) -} - -func (s *socketFD) acquire() int { - for { - oldState := s.state.Load() - refCount := (oldState >> 32) + 1 - newState := (refCount << 32) | (oldState & 0xFFFFFFFF) - - if int32(oldState) < 0 { - return -1 - } - if s.state.CompareAndSwap(oldState, newState) { - return int(int32(oldState)) // int32->int for sign extension - } - } -} - -func (s *socketFD) release(fd int) { - for { - oldState := s.state.Load() - refCount := (oldState >> 32) - 1 - newState := (oldState << 32) | (oldState & 0xFFFFFFFF) - - if s.state.CompareAndSwap(oldState, newState) { - if int32(oldState) < 0 && refCount == 0 { - unix.Close(fd) - } - break - } - } -} - -func (s *socketFD) close() error { - for { - oldState := s.state.Load() - refCount := oldState >> 32 - newState := oldState | 0xFFFFFFFF - - if s.state.CompareAndSwap(oldState, newState) { - fd := int32(oldState) - if fd < 0 { - return EBADF - } - if refCount == 0 { - return unix.Close(int(fd)) - } - return nil - } - } -} diff --git a/internal/network/socket.go b/internal/network/socket.go new file mode 100644 index 00000000..bd801e17 --- /dev/null +++ b/internal/network/socket.go @@ -0,0 +1,66 @@ +package network + +import ( + "sync/atomic" +) + +type socketFD struct { + state atomic.Uint64 // upper 32 bits: refCount, lower 32 bits: fd +} + +func (s *socketFD) init(fd int) { + s.state.Store(uint64(fd & 0xFFFFFFFF)) +} + +func (s *socketFD) load() int { + return int(int32(s.state.Load())) +} + +func (s *socketFD) refCount() int { + return int(s.state.Load() >> 32) +} + +func (s *socketFD) acquire() int { + for { + oldState := s.state.Load() + refCount := (oldState >> 32) + 1 + newState := (refCount << 32) | (oldState & 0xFFFFFFFF) + + if int32(oldState) < 0 { + return -1 + } + if s.state.CompareAndSwap(oldState, newState) { + return int(int32(oldState)) // int32->int for sign extension + } + } +} + +func (s *socketFD) releaseFunc(fd int, closeFD func(int)) { + for { + oldState := s.state.Load() + refCount := (oldState >> 32) - 1 + newState := (refCount << 32) | (oldState & 0xFFFFFFFF) + + if s.state.CompareAndSwap(oldState, newState) { + if int32(oldState) < 0 && refCount == 0 { + closeFD(fd) + } + break + } + } +} + +func (s *socketFD) closeFunc(closeFD func(int)) { + for { + oldState := s.state.Load() + refCount := oldState >> 32 + newState := oldState | 0xFFFFFFFF + + if s.state.CompareAndSwap(oldState, newState) { + if fd := int32(oldState); fd >= 0 && refCount == 0 { + closeFD(int(fd)) + } + break + } + } +} diff --git a/internal/network/socket_test.go b/internal/network/socket_test.go new file mode 100644 index 00000000..30e9ab26 --- /dev/null +++ b/internal/network/socket_test.go @@ -0,0 +1,60 @@ +package network + +import ( + "testing" + + "github.com/stealthrocket/timecraft/internal/assert" +) + +func TestSocketRefCount(t *testing.T) { + var lastCloseFD int + var closeFD = func(fd int) { lastCloseFD = fd } + + t.Run("close with zero ref count", func(t *testing.T) { + var s socketFD + s.init(42) + assert.Equal(t, s.refCount(), 0) + + s.closeFunc(closeFD) + assert.Equal(t, lastCloseFD, 42) + }) + + t.Run("release with zero ref count", func(t *testing.T) { + var s socketFD + s.init(21) + + fd := s.acquire() + assert.Equal(t, fd, 21) + assert.Equal(t, s.refCount(), 1) + + lastCloseFD = -1 + s.releaseFunc(fd, closeFD) + assert.Equal(t, lastCloseFD, -1) + assert.Equal(t, s.refCount(), 0) + }) + + t.Run("close with non zero ref count", func(t *testing.T) { + var s socketFD + s.init(10) + + fd0 := s.acquire() + assert.Equal(t, fd0, 10) + assert.Equal(t, s.refCount(), 1) + + fd1 := s.acquire() + assert.Equal(t, fd1, 10) + assert.Equal(t, s.refCount(), 2) + + lastCloseFD = -1 + s.closeFunc(closeFD) + assert.Equal(t, lastCloseFD, -1) + + s.releaseFunc(fd0, closeFD) + assert.Equal(t, lastCloseFD, -1) + assert.Equal(t, s.refCount(), 1) + + s.releaseFunc(fd1, closeFD) + assert.Equal(t, lastCloseFD, 10) + assert.Equal(t, s.refCount(), 0) + }) +} diff --git a/internal/network/socket_unix.go b/internal/network/socket_unix.go new file mode 100644 index 00000000..6932c67e --- /dev/null +++ b/internal/network/socket_unix.go @@ -0,0 +1,24 @@ +package network + +import ( + "fmt" + "os" + "runtime/debug" + + "golang.org/x/sys/unix" +) + +func (s *socketFD) release(fd int) { + s.releaseFunc(fd, closeTraceError) +} + +func (s *socketFD) close() { + s.closeFunc(closeTraceError) +} + +func closeTraceError(fd int) { + if err := unix.Close(fd); err != nil { + fmt.Fprintf(os.Stderr, "close(%d) => %s\n", fd, err) + debug.PrintStack() + } +} From d6c6b7a3d67f643380d407e3833cfb7e4de3a959 Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Fri, 14 Jul 2023 12:56:32 -0700 Subject: [PATCH 08/16] network: add documentation Signed-off-by: Achille Roussel --- internal/network/conn.go | 71 +++++++++++++++++++ internal/network/local_unix.go | 125 +++++++++++++++++++++++++++++---- internal/network/socket.go | 30 +++++++- 3 files changed, 208 insertions(+), 18 deletions(-) diff --git a/internal/network/conn.go b/internal/network/conn.go index 300c50d4..7d03f076 100644 --- a/internal/network/conn.go +++ b/internal/network/conn.go @@ -59,3 +59,74 @@ func copyAndClose(w, r net.Conn, b []byte, errs chan<- error, wg *sync.WaitGroup errs <- err } } + +/* +func packetTunnel(downstream, upstream net.PacketConn, rbufsize, wbufsize int) error { + buffer := make([]byte, 2*addrBufSize+rbufsize+wbufsize) // TODO: pool this buffer? + errs := make(chan error, 2) + wg := new(sync.WaitGroup) + wg.Add(2) + + go packetCopyAndCloseInbound(downstream, upstream, buffer[:addrBufSize+rbufsize], errs, wg) + go packetCopyAndCloseOutbound(upstream, downstream, buffer[addrBufSize+rbufsize:], errs, wg) + + wg.Wait() + close(errs) + return <-errs +} + +func packetCopyAndCloseInbound(w, r net.PacketConn, b []byte, errs chan<- error, wg *sync.WaitGroup) { + defer wg.Done() + defer closeWrite(w) + + for { + n, addr, err := r.ReadFrom(b[addrBufSize:]) + if err != nil { + if err != io.EOF { + errs <- err + } + return + } + + addrBuf := encodeSockaddrAny(addr) + copy(b, addrBuf[:]) + + _, err := w.WriteTo(b[:addrBufSize+n], nil) + if err != nil { + errs <- err + return + } + } +} + +func packetCopyAndCloseOutbound(w, r net.PacketConn, b []byte, errs chan<- error, wg *sync.WaitGroup) { + defer wg.Done() + defer closeWrite(w) + + var addrBuf = b[:addrBufSize] + var dstAddr net.UDPAddr + for { + n, _, err := r.ReadFrom(b) + if err != nil { + if err != io.EOF { + errs <- err + } + return + } + + dstAddr.Port = int(binary.LittleEndian.Uint16(addrBuf[2:4])) + switch Family(binary.LittleEndian.Uint16(addrBuf[0:2])) { + case INET: + dstAddr.IP = net.IP(addrBuf[4:8]) + default: + dstAddr.IP = net.IP(addrBuf[4:20]) + } + + _, err := w.WriteTo(b[addrBufSize:addrBufSize+n], &dstAddr) + if err != nil { + errs <- err + return + } + } +} +*/ diff --git a/internal/network/local_unix.go b/internal/network/local_unix.go index 05ff09d7..bc03fab3 100644 --- a/internal/network/local_unix.go +++ b/internal/network/local_unix.go @@ -109,15 +109,38 @@ func (s *localSocket) Close() error { } } + // First close the socket pair; if there are background goroutines managing + // connections to external networks, this will interrupt the tunnels in + // charge of passing data back and forth between the local socket fds and + // the connections. Note that fd1 may have been detached if it was sent to + // another socket to establish a connection. s.fd0.close() s.fd1.close() + // When a listen function is configured on the parent namespace, the socket + // may have created a bridge to accept inbound connections from other + // networks so we have to close the net.Listener in order to interrupt the + // goroutine in charge of accepting connections. if s.lstn != nil { s.lstn.Close() } + + // When a dial function isconfigured on the parent namespace, the socket may + // be in the process of establishing an outbound connection; in that case, + // a context was created to control asynchronous cancellation of the dial + // and we must invoke the cancellation function to interrupt it. if s.cancel != nil { s.cancel() } + + // When either a listen or dial functions were set on the parent namespace + // and the socket has created bridges to external networks, an error channel + // was set to receive errors from the background goroutines managing those + // connections. Because we arleady interrupted asynchrononus operations, we + // have the guarantee that the channel will be closed when the goroutines + // exit, so the for loop will eventually stop. Flushing the error channel is + // necessary to ensure that none of the background goroutines remain blocked + // attempting to produce to the errors channel. if s.errs != nil { for range s.errs { } @@ -139,11 +162,11 @@ func (s *localSocket) Bind(addr Sockaddr) error { switch bind := addr.(type) { case *SockaddrInet4: if s.family == INET { - return s.ns.bindInet4(s, bind) + return s.bindInet4(bind) } case *SockaddrInet6: if s.family == INET6 { - return s.ns.bindInet6(s, bind) + return s.bindInet6(bind) } } return EAFNOSUPPORT @@ -152,10 +175,30 @@ func (s *localSocket) Bind(addr Sockaddr) error { func (s *localSocket) bindAny() error { switch s.family { case INET: - return s.ns.bindInet4(s, &sockaddrInet4Any) + return s.bindInet4(&sockaddrInet4Any) default: - return s.ns.bindInet6(s, &sockaddrInet6Any) + return s.bindInet6(&sockaddrInet6Any) + } +} + +func (s *localSocket) bindInet4(addr *SockaddrInet4) error { + if err := s.ns.bindInet4(s, addr); err != nil { + return err + } + if s.socktype == DGRAM && s.ns.listenPacket != nil { + return s.listenPacket() + } + return nil +} + +func (s *localSocket) bindInet6(addr *SockaddrInet6) error { + if err := s.ns.bindInet6(s, addr); err != nil { + return err } + if s.socktype == DGRAM && s.ns.listenPacket != nil { + return s.listenPacket() + } + return nil } func (s *localSocket) Listen(backlog int) error { @@ -180,7 +223,7 @@ func (s *localSocket) Listen(backlog int) error { } } if s.ns.listen != nil { - if err := s.bridge(); err != nil { + if err := s.listen(); err != nil { return err } } @@ -188,16 +231,59 @@ func (s *localSocket) Listen(backlog int) error { return nil } -func (s *localSocket) bridge() error { - var address string +func (s *localSocket) listenAddress() string { switch a := s.name.Load().(type) { case *SockaddrInet4: - address = fmt.Sprintf(":%d", a.Port) + return fmt.Sprintf(":%d", a.Port) case *SockaddrInet6: - address = fmt.Sprintf("[::]:%d", a.Port) + return fmt.Sprintf("[::]:%d", a.Port) + default: + return "" } +} - l, err := s.ns.listen(context.TODO(), "tcp", address) +func (s *localSocket) listenPacket() error { + // TODO: figure out how to tunnel packet connections in a way that allows + // for both internal and external packets to transit. + + // fd := s.fd1.acquire() + // if fd < 0 { + // return EBADF + // } + // defer s.fd1.release(fd) + + // network := s.protocol.Network() + // address := s.listenAddress() + + // upstream, err := s.ns.listenPacket(context.TODO(), network, address) + // if err != nil { + // return err + // } + + // f := os.NewFile(uintptr(fd), "") + // defer f.Close() + // s.fd1.acquire() + // s.fd1.close() + + // downstream, err := net.FilePacketConn(f) + // if err != nil { + // upstream.Close() + // return err + // } + + // go func() { + // defer downstream.Close() + // defer upstream.Close() + + // }() + return nil +} + +func (s *localSocket) listen() error { + network := s.protocol.Network() + address := s.listenAddress() + + l, err := s.ns.listen(context.TODO(), network, address) if err != nil { return err } @@ -368,6 +454,9 @@ func (s *localSocket) dial(addr Sockaddr) error { } defer s.fd1.release(fd1) + // TODO: + // - remove the 2x factor that linux applies on socket buffers + // - do the two ends of a socket pair share the same buffer sizes? rbufsize, err := getsockoptInt(fd1, unix.SOL_SOCKET, unix.SO_RCVBUF) if err != nil { return err @@ -669,16 +758,22 @@ func encodeSockaddrAny(addr any) (buf [addrBufSize]byte) { case *SockaddrInet6: return encodeAddrPortInet6(a.Addr, a.Port) case *net.TCPAddr: - if ipv4 := a.IP.To4(); ipv4 != nil { - return encodeAddrPortInet4(([4]byte)(ipv4), a.Port) - } else { - return encodeAddrPortInet6(([16]byte)(a.IP), a.Port) - } + return encodeAddrPortIP(a.IP, a.Port) + case *net.UDPAddr: + return encodeAddrPortIP(a.IP, a.Port) default: return } } +func encodeAddrPortIP(addr net.IP, port int) (buf [addrBufSize]byte) { + if ipv4 := addr.To4(); ipv4 != nil { + return encodeAddrPortInet4(([4]byte)(ipv4), port) + } else { + return encodeAddrPortInet6(([16]byte)(addr), port) + } +} + func encodeAddrPortInet4(addr [4]byte, port int) (buf [addrBufSize]byte) { binary.LittleEndian.PutUint16(buf[0:2], uint16(INET)) binary.LittleEndian.PutUint16(buf[2:4], uint16(port)) diff --git a/internal/network/socket.go b/internal/network/socket.go index bd801e17..15f89761 100644 --- a/internal/network/socket.go +++ b/internal/network/socket.go @@ -4,6 +4,25 @@ import ( "sync/atomic" ) +// socketFD is used to manage the lifecycle of socket file descriptors; +// it allows multiple goroutines to share ownership of the socket while +// coordinating to close the file descriptor via an atomic reference count. +// +// Goroutines must call acquire to access the file descriptor; if they get a +// negative number, it indicates that the socket was already closed and the +// method should usually return EBADF. +// +// After acquiring a valid file descriptor, the goroutine is responsible for +// calling release with the same fd number that was returned by acquire. The +// release may cause the file descriptor to be closed if the close method was +// called in between and releasing the fd causes the reference count to reach +// zero. +// +// The close method detaches the file descriptor from the socketFD, but it only +// closes it if the reference count is zero (no other goroutines was sharing +// ownership). After closing the socketFD, all future calls to acquire return a +// negative number, preventing other goroutines from acquiring ownership of the +// file descriptor and guaranteeing that it will eventually be closed. type socketFD struct { state atomic.Uint64 // upper 32 bits: refCount, lower 32 bits: fd } @@ -26,11 +45,12 @@ func (s *socketFD) acquire() int { refCount := (oldState >> 32) + 1 newState := (refCount << 32) | (oldState & 0xFFFFFFFF) - if int32(oldState) < 0 { + fd := int32(oldState) + if fd < 0 { return -1 } if s.state.CompareAndSwap(oldState, newState) { - return int(int32(oldState)) // int32->int for sign extension + return int(fd) } } } @@ -56,8 +76,12 @@ func (s *socketFD) closeFunc(closeFD func(int)) { refCount := oldState >> 32 newState := oldState | 0xFFFFFFFF + fd := int32(oldState) + if fd < 0 { + break + } if s.state.CompareAndSwap(oldState, newState) { - if fd := int32(oldState); fd >= 0 && refCount == 0 { + if refCount == 0 { closeFD(int(fd)) } break From bb5dc7e7b0dd4678936f6c88b408b90e62638922 Mon Sep 17 00:00:00 2001 From: Achille Date: Mon, 17 Jul 2023 15:04:50 -0700 Subject: [PATCH 09/16] Update internal/network/local.go Co-authored-by: Thomas Pelletier Signed-off-by: Achille --- internal/network/local.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/network/local.go b/internal/network/local.go index 4feebe5a..5c5006f8 100644 --- a/internal/network/local.go +++ b/internal/network/local.go @@ -358,7 +358,7 @@ func (i *localInterface) bind(sock *localSocket, addr net.IP, port int) error { } break } - if port == 65535 { + if port == 65536 { return EADDRNOTAVAIL } switch a := name.(type) { From 4134ef28d3162bb113c40bed70726d09115da8f2 Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Tue, 18 Jul 2023 23:07:26 -0700 Subject: [PATCH 10/16] make lint happy Signed-off-by: Achille Roussel --- internal/network/conn.go | 40 ++++++---------------------------- internal/network/local_test.go | 5 +++-- internal/network/local_unix.go | 15 ++++++------- 3 files changed, 17 insertions(+), 43 deletions(-) diff --git a/internal/network/conn.go b/internal/network/conn.go index 7d03f076..e217dbfc 100644 --- a/internal/network/conn.go +++ b/internal/network/conn.go @@ -6,37 +6,6 @@ import ( "sync" ) -type closeReader interface { - CloseRead() error -} - -type closeWriter interface { - CloseWrite() error -} - -var ( - _ closeReader = (*net.UnixConn)(nil) - _ closeWriter = (*net.UnixConn)(nil) -) - -func closeRead(conn io.Closer) error { - switch c := conn.(type) { - case closeReader: - return c.CloseRead() - default: - return c.Close() - } -} - -func closeWrite(conn io.Closer) error { - switch c := conn.(type) { - case closeWriter: - return c.CloseWrite() - default: - return c.Close() - } -} - func tunnel(downstream, upstream net.Conn, rbufsize, wbufsize int) error { buffer := make([]byte, rbufsize+wbufsize) // TODO: pool this buffer? errs := make(chan error, 2) @@ -52,12 +21,17 @@ func tunnel(downstream, upstream net.Conn, rbufsize, wbufsize int) error { } func copyAndClose(w, r net.Conn, b []byte, errs chan<- error, wg *sync.WaitGroup) { - defer wg.Done() - defer closeWrite(w) //nolint:errcheck _, err := io.CopyBuffer(w, r, b) if err != nil { errs <- err } + switch c := w.(type) { + case interface{ CloseWrite() error }: + c.CloseWrite() //nolint:errcheck + default: + c.Close() + } + wg.Done() } /* diff --git a/internal/network/local_test.go b/internal/network/local_test.go index 681b2f81..21f23c0c 100644 --- a/internal/network/local_test.go +++ b/internal/network/local_test.go @@ -356,8 +356,9 @@ func testLocalNetworkInboundAccept(t *testing.T, n *network.LocalNetwork) { assert.Equal(t, string(buf[:7]), "message") // exercise shutting down the write end of the inbound connection - conn.(*net.TCPConn).CloseWrite() - waitReadyRead(peer) + assert.OK(t, conn.(*net.TCPConn).CloseWrite()) + assert.OK(t, waitReadyRead(peer)) + size, _, _, err = peer.RecvFrom([][]byte{buf}, 0) assert.OK(t, err) assert.Equal(t, size, 0) diff --git a/internal/network/local_unix.go b/internal/network/local_unix.go index bc03fab3..265560b1 100644 --- a/internal/network/local_unix.go +++ b/internal/network/local_unix.go @@ -352,14 +352,13 @@ func (s *localSocket) serve(upstream net.Conn) error { defer upstream.Close() defer downstream.Close() - if err := tunnel(downstream, upstream, rbufsize, wbufsize); err != nil { - // TODO: figure out if this error needs to be reported: - // - // When the downstream file is closed, the other end of the - // connection will observe that the socket was shutdown, we - // lose the information of why but we currently have no way - // of locating the peer socket on the other side. - } + _ = tunnel(downstream, upstream, rbufsize, wbufsize) + // TODO: figure out if this error needs to be reported: + // + // When the downstream file is closed, the other end of the + // connection will observe that the socket was shutdown, we + // lose the information of why but we currently have no way + // of locating the peer socket on the other side. }() return nil } From e52b773da1c9966c4564c579e816dc1271650e9f Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Wed, 19 Jul 2023 06:23:00 +0000 Subject: [PATCH 11/16] fix linux build Signed-off-by: Achille Roussel --- internal/network/host_linux.go | 4 ++-- internal/network/local_linux.go | 2 +- internal/network/local_unix.go | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/internal/network/host_linux.go b/internal/network/host_linux.go index 7047b755..3a2161d5 100644 --- a/internal/network/host_linux.go +++ b/internal/network/host_linux.go @@ -13,11 +13,11 @@ func (hostNamespace) Socket(family Family, socktype Socktype, protocol Protocol) } func (s *hostSocket) Accept() (Socket, Sockaddr, error) { - fd := s.acquire() + fd := s.fd.acquire() if fd < 0 { return nil, nil, EBADF } - defer s.release(fd) + defer s.fd.release(fd) conn, addr, err := ignoreEINTR3(func() (int, unix.Sockaddr, error) { return unix.Accept4(fd, unix.SOCK_CLOEXEC|unix.SOCK_NONBLOCK) }) diff --git a/internal/network/local_linux.go b/internal/network/local_linux.go index d9d14abe..47735332 100644 --- a/internal/network/local_linux.go +++ b/internal/network/local_linux.go @@ -39,7 +39,7 @@ func (s *localSocket) SetOptInt(level, name, value int) error { case unix.SOL_SOCKET: switch name { case unix.SO_BINDTODEVICE: - return 0, ENOPROTOOPT + return ENOPROTOOPT } } diff --git a/internal/network/local_unix.go b/internal/network/local_unix.go index 265560b1..73e8f7e5 100644 --- a/internal/network/local_unix.go +++ b/internal/network/local_unix.go @@ -527,7 +527,7 @@ func (s *localSocket) Accept() (Socket, Sockaddr, error) { } var oobn int - var oobBuf [16]byte + var oobBuf [24]byte var addrBuf [addrBufSize]byte for { var err error From ae2221d839c8c2ae815c64862bc95333a225116f Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Wed, 19 Jul 2023 07:10:05 +0000 Subject: [PATCH 12/16] network: implement hTLS Signed-off-by: Achille Roussel --- internal/htls/htls.go | 3 +- internal/network/host_unix.go | 24 ++++++- internal/network/local_darwin.go | 18 ----- internal/network/local_linux.go | 42 +----------- internal/network/local_unix.go | 114 +++++++++++++++++++++++++++++-- internal/network/network.go | 4 ++ internal/network/network_unix.go | 8 +++ internal/sandbox/socket.go | 2 +- 8 files changed, 149 insertions(+), 66 deletions(-) diff --git a/internal/htls/htls.go b/internal/htls/htls.go index dedb55b3..047792b9 100644 --- a/internal/htls/htls.go +++ b/internal/htls/htls.go @@ -1,4 +1,5 @@ package htls const Level = 0x74696d65 -const Option = 1 + +const ServerName = 1 diff --git a/internal/network/host_unix.go b/internal/network/host_unix.go index 2c64c793..da685d0a 100644 --- a/internal/network/host_unix.go +++ b/internal/network/host_unix.go @@ -126,6 +126,24 @@ func (s *hostSocket) Shutdown(how int) error { return shutdown(fd, how) } +func (s *hostSocket) GetOptInt(level, name int) (int, error) { + fd := s.fd.acquire() + if fd < 0 { + return -1, EBADF + } + defer s.fd.release(fd) + return getsockoptInt(fd, level, name) +} + +func (s *hostSocket) GetOptString(level, name int) (string, error) { + fd := s.fd.acquire() + if fd < 0 { + return "", EBADF + } + defer s.fd.release(fd) + return getsockoptString(fd, level, name) +} + func (s *hostSocket) SetOptInt(level, name, value int) error { fd := s.fd.acquire() if fd < 0 { @@ -135,11 +153,11 @@ func (s *hostSocket) SetOptInt(level, name, value int) error { return setsockoptInt(fd, level, name, value) } -func (s *hostSocket) GetOptInt(level, name int) (int, error) { +func (s *hostSocket) SetOptString(level, name int, value string) error { fd := s.fd.acquire() if fd < 0 { - return -1, EBADF + return EBADF } defer s.fd.release(fd) - return getsockoptInt(fd, level, name) + return setsockoptString(fd, level, name, value) } diff --git a/internal/network/local_darwin.go b/internal/network/local_darwin.go index 7c936efe..5cd136b9 100644 --- a/internal/network/local_darwin.go +++ b/internal/network/local_darwin.go @@ -33,21 +33,3 @@ func closePair(fds *[2]int) { fds[0] = -1 fds[1] = -1 } - -func (s *localSocket) GetOptInt(level, name int) (int, error) { - fd := s.fd0.acquire() - if fd < 0 { - return 0, EBADF - } - defer s.fd0.release(fd) - return getsockoptInt(fd, level, name) -} - -func (s *localSocket) SetOptInt(level, name, value int) error { - fd := s.fd0.acquire() - if fd < 0 { - return EBADF - } - defer s.fd0.release(fd) - return setsockoptInt(fd, level, name, value) -} diff --git a/internal/network/local_linux.go b/internal/network/local_linux.go index 47735332..3370f115 100644 --- a/internal/network/local_linux.go +++ b/internal/network/local_linux.go @@ -1,47 +1,11 @@ package network -import "golang.org/x/sys/unix" +import ( + "golang.org/x/sys/unix" +) func socketpair(family, socktype, protocol int) ([2]int, error) { return ignoreEINTR2(func() ([2]int, error) { return unix.Socketpair(family, socktype|unix.SOCK_CLOEXEC|unix.SOCK_NONBLOCK, protocol) }) } - -func (s *localSocket) GetOptInt(level, name int) (int, error) { - fd := s.fd0.acquire() - if fd < 0 { - return 0, EBADF - } - defer s.fd0.release(fd) - - switch level { - case unix.SOL_SOCKET: - switch name { - case unix.SO_DOMAIN: - return int(s.family), nil - case unix.SO_BINDTODEVICE: - return 0, ENOPROTOOPT - } - } - - return getsockoptInt(fd, level, name) -} - -func (s *localSocket) SetOptInt(level, name, value int) error { - fd := s.fd0.acquire() - if fd < 0 { - return EBADF - } - defer s.fd0.release(fd) - - switch level { - case unix.SOL_SOCKET: - switch name { - case unix.SO_BINDTODEVICE: - return ENOPROTOOPT - } - } - - return setsockoptInt(fd, level, name, value) -} diff --git a/internal/network/local_unix.go b/internal/network/local_unix.go index 73e8f7e5..afe1b5ad 100644 --- a/internal/network/local_unix.go +++ b/internal/network/local_unix.go @@ -2,12 +2,14 @@ package network import ( "context" + "crypto/tls" "encoding/binary" "fmt" "net" "os" "sync/atomic" + "github.com/stealthrocket/timecraft/internal/htls" "golang.org/x/sys/unix" ) @@ -65,6 +67,7 @@ type localSocket struct { addrBuf [addrBufSize]byte lstn net.Listener + htls chan<- string errs <-chan error cancel context.CancelFunc } @@ -475,25 +478,46 @@ func (s *localSocket) dial(addr Sockaddr) error { return err } - dialNetwork := s.protocol.Network() - dialAddress := SockaddrAddrPort(addr).String() ctx, cancel := context.WithCancel(context.Background()) + s.cancel = cancel + + htls := make(chan string, 1) + s.htls = htls errs := make(chan error, 1) s.errs = errs - s.cancel = cancel + network := s.protocol.Network() + address := SockaddrAddrPort(addr).String() go func() { defer close(errs) defer downstream.Close() - upstream, err := dial(ctx, dialNetwork, dialAddress) + upstream, err := dial(ctx, network, address) if err != nil { errs <- err return } defer upstream.Close() + select { + case <-ctx.Done(): + errs <- ctx.Err() + return + case serverName, ok := <-htls: + if !ok { + break + } + tlsConn := tls.Client(upstream, &tls.Config{ + ServerName: serverName, + }) + if err := tlsConn.HandshakeContext(ctx); err != nil { + errs <- err + return + } + upstream = tlsConn + } + if err := tunnel(downstream, upstream, rbufsize, wbufsize); err != nil { errs <- err } @@ -611,6 +635,7 @@ func (s *localSocket) RecvFrom(iovs [][]byte, flags int) (int, int, Sockaddr, er return -1, 0, nil, EBADF } defer s.fd0.release(fd) + s.htlsClear() if s.state.is(listening) { return -1, 0, nil, EINVAL @@ -658,6 +683,7 @@ func (s *localSocket) SendTo(iovs [][]byte, addr Sockaddr, flags int) (int, erro return -1, EBADF } defer s.fd0.release(fd) + s.htlsClear() if s.state.is(listening) { return -1, EINVAL @@ -732,9 +758,68 @@ func (s *localSocket) Shutdown(how int) error { return EBADF } defer s.fd0.release(fd) + defer s.htlsClear() return shutdown(fd, how) } +func (s *localSocket) GetOptInt(level, name int) (int, error) { + fd := s.fd0.acquire() + if fd < 0 { + return -1, EBADF + } + defer s.fd0.release(fd) + return getsockoptInt(fd, level, name) +} + +func (s *localSocket) GetOptString(level, name int) (string, error) { + fd := s.fd0.acquire() + if fd < 0 { + return "", EBADF + } + defer s.fd0.release(fd) + + switch level { + case htls.Level: + switch name { + case htls.ServerName: + return "", EINVAL + default: + return "", ENOPROTOOPT + } + } + + return getsockoptString(fd, level, name) +} + +func (s *localSocket) SetOptInt(level, name, value int) error { + fd := s.fd0.acquire() + if fd < 0 { + return EBADF + } + defer s.fd0.release(fd) + return setsockoptInt(fd, level, name, value) +} + +func (s *localSocket) SetOptString(level, name int, value string) error { + fd := s.fd0.acquire() + if fd < 0 { + return EBADF + } + defer s.fd0.release(fd) + + switch level { + case htls.Level: + switch name { + case htls.ServerName: + return s.htlsSetServerName(value) + default: + return ENOPROTOOPT + } + } + + return setsockoptString(fd, level, name, value) +} + func (s *localSocket) getError() error { select { case err := <-s.errs: @@ -744,6 +829,27 @@ func (s *localSocket) getError() error { } } +func (s *localSocket) htlsClear() { + if s.htls != nil { + close(s.htls) + } +} + +func (s *localSocket) htlsSetServerName(hostname string) (err error) { + defer func() { + if recover() != nil { + if s.htls == nil { + err = EINVAL + } else { + err = EISCONN + } + } + }() + s.htls <- hostname + close(s.htls) + return nil +} + func clearIOVecs(iovs [][]byte) { for i := range iovs { iovs[i] = nil diff --git a/internal/network/network.go b/internal/network/network.go index 0ed54c30..1cbafa6d 100644 --- a/internal/network/network.go +++ b/internal/network/network.go @@ -40,7 +40,11 @@ type Socket interface { GetOptInt(level, name int) (int, error) + GetOptString(level, name int) (string, error) + SetOptInt(level, name, value int) error + + SetOptString(level, name int, value string) error } type Socktype uint8 diff --git a/internal/network/network_unix.go b/internal/network/network_unix.go index fd22c6a2..4ec366f6 100644 --- a/internal/network/network_unix.go +++ b/internal/network/network_unix.go @@ -130,6 +130,14 @@ func getsockoptInt(fd, level, name int) (int, error) { return ignoreEINTR2(func() (int, error) { return unix.GetsockoptInt(fd, level, name) }) } +func getsockoptString(fd, level, name int) (string, error) { + return ignoreEINTR2(func() (string, error) { return unix.GetsockoptString(fd, level, name) }) +} + func setsockoptInt(fd, level, name, value int) error { return ignoreEINTR(func() error { return unix.SetsockoptInt(fd, level, name, value) }) } + +func setsockoptString(fd, level, name int, value string) error { + return ignoreEINTR(func() error { return unix.SetsockoptString(fd, level, name, value) }) +} diff --git a/internal/sandbox/socket.go b/internal/sandbox/socket.go index 17644652..77e1cf11 100644 --- a/internal/sandbox/socket.go +++ b/internal/sandbox/socket.go @@ -20,7 +20,7 @@ func sumIOVecLen(iovs []wasi.IOVec) (n int) { } const ( - htlsOption wasi.SocketOption = (wasi.SocketOption(htls.Level) << 32) | (htls.Option) + htlsOption wasi.SocketOption = (wasi.SocketOption(htls.Level) << 32) | (htls.ServerName) ) type sockflags uint32 From b2cd4f8c8791c6ce510995fadf1df15549993dd5 Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Wed, 19 Jul 2023 13:37:51 -0700 Subject: [PATCH 13/16] network: implement bridge of datagram sockets to external networks Signed-off-by: Achille Roussel --- internal/network/conn.go | 106 --------- internal/network/local_unix.go | 253 +++++++++++++++++---- internal/timemachine/wasicall/wasi_test.go | 2 +- 3 files changed, 204 insertions(+), 157 deletions(-) delete mode 100644 internal/network/conn.go diff --git a/internal/network/conn.go b/internal/network/conn.go deleted file mode 100644 index e217dbfc..00000000 --- a/internal/network/conn.go +++ /dev/null @@ -1,106 +0,0 @@ -package network - -import ( - "io" - "net" - "sync" -) - -func tunnel(downstream, upstream net.Conn, rbufsize, wbufsize int) error { - buffer := make([]byte, rbufsize+wbufsize) // TODO: pool this buffer? - errs := make(chan error, 2) - wg := new(sync.WaitGroup) - wg.Add(2) - - go copyAndClose(downstream, upstream, buffer[:rbufsize], errs, wg) - go copyAndClose(upstream, downstream, buffer[rbufsize:], errs, wg) - - wg.Wait() - close(errs) - return <-errs -} - -func copyAndClose(w, r net.Conn, b []byte, errs chan<- error, wg *sync.WaitGroup) { - _, err := io.CopyBuffer(w, r, b) - if err != nil { - errs <- err - } - switch c := w.(type) { - case interface{ CloseWrite() error }: - c.CloseWrite() //nolint:errcheck - default: - c.Close() - } - wg.Done() -} - -/* -func packetTunnel(downstream, upstream net.PacketConn, rbufsize, wbufsize int) error { - buffer := make([]byte, 2*addrBufSize+rbufsize+wbufsize) // TODO: pool this buffer? - errs := make(chan error, 2) - wg := new(sync.WaitGroup) - wg.Add(2) - - go packetCopyAndCloseInbound(downstream, upstream, buffer[:addrBufSize+rbufsize], errs, wg) - go packetCopyAndCloseOutbound(upstream, downstream, buffer[addrBufSize+rbufsize:], errs, wg) - - wg.Wait() - close(errs) - return <-errs -} - -func packetCopyAndCloseInbound(w, r net.PacketConn, b []byte, errs chan<- error, wg *sync.WaitGroup) { - defer wg.Done() - defer closeWrite(w) - - for { - n, addr, err := r.ReadFrom(b[addrBufSize:]) - if err != nil { - if err != io.EOF { - errs <- err - } - return - } - - addrBuf := encodeSockaddrAny(addr) - copy(b, addrBuf[:]) - - _, err := w.WriteTo(b[:addrBufSize+n], nil) - if err != nil { - errs <- err - return - } - } -} - -func packetCopyAndCloseOutbound(w, r net.PacketConn, b []byte, errs chan<- error, wg *sync.WaitGroup) { - defer wg.Done() - defer closeWrite(w) - - var addrBuf = b[:addrBufSize] - var dstAddr net.UDPAddr - for { - n, _, err := r.ReadFrom(b) - if err != nil { - if err != io.EOF { - errs <- err - } - return - } - - dstAddr.Port = int(binary.LittleEndian.Uint16(addrBuf[2:4])) - switch Family(binary.LittleEndian.Uint16(addrBuf[0:2])) { - case INET: - dstAddr.IP = net.IP(addrBuf[4:8]) - default: - dstAddr.IP = net.IP(addrBuf[4:20]) - } - - _, err := w.WriteTo(b[addrBufSize:addrBufSize+n], &dstAddr) - if err != nil { - errs <- err - return - } - } -} -*/ diff --git a/internal/network/local_unix.go b/internal/network/local_unix.go index afe1b5ad..9ad52d57 100644 --- a/internal/network/local_unix.go +++ b/internal/network/local_unix.go @@ -5,8 +5,10 @@ import ( "crypto/tls" "encoding/binary" "fmt" + "io" "net" "os" + "sync" "sync/atomic" "github.com/stealthrocket/timecraft/internal/htls" @@ -35,6 +37,7 @@ const ( bound localSocketState = 1 << iota accepted connected + tunneled listening ) @@ -66,6 +69,7 @@ type localSocket struct { iovs [][]byte addrBuf [addrBufSize]byte + conn net.PacketConn lstn net.Listener htls chan<- string errs <-chan error @@ -120,6 +124,14 @@ func (s *localSocket) Close() error { s.fd0.close() s.fd1.close() + // When a packet listen function is configured on the parent namespace, the + // socket may have created a bridge to accept inbound datagrams from other + // networks so we have to close the net.PacketConn in order to interrupt the + // goroutine in charge of receiving packets. + if s.conn != nil { + s.conn.Close() + } + // When a listen function is configured on the parent namespace, the socket // may have created a bridge to accept inbound connections from other // networks so we have to close the net.Listener in order to interrupt the @@ -246,39 +258,74 @@ func (s *localSocket) listenAddress() string { } func (s *localSocket) listenPacket() error { - // TODO: figure out how to tunnel packet connections in a way that allows - // for both internal and external packets to transit. - - // fd := s.fd1.acquire() - // if fd < 0 { - // return EBADF - // } - // defer s.fd1.release(fd) - - // network := s.protocol.Network() - // address := s.listenAddress() - - // upstream, err := s.ns.listenPacket(context.TODO(), network, address) - // if err != nil { - // return err - // } - - // f := os.NewFile(uintptr(fd), "") - // defer f.Close() - // s.fd1.acquire() - // s.fd1.close() - - // downstream, err := net.FilePacketConn(f) - // if err != nil { - // upstream.Close() - // return err - // } - - // go func() { - // defer downstream.Close() - // defer upstream.Close() - - // }() + sendSocketFd := s.fd1.acquire() + if sendSocketFd < 0 { + return EBADF + } + defer s.fd1.release(sendSocketFd) + + rbufsize, err := getsockoptInt(sendSocketFd, unix.SOL_SOCKET, unix.SO_RCVBUF) + if err != nil { + return err + } + wbufsize, err := getsockoptInt(sendSocketFd, unix.SOL_SOCKET, unix.SO_RCVBUF) + if err != nil { + return err + } + + network := s.protocol.Network() + address := s.listenAddress() + + conn, err := s.ns.listenPacket(context.TODO(), network, address) + if err != nil { + return err + } + + switch c := conn.(type) { + case *net.UDPConn: + _ = c.SetReadBuffer(rbufsize) + _ = c.SetWriteBuffer(wbufsize) + } + + errs := make(chan error, 1) + s.errs = errs + s.conn = conn + + buffer := make([]byte, rbufsize) + // Increase reference count for fd1 because the goroutine will now share + // ownership of the file descriptor. + s.fd1.acquire() + go func() { + defer close(errs) + defer conn.Close() + defer s.fd1.release(sendSocketFd) + + var iovs [2][]byte + var addrBuf [addrBufSize]byte + // TODO: use optimizations like net.(*UDPConn).ReadMsgUDPAddrPort to + // remove the heap allocation of the net.Addr returned by ReadFrom. + for { + n, addr, err := conn.ReadFrom(buffer) + if err != nil { + if err != io.EOF { + errs <- err + } + return + } + + addrBuf = encodeSockaddrAny(addr) + iovs[0] = addrBuf[:] + iovs[1] = buffer[:n] + + _, err = ignoreEINTR2(func() (int, error) { + return unix.SendmsgBuffers(sendSocketFd, iovs[:], nil, nil, 0) + }) + if err != nil { + errs <- err + return + } + } + }() return nil } @@ -405,20 +452,16 @@ func (s *localSocket) Connect(addr Sockaddr) error { return ECONNREFUSED } - if s.socktype == DGRAM { - s.peer.Store(addr) - s.state.set(connected) - return nil - } - - serverFd := server.fd1.acquire() - if serverFd < 0 { - return ECONNREFUSED - } - defer server.fd1.release(serverFd) + if s.socktype != DGRAM { + serverFd := server.fd1.acquire() + if serverFd < 0 { + return ECONNREFUSED + } + defer server.fd1.release(serverFd) - if err := s.connect(serverFd, s.name.Load()); err != nil { - return err + if err := s.connect(serverFd, s.name.Load()); err != nil { + return err + } } s.peer.Store(addr) @@ -445,6 +488,16 @@ func (s *localSocket) connect(serverFd int, addr any) error { } func (s *localSocket) dial(addr Sockaddr) error { + // When using datagram sockets with a packet listen function setup, the + // connection is emulated by simply setting the peer address since a packet + // tunnel has already been constructed, all we care about is making sure the + // socket only exchange datagrams with the address it is connected to. + if s.conn != nil { + s.peer.Store(addr) + s.state.set(connected) + return nil + } + dial := s.ns.dial if dial == nil { return ENETUNREACH @@ -524,7 +577,7 @@ func (s *localSocket) dial(addr Sockaddr) error { }() s.peer.Store(addr) - s.state.set(connected) + s.state.set(connected | tunneled) return EINPROGRESS } @@ -649,7 +702,7 @@ func (s *localSocket) RecvFrom(iovs [][]byte, flags int) (int, int, Sockaddr, er } } - if s.socktype == DGRAM { + if s.socktype == DGRAM && !s.state.is(tunneled) { s.iovs = s.iovs[:0] s.iovs = append(s.iovs, s.addrBuf[:]) s.iovs = append(s.iovs, iovs...) @@ -668,11 +721,22 @@ func (s *localSocket) RecvFrom(iovs [][]byte, flags int) (int, int, Sockaddr, er } err = nil } + var addr Sockaddr - if s.socktype == DGRAM { + if s.socktype == DGRAM && !s.state.is(tunneled) { addr = decodeSockaddr(s.addrBuf) n -= addrBufSize + // Connected datagram sockets may receive data from addresses that + // they are not connected to, those datagrams should be dropped. + if s.state.is(connected) { + recvAddrPort := SockaddrAddrPort(addr) + peerAddrPort := SockaddrAddrPort(s.peer.Load().(Sockaddr)) + if recvAddrPort != peerAddrPort { + continue + } + } } + return n, rflags, addr, err } } @@ -703,10 +767,25 @@ func (s *localSocket) SendTo(iovs [][]byte, addr Sockaddr, flags int) (int, erro } } + if s.socktype == DGRAM && !s.state.is(tunneled) && addr == nil { + switch peer := s.peer.Load().(type) { + case *SockaddrInet4: + addr = peer + case *SockaddrInet6: + addr = peer + } + } + sendSocketFd := fd + // We only perform a lookup of the peer socket if an address is provided, + // which means that the socket is not connected to a particular destination + // (it must be a datagram socket). This may result in sending the datagram + // directly to the peer socket's file descriptor, or passing it to a packet + // connection if one was opened by the socket. if addr != nil { var peer *localSocket var err error + switch a := addr.(type) { case *SockaddrInet4: peer, err = s.ns.lookupInet4(s, a) @@ -715,20 +794,39 @@ func (s *localSocket) SendTo(iovs [][]byte, addr Sockaddr, flags int) (int, erro default: return -1, EAFNOSUPPORT } + if err != nil { + // When the destination is not an address within the local network, + // but a packet listen function was setup then the socket has opened + // a packet connection that we use to send the datagram to a remote + // address. Note that we do expect that writing datagrams is not a + // blocking operation and the net.PacketConn may drop packets. + if err == ENETUNREACH && s.conn != nil { + return writeTo(s.conn, iovs, addr) + } return -1, err } + + // If the application tried to send a datagram to a socket which is not + // a datagram socket, we drop the data here pretending that we were able + // to send it. if peer.socktype != DGRAM { return iovecLen(iovs), nil } + // There are two reasons why the peer's second file descriptor may not + // be available here: the peer could have been closed concurrently, or + // it may have been connected to a specific address which indicates + // that it is not a listening socket. peerFd := peer.fd1.acquire() if peerFd < 0 { return -1, EHOSTUNREACH } defer peer.fd1.release(peerFd) sendSocketFd = peerFd + } + if s.socktype == DGRAM && !s.state.is(tunneled) { s.addrBuf = encodeSockaddrAny(s.name.Load()) s.iovs = s.iovs[:0] s.iovs = append(s.iovs, s.addrBuf[:]) @@ -745,7 +843,7 @@ func (s *localSocket) SendTo(iovs [][]byte, addr Sockaddr, flags int) (int, erro } err = nil } - if n > 0 && s.socktype == DGRAM { + if n > 0 && addr != nil { n -= addrBufSize } return n, err @@ -922,3 +1020,58 @@ func iovecLen(iovs [][]byte) (n int) { } return n } + +func iovecBuf(iovs [][]byte) []byte { + buf := make([]byte, 0, iovecLen(iovs)) + for _, iov := range iovs { + buf = append(buf, iov...) + } + return buf +} + +func tunnel(downstream, upstream net.Conn, rbufsize, wbufsize int) error { + buffer := make([]byte, rbufsize+wbufsize) // TODO: pool this buffer? + errs := make(chan error, 2) + wg := new(sync.WaitGroup) + wg.Add(2) + + go copyAndClose(downstream, upstream, buffer[:rbufsize], errs, wg) + go copyAndClose(upstream, downstream, buffer[rbufsize:], errs, wg) + + wg.Wait() + close(errs) + return <-errs +} + +func copyAndClose(w, r net.Conn, b []byte, errs chan<- error, wg *sync.WaitGroup) { + _, err := io.CopyBuffer(w, r, b) + if err != nil { + errs <- err + } + switch c := w.(type) { + case interface{ CloseWrite() error }: + c.CloseWrite() //nolint:errcheck + default: + c.Close() + } + wg.Done() +} + +// writeTo writes a datagram represented by the given I/O vector to a destination +// address on a packet connection. +// +// If the net.PacketConn is an instance of *net.UDPConn, the function uses +// optimized methods to avoid heap allocations for the intermediary net.Addr +// value that must be constructed. +func writeTo(conn net.PacketConn, iovs [][]byte, addr Sockaddr) (int, error) { + buf := iovecBuf(iovs) // TODO: pool this buffer? + addrPort := SockaddrAddrPort(addr) + switch c := conn.(type) { + case *net.UDPConn: + return c.WriteToUDPAddrPort(buf, addrPort) + default: + addr := addrPort.Addr() + port := addrPort.Port() + return c.WriteTo(buf, &net.UDPAddr{IP: addr.AsSlice(), Port: int(port)}) + } +} diff --git a/internal/timemachine/wasicall/wasi_test.go b/internal/timemachine/wasicall/wasi_test.go index 37eb5187..295e9106 100644 --- a/internal/timemachine/wasicall/wasi_test.go +++ b/internal/timemachine/wasicall/wasi_test.go @@ -155,7 +155,7 @@ var syscalls = []Syscall{ &SockGetOptSyscall{FD: 1, Option: wasi.Broadcast, Value: wasi.IntValue(1), Errno: 0}, &SockGetOptSyscall{FD: 1, Option: ^wasi.SocketOption(0), Value: nil, Errno: 1}, &SockGetOptSyscall{}, - &SockSetOptSyscall{FD: 1, Option: wasi.MakeSocketOption(htls.Level, htls.Option), Value: wasi.BytesValue("foo"), Errno: 0}, + &SockSetOptSyscall{FD: 1, Option: wasi.MakeSocketOption(htls.Level, htls.ServerName), Value: wasi.BytesValue("foo"), Errno: 0}, &SockSetOptSyscall{FD: 1, Option: wasi.Broadcast, Value: wasi.IntValue(1), Errno: 0}, &SockSetOptSyscall{FD: 1, Option: ^wasi.SocketOption(0), Value: nil, Errno: 1}, &SockSetOptSyscall{}, From 9543619c86005799b8cfb6210fce1c52f1527123 Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Wed, 19 Jul 2023 15:47:53 -0700 Subject: [PATCH 14/16] network: add tests for datagram sockets Signed-off-by: Achille Roussel --- internal/network/host_test.go | 18 ++- internal/network/local_test.go | 208 ++++++++++++++++++++++++++----- internal/network/local_unix.go | 6 +- internal/network/network_test.go | 134 +++++++++++++++++++- 4 files changed, 325 insertions(+), 41 deletions(-) diff --git a/internal/network/host_test.go b/internal/network/host_test.go index a03e491d..2c9d4ee9 100644 --- a/internal/network/host_test.go +++ b/internal/network/host_test.go @@ -19,13 +19,23 @@ func TestHostNetwork(t *testing.T) { }, { - scenario: "ipv4 sockets can connect to one another on the loopback interface", - function: testNamespaceConnectLoopbackIPv4, + scenario: "ipv4 stream sockets can connect to one another on the loopback interface", + function: testNamespaceConnectStreamLoopbackIPv4, }, { - scenario: "ipv6 sockets can connect to one another on the loopback interface", - function: testNamespaceConnectLoopbackIPv6, + scenario: "ipv6 stream sockets can connect to one another on the loopback interface", + function: testNamespaceConnectStreamLoopbackIPv6, + }, + + { + scenario: "ipv4 datagram sockets can connect to one another on the loopback interface", + function: testNamespaceConnectDatagramLoopbackIPv4, + }, + + { + scenario: "ipv6 datagram sockets can connect to one another on the loopback interface", + function: testNamespaceConnectDatagramLoopbackIPv6, }, { diff --git a/internal/network/local_test.go b/internal/network/local_test.go index 21f23c0c..5a61060d 100644 --- a/internal/network/local_test.go +++ b/internal/network/local_test.go @@ -23,44 +23,74 @@ func TestLocalNetwork(t *testing.T) { }, { - scenario: "ipv4 sockets can connect to one another on a loopback interface", - function: testLocalNetworkConnectLoopbackIPv4, + scenario: "ipv4 stream sockets can connect to one another on a loopback interface", + function: testLocalNetworkConnectStreamLoopbackIPv4, }, { - scenario: "ipv6 sockets can connect to one another on a loopback interface", - function: testLocalNetworkConnectLoopbackIPv6, + scenario: "ipv6 stream sockets can connect to one another on a loopback interface", + function: testLocalNetworkConnectStreamLoopbackIPv6, }, { - scenario: "ipv4 sockets can connect to one another on a network interface", - function: testLocalNetworkConnectInterfaceIPv4, + scenario: "ipv4 stream sockets can connect to one another on a network interface", + function: testLocalNetworkConnectStreamInterfaceIPv4, }, { - scenario: "ipv6 sockets can connect to one another on a network interface", - function: testLocalNetworkConnectInterfaceIPv6, + scenario: "ipv6 stream sockets can connect to one another on a network interface", + function: testLocalNetworkConnectStreamInterfaceIPv6, }, { - scenario: "ipv4 sockets in different namespaces can connect to one another", - function: testLocalNetworkConnectNamespacesIPv4, + scenario: "ipv4 stream sockets in different namespaces can connect to one another", + function: testLocalNetworkConnectStreamNamespacesIPv4, }, { - scenario: "ipv6 sockets in different namespaces can connect to one another", - function: testLocalNetworkConnectNamespacesIPv6, + scenario: "ipv6 stream sockets in different namespaces can connect to one another", + function: testLocalNetworkConnectStreamNamespacesIPv6, }, { - scenario: "local sockets can establish connections to foreign networks when a dial function is configured", - function: testLocalNetworkOutboundConnect, + scenario: "stream sockets can establish connections to foreign networks when a dial function is configured", + function: testLocalNetworkOutboundConnectStream, }, { - scenario: "local sockets can receive inbound connections from foreign network when a listen function is configured", + scenario: "stream sockets can receive inbound connections from foreign network when a listen function is configured", function: testLocalNetworkInboundAccept, }, + + { + scenario: "ipv4 datagram sockets can connect to one another on the loopback interface", + function: testLocalNetworkConnectDatagramIPv4, + }, + + { + scenario: "ipv6 datagram sockets can connect to one another on the loopback interface", + function: testLocalNetworkConnectDatagramIPv6, + }, + + { + scenario: "ipv4 datagram sockets can exchange datagrams without being connected to one another", + function: testLocalNetworkExchangeDatagramIPv4, + }, + + { + scenario: "ipv6 datagram sockets can exchange datagrams without being connected to one another", + function: testLocalNetworkExchangeDatagramIPv6, + }, + + { + scenario: "datagram sockets can receive messages from foreign networks when a listen packet function is configured", + function: testLocalNetworkInboundDatagram, + }, + + { + scenario: "datagram sockets can send messages to foreign networks when a listen packet function is configured", + function: testLocalNetworkOutboundDatagram, + }, } ipv4, ipnet4, err := net.ParseCIDR("192.168.0.1/24") @@ -114,49 +144,49 @@ func testLocalNetworkInterfaces(t *testing.T, n *network.LocalNetwork) { assert.Equal(t, en0Addrs[1].String(), "fe80::1/64") } -func testLocalNetworkConnectLoopbackIPv4(t *testing.T, n *network.LocalNetwork) { - testLocalNetworkConnect(t, n, &network.SockaddrInet4{ +func testLocalNetworkConnectStreamLoopbackIPv4(t *testing.T, n *network.LocalNetwork) { + testLocalNetworkConnectStream(t, n, &network.SockaddrInet4{ Addr: [4]byte{127, 0, 0, 1}, Port: 80, }) } -func testLocalNetworkConnectLoopbackIPv6(t *testing.T, n *network.LocalNetwork) { - testLocalNetworkConnect(t, n, &network.SockaddrInet6{ +func testLocalNetworkConnectStreamLoopbackIPv6(t *testing.T, n *network.LocalNetwork) { + testLocalNetworkConnectStream(t, n, &network.SockaddrInet6{ Addr: [16]byte{15: 1}, Port: 80, }) } -func testLocalNetworkConnectInterfaceIPv4(t *testing.T, n *network.LocalNetwork) { - testLocalNetworkConnect(t, n, &network.SockaddrInet4{ +func testLocalNetworkConnectStreamInterfaceIPv4(t *testing.T, n *network.LocalNetwork) { + testLocalNetworkConnectStream(t, n, &network.SockaddrInet4{ Addr: [4]byte{192, 168, 0, 1}, Port: 80, }) } -func testLocalNetworkConnectInterfaceIPv6(t *testing.T, n *network.LocalNetwork) { - testLocalNetworkConnect(t, n, &network.SockaddrInet6{ +func testLocalNetworkConnectStreamInterfaceIPv6(t *testing.T, n *network.LocalNetwork) { + testLocalNetworkConnectStream(t, n, &network.SockaddrInet6{ Addr: [16]byte{0: 0xfe, 1: 0x80, 15: 1}, Port: 80, }) } -func testLocalNetworkConnect(t *testing.T, n *network.LocalNetwork, bind network.Sockaddr) { +func testLocalNetworkConnectStream(t *testing.T, n *network.LocalNetwork, bind network.Sockaddr) { ns, err := n.CreateNamespace(nil) assert.OK(t, err) - testNamespaceConnect(t, ns, bind) + testNamespaceConnectStream(t, ns, bind) } -func testLocalNetworkConnectNamespacesIPv4(t *testing.T, n *network.LocalNetwork) { - testLocalNetworkConnectNamespaces(t, n, network.INET) +func testLocalNetworkConnectStreamNamespacesIPv4(t *testing.T, n *network.LocalNetwork) { + testLocalNetworkConnectStreamNamespaces(t, n, network.INET) } -func testLocalNetworkConnectNamespacesIPv6(t *testing.T, n *network.LocalNetwork) { - testLocalNetworkConnectNamespaces(t, n, network.INET6) +func testLocalNetworkConnectStreamNamespacesIPv6(t *testing.T, n *network.LocalNetwork) { + testLocalNetworkConnectStreamNamespaces(t, n, network.INET6) } -func testLocalNetworkConnectNamespaces(t *testing.T, n *network.LocalNetwork, family network.Family) { +func testLocalNetworkConnectStreamNamespaces(t *testing.T, n *network.LocalNetwork, family network.Family) { ns1, err := n.CreateNamespace(nil) assert.OK(t, err) @@ -227,7 +257,7 @@ func testLocalNetworkConnectNamespaces(t *testing.T, n *network.LocalNetwork, fa assert.Equal(t, peer, nil) } -func testLocalNetworkOutboundConnect(t *testing.T, n *network.LocalNetwork) { +func testLocalNetworkOutboundConnectStream(t *testing.T, n *network.LocalNetwork) { ns1 := network.Host() ifaces1, err := ns1.Interfaces() @@ -368,3 +398,119 @@ func testLocalNetworkInboundAccept(t *testing.T, n *network.LocalNetwork) { _, err = conn.Read(buf) assert.Equal(t, err, io.EOF) } + +func testLocalNetworkConnectDatagramIPv4(t *testing.T, n *network.LocalNetwork) { + ns, err := n.CreateNamespace(nil) + assert.OK(t, err) + testNamespaceConnectDatagram(t, ns, &network.SockaddrInet4{ + Addr: [4]byte{192, 168, 0, 1}, + }) +} + +func testLocalNetworkConnectDatagramIPv6(t *testing.T, n *network.LocalNetwork) { + ns, err := n.CreateNamespace(nil) + assert.OK(t, err) + testNamespaceConnectDatagram(t, ns, &network.SockaddrInet6{ + Addr: [16]byte{0: 0xfe, 1: 0x80, 15: 1}, + }) +} + +func testLocalNetworkExchangeDatagramIPv4(t *testing.T, n *network.LocalNetwork) { + ns, err := n.CreateNamespace(nil) + assert.OK(t, err) + testNamespaceExchangeDatagram(t, ns, &network.SockaddrInet4{ + Addr: [4]byte{192, 168, 0, 1}, + }) +} + +func testLocalNetworkExchangeDatagramIPv6(t *testing.T, n *network.LocalNetwork) { + ns, err := n.CreateNamespace(nil) + assert.OK(t, err) + testNamespaceExchangeDatagram(t, ns, &network.SockaddrInet6{ + Addr: [16]byte{0: 0xfe, 1: 0x80, 15: 1}, + }) +} + +func testLocalNetworkInboundDatagram(t *testing.T, n *network.LocalNetwork) { + ns, err := n.CreateNamespace(nil, + network.ListenPacketFunc(func(ctx context.Context, network, address string) (net.PacketConn, error) { + _, port, err := net.SplitHostPort(address) + if err != nil { + return nil, err + } + return net.ListenPacket(network, net.JoinHostPort("127.0.0.1", port)) + }), + ) + assert.OK(t, err) + + sock, err := ns.Socket(network.INET, network.DGRAM, network.UDP) + assert.OK(t, err) + defer sock.Close() + + assert.OK(t, sock.Bind(&network.SockaddrInet4{})) + addr, err := sock.Name() + assert.OK(t, err) + + addrPort := network.SockaddrAddrPort(addr) + connAddr := net.JoinHostPort("127.0.0.1", strconv.Itoa(int(addrPort.Port()))) + conn, err := net.Dial("udp", connAddr) + assert.OK(t, err) + defer conn.Close() + localAddr := conn.LocalAddr().(*net.UDPAddr) + + size, err := conn.Write([]byte("message")) + assert.OK(t, err) + assert.Equal(t, size, 7) + + assert.OK(t, waitReadyRead(sock)) + buf := make([]byte, 32) + size, _, peer, err := sock.RecvFrom([][]byte{buf}, 0) + assert.OK(t, err) + assert.Equal(t, size, 7) + assert.Equal(t, string(buf[:7]), "message") + assert.Equal(t, network.SockaddrAddrPort(peer), localAddr.AddrPort()) +} + +func testLocalNetworkOutboundDatagram(t *testing.T, n *network.LocalNetwork) { + hostAddrs, err := findNonLoopbackIPv4HostAddress() + assert.OK(t, err) + + ns, err := n.CreateNamespace(nil, + network.ListenPacketFunc(func(ctx context.Context, network, address string) (net.PacketConn, error) { + _, port, err := net.SplitHostPort(address) + if err != nil { + return nil, err + } + return net.ListenPacket(network, net.JoinHostPort("", port)) + }), + ) + assert.OK(t, err) + + conn, err := net.ListenPacket("udp4", net.JoinHostPort(hostAddrs[0].String(), "0")) + assert.OK(t, err) + defer conn.Close() + + sock, err := ns.Socket(network.INET, network.DGRAM, network.UDP) + assert.OK(t, err) + defer sock.Close() + + connAddr := conn.LocalAddr().(*net.UDPAddr) + addrPort := connAddr.AddrPort() + sendAddr := &network.SockaddrInet4{ + Addr: addrPort.Addr().As4(), + Port: int(addrPort.Port()), + } + + size, err := sock.SendTo([][]byte{[]byte("message")}, sendAddr, 0) + assert.OK(t, err) + assert.Equal(t, size, 7) + + addr, err := sock.Name() + assert.OK(t, err) + + buf := make([]byte, 32) + size, peer, err := conn.ReadFrom(buf) + assert.OK(t, err) + assert.Equal(t, size, 7) + assert.Equal(t, peer.(*net.UDPAddr).AddrPort().Port(), network.SockaddrAddrPort(addr).Port()) +} diff --git a/internal/network/local_unix.go b/internal/network/local_unix.go index 9ad52d57..cfd91c62 100644 --- a/internal/network/local_unix.go +++ b/internal/network/local_unix.go @@ -466,7 +466,11 @@ func (s *localSocket) Connect(addr Sockaddr) error { s.peer.Store(addr) s.state.set(connected) - return EINPROGRESS + + if s.socktype != DGRAM { + return EINPROGRESS + } + return nil } func (s *localSocket) connect(serverFd int, addr any) error { diff --git a/internal/network/network_test.go b/internal/network/network_test.go index 6690dc4f..c9848cb8 100644 --- a/internal/network/network_test.go +++ b/internal/network/network_test.go @@ -1,6 +1,8 @@ package network_test import ( + "fmt" + "net" "testing" "time" @@ -8,19 +10,65 @@ import ( "github.com/stealthrocket/timecraft/internal/network" ) -func testNamespaceConnectLoopbackIPv4(t *testing.T, ns network.Namespace) { - testNamespaceConnect(t, ns, &network.SockaddrInet4{ +func findNonLoopbackIPv4HostAddress() ([]net.Addr, error) { + addrs, err := findNonLoopbackHostAddresses() + if err != nil { + return nil, err + } + var ipv4Addrs []net.Addr + for _, addr := range addrs { + switch a := addr.(type) { + case *net.IPNet: + if a.IP.To4() != nil { + ipv4Addrs = append(ipv4Addrs, &net.IPAddr{IP: a.IP}) + } + } + } + if len(ipv4Addrs) == 0 { + return nil, fmt.Errorf("no IPv4 addresses were found that were not on the loopback interface") + } + return ipv4Addrs, nil +} + +func findNonLoopbackHostAddresses() ([]net.Addr, error) { + ifaces, err := network.Host().Interfaces() + if err != nil { + return nil, err + } + var nonLoopbackAddrs []net.Addr + for _, iface := range ifaces { + flags := iface.Flags() + if (flags & net.FlagUp) == 0 { + continue + } + if (flags & net.FlagLoopback) != 0 { + continue + } + addrs, err := iface.Addrs() + if err != nil { + return nil, err + } + nonLoopbackAddrs = append(nonLoopbackAddrs, addrs...) + } + if len(nonLoopbackAddrs) == 0 { + return nil, fmt.Errorf("no addresses were found that were not on the loopback interface") + } + return nonLoopbackAddrs, nil +} + +func testNamespaceConnectStreamLoopbackIPv4(t *testing.T, ns network.Namespace) { + testNamespaceConnectStream(t, ns, &network.SockaddrInet4{ Addr: [4]byte{127, 0, 0, 1}, }) } -func testNamespaceConnectLoopbackIPv6(t *testing.T, ns network.Namespace) { - testNamespaceConnect(t, ns, &network.SockaddrInet6{ +func testNamespaceConnectStreamLoopbackIPv6(t *testing.T, ns network.Namespace) { + testNamespaceConnectStream(t, ns, &network.SockaddrInet6{ Addr: [16]byte{15: 1}, }) } -func testNamespaceConnect(t *testing.T, ns network.Namespace, bind network.Sockaddr) { +func testNamespaceConnectStream(t *testing.T, ns network.Namespace, bind network.Sockaddr) { family := network.SockaddrFamily(bind) server, err := ns.Socket(family, network.STREAM, network.TCP) @@ -66,6 +114,70 @@ func testNamespaceConnect(t *testing.T, ns network.Namespace, bind network.Socka assert.Equal(t, peer, nil) } +func testNamespaceConnectDatagramLoopbackIPv4(t *testing.T, ns network.Namespace) { + testNamespaceConnectDatagram(t, ns, &network.SockaddrInet4{ + Addr: [4]byte{127, 0, 0, 1}, + }) +} + +func testNamespaceConnectDatagramLoopbackIPv6(t *testing.T, ns network.Namespace) { + testNamespaceConnectDatagram(t, ns, &network.SockaddrInet6{ + Addr: [16]byte{15: 1}, + }) +} + +func testNamespaceConnectDatagram(t *testing.T, ns network.Namespace, bind network.Sockaddr) { + family := network.SockaddrFamily(bind) + + server, err := ns.Socket(family, network.DGRAM, network.UDP) + assert.OK(t, err) + defer server.Close() + + assert.OK(t, server.Bind(bind)) + addr, err := server.Name() + assert.OK(t, err) + + client, err := ns.Socket(family, network.DGRAM, network.UDP) + assert.OK(t, err) + defer client.Close() + + assert.OK(t, client.Connect(addr)) + assert.OK(t, waitReadyWrite(client)) + + name, err := client.Name() + assert.OK(t, err) + assert.NotEqual(t, name, nil) + + peer, err := client.Peer() + assert.OK(t, err) + assert.Equal(t, network.SockaddrAddrPort(peer), network.SockaddrAddrPort(addr)) + + wn, err := client.SendTo([][]byte{[]byte("Hello, World!")}, nil, 0) + assert.OK(t, err) + assert.Equal(t, wn, 13) + assert.OK(t, waitReadyRead(server)) + + buf := make([]byte, 32) + rn, rflags, peer, err := server.RecvFrom([][]byte{buf}, 0) + assert.OK(t, err) + assert.Equal(t, rn, 13) + assert.Equal(t, rflags, 0) + assert.Equal(t, string(buf[:13]), "Hello, World!") + assert.Equal(t, network.SockaddrAddrPort(peer), network.SockaddrAddrPort(name)) + + wn, err = server.SendTo([][]byte{[]byte("How are you?")}, peer, 0) + assert.OK(t, err) + assert.Equal(t, wn, 12) + assert.OK(t, waitReadyRead(client)) + + rn, rflags, peer, err = client.RecvFrom([][]byte{buf}, 0) + assert.OK(t, err) + assert.Equal(t, rn, 12) + assert.Equal(t, rflags, 0) + assert.Equal(t, string(buf[:12]), "How are you?") + assert.Equal(t, network.SockaddrAddrPort(peer), network.SockaddrAddrPort(addr)) +} + func testNamespaceExchangeDatagramLoopbackIPv4(t *testing.T, ns network.Namespace) { testNamespaceExchangeDatagram(t, ns, &network.SockaddrInet4{ Addr: [4]byte{127, 0, 0, 1}, @@ -123,6 +235,18 @@ func testNamespaceExchangeDatagram(t *testing.T, ns network.Namespace, bind netw assert.Equal(t, rflags, network.TRUNC) assert.Equal(t, string(buf[:11]), "How are you") assert.Equal(t, network.SockaddrAddrPort(addr), network.SockaddrAddrPort(addr2)) + + wn, err = socket1.SendTo([][]byte{[]byte("How are you?")}, addr, 0) + assert.OK(t, err) + assert.Equal(t, wn, 12) + assert.OK(t, waitReadyRead(socket2)) + + rn, rflags, addr, err = socket2.RecvFrom([][]byte{buf}, 0) + assert.OK(t, err) + assert.Equal(t, rn, 12) + assert.Equal(t, rflags, 0) + assert.Equal(t, string(buf[:12]), "How are you?") + assert.Equal(t, network.SockaddrAddrPort(addr), network.SockaddrAddrPort(addr1)) } func waitReadyRead(socket network.Socket) error { From 70774670d1c695b398ce376adc4b109dbf1aaf02 Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Wed, 19 Jul 2023 15:55:01 -0700 Subject: [PATCH 15/16] fix tests Signed-off-by: Achille Roussel --- sdk/go/timecraft/dial.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/go/timecraft/dial.go b/sdk/go/timecraft/dial.go index c88250ec..6d2720f3 100644 --- a/sdk/go/timecraft/dial.go +++ b/sdk/go/timecraft/dial.go @@ -58,7 +58,7 @@ func DialTLS(ctx context.Context, network, addr string) (net.Conn, error) { var errno syscall.Errno err = rawConn.Control(func(fd uintptr) { - errno = setsockopt(int32(fd), htls.Level, htls.Option, unsafe.Pointer(unsafe.SliceData(host)), uint32(len(hostname))) + errno = setsockopt(int32(fd), htls.Level, htls.ServerName, unsafe.Pointer(unsafe.SliceData(host)), uint32(len(hostname))) }) if errno != 0 { err = os.NewSyscallError("setsockopt", errno) From 34c16987276886ca52e0d4f188eb6d813a953e8d Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Wed, 19 Jul 2023 18:18:07 -0700 Subject: [PATCH 16/16] network: use netip.Addr instead of net.IP Signed-off-by: Achille Roussel --- internal/ipam/ipam.go | 17 ++-- internal/ipam/ipv4.go | 12 +-- internal/ipam/ipv6.go | 12 +-- internal/network/local.go | 150 ++++++++++++++------------------- internal/network/local_test.go | 7 +- internal/network/network.go | 28 +++++- 6 files changed, 110 insertions(+), 116 deletions(-) diff --git a/internal/ipam/ipam.go b/internal/ipam/ipam.go index 1e13a85f..4e07713e 100644 --- a/internal/ipam/ipam.go +++ b/internal/ipam/ipam.go @@ -2,24 +2,25 @@ // IPv6 networks. package ipam -import "net" +import "net/netip" // Pool is an interface implemented by the IPv4Pool and IPv6Pool types to // abstract the type of IP addresses that are managed by the pool. type Pool interface { // Obtains the next IP address, or returns nil if the pool was exhausted. - GetIP() net.IP + GetAddr() (netip.Addr, bool) // Returns an IP address to the pool. The ip address must have been obtained // by a previous call to GetIP or the method panics. - PutIP(net.IP) + PutAddr(netip.Addr) } // NewPool constructs a pool of IP addresses for the network passed as argument. -func NewPool(ipnet *net.IPNet) Pool { - ones, _ := ipnet.Mask.Size() - if ipv4 := ipnet.IP.To4(); ipv4 != nil { - return NewIPv4Pool((IPv4)(ipv4), ones) +func NewPool(prefix netip.Prefix) Pool { + addr := prefix.Addr() + bits := prefix.Bits() + if addr.Is4() { + return NewIPv4Pool(addr.As4(), bits) } else { - return NewIPv6Pool((IPv6)(ipnet.IP), ones) + return NewIPv6Pool(addr.As16(), bits) } } diff --git a/internal/ipam/ipv4.go b/internal/ipam/ipv4.go index e7bc6ecd..83e97b9a 100644 --- a/internal/ipam/ipv4.go +++ b/internal/ipam/ipv4.go @@ -4,7 +4,6 @@ import ( "encoding/binary" "fmt" "math/bits" - "net" "net/netip" ) @@ -70,12 +69,9 @@ func (p *IPv4Pool) Reset(ip IPv4, nbits int) { p.bits.clear() } -func (p *IPv4Pool) GetIP() net.IP { +func (p *IPv4Pool) GetAddr() (netip.Addr, bool) { ip, ok := p.Get() - if !ok { - return nil - } - return ip[:] + return netip.AddrFrom4(ip), ok } func (p *IPv4Pool) Get() (IPv4, bool) { @@ -91,8 +87,8 @@ func (p *IPv4Pool) Get() (IPv4, bool) { return p.base.add(i), true } -func (p *IPv4Pool) PutIP(ip net.IP) { - p.Put((IPv4)(ip)) +func (p *IPv4Pool) PutAddr(ip netip.Addr) { + p.Put(ip.As4()) } func (p *IPv4Pool) Put(ip IPv4) { diff --git a/internal/ipam/ipv6.go b/internal/ipam/ipv6.go index 2597a651..1665d462 100644 --- a/internal/ipam/ipv6.go +++ b/internal/ipam/ipv6.go @@ -4,7 +4,6 @@ import ( "encoding/binary" "fmt" "math/bits" - "net" "net/netip" ) @@ -96,12 +95,9 @@ func (p *IPv6Pool) Reset(ip IPv6, nbits int) { p.bits.clear() } -func (p *IPv6Pool) GetIP() net.IP { +func (p *IPv6Pool) GetAddr() (netip.Addr, bool) { ip, ok := p.Get() - if !ok { - return nil - } - return ip[:] + return netip.AddrFrom16(ip), ok } func (p *IPv6Pool) Get() (IPv6, bool) { @@ -117,8 +113,8 @@ func (p *IPv6Pool) Get() (IPv6, bool) { return p.base.add(i), true } -func (p *IPv6Pool) PutIP(ip net.IP) { - p.Put((IPv6)(ip)) +func (p *IPv6Pool) PutAddr(ip netip.Addr) { + p.Put(ip.As16()) } func (p *IPv6Pool) Put(ip IPv6) { diff --git a/internal/network/local.go b/internal/network/local.go index 5c5006f8..058b10da 100644 --- a/internal/network/local.go +++ b/internal/network/local.go @@ -9,30 +9,16 @@ import ( "sync" "sync/atomic" + "golang.org/x/exp/slices" + "github.com/stealthrocket/timecraft/internal/ipam" ) var ErrIPAM = errors.New("IP pool exhausted") -var localAddrs = [2]net.IPNet{ - net.IPNet{ - IP: net.IP{ - 127, 0, 0, 1, - }, - Mask: net.IPMask{ - 255, 0, 0, 0, - }, - }, - net.IPNet{ - IP: net.IP{ - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, - }, - Mask: net.IPMask{ - 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, - 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, - }, - }, +var localAddrs = [2]netip.Prefix{ + netip.PrefixFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), 8), + netip.PrefixFrom(netip.IPv6Loopback(), 128), } type LocalOption func(*LocalNamespace) @@ -50,22 +36,21 @@ func ListenPacketFunc(listenPacket func(context.Context, string, string) (net.Pa } type LocalNetwork struct { - addrs []net.IPNet + addrs []netip.Prefix ipams []ipam.Pool mutex sync.RWMutex routes map[netip.Addr]*localInterface } -func NewLocalNetwork(addrs ...*net.IPNet) *LocalNetwork { +func NewLocalNetwork(addrs ...netip.Prefix) *LocalNetwork { n := &LocalNetwork{ - addrs: make([]net.IPNet, len(addrs)), + addrs: slices.Clone(addrs), ipams: make([]ipam.Pool, len(addrs)), routes: make(map[netip.Addr]*localInterface), } for i, addr := range addrs { - n.addrs[i] = *addr - n.ipams[i] = ipam.NewPool(&n.addrs[i]) + n.ipams[i] = ipam.NewPool(addr) } return n } @@ -83,7 +68,7 @@ func (n *LocalNetwork) CreateNamespace(host Namespace, opts ...LocalOption) (*Lo index: 1, name: "en0", flags: net.FlagUp, - addrs: make([]net.IPNet, 0, len(n.addrs)), + addrs: make([]netip.Prefix, 0, len(n.addrs)), }, } @@ -95,15 +80,12 @@ func (n *LocalNetwork) CreateNamespace(host Namespace, opts ...LocalOption) (*Lo defer n.mutex.Unlock() for i, ipam := range n.ipams { - ip := ipam.GetIP() - if ip == nil { + ip, ok := ipam.GetAddr() + if !ok { n.detach(&ns.en0) return nil, fmt.Errorf("%s: %w", ipam, ErrIPAM) } - ns.en0.addrs = append(ns.en0.addrs, net.IPNet{ - IP: ip, - Mask: n.addrs[i].Mask, - }) + ns.en0.addrs = append(ns.en0.addrs, netip.PrefixFrom(ip, n.addrs[i].Bits())) } ns.lo0.ports = make([]map[localPort]*localSocket, len(ns.lo0.addrs)) @@ -114,28 +96,21 @@ func (n *LocalNetwork) CreateNamespace(host Namespace, opts ...LocalOption) (*Lo } func (n *LocalNetwork) attach(iface *localInterface) { - for _, ipnet := range iface.addrs { - n.routes[addrFromIP(ipnet.IP)] = iface + for _, prefix := range iface.addrs { + n.routes[prefix.Addr()] = iface } } func (n *LocalNetwork) detach(iface *localInterface) { - for i, ipnet := range iface.addrs { - delete(n.routes, addrFromIP(ipnet.IP)) - n.ipams[i].PutIP(ipnet.IP) + for i, prefix := range iface.addrs { + addr := prefix.Addr() + delete(n.routes, addr) + n.ipams[i].PutAddr(addr) } } -func (n *LocalNetwork) lookup(ip net.IP) *localInterface { - return n.routes[addrFromIP(ip)] -} - -func addrFromIP(ip net.IP) netip.Addr { - if ipv4 := ip.To4(); ipv4 != nil { - return netip.AddrFrom4(([4]byte)(ipv4)) - } else { - return netip.AddrFrom16(([16]byte)(ip)) - } +func (n *LocalNetwork) lookup(addr netip.Addr) *localInterface { + return n.routes[addr] } type LocalNamespace struct { @@ -202,16 +177,16 @@ func (ns *LocalNamespace) interfaces() []*localInterface { } func (ns *LocalNamespace) bindInet4(sock *localSocket, addr *SockaddrInet4) error { - return ns.bind(sock, addr.Addr[:], addr.Port) + return ns.bind(sock, addrPortFromInet4(addr)) } func (ns *LocalNamespace) bindInet6(sock *localSocket, addr *SockaddrInet6) error { - return ns.bind(sock, addr.Addr[:], addr.Port) + return ns.bind(sock, addrPortFromInet6(addr)) } -func (ns *LocalNamespace) bind(sock *localSocket, addr net.IP, port int) error { +func (ns *LocalNamespace) bind(sock *localSocket, addrPort netip.AddrPort) error { for _, iface := range ns.interfaces() { - if err := iface.bind(sock, addr, port); err != nil { + if err := iface.bind(sock, addrPort); err != nil { return err } } @@ -219,20 +194,21 @@ func (ns *LocalNamespace) bind(sock *localSocket, addr net.IP, port int) error { } func (ns *LocalNamespace) lookupInet4(sock *localSocket, addr *SockaddrInet4) (*localSocket, error) { - return ns.lookup(sock, addr.Addr[:], addr.Port) + return ns.lookup(sock, addrPortFromInet4(addr)) } func (ns *LocalNamespace) lookupInet6(sock *localSocket, addr *SockaddrInet6) (*localSocket, error) { - return ns.lookup(sock, addr.Addr[:], addr.Port) + return ns.lookup(sock, addrPortFromInet6(addr)) } -func (ns *LocalNamespace) lookup(sock *localSocket, addr net.IP, port int) (*localSocket, error) { +func (ns *LocalNamespace) lookup(sock *localSocket, addrPort netip.AddrPort) (*localSocket, error) { for _, iface := range ns.interfaces() { - if peer := iface.lookup(sock, addr, port); peer != nil { + if peer := iface.lookup(sock, addrPort); peer != nil { return peer, nil } } + addr := addrPort.Addr() if addr.IsUnspecified() { return nil, EHOSTUNREACH } @@ -247,7 +223,7 @@ func (ns *LocalNamespace) lookup(sock *localSocket, addr net.IP, port int) (*loc n.mutex.RUnlock() if iface != nil { - if peer := iface.lookup(sock, addr, port); peer != nil { + if peer := iface.lookup(sock, addrPort); peer != nil { return peer, nil } return nil, EHOSTUNREACH @@ -256,16 +232,16 @@ func (ns *LocalNamespace) lookup(sock *localSocket, addr net.IP, port int) (*loc } func (ns *LocalNamespace) unlinkInet4(sock *localSocket, addr *SockaddrInet4) { - ns.unlink(sock, addr.Addr[:], addr.Port) + ns.unlink(sock, addrPortFromInet4(addr)) } func (ns *LocalNamespace) unlinkInet6(sock *localSocket, addr *SockaddrInet6) { - ns.unlink(sock, addr.Addr[:], addr.Port) + ns.unlink(sock, addrPortFromInet6(addr)) } -func (ns *LocalNamespace) unlink(sock *localSocket, addr net.IP, port int) { +func (ns *LocalNamespace) unlink(sock *localSocket, addrPort netip.AddrPort) { for _, iface := range ns.interfaces() { - iface.unlink(sock, addr, port) + iface.unlink(sock, addrPort) } } @@ -279,7 +255,7 @@ type localInterface struct { name string haddr net.HardwareAddr flags net.Flags - addrs []net.IPNet + addrs []netip.Prefix mutex sync.RWMutex ports []map[localPort]*localSocket @@ -307,8 +283,14 @@ func (i *localInterface) Flags() net.Flags { func (i *localInterface) Addrs() ([]net.Addr, error) { addrs := make([]net.Addr, len(i.addrs)) - for j := range addrs { - addrs[j] = &i.addrs[j] + for j, prefix := range i.addrs { + addr := prefix.Addr() + ones := prefix.Bits() + bits := addr.BitLen() + addrs[j] = &net.IPNet{ + IP: addr.AsSlice(), + Mask: net.CIDRMask(ones, bits), + } } return addrs, nil } @@ -317,23 +299,17 @@ func (i *localInterface) MulticastAddrs() ([]net.Addr, error) { return nil, nil } -func (i *localInterface) bind(sock *localSocket, addr net.IP, port int) error { - link := localPort{sock.protocol, uint16(port)} - ipv4 := addr.To4() - name := (Sockaddr)(nil) - - if ipv4 != nil { - name = &SockaddrInet4{Addr: ([4]byte)(ipv4), Port: port} - } else { - name = &SockaddrInet6{Addr: ([16]byte)(addr), Port: port} - } +func (i *localInterface) bind(sock *localSocket, addrPort netip.AddrPort) error { + link := localPort{sock.protocol, addrPort.Port()} + name := SockaddrFromAddrPort(addrPort) + addr := addrPort.Addr() i.mutex.Lock() defer i.mutex.Unlock() if link.port != 0 { for j, a := range i.addrs { - if !socketAndInterfaceMatch(addr, a.IP) { + if !socketAndInterfaceMatch(addr, a.Addr()) { continue } if _, used := i.ports[j][link]; used { @@ -349,7 +325,7 @@ func (i *localInterface) bind(sock *localSocket, addr net.IP, port int) error { for port = 49152; port <= 65535; port++ { link.port = uint16(port) for j, a := range i.addrs { - if !socketAndInterfaceMatch(addr, a.IP) { + if !socketAndInterfaceMatch(addr, a.Addr()) { continue } if _, used := i.ports[j][link]; used { @@ -373,7 +349,7 @@ func (i *localInterface) bind(sock *localSocket, addr net.IP, port int) error { a := i.addrs[j] p := i.ports[j] - if socketAndInterfaceMatch(addr, a.IP) { + if socketAndInterfaceMatch(addr, a.Addr()) { if p == nil { p = make(map[localPort]*localSocket) i.ports[j] = p @@ -387,7 +363,9 @@ func (i *localInterface) bind(sock *localSocket, addr net.IP, port int) error { return nil } -func (i *localInterface) lookup(sock *localSocket, addr net.IP, port int) *localSocket { +func (i *localInterface) lookup(sock *localSocket, addrPort netip.AddrPort) *localSocket { + addr := addrPort.Addr() + i.mutex.RLock() defer i.mutex.RUnlock() @@ -395,8 +373,8 @@ func (i *localInterface) lookup(sock *localSocket, addr net.IP, port int) *local a := i.addrs[j] p := i.ports[j] - if socketAndInterfaceMatch(addr, a.IP) { - link := localPort{sock.protocol, uint16(port)} + if socketAndInterfaceMatch(addr, a.Addr()) { + link := localPort{sock.protocol, addrPort.Port()} peer := p[link] if peer != nil { return peer @@ -407,7 +385,9 @@ func (i *localInterface) lookup(sock *localSocket, addr net.IP, port int) *local return nil } -func (i *localInterface) unlink(sock *localSocket, addr net.IP, port int) { +func (i *localInterface) unlink(sock *localSocket, addrPort netip.AddrPort) { + addr := addrPort.Addr() + i.mutex.Lock() defer i.mutex.Unlock() @@ -415,8 +395,8 @@ func (i *localInterface) unlink(sock *localSocket, addr net.IP, port int) { a := i.addrs[j] p := i.ports[j] - if socketAndInterfaceMatch(addr, a.IP) { - link := localPort{sock.protocol, uint16(port)} + if socketAndInterfaceMatch(addr, a.Addr()) { + link := localPort{sock.protocol, addrPort.Port()} peer := p[link] if sock == peer { delete(p, link) @@ -425,12 +405,12 @@ func (i *localInterface) unlink(sock *localSocket, addr net.IP, port int) { } } -func socketAndInterfaceMatch(sock, iface net.IP) bool { - return (sock.IsUnspecified() && family(sock) == family(iface)) || sock.Equal(iface) +func socketAndInterfaceMatch(sock, iface netip.Addr) bool { + return (sock.IsUnspecified() && family(sock) == family(iface)) || sock == iface } -func family(ip net.IP) Family { - if ip.To4() != nil { +func family(addr netip.Addr) Family { + if addr.Is4() { return INET } else { return INET6 diff --git a/internal/network/local_test.go b/internal/network/local_test.go index 5a61060d..6d66a606 100644 --- a/internal/network/local_test.go +++ b/internal/network/local_test.go @@ -93,15 +93,12 @@ func TestLocalNetwork(t *testing.T) { }, } - ipv4, ipnet4, err := net.ParseCIDR("192.168.0.1/24") + ipnet4, err := netip.ParsePrefix("192.168.0.1/24") assert.OK(t, err) - ipv6, ipnet6, err := net.ParseCIDR("fe80::1/64") + ipnet6, err := netip.ParsePrefix("fe80::1/64") assert.OK(t, err) - ipnet4.IP = ipv4 - ipnet6.IP = ipv6 - for _, test := range tests { t.Run(test.scenario, func(t *testing.T) { test.function(t, diff --git a/internal/network/network.go b/internal/network/network.go index 1cbafa6d..15ab24f8 100644 --- a/internal/network/network.go +++ b/internal/network/network.go @@ -147,14 +147,38 @@ func SockaddrAddr(sa Sockaddr) netip.Addr { func SockaddrAddrPort(sa Sockaddr) netip.AddrPort { switch a := sa.(type) { case *SockaddrInet4: - return netip.AddrPortFrom(netip.AddrFrom4(a.Addr), uint16(a.Port)) + return addrPortFromInet4(a) case *SockaddrInet6: - return netip.AddrPortFrom(netip.AddrFrom16(a.Addr), uint16(a.Port)) + return addrPortFromInet6(a) default: return netip.AddrPort{} } } +func addrPortFromInet4(a *SockaddrInet4) netip.AddrPort { + return netip.AddrPortFrom(netip.AddrFrom4(a.Addr), uint16(a.Port)) +} + +func addrPortFromInet6(a *SockaddrInet6) netip.AddrPort { + return netip.AddrPortFrom(netip.AddrFrom16(a.Addr), uint16(a.Port)) +} + +func SockaddrFromAddrPort(addrPort netip.AddrPort) Sockaddr { + addr := addrPort.Addr() + port := addrPort.Port() + if addr.Is4() { + return &SockaddrInet4{ + Addr: addr.As4(), + Port: int(port), + } + } else { + return &SockaddrInet6{ + Addr: addr.As16(), + Port: int(port), + } + } +} + func errInterfaceIndexNotFound(index int) error { return fmt.Errorf("index=%d: %w", index, ErrInterfaceNotFound) }