//go:build linux && !no_fake_tcp // +build linux,!no_fake_tcp package faketcp import ( "crypto/rand" "encoding/binary" "errors" "fmt" "io" "io/ioutil" "net" "sync" "sync/atomic" "syscall" "time" "github.com/coreos/go-iptables/iptables" "github.com/google/gopacket" "github.com/google/gopacket/layers" ) var ( errOpNotImplemented = errors.New("operation not implemented") errTimeout = errors.New("timeout") expire = time.Minute ) // a message from NIC type message struct { bts []byte addr net.Addr } // a tcp flow information of a connection pair type tcpFlow struct { conn *net.TCPConn // the related system TCP connection of this flow handle *net.IPConn // the handle to send packets seq uint32 // TCP sequence number ack uint32 // TCP acknowledge number networkLayer gopacket.SerializableLayer // network layer header for tx ts time.Time // last packet incoming time buf gopacket.SerializeBuffer // a buffer for write tcpHeader layers.TCP } // TCPConn defines a TCP-packet oriented connection type TCPConn struct { die chan struct{} dieOnce sync.Once // the main golang sockets tcpconn *net.TCPConn // from net.Dial listener *net.TCPListener // from net.Listen // handles handles []*net.IPConn // packets captured from all related NICs will be delivered to this channel chMessage chan message // all TCP flows flowTable map[string]*tcpFlow flowsLock sync.Mutex // iptables iptables *iptables.IPTables iprule []string ip6tables *iptables.IPTables ip6rule []string // deadlines readDeadline atomic.Value writeDeadline atomic.Value // serialization opts gopacket.SerializeOptions } // lockflow locks the flow table and apply function `f` to the entry, and create one if not exist func (conn *TCPConn) lockflow(addr net.Addr, f func(e *tcpFlow)) { key := addr.String() conn.flowsLock.Lock() e := conn.flowTable[key] if e == nil { // entry first visit e = new(tcpFlow) e.ts = time.Now() e.buf = gopacket.NewSerializeBuffer() } f(e) conn.flowTable[key] = e conn.flowsLock.Unlock() } // clean expired flows func (conn *TCPConn) cleaner() { ticker := time.NewTicker(time.Minute) select { case <-conn.die: return case <-ticker.C: conn.flowsLock.Lock() for k, v := range conn.flowTable { if time.Now().Sub(v.ts) > expire { if v.conn != nil { setTTL(v.conn, 64) v.conn.Close() } delete(conn.flowTable, k) } } conn.flowsLock.Unlock() } } // captureFlow capture every inbound packets based on rules of BPF func (conn *TCPConn) captureFlow(handle *net.IPConn, port int) { buf := make([]byte, 2048) opt := gopacket.DecodeOptions{NoCopy: true, Lazy: true} for { n, addr, err := handle.ReadFromIP(buf) if err != nil { return } // try decoding TCP frame from buf[:n] packet := gopacket.NewPacket(buf[:n], layers.LayerTypeTCP, opt) transport := packet.TransportLayer() tcp, ok := transport.(*layers.TCP) if !ok { continue } // port filtering if int(tcp.DstPort) != port { continue } // address building var src net.TCPAddr src.IP = addr.IP src.Port = int(tcp.SrcPort) var orphan bool // flow maintaince conn.lockflow(&src, func(e *tcpFlow) { if e.conn == nil { // make sure it's related to net.TCPConn orphan = true // mark as orphan if it's not related net.TCPConn } // to keep track of TCP header related to this source e.ts = time.Now() if tcp.ACK { e.seq = tcp.Ack } if tcp.SYN { e.ack = tcp.Seq + 1 } if tcp.PSH { if e.ack == tcp.Seq { e.ack = tcp.Seq + uint32(len(tcp.Payload)) } } e.handle = handle }) // push data if it's not orphan if !orphan && tcp.PSH { payload := make([]byte, len(tcp.Payload)) copy(payload, tcp.Payload) select { case conn.chMessage <- message{payload, &src}: case <-conn.die: return } } } } // ReadFrom implements the PacketConn ReadFrom method. func (conn *TCPConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { var timer *time.Timer var deadline <-chan time.Time if d, ok := conn.readDeadline.Load().(time.Time); ok && !d.IsZero() { timer = time.NewTimer(time.Until(d)) defer timer.Stop() deadline = timer.C } select { case <-deadline: return 0, nil, errTimeout case <-conn.die: return 0, nil, io.EOF case packet := <-conn.chMessage: n = copy(p, packet.bts) return n, packet.addr, nil } } // WriteTo implements the PacketConn WriteTo method. func (conn *TCPConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { var deadline <-chan time.Time if d, ok := conn.writeDeadline.Load().(time.Time); ok && !d.IsZero() { timer := time.NewTimer(time.Until(d)) defer timer.Stop() deadline = timer.C } select { case <-deadline: return 0, errTimeout case <-conn.die: return 0, io.EOF default: raddr, err := net.ResolveTCPAddr("tcp", addr.String()) if err != nil { return 0, err } var lport int if conn.tcpconn != nil { lport = conn.tcpconn.LocalAddr().(*net.TCPAddr).Port } else { lport = conn.listener.Addr().(*net.TCPAddr).Port } conn.lockflow(addr, func(e *tcpFlow) { // if the flow doesn't have handle , assume this packet has lost, without notification if e.handle == nil { n = len(p) return } // build tcp header with local and remote port e.tcpHeader.SrcPort = layers.TCPPort(lport) e.tcpHeader.DstPort = layers.TCPPort(raddr.Port) binary.Read(rand.Reader, binary.LittleEndian, &e.tcpHeader.Window) e.tcpHeader.Window |= 0x8000 // make sure it's larger than 32768 e.tcpHeader.Ack = e.ack e.tcpHeader.Seq = e.seq e.tcpHeader.PSH = true e.tcpHeader.ACK = true // build IP header with src & dst ip for TCP checksum if raddr.IP.To4() != nil { ip := &layers.IPv4{ Protocol: layers.IPProtocolTCP, SrcIP: e.handle.LocalAddr().(*net.IPAddr).IP.To4(), DstIP: raddr.IP.To4(), } e.tcpHeader.SetNetworkLayerForChecksum(ip) } else { ip := &layers.IPv6{ NextHeader: layers.IPProtocolTCP, SrcIP: e.handle.LocalAddr().(*net.IPAddr).IP.To16(), DstIP: raddr.IP.To16(), } e.tcpHeader.SetNetworkLayerForChecksum(ip) } e.buf.Clear() gopacket.SerializeLayers(e.buf, conn.opts, &e.tcpHeader, gopacket.Payload(p)) if conn.tcpconn != nil { _, err = e.handle.Write(e.buf.Bytes()) } else { _, err = e.handle.WriteToIP(e.buf.Bytes(), &net.IPAddr{IP: raddr.IP}) } // increase seq in flow e.seq += uint32(len(p)) n = len(p) }) } return } // Close closes the connection. func (conn *TCPConn) Close() error { var err error conn.dieOnce.Do(func() { // signal closing close(conn.die) // close all established tcp connections if conn.tcpconn != nil { // client setTTL(conn.tcpconn, 64) err = conn.tcpconn.Close() } else if conn.listener != nil { err = conn.listener.Close() // server conn.flowsLock.Lock() for k, v := range conn.flowTable { if v.conn != nil { setTTL(v.conn, 64) v.conn.Close() } delete(conn.flowTable, k) } conn.flowsLock.Unlock() } // close handles for k := range conn.handles { conn.handles[k].Close() } // delete iptable if conn.iptables != nil { conn.iptables.Delete("filter", "OUTPUT", conn.iprule...) } if conn.ip6tables != nil { conn.ip6tables.Delete("filter", "OUTPUT", conn.ip6rule...) } }) return err } // LocalAddr returns the local network address. func (conn *TCPConn) LocalAddr() net.Addr { if conn.tcpconn != nil { return conn.tcpconn.LocalAddr() } else if conn.listener != nil { return conn.listener.Addr() } return nil } // SetDeadline implements the Conn SetDeadline method. func (conn *TCPConn) SetDeadline(t time.Time) error { if err := conn.SetReadDeadline(t); err != nil { return err } if err := conn.SetWriteDeadline(t); err != nil { return err } return nil } // SetReadDeadline implements the Conn SetReadDeadline method. func (conn *TCPConn) SetReadDeadline(t time.Time) error { conn.readDeadline.Store(t) return nil } // SetWriteDeadline implements the Conn SetWriteDeadline method. func (conn *TCPConn) SetWriteDeadline(t time.Time) error { conn.writeDeadline.Store(t) return nil } // SetDSCP sets the 6bit DSCP field in IPv4 header, or 8bit Traffic Class in IPv6 header. func (conn *TCPConn) SetDSCP(dscp int) error { for k := range conn.handles { if err := setDSCP(conn.handles[k], dscp); err != nil { return err } } return nil } // SetReadBuffer sets the size of the operating system's receive buffer associated with the connection. func (conn *TCPConn) SetReadBuffer(bytes int) error { var err error for k := range conn.handles { if err := conn.handles[k].SetReadBuffer(bytes); err != nil { return err } } return err } // SetWriteBuffer sets the size of the operating system's transmit buffer associated with the connection. func (conn *TCPConn) SetWriteBuffer(bytes int) error { var err error for k := range conn.handles { if err := conn.handles[k].SetWriteBuffer(bytes); err != nil { return err } } return err } func (conn *TCPConn) SyscallConn() (syscall.RawConn, error) { if len(conn.handles) == 0 { return nil, errors.New("no handles") // How is it possible? } return conn.handles[0].SyscallConn() } // Dial connects to the remote TCP port, // and returns a single packet-oriented connection func Dial(network, address string) (*TCPConn, error) { // remote address resolve raddr, err := net.ResolveTCPAddr(network, address) if err != nil { return nil, err } // AF_INET handle, err := net.DialIP("ip:tcp", nil, &net.IPAddr{IP: raddr.IP}) if err != nil { return nil, err } // create an established tcp connection // will hack this tcp connection for packet transmission tcpconn, err := net.DialTCP(network, nil, raddr) if err != nil { return nil, err } // fields conn := new(TCPConn) conn.die = make(chan struct{}) conn.flowTable = make(map[string]*tcpFlow) conn.tcpconn = tcpconn conn.chMessage = make(chan message) conn.lockflow(tcpconn.RemoteAddr(), func(e *tcpFlow) { e.conn = tcpconn }) conn.handles = append(conn.handles, handle) conn.opts = gopacket.SerializeOptions{ FixLengths: true, ComputeChecksums: true, } go conn.captureFlow(handle, tcpconn.LocalAddr().(*net.TCPAddr).Port) go conn.cleaner() // iptables err = setTTL(tcpconn, 1) if err != nil { return nil, err } if ipt, err := iptables.NewWithProtocol(iptables.ProtocolIPv4); err == nil { rule := []string{"-m", "ttl", "--ttl-eq", "1", "-p", "tcp", "-d", raddr.IP.String(), "--dport", fmt.Sprint(raddr.Port), "-j", "DROP"} if exists, err := ipt.Exists("filter", "OUTPUT", rule...); err == nil { if !exists { if err = ipt.Append("filter", "OUTPUT", rule...); err == nil { conn.iprule = rule conn.iptables = ipt } } } } if ipt, err := iptables.NewWithProtocol(iptables.ProtocolIPv6); err == nil { rule := []string{"-m", "hl", "--hl-eq", "1", "-p", "tcp", "-d", raddr.IP.String(), "--dport", fmt.Sprint(raddr.Port), "-j", "DROP"} if exists, err := ipt.Exists("filter", "OUTPUT", rule...); err == nil { if !exists { if err = ipt.Append("filter", "OUTPUT", rule...); err == nil { conn.ip6rule = rule conn.ip6tables = ipt } } } } // discard everything go io.Copy(ioutil.Discard, tcpconn) return conn, nil } // Listen acts like net.ListenTCP, // and returns a single packet-oriented connection func Listen(network, address string) (*TCPConn, error) { // fields conn := new(TCPConn) conn.flowTable = make(map[string]*tcpFlow) conn.die = make(chan struct{}) conn.chMessage = make(chan message) conn.opts = gopacket.SerializeOptions{ FixLengths: true, ComputeChecksums: true, } // resolve address laddr, err := net.ResolveTCPAddr(network, address) if err != nil { return nil, err } // AF_INET ifaces, err := net.Interfaces() if err != nil { return nil, err } if laddr.IP == nil || laddr.IP.IsUnspecified() { // if address is not specified, capture on all ifaces var lasterr error for _, iface := range ifaces { if addrs, err := iface.Addrs(); err == nil { for _, addr := range addrs { if ipaddr, ok := addr.(*net.IPNet); ok { if handle, err := net.ListenIP("ip:tcp", &net.IPAddr{IP: ipaddr.IP}); err == nil { conn.handles = append(conn.handles, handle) go conn.captureFlow(handle, laddr.Port) } else { lasterr = err } } } } } if len(conn.handles) == 0 { return nil, lasterr } } else { if handle, err := net.ListenIP("ip:tcp", &net.IPAddr{IP: laddr.IP}); err == nil { conn.handles = append(conn.handles, handle) go conn.captureFlow(handle, laddr.Port) } else { return nil, err } } // start listening l, err := net.ListenTCP(network, laddr) if err != nil { return nil, err } conn.listener = l // start cleaner go conn.cleaner() // iptables drop packets marked with TTL = 1 // TODO: what if iptables is not available, the next hop will send back ICMP Time Exceeded, // is this still an acceptable behavior? if ipt, err := iptables.NewWithProtocol(iptables.ProtocolIPv4); err == nil { rule := []string{"-m", "ttl", "--ttl-eq", "1", "-p", "tcp", "--sport", fmt.Sprint(laddr.Port), "-j", "DROP"} if exists, err := ipt.Exists("filter", "OUTPUT", rule...); err == nil { if !exists { if err = ipt.Append("filter", "OUTPUT", rule...); err == nil { conn.iprule = rule conn.iptables = ipt } } } } if ipt, err := iptables.NewWithProtocol(iptables.ProtocolIPv6); err == nil { rule := []string{"-m", "hl", "--hl-eq", "1", "-p", "tcp", "--sport", fmt.Sprint(laddr.Port), "-j", "DROP"} if exists, err := ipt.Exists("filter", "OUTPUT", rule...); err == nil { if !exists { if err = ipt.Append("filter", "OUTPUT", rule...); err == nil { conn.ip6rule = rule conn.ip6tables = ipt } } } } // discard everything in original connection go func() { for { tcpconn, err := l.AcceptTCP() if err != nil { return } // if we cannot set TTL = 1, the only thing reasonable is panic if err := setTTL(tcpconn, 1); err != nil { panic(err) } // record net.Conn conn.lockflow(tcpconn.RemoteAddr(), func(e *tcpFlow) { e.conn = tcpconn }) // discard everything go io.Copy(ioutil.Discard, tcpconn) } }() return conn, nil } // setTTL sets the Time-To-Live field on a given connection func setTTL(c *net.TCPConn, ttl int) error { raw, err := c.SyscallConn() if err != nil { return err } addr := c.LocalAddr().(*net.TCPAddr) if addr.IP.To4() == nil { raw.Control(func(fd uintptr) { err = syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IPV6, syscall.IPV6_UNICAST_HOPS, ttl) }) } else { raw.Control(func(fd uintptr) { err = syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP, syscall.IP_TTL, ttl) }) } return err } // setDSCP sets the 6bit DSCP field in IPv4 header, or 8bit Traffic Class in IPv6 header. func setDSCP(c *net.IPConn, dscp int) error { raw, err := c.SyscallConn() if err != nil { return err } addr := c.LocalAddr().(*net.IPAddr) if addr.IP.To4() == nil { raw.Control(func(fd uintptr) { err = syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IPV6, syscall.IPV6_TCLASS, dscp) }) } else { raw.Control(func(fd uintptr) { err = syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP, syscall.IP_TOS, dscp<<2) }) } return err }