Skip to content

Commit

Permalink
Merge pull request #5 from useblacksmith/add-region
Browse files Browse the repository at this point in the history
*: add a region flag to vprox and fix some internal network interface…
  • Loading branch information
adityamaru authored Nov 25, 2024
2 parents b01607d + 0b08cb9 commit 1233ced
Show file tree
Hide file tree
Showing 5 changed files with 247 additions and 162 deletions.
19 changes: 19 additions & 0 deletions .github/workflows/build.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
name: Build

on:
push:
pull_request:

jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3

- name: Set up Go
uses: actions/setup-go@v4
with:
go-version: '1.23'

- name: Build
run: CGO_ENABLED=0 go build
95 changes: 21 additions & 74 deletions cmd/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@ var ServerCmd = &cobra.Command{
}

var serverCmdArgs struct {
ip []string
wgBlock string
wgBlockPerIp string
cloud string
ip []string
wgBlock string
wgBlockPerIp string
cloud string
region string
internalNetworkCidr string
}

func init() {
Expand All @@ -39,6 +41,10 @@ func init() {
"", "WireGuard block size for each --ip flag, if multiple are provided")
ServerCmd.Flags().StringVar(&serverCmdArgs.cloud, "cloud",
"", "Cloud provider for IP metadata (watches for changes)")
ServerCmd.Flags().StringVar(&serverCmdArgs.region, "region", "",
"Region of the server")
ServerCmd.Flags().StringVar(&serverCmdArgs.internalNetworkCidr, "internal-network-cidr",
"10.0.0.0/24", "Internal network CIDR to route to")
}

func runServer(cmd *cobra.Command, args []string) error {
Expand All @@ -56,6 +62,9 @@ func runServer(cmd *cobra.Command, args []string) error {
if serverCmdArgs.wgBlock == "" {
return errors.New("missing required flag: --wg-block")
}
if serverCmdArgs.region == "" {
return errors.New("missing required flag: --region")
}

wgBlock, err := netip.ParsePrefix(serverCmdArgs.wgBlock)
if err != nil || !wgBlock.Addr().Is4() {
Expand Down Expand Up @@ -97,7 +106,7 @@ func runServer(cmd *cobra.Command, args []string) error {

ctx, done := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)

sm, err := lib.NewServerManager(wgBlock, wgBlockPerIp, ctx, key, password)
sm, err := lib.NewServerManager(wgBlock, wgBlockPerIp, ctx, key, password, serverCmdArgs.region, serverCmdArgs.internalNetworkCidr)
if err != nil {
done()
return err
Expand All @@ -106,26 +115,17 @@ func runServer(cmd *cobra.Command, args []string) error {
defer sm.Wait()
defer done()

if cloud == "aws" {
initialIps, err := pollAws(lib.NewAwsMetadata(), make(ipSet), sm)
for _, ipStr := range serverCmdArgs.ip {
ip, err := netip.ParseAddr(ipStr)
if err != nil || !ip.Is4() {
return fmt.Errorf("invalid IPv4 address: %q", ipStr)
}
err = sm.Start(ip)
if err != nil {
return err
}

pollAwsLoop(ctx, sm, initialIps)
} else {
for _, ipStr := range serverCmdArgs.ip {
ip, err := netip.ParseAddr(ipStr)
if err != nil || !ip.Is4() {
return fmt.Errorf("invalid IPv4 address: %q", ipStr)
}
err = sm.Start(ip)
if err != nil {
return err
}
}
sm.Wait()
}
sm.Wait()

return nil
}
Expand All @@ -145,56 +145,3 @@ func parseIpSet(ipStrs []string) (ipSet, error) {
}
return m, nil
}

// pollAws gets the current set of IP associations from AWS and starts/stops the
// server for those IPs.
func pollAws(awsClient *lib.AwsMetadata, currentIps ipSet, sm *lib.ServerManager) (ipSet, error) {
interfaces, err := awsClient.GetAddresses()

if err != nil {
return currentIps, fmt.Errorf("failed to get AWS MAC addresses: %v", err)
}

newIps, err := parseIpSet(interfaces[0].PrivateIps)
if err != nil {
return currentIps, err
}

for ip := range currentIps {
if _, ok := newIps[ip]; !ok {
sm.Stop(ip)
delete(currentIps, ip)
}
}

for ip := range newIps {
if _, ok := currentIps[ip]; !ok {
if err := sm.Start(ip); err != nil {
return currentIps, fmt.Errorf("error starting new ip: %v", err)
}
currentIps[ip] = struct{}{}
}
}
return currentIps, nil
}

// pollAwsLoop polls AWS in a blocking loop on an interval of AWS_POLL_DURATION
// until ctx is done.
func pollAwsLoop(ctx context.Context, sm *lib.ServerManager, initialIps ipSet) {
currentIps := initialIps
awsClient := lib.NewAwsMetadata()
ticker := time.NewTicker(awsPollDuration)

for {
select {
case <-ctx.Done():
return
case <-ticker.C:
var err error
currentIps, err = pollAws(awsClient, currentIps, sm)
if err != nil {
fmt.Printf("error during aws poll: %v", err)
}
}
}
}
94 changes: 94 additions & 0 deletions lib/iputils.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"log"
"net"
"net/netip"
"strings"
"sync"

"github.com/vishvananda/netlink"
Expand Down Expand Up @@ -42,6 +43,99 @@ func getDefaultInterface() (netlink.Link, error) {
return nil, fmt.Errorf("failed to find default route")
}

func getInternalInterface() (netlink.Link, error) {
links, err := netlink.LinkList()
if err != nil {
return nil, fmt.Errorf("failed to list interfaces: %v", err)
}

// Define interface types to skip
skipTypes := map[string]bool{
"veth": true,
"bridge": true,
"dummy": true,
"wireguard": true,
"docker": true,
}

var candidates []netlink.Link
for _, link := range links {
// Skip interfaces that aren't up
if link.Attrs().OperState != netlink.OperUp {
continue
}

// Skip certain interface types
if skipTypes[link.Type()] {
continue
}

// Skip interfaces with certain name patterns
name := link.Attrs().Name
if strings.HasPrefix(name, "veth") ||
strings.HasPrefix(name, "docker") ||
strings.HasPrefix(name, "wg") {
continue
}

addrs, err := netlink.AddrList(link, netlink.FAMILY_V4)
if err != nil {
log.Printf("failed to get addresses for interface %s: %v", name, err)
continue
}

// Check for RFC1918 addresses
for _, addr := range addrs {
if isRFC1918(addr.IP) {
candidates = append(candidates, link)
break
}
}
}

// If we found candidates, prefer physical interfaces over others
for _, link := range candidates {
name := link.Attrs().Name
// Common patterns for physical interfaces
if strings.HasPrefix(name, "en") ||
strings.HasPrefix(name, "eth") ||
strings.HasPrefix(name, "enp") ||
strings.HasPrefix(name, "ens") {
return link, nil
}
}

// If no physical interface found, return the first candidate
if len(candidates) > 0 {
return candidates[0], nil
}

return nil, fmt.Errorf("failed to find internal interface")
}

// Helper function to check if an interface is a VLAN interface
func isVLANInterface(link netlink.Link) bool {
_, isVlan := link.(*netlink.Vlan)
if !isVlan {
// Also check the name for "@" which indicates a VLAN interface in Linux
return strings.Contains(link.Attrs().Name, ".")
}
return true
}

// Helper function to check RFC1918 addresses
func isRFC1918(ip net.IP) bool {
if ip4 := ip.To4(); ip4 != nil {
// Convert to 32-bit integer for easier comparison
ipInt := uint32(ip4[0])<<24 | uint32(ip4[1])<<16 | uint32(ip4[2])<<8 | uint32(ip4[3])

return (ipInt >= 0x0A000000 && ipInt <= 0x0AFFFFFF) || // 10.0.0.0/8
(ipInt >= 0xAC100000 && ipInt <= 0xAC1FFFFF) || // 172.16.0.0/12
(ipInt >= 0xC0A80000 && ipInt <= 0xC0A8FFFF) // 192.168.0.0/16
}
return false
}

// IpAllocator is a simple IP address allocator that produces IP addresses
// within a prefix, in increasing order of available IPs.
//
Expand Down
Loading

0 comments on commit 1233ced

Please sign in to comment.