diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml new file mode 100644 index 0000000..76d8593 --- /dev/null +++ b/.github/workflows/build.yaml @@ -0,0 +1,19 @@ +name: Build + +on: + push: + pull_request: + +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + + - name: Set up Go + uses: actions/setup-go@v4 + with: + go-version: '1.23' + + - name: Build + run: CGO_ENABLED=0 go build diff --git a/cmd/server.go b/cmd/server.go index aa51790..2cf3b87 100644 --- a/cmd/server.go +++ b/cmd/server.go @@ -24,10 +24,12 @@ var ServerCmd = &cobra.Command{ } var serverCmdArgs struct { - ip []string - wgBlock string - wgBlockPerIp string - cloud string + ip []string + wgBlock string + wgBlockPerIp string + cloud string + region string + internalNetworkCidr string } func init() { @@ -39,6 +41,10 @@ func init() { "", "WireGuard block size for each --ip flag, if multiple are provided") ServerCmd.Flags().StringVar(&serverCmdArgs.cloud, "cloud", "", "Cloud provider for IP metadata (watches for changes)") + ServerCmd.Flags().StringVar(&serverCmdArgs.region, "region", "", + "Region of the server") + ServerCmd.Flags().StringVar(&serverCmdArgs.internalNetworkCidr, "internal-network-cidr", + "10.0.0.0/24", "Internal network CIDR to route to") } func runServer(cmd *cobra.Command, args []string) error { @@ -56,6 +62,9 @@ func runServer(cmd *cobra.Command, args []string) error { if serverCmdArgs.wgBlock == "" { return errors.New("missing required flag: --wg-block") } + if serverCmdArgs.region == "" { + return errors.New("missing required flag: --region") + } wgBlock, err := netip.ParsePrefix(serverCmdArgs.wgBlock) if err != nil || !wgBlock.Addr().Is4() { @@ -97,7 +106,7 @@ func runServer(cmd *cobra.Command, args []string) error { ctx, done := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) - sm, err := lib.NewServerManager(wgBlock, wgBlockPerIp, ctx, key, password) + sm, err := lib.NewServerManager(wgBlock, wgBlockPerIp, ctx, key, password, serverCmdArgs.region, serverCmdArgs.internalNetworkCidr) if err != nil { done() return err @@ -106,26 +115,17 @@ func runServer(cmd *cobra.Command, args []string) error { defer sm.Wait() defer done() - if cloud == "aws" { - initialIps, err := pollAws(lib.NewAwsMetadata(), make(ipSet), sm) + for _, ipStr := range serverCmdArgs.ip { + ip, err := netip.ParseAddr(ipStr) + if err != nil || !ip.Is4() { + return fmt.Errorf("invalid IPv4 address: %q", ipStr) + } + err = sm.Start(ip) if err != nil { return err } - - pollAwsLoop(ctx, sm, initialIps) - } else { - for _, ipStr := range serverCmdArgs.ip { - ip, err := netip.ParseAddr(ipStr) - if err != nil || !ip.Is4() { - return fmt.Errorf("invalid IPv4 address: %q", ipStr) - } - err = sm.Start(ip) - if err != nil { - return err - } - } - sm.Wait() } + sm.Wait() return nil } @@ -145,56 +145,3 @@ func parseIpSet(ipStrs []string) (ipSet, error) { } return m, nil } - -// pollAws gets the current set of IP associations from AWS and starts/stops the -// server for those IPs. -func pollAws(awsClient *lib.AwsMetadata, currentIps ipSet, sm *lib.ServerManager) (ipSet, error) { - interfaces, err := awsClient.GetAddresses() - - if err != nil { - return currentIps, fmt.Errorf("failed to get AWS MAC addresses: %v", err) - } - - newIps, err := parseIpSet(interfaces[0].PrivateIps) - if err != nil { - return currentIps, err - } - - for ip := range currentIps { - if _, ok := newIps[ip]; !ok { - sm.Stop(ip) - delete(currentIps, ip) - } - } - - for ip := range newIps { - if _, ok := currentIps[ip]; !ok { - if err := sm.Start(ip); err != nil { - return currentIps, fmt.Errorf("error starting new ip: %v", err) - } - currentIps[ip] = struct{}{} - } - } - return currentIps, nil -} - -// pollAwsLoop polls AWS in a blocking loop on an interval of AWS_POLL_DURATION -// until ctx is done. -func pollAwsLoop(ctx context.Context, sm *lib.ServerManager, initialIps ipSet) { - currentIps := initialIps - awsClient := lib.NewAwsMetadata() - ticker := time.NewTicker(awsPollDuration) - - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - var err error - currentIps, err = pollAws(awsClient, currentIps, sm) - if err != nil { - fmt.Printf("error during aws poll: %v", err) - } - } - } -} diff --git a/lib/iputils.go b/lib/iputils.go index 2c15931..28d3525 100644 --- a/lib/iputils.go +++ b/lib/iputils.go @@ -5,6 +5,7 @@ import ( "log" "net" "net/netip" + "strings" "sync" "github.com/vishvananda/netlink" @@ -42,6 +43,99 @@ func getDefaultInterface() (netlink.Link, error) { return nil, fmt.Errorf("failed to find default route") } +func getInternalInterface() (netlink.Link, error) { + links, err := netlink.LinkList() + if err != nil { + return nil, fmt.Errorf("failed to list interfaces: %v", err) + } + + // Define interface types to skip + skipTypes := map[string]bool{ + "veth": true, + "bridge": true, + "dummy": true, + "wireguard": true, + "docker": true, + } + + var candidates []netlink.Link + for _, link := range links { + // Skip interfaces that aren't up + if link.Attrs().OperState != netlink.OperUp { + continue + } + + // Skip certain interface types + if skipTypes[link.Type()] { + continue + } + + // Skip interfaces with certain name patterns + name := link.Attrs().Name + if strings.HasPrefix(name, "veth") || + strings.HasPrefix(name, "docker") || + strings.HasPrefix(name, "wg") { + continue + } + + addrs, err := netlink.AddrList(link, netlink.FAMILY_V4) + if err != nil { + log.Printf("failed to get addresses for interface %s: %v", name, err) + continue + } + + // Check for RFC1918 addresses + for _, addr := range addrs { + if isRFC1918(addr.IP) { + candidates = append(candidates, link) + break + } + } + } + + // If we found candidates, prefer physical interfaces over others + for _, link := range candidates { + name := link.Attrs().Name + // Common patterns for physical interfaces + if strings.HasPrefix(name, "en") || + strings.HasPrefix(name, "eth") || + strings.HasPrefix(name, "enp") || + strings.HasPrefix(name, "ens") { + return link, nil + } + } + + // If no physical interface found, return the first candidate + if len(candidates) > 0 { + return candidates[0], nil + } + + return nil, fmt.Errorf("failed to find internal interface") +} + +// Helper function to check if an interface is a VLAN interface +func isVLANInterface(link netlink.Link) bool { + _, isVlan := link.(*netlink.Vlan) + if !isVlan { + // Also check the name for "@" which indicates a VLAN interface in Linux + return strings.Contains(link.Attrs().Name, ".") + } + return true +} + +// Helper function to check RFC1918 addresses +func isRFC1918(ip net.IP) bool { + if ip4 := ip.To4(); ip4 != nil { + // Convert to 32-bit integer for easier comparison + ipInt := uint32(ip4[0])<<24 | uint32(ip4[1])<<16 | uint32(ip4[2])<<8 | uint32(ip4[3]) + + return (ipInt >= 0x0A000000 && ipInt <= 0x0AFFFFFF) || // 10.0.0.0/8 + (ipInt >= 0xAC100000 && ipInt <= 0xAC1FFFFF) || // 172.16.0.0/12 + (ipInt >= 0xC0A80000 && ipInt <= 0xC0A8FFFF) // 192.168.0.0/16 + } + return false +} + // IpAllocator is a simple IP address allocator that produces IP addresses // within a prefix, in increasing order of available IPs. // diff --git a/lib/server.go b/lib/server.go index 788155f..775ffd0 100644 --- a/lib/server.go +++ b/lib/server.go @@ -57,6 +57,12 @@ type Server struct { // Currently only setting this to the default interface is supported. BindIface netlink.Link + // InternalBindIface is the interface for internal network traffic. + InternalBindIface netlink.Link + + // InternalNetworkCidr is the CIDR block of the internal network. + InternalNetworkCidr string + // Password is needed to authenticate connection requests. Password string @@ -75,6 +81,9 @@ type Server struct { // Ctx is the shutdown context for the server. Ctx context.Context + // Region is the region of the server. + Region string + ipAllocator *IpAllocator mu sync.Mutex // Protects the fields below. @@ -90,6 +99,13 @@ func (srv *Server) InitState() error { } srv.BindIface = iface } + if srv.Region == "us-west" && srv.InternalBindIface == nil { + iface, err := getInternalInterface() + if err != nil { + return err + } + srv.InternalBindIface = iface + } srv.ipAllocator = NewIpAllocator(srv.WgCidr) // Reserve the first IP address for the server itself. @@ -332,43 +348,45 @@ func (srv *Server) StartIptables() error { return fmt.Errorf("failed to add inbound TCP MSS rule: %v", err) } - // SNAT rule for internal network traffic - // TODO(adityamaru): Make the source and the -o interface dynamic and not hardcoded. - rule = []string{ - "-s", "10.100.0.0/16", - "-d", "10.0.0.0/24", - "-o", "bond0.2", - "-j", "SNAT", "--to-source", "10.0.0.40", - "-m", "comment", "--comment", "SNAT for WireGuard to internal network", - } - if err := srv.Ipt.AppendUnique("nat", "POSTROUTING", rule...); err != nil { - return fmt.Errorf("failed to add SNAT rule: %v", err) - } + // SNAT rule for internal network traffic. This is currently only applicable for boxes in + // the US. + if srv.Region == "us-west" { + rule = []string{ + "-s", srv.WgCidr.String(), + "-d", srv.InternalNetworkCidr, + "-o", srv.InternalBindIface.Attrs().Name, + "-j", "SNAT", "--to-source", srv.BindAddr.String(), + "-m", "comment", "--comment", "SNAT for WireGuard to internal network", + } + if err := srv.Ipt.AppendUnique("nat", "POSTROUTING", rule...); err != nil { + return fmt.Errorf("failed to add SNAT rule: %v", err) + } - // FORWARD rule from WireGuard to internal network - rule = []string{ - "-i", srv.Ifname(), - "-o", "bond0.2", - "-s", "10.100.0.0/16", - "-d", "10.0.0.0/24", - "-j", "ACCEPT", - "-m", "comment", "--comment", "Forward from WireGuard to internal network", - } - if err := srv.Ipt.AppendUnique("filter", "FORWARD", rule...); err != nil { - return fmt.Errorf("failed to add FORWARD rule: %v", err) - } + // FORWARD rule from WireGuard to internal network + rule = []string{ + "-i", srv.Ifname(), + "-o", srv.InternalBindIface.Attrs().Name, + "-s", srv.WgCidr.String(), + "-d", srv.InternalNetworkCidr, + "-j", "ACCEPT", + "-m", "comment", "--comment", "Forward from WireGuard to internal network", + } + if err := srv.Ipt.AppendUnique("filter", "FORWARD", rule...); err != nil { + return fmt.Errorf("failed to add FORWARD rule: %v", err) + } - // FORWARD rule from internal network to WireGuard - rule = []string{ - "-i", "bond0.2", - "-o", srv.Ifname(), - "-s", "10.0.0.0/24", - "-d", "10.100.0.0/16", - "-j", "ACCEPT", - "-m", "comment", "--comment", "Forward from internal network to WireGuard", - } - if err := srv.Ipt.AppendUnique("filter", "FORWARD", rule...); err != nil { - return fmt.Errorf("failed to add FORWARD rule: %v", err) + // FORWARD rule from internal network to WireGuard + rule = []string{ + "-i", srv.InternalBindIface.Attrs().Name, + "-o", srv.Ifname(), + "-s", srv.InternalNetworkCidr, + "-d", srv.WgCidr.String(), + "-j", "ACCEPT", + "-m", "comment", "--comment", "Forward from internal network to WireGuard", + } + if err := srv.Ipt.AppendUnique("filter", "FORWARD", rule...); err != nil { + return fmt.Errorf("failed to add FORWARD rule: %v", err) + } } return nil @@ -420,43 +438,44 @@ func (srv *Server) CleanupIptables() { log.Printf("failed to remove inbound TCP MSS rule: %v", err) } - // Remove SNAT rule for internal traffic - rule = []string{ - "-s", srv.WgCidr.String(), - "-d", "10.0.0.0/24", - "-o", srv.BindIface.Attrs().Name, - "-j", "SNAT", - "--to-source", srv.BindAddr.String(), - "-m", "comment", "--comment", fmt.Sprintf("vprox SNAT rule for internal traffic from %s", srv.Ifname()), - } - if err := srv.Ipt.Delete("nat", "POSTROUTING", rule...); err != nil { - log.Printf("failed to remove SNAT rule for internal traffic: %v", err) - } + if srv.Region == "us-west" { + // Remove SNAT rule for internal traffic + rule = []string{ + "-s", srv.WgCidr.String(), + "-d", srv.InternalNetworkCidr, + "-o", srv.InternalBindIface.Attrs().Name, + "-j", "SNAT", "--to-source", srv.BindAddr.String(), + "-m", "comment", "--comment", "SNAT for WireGuard to internal network", + } + if err := srv.Ipt.Delete("nat", "POSTROUTING", rule...); err != nil { + log.Printf("failed to remove SNAT rule for internal traffic: %v", err) + } - // Remove forward rule from WireGuard to internal network - rule = []string{ - "-i", srv.Ifname(), - "-o", srv.BindIface.Attrs().Name, - "-s", srv.WgCidr.String(), - "-d", "10.0.0.0/24", - "-j", "ACCEPT", - "-m", "comment", "--comment", fmt.Sprintf("vprox forward rule from %s to internal network", srv.Ifname()), - } - if err := srv.Ipt.Delete("filter", "FORWARD", rule...); err != nil { - log.Printf("failed to remove forward rule from WireGuard to internal network: %v", err) - } + // Remove forward rule from WireGuard to internal network + rule = []string{ + "-i", srv.Ifname(), + "-o", srv.InternalBindIface.Attrs().Name, + "-s", srv.WgCidr.String(), + "-d", srv.InternalNetworkCidr, + "-j", "ACCEPT", + "-m", "comment", "--comment", "Forward from WireGuard to internal network", + } + if err := srv.Ipt.Delete("filter", "FORWARD", rule...); err != nil { + log.Printf("failed to remove forward rule from WireGuard to internal network: %v", err) + } - // Remove forward rule from internal network to WireGuard - rule = []string{ - "-i", srv.BindIface.Attrs().Name, - "-o", srv.Ifname(), - "-s", "10.0.0.0/24", - "-d", srv.WgCidr.String(), - "-j", "ACCEPT", - "-m", "comment", "--comment", fmt.Sprintf("vprox forward rule from internal network to %s", srv.Ifname()), - } - if err := srv.Ipt.Delete("filter", "FORWARD", rule...); err != nil { - log.Printf("failed to remove forward rule from internal network to WireGuard: %v", err) + // Remove forward rule from internal network to WireGuard + rule = []string{ + "-i", srv.InternalBindIface.Attrs().Name, + "-o", srv.Ifname(), + "-s", srv.InternalNetworkCidr, + "-d", srv.WgCidr.String(), + "-j", "ACCEPT", + "-m", "comment", "--comment", "Forward from internal network to WireGuard", + } + if err := srv.Ipt.Delete("filter", "FORWARD", rule...); err != nil { + log.Printf("failed to remove forward rule from internal network to WireGuard: %v", err) + } } } diff --git a/lib/server_manager.go b/lib/server_manager.go index addbb90..876c86b 100644 --- a/lib/server_manager.go +++ b/lib/server_manager.go @@ -21,15 +21,17 @@ type ServerInfo struct { // ServerManager handles creating and terminating servers on ips // ServerManager is not thread safe for concurrent access. type ServerManager struct { - wgClient *wgctrl.Client - ipt *iptables.IPTables - key wgtypes.Key - password string - ctx context.Context - waitGroup *sync.WaitGroup - wgBlock netip.Prefix - wgBlockPerIp uint - activeServers map[netip.Addr]ServerInfo + wgClient *wgctrl.Client + ipt *iptables.IPTables + key wgtypes.Key + password string + ctx context.Context + waitGroup *sync.WaitGroup + wgBlock netip.Prefix + wgBlockPerIp uint + activeServers map[netip.Addr]ServerInfo + region string + internalNetworkCidr string // freeIndices and nextFreeIndex together track usage of the range 0..numWgBlocks freeIndices []uint16 // stack of indices that are free @@ -37,7 +39,7 @@ type ServerManager struct { } // NewServerManager creates a new server manager -func NewServerManager(wgBlock netip.Prefix, wgBlockPerIp uint, ctx context.Context, key wgtypes.Key, password string) (*ServerManager, error) { +func NewServerManager(wgBlock netip.Prefix, wgBlockPerIp uint, ctx context.Context, key wgtypes.Key, password string, region string, internalNetworkCidr string) (*ServerManager, error) { // Make a shared WireGuard client. wgClient, err := wgctrl.New() if err != nil { @@ -64,6 +66,8 @@ func NewServerManager(wgBlock netip.Prefix, wgBlockPerIp uint, ctx context.Conte sm.wgBlock = wgBlock.Masked() sm.wgBlockPerIp = wgBlockPerIp sm.activeServers = make(map[netip.Addr]ServerInfo) + sm.region = region + sm.internalNetworkCidr = internalNetworkCidr return sm, nil } @@ -104,14 +108,16 @@ func (sm *ServerManager) Start(ip netip.Addr) error { wgCidr := netip.PrefixFrom(subnetStart.Next(), int(sm.wgBlockPerIp)) srv := &Server{ - Key: sm.key, - BindAddr: ip, - Password: sm.password, - Index: i, - Ipt: sm.ipt, - WgClient: sm.wgClient, - WgCidr: wgCidr, - Ctx: subctx, + Key: sm.key, + BindAddr: ip, + Password: sm.password, + Index: i, + Ipt: sm.ipt, + WgClient: sm.wgClient, + WgCidr: wgCidr, + Ctx: subctx, + Region: sm.region, + InternalNetworkCidr: sm.internalNetworkCidr, } if err := srv.InitState(); err != nil { _ = cancel // cancel should be discarded