Skip to content

Commit

Permalink
lib,cmd: add -connect-to flag
Browse files Browse the repository at this point in the history
Closes #692, #691, #575

Co-authored-by: dank@kegel.com
Co-authored-by: Antonio M. Amaya <antoniomanuel.amaya@bbva.com>
  • Loading branch information
tsenart and AntonioMA committed Jul 29, 2024
1 parent ce93cd8 commit 647dacc
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 1 deletion.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ attack command:
TLS client PEM encoded certificate file
-chunked
Send body with chunked transfer encoding
-connect-to value
A mapping of (ip|host):port to use instead of a target URL's (ip|host):port. Can be repeated multiple times.
Identical src:port with different dst:port will round-robin over the different dst:port pairs.
Example: google.com:80:localhost:6060
-connections int
Max open idle connections per target host (default 10000)
-dns-ttl value
Expand Down Expand Up @@ -178,7 +182,6 @@ examples:
vegeta report -type=json results.bin > metrics.json
cat results.bin | vegeta plot > plot.html
cat results.bin | vegeta report -type="hist[0,100ms,200ms,300ms]"

```

#### `-cpus`
Expand Down
3 changes: 3 additions & 0 deletions attack.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ func attackCmd() command {
fs.StringVar(&opts.promAddr, "prometheus-addr", "", "Prometheus exporter listen address [empty = disabled]. Example: 0.0.0.0:8880")
fs.Var(&dnsTTLFlag{&opts.dnsTTL}, "dns-ttl", "Cache DNS lookups for the given duration [-1 = disabled, 0 = forever]")
fs.BoolVar(&opts.sessionTickets, "session-tickets", false, "Enable TLS session resumption using session tickets")
fs.Var(&connectToFlag{&opts.connectTo}, "connect-to", "A mapping of (ip|host):port to use instead of a target URL's (ip|host):port. Can be repeated multiple times.\nIdentical src:port with different dst:port will round-robin over the different dst:port pairs.\nExample: google.com:80:localhost:6060")
systemSpecificFlags(fs, opts)

return command{fs, func(args []string) error {
Expand Down Expand Up @@ -108,6 +109,7 @@ type attackOpts struct {
promAddr string
dnsTTL time.Duration
sessionTickets bool
connectTo map[string][]string
}

// attack validates the attack arguments, sets up the
Expand Down Expand Up @@ -218,6 +220,7 @@ func attack(opts *attackOpts) (err error) {
vegeta.ProxyHeader(proxyHdr),
vegeta.ChunkedBody(opts.chunked),
vegeta.DNSCaching(opts.dnsTTL),
vegeta.ConnectTo(opts.connectTo),
vegeta.SessionTickets(opts.sessionTickets),
)

Expand Down
52 changes: 52 additions & 0 deletions flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"math"
"net"
"net/http"
"sort"
"strconv"
"strings"
"time"
Expand Down Expand Up @@ -153,3 +154,54 @@ func (f *dnsTTLFlag) String() string {
}
return f.ttl.String()
}

const connectToFormat = "src:port:dst:port"

type connectToFlag struct {
addrMap *map[string][]string
}

func (c *connectToFlag) String() string {
if c.addrMap == nil {
return ""
}

addrMappings := make([]string, 0, len(*c.addrMap))
for k, v := range *c.addrMap {
addrMappings = append(addrMappings, k+":"+strings.Join(v, ","))
}

sort.Strings(addrMappings)
return strings.Join(addrMappings, ";")
}

func (c *connectToFlag) Set(s string) error {
if c.addrMap == nil {
return nil
}

if *c.addrMap == nil {
*c.addrMap = make(map[string][]string)
}

parts := strings.Split(s, ":")
if len(parts) != 4 {
return fmt.Errorf("invalid -connect-to %q, expected format: %s", s, connectToFormat)
}
srcAddr := parts[0] + ":" + parts[1]
dstAddr := parts[2] + ":" + parts[3]

// Parse source address
if _, _, err := net.SplitHostPort(srcAddr); err != nil {
return fmt.Errorf("invalid source address expression [%s], expected address:port", srcAddr)
}

// Parse destination address
if _, _, err := net.SplitHostPort(dstAddr); err != nil {
return fmt.Errorf("invalid destination address expression [%s], expected address:port", dstAddr)
}

(*c.addrMap)[srcAddr] = append((*c.addrMap)[srcAddr], dstAddr)

return nil
}
42 changes: 42 additions & 0 deletions lib/attack.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ type Attacker struct {
maxWorkers uint64
maxBody int64
redirects int
seqmu sync.Mutex
seq uint64
began time.Time
chunked bool
}

Expand Down Expand Up @@ -272,6 +275,45 @@ func ProxyHeader(h http.Header) func(*Attacker) {
}
}

// ConnectTo returns a functional option which makes the attacker use the
// passed in map to translate target addr:port pairs. When used with DNSCaching,
// it must be used after it.
func ConnectTo(addrMap map[string][]string) func(*Attacker) {
return func(a *Attacker) {
if len(addrMap) == 0 {
return
}

tr, ok := a.client.Transport.(*http.Transport)
if !ok {
return
}

dial := tr.DialContext
if dial == nil {
dial = a.dialer.DialContext
}

type roundRobin struct {
addrs []string
n int
}

connectTo := make(map[string]*roundRobin, len(addrMap))
for k, v := range addrMap {
connectTo[k] = &roundRobin{addrs: v}
}

tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
if cm, ok := connectTo[addr]; ok {
cm.n = (cm.n + 1) % len(cm.addrs)
addr = cm.addrs[cm.n]
}
return dial(ctx, network, addr)
}
}
}

// DNSCaching returns a functional option that enables DNS caching for
// the given ttl. When ttl is zero cached entries will never expire.
// When ttl is non-zero, this will start a refresh go-routine that updates
Expand Down
64 changes: 64 additions & 0 deletions lib/attack_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"reflect"
"strconv"
"strings"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -498,3 +499,66 @@ func TestFirstOfEachIPFamily(t *testing.T) {
})
}
}

func TestAttackConnectTo(t *testing.T) {
t.Parallel()
var mu sync.Mutex
hits := make(map[string]int)
srvs := make(map[string]int)

handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mu.Lock()
hits[r.Host]++
mu.Unlock()
})

addrs := make([]string, 3)
for i := range addrs {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
addrs[i] = ln.Addr().String()

srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mu.Lock()
srvs[ln.Addr().String()]++
mu.Unlock()
handler.ServeHTTP(w, r)
}))

srv.Listener = ln
srv.Start()
t.Cleanup(srv.Close)
}

tr := NewStaticTargeter(
Target{Method: "GET", URL: "http://sapo.pt:80"},
Target{Method: "GET", URL: "http://sapo.pt:80"},
Target{Method: "GET", URL: "http://sapo.pt:80"},
Target{Method: "GET", URL: "http://" + addrs[0]},
)

atk := NewAttacker(
KeepAlive(false),
ConnectTo(map[string][]string{"sapo.pt:80": addrs}),
)

a := &attack{name: "TEST", began: time.Now()}
for i := 0; i < 4; i++ {
resp := atk.hit(tr, a)
if resp.Error != "" {
t.Fatal(resp.Error)
}
}

want := map[string]int{"sapo.pt:80": 3, addrs[0]: 1}
if diff := cmp.Diff(want, hits); diff != "" {
t.Errorf("unexpected hits (-want +got):\n%s", diff)
}

want = map[string]int{addrs[0]: 2, addrs[1]: 1, addrs[2]: 1}
if diff := cmp.Diff(want, srvs); diff != "" {
t.Errorf("unexpected hits (-want +got):\n%s", diff)
}
}

0 comments on commit 647dacc

Please sign in to comment.