diff --git a/main.go b/main.go index cd235fb..8327299 100644 --- a/main.go +++ b/main.go @@ -8,141 +8,146 @@ import ( "net/url" "os" "strings" + + "github.com/asaskevich/govalidator" ) +func validateUrl(u string) (*url.URL, error) { + // Verify the target is HTTP or HTTPS. + if !strings.HasPrefix(u, "http://") && !strings.HasPrefix(u, "https://") { + return nil, fmt.Errorf("%s: must be an http(s) URL.\n", u) + } + + // Validate the URL. + if !govalidator.IsURL(u) { + return nil, fmt.Errorf("%s: must be a valid URL.\n", u) + } + + // Parse the target as a URL. + ua, err := url.Parse(u) + if err != nil { + panic(err) + } + + return ua, nil +} + func main() { - if len(os.Args) != 4 { - fmt.Println("usage: multireq ") + if len(os.Args) < 3 { + fmt.Fprintln(os.Stderr, "USAGE: multireq [listen-addr] [target]...") os.Exit(1) } listen := os.Args[1] targets := os.Args[2:] - if !strings.HasPrefix(targets[0], "http") { - fmt.Println("must specify http targets") - os.Exit(1) - } - if !strings.HasPrefix(targets[1], "http") { - fmt.Println("must specify http targets") - os.Exit(1) - } + urls := make([]*url.URL, len(targets)) - ua, err := url.Parse(targets[0]) - if err != nil { - panic(err) + failed := false + for i, v := range targets { + ua, err := validateUrl(v) + if err != nil { + fmt.Fprintln(os.Stderr, err) + failed = true + continue + } + urls[i] = ua } - ub, err := url.Parse(targets[1]) - if err != nil { - panic(err) + if failed { + os.Exit(1) } http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - resp_a := make(chan *http.Response) - resp_b := make(chan *http.Response) + // Allocate a channel for the HTTP responses. + responses := make(chan *http.Response, len(targets)) + + // Allocate a cancellation channel for the targets' HTTP requests. + cancels := make([]chan struct{}, len(targets)) - fail_a := make(chan struct{}) - fail_b := make(chan struct{}) - fail := make(chan struct{}) + // Allocate a failure channel to collect failed HTTP requests. + fails := make(chan bool, len(targets)) - cancel_a := make(chan struct{}) - cancel_b := make(chan struct{}) + // Channel to indicate all requests have failed. + allFailed := make(chan struct{}) r.RequestURI = "" r.URL.Scheme = "http" - go func() { - req_a := *r - req_a.URL.Host = ua.Host - req_a.URL, _ = url.Parse(r.URL.String()) - req_a.Cancel = cancel_a - - rt := &http.Transport{DisableKeepAlives: true} - resp, err := rt.RoundTrip(&req_a) - if err != nil { - log.Printf("target A failed: %s", err) - close(fail_a) - } else if resp.StatusCode >= 500 || resp.StatusCode == 408 { - log.Printf("target A unsatisfying status: %d", resp.StatusCode) - close(fail_a) - } else { - resp_a <- resp - } - }() + // Create and send out an HTTP request to each target. + for i := range targets { + req := *r + req.URL.Host = urls[i].Host + req.URL, _ = url.Parse(r.URL.String()) + cancels[i] = make(chan struct{}) + req.Cancel = cancels[i] + + go func() { + rt := &http.Transport{DisableKeepAlives: true} + res, err := rt.RoundTrip(&req) + + switch { + case err != nil: + log.Printf("request failed: %s\n", err) + fails <- true + case res.StatusCode >= 500 || res.StatusCode == 408: + log.Printf("target (%s) unsatisfying status: %d", req.URL, res.StatusCode) + fails <- true + default: + log.Printf("target (%s) responded with %d", req.URL, res.StatusCode) + responses <- res + } + }() + } + // Listen to the requests' failure channel; close it when all requests fail. go func() { - req_b := *r - req_b.URL, _ = url.Parse(r.URL.String()) - req_b.URL.Host = ub.Host - req_b.Cancel = cancel_b - - rt := &http.Transport{DisableKeepAlives: true} - resp, err := rt.RoundTrip(&req_b) - if err != nil { - log.Printf("target B failed: %s", err) - close(fail_b) - } else if resp.StatusCode >= 500 || resp.StatusCode == 408 { - log.Printf("target B unsatisfying status: %d", resp.StatusCode) - close(fail_b) - } else { - resp_b <- resp + for i := 0; i < len(targets); i++ { + _, ok := <-fails + if !ok { + return + } } - }() - go func() { - <-fail_a - <-fail_b - close(fail) + // Close the channel, signalling that all requests have failed. + close(allFailed) }() - var ra *http.Response - var rb *http.Response - var resp *http.Response - done := false - OuterLoop: + // Wait for a successful response (or failures across the board). for { select { - case ra = <-resp_a: - if !done { - done = true - resp = ra - log.Print("close cancel_b") - close(cancel_b) - break OuterLoop + case res := <-responses: + // Close the cancel channels. + for _, c := range cancels { + close(c) } - case rb = <-resp_b: - if !done { - done = true - resp = rb - log.Print("close cancel_a") - close(cancel_a) - break OuterLoop + + // Copy headers over. + for k, v := range res.Header { + w.Header()[k] = v } - case <-fail: - log.Print("both failed") - break OuterLoop - } - } + w.WriteHeader(res.StatusCode) - if !done { - w.WriteHeader(503) - return - } + written, err := io.Copy(w, res.Body) + if err != nil { + log.Printf("io.Copy error: %s", err) + } - for k, v := range resp.Header { - w.Header()[k] = v - } - w.WriteHeader(resp.StatusCode) - written, err := io.Copy(w, resp.Body) - if err != nil { - log.Printf("io.Copy error: %s", err) + log.Printf("io.Copy %d bytes written", written) + return + case _, ok := <-allFailed: + if !ok { + // All requests have met failure. + w.WriteHeader(503) + w.(http.Flusher).Flush() + return + } + } } - log.Printf("io.Copy %d bytes written", written) }) log.Printf("listening on %s", listen) - err = http.ListenAndServe(listen, nil) + err := http.ListenAndServe(listen, nil) if err != nil { - fmt.Println(err) + fmt.Fprintln(os.Stderr, err) os.Exit(1) } }