diff --git a/main.go b/main.go index c40344d..18d973b 100644 --- a/main.go +++ b/main.go @@ -10,6 +10,8 @@ import ( "net" "net/http" "os" + "strconv" + "time" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/collectors" @@ -67,6 +69,8 @@ func main() { } } +const DEFAULT_CHECK_TIMEOUT = 60 + func startServer(ctx context.Context, d *daemon, tcpListener, metricsUsername, metricPassword string) error { log.Printf("Starting %s %s\n", name, version) l, err := net.Listen("tcp", tcpListener) @@ -91,18 +95,31 @@ func startServer(ctx context.Context, d *daemon, tcpListener, metricsUsername, m maStr := r.URL.Query().Get("multiaddr") cidStr := r.URL.Query().Get("cid") + timeoutStr := r.URL.Query().Get("timeoutSeconds") if cidStr == "" { - err = errors.New("missing 'cid' argument") + err = errors.New("missing 'cid' query parameter") + } + + timeout := DEFAULT_CHECK_TIMEOUT + if timeoutStr != "" { + timeout, err = strconv.Atoi(timeoutStr) + if err != nil { + http.Error(w, "Invalid timeout value (in seconds)", http.StatusBadRequest) + return + } } + log.Printf("Checking %s with timeout %d seconds", cidStr, timeout) + withTimeout, cancel := context.WithTimeout(r.Context(), time.Duration(timeout)*time.Second) + defer cancel() var err error var data interface{} if maStr == "" { - data, err = d.runCidCheck(r.Context(), cidStr) + data, err = d.runCidCheck(withTimeout, cidStr) } else { - data, err = d.runPeerCheck(r.Context(), maStr, cidStr) + data, err = d.runPeerCheck(withTimeout, maStr, cidStr) } if err == nil { diff --git a/web/index.html b/web/index.html index 32bcfdd..db7ec00 100644 --- a/web/index.html +++ b/web/index.html @@ -36,6 +36,11 @@