Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

handle N targets #2

Closed
wants to merge 6 commits into from
Closed
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 93 additions & 103 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,141 +8,131 @@ import (
"net/url"
"os"
"strings"

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, reflection isn't necessary here. Let's just use a single channel.

"github.com/asaskevich/govalidator"
)

func validateUrl(u string) (*url.URL, error) {
// Verify the target is HTTP or HTTPS.
if !strings.HasPrefix(u, "http://") && !strings.HasPrefix(u, "https://") {
return nil, fmt.Errorf("%s: must be an http(s) URL.\n", u)
}
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style nitpick: spaces after closing brackets.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what this means. Are you saying you want trailing spaces after the closing brackets?

(I assumed you meant newline, so I added those.)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like this:

if thing {
    return nil, error.Broken()
}

// comment
if nextThing {

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yeah, refreshed and saw it. thanks!


// Validate the URL.
if !govalidator.IsURL(u) {
return nil, fmt.Errorf("%s: must be a valid URL.\n", u)
}

// Parse the target as a URL.
ua, err := url.Parse(u)
if err != nil {
panic(err)
}

return ua, nil
}

func main() {
if len(os.Args) != 4 {
fmt.Println("usage: multireq <listen addr> <target A> <target B>")
if len(os.Args) < 3 {
fmt.Fprintln(os.Stderr, "USAGE: multireq [listen-addr] [target]...")
os.Exit(1)
}

listen := os.Args[1]
targets := os.Args[2:]
if !strings.HasPrefix(targets[0], "http") {
fmt.Println("must specify http targets")
os.Exit(1)
}
if !strings.HasPrefix(targets[1], "http") {
fmt.Println("must specify http targets")
os.Exit(1)
}
urls := make([]*url.URL, len(targets))

ua, err := url.Parse(targets[0])
if err != nil {
panic(err)
failed := false
for i, v := range targets {
ua, err := validateUrl(v)
if err != nil {
fmt.Fprintln(os.Stderr, err)
failed = true
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could just put the os.Exit in here

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True -- my intent here was to print out all of the malformed URLs before bailing.

continue
}
urls[i] = ua
}
ub, err := url.Parse(targets[1])
if err != nil {
panic(err)
if failed {
os.Exit(1)
}

http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
resp_a := make(chan *http.Response)
resp_b := make(chan *http.Response)
// Allocate a channel for the HTTP responses.
responses := make(chan *http.Response, len(targets))

fail_a := make(chan struct{})
fail_b := make(chan struct{})
fail := make(chan struct{})
// Allocate a cancellation channel for each target's HTTP request.
cancels := make([]chan struct{}, len(targets))

cancel_a := make(chan struct{})
cancel_b := make(chan struct{})
// Allocate a failure channel for each target's HTTP request.
fails := make([]chan struct{}, len(targets))

r.RequestURI = ""
r.URL.Scheme = "http"

go func() {
req_a := *r
req_a.URL.Host = ua.Host
req_a.URL, _ = url.Parse(r.URL.String())
req_a.Cancel = cancel_a

rt := &http.Transport{DisableKeepAlives: true}
resp, err := rt.RoundTrip(&req_a)
if err != nil {
log.Printf("target A failed: %s", err)
close(fail_a)
} else if resp.StatusCode >= 500 || resp.StatusCode == 408 {
log.Printf("target A unsatisfying status: %d", resp.StatusCode)
close(fail_a)
} else {
resp_a <- resp
}
}()
// Create and send out an HTTP request to each target.
for i := range targets {
req := *r
req.URL.Host = urls[i].Host
req.URL, _ = url.Parse(r.URL.String())
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this necessary?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh wait. i see, the host part changes

cancels[i] = make(chan struct{})
fails[i] = make(chan struct{})
req.Cancel = cancels[i]
fail := fails[i]

go func() {
rt := &http.Transport{DisableKeepAlives: true}
res, err := rt.RoundTrip(&req)
if err != nil {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets go with:

switch {
case err != nil:
...
case res.StatusCode >= 500 || res.StatusCode == 400:
...
default:
...
}

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh hey, I really like this pattern! Go's switches are really nice.

log.Printf("request failed: %s\n", err)
close(fail)
} else if res.StatusCode >= 500 || res.StatusCode == 408 {
log.Printf("target (%s) unsatisfying status: %d", req.URL, res.StatusCode)
close(fail)
} else {
log.Printf("target (%s) responded with %d", req.URL, res.StatusCode)
responses <- res
}
}()
}

// Listen to all requests' failure channels; issue a signal when all requests fail.
failed := make(chan struct{})
go func() {
req_b := *r
req_b.URL, _ = url.Parse(r.URL.String())
req_b.URL.Host = ub.Host
req_b.Cancel = cancel_b

rt := &http.Transport{DisableKeepAlives: true}
resp, err := rt.RoundTrip(&req_b)
if err != nil {
log.Printf("target B failed: %s", err)
close(fail_b)
} else if resp.StatusCode >= 500 || resp.StatusCode == 408 {
log.Printf("target B unsatisfying status: %d", resp.StatusCode)
close(fail_b)
} else {
resp_b <- resp
for i := range fails {
<-fails[i]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is back to the point where I don't think this goroutine is ever gonna exit. where do all the fails get closed?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see now. I had misunderstood how this piece worked -- corrected.

}
close(failed)
}()

go func() {
<-fail_a
<-fail_b
close(fail)
}()

var ra *http.Response
var rb *http.Response
var resp *http.Response
done := false
OuterLoop:
for {
select {
case ra = <-resp_a:
if !done {
done = true
resp = ra
log.Print("close cancel_b")
close(cancel_b)
break OuterLoop
}
case rb = <-resp_b:
if !done {
done = true
resp = rb
log.Print("close cancel_a")
close(cancel_a)
break OuterLoop
}
case <-fail:
log.Print("both failed")
break OuterLoop
// Wait for a successful response (or failures across the board).
select {
case res := <-responses:
// Close all of the pending request channels.
for _, c := range cancels {
close(c)
}
}
// Copy headers over.
for k, v := range res.Header {
w.Header()[k] = v
}
w.WriteHeader(res.StatusCode)

if !done {
written, err := io.Copy(w, res.Body)
if err != nil {
log.Printf("io.Copy error: %s", err)
}
log.Printf("io.Copy %d bytes written", written)
case <-failed:
// All requests have met failure.
w.WriteHeader(503)
return
}

for k, v := range resp.Header {
w.Header()[k] = v
}
w.WriteHeader(resp.StatusCode)
written, err := io.Copy(w, resp.Body)
if err != nil {
log.Printf("io.Copy error: %s", err)
w.(http.Flusher).Flush()
}
log.Printf("io.Copy %d bytes written", written)
})

log.Printf("listening on %s", listen)
err = http.ListenAndServe(listen, nil)
err := http.ListenAndServe(listen, nil)
if err != nil {
fmt.Println(err)
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
}