diff --git a/pkg/importer/http-datasource.go b/pkg/importer/http-datasource.go index 10316147eb..a60161cc78 100644 --- a/pkg/importer/http-datasource.go +++ b/pkg/importer/http-datasource.go @@ -48,6 +48,8 @@ const ( nbdkitPid = "/tmp/nbdkit.pid" nbdkitSocket = "/tmp/nbdkit.sock" defaultUserAgent = "cdi-golang-importer" + contentType = "Content-Type" + contentLenght = "Content-Length" ) // HTTPDataSource is the data provider for http(s) endpoints. @@ -96,7 +98,7 @@ func NewHTTPDataSource(endpoint, accessKey, secKey, certDir string, contentType return nil, errors.Wrap(err, "Error getting extra headers for HTTP client") } - httpReader, contentLength, brokenForQemuImg, err := createHTTPReader(ctx, ep, accessKey, secKey, certDir, extraHeaders, secretExtraHeaders) + httpReader, contentLength, brokenForQemuImg, err := createHTTPReader(ctx, ep, accessKey, secKey, certDir, extraHeaders, secretExtraHeaders, contentType) if err != nil { cancel() return nil, err @@ -294,7 +296,7 @@ func addExtraheaders(req *http.Request, extraHeaders []string) { req.Header.Add("User-Agent", defaultUserAgent) } -func createHTTPReader(ctx context.Context, ep *url.URL, accessKey, secKey, certDir string, extraHeaders, secretExtraHeaders []string) (io.ReadCloser, uint64, bool, error) { +func createHTTPReader(ctx context.Context, ep *url.URL, accessKey, secKey, certDir string, extraHeaders, secretExtraHeaders []string, contentType cdiv1.DataVolumeContentType) (io.ReadCloser, uint64, bool, error) { var brokenForQemuImg bool client, err := createHTTPClient(certDir) if err != nil { @@ -334,6 +336,14 @@ func createHTTPReader(ctx context.Context, ep *url.URL, accessKey, secKey, certD return nil, uint64(0), true, errors.Errorf("expected status code 200, got %d. Status: %s", resp.StatusCode, resp.Status) } + if contentType == cdiv1.DataVolumeKubeVirt { + // Don't proceed with the import if we are expecting a KubeVirt disk image + // but content-type is unexpected + if err := checkHTTPContentType(resp); err != nil { + return nil, uint64(0), true, err + } + } + acceptRanges, ok := resp.Header["Accept-Ranges"] if !ok || acceptRanges[0] == "none" { klog.V(2).Infof("Accept-Ranges isn't bytes, avoiding qemu-img") @@ -416,7 +426,7 @@ func getContentLength(client *http.Client, ep *url.URL, accessKey, secKey string func parseHTTPHeader(resp *http.Response) uint64 { var err error total := uint64(0) - if val, ok := resp.Header["Content-Length"]; ok { + if val, ok := resp.Header[contentLenght]; ok { total, err = strconv.ParseUint(val[0], 10, 64) if err != nil { klog.Errorf("could not convert content length, got %v", err) @@ -427,6 +437,17 @@ func parseHTTPHeader(resp *http.Response) uint64 { return total } +func checkHTTPContentType(resp *http.Response) error { + if val, ok := resp.Header[contentType]; ok { + // TODO: Improve this with a list of unwanted content types + if strings.HasPrefix(val[0], "text/") { + klog.Errorf("http: unexpected content type %s", val[0]) + return errors.Errorf("unexpected content type %s. Aborting import", val[0]) + } + } + return nil +} + // Check for any extra headers to pass along. Return secret headers separately so callers can suppress logging them. func getExtraHeaders() ([]string, []string, error) { extraHeaders := getExtraHeadersFromEnvironment() diff --git a/pkg/importer/http-datasource_test.go b/pkg/importer/http-datasource_test.go index 36dbe1fc7a..ea2de4f102 100644 --- a/pkg/importer/http-datasource_test.go +++ b/pkg/importer/http-datasource_test.go @@ -232,7 +232,7 @@ var _ = Describe("Http client", func() { var _ = Describe("Http reader", func() { It("should fail when passed an invalid cert directory", func() { - _, total, _, err := createHTTPReader(context.Background(), nil, "", "", "/invalid", nil, nil) + _, total, _, err := createHTTPReader(context.Background(), nil, "", "", "/invalid", nil, nil, cdiv1.DataVolumeKubeVirt) Expect(err).To(HaveOccurred()) Expect(uint64(0)).To(Equal(total)) }) @@ -249,7 +249,7 @@ var _ = Describe("Http reader", func() { defer ts.Close() ep, err := url.Parse(ts.URL) Expect(err).ToNot(HaveOccurred()) - r, total, _, err := createHTTPReader(context.Background(), ep, "user", "password", "", nil, nil) + r, total, _, err := createHTTPReader(context.Background(), ep, "user", "password", "", nil, nil, cdiv1.DataVolumeKubeVirt) Expect(err).ToNot(HaveOccurred()) Expect(uint64(25)).To(Equal(total)) err = r.Close() @@ -272,7 +272,7 @@ var _ = Describe("Http reader", func() { defer ts.Close() ep, err := url.Parse(ts.URL) Expect(err).ToNot(HaveOccurred()) - r, total, _, err := createHTTPReader(context.Background(), ep, "user", "password", "", nil, nil) + r, total, _, err := createHTTPReader(context.Background(), ep, "user", "password", "", nil, nil, cdiv1.DataVolumeKubeVirt) Expect(err).ToNot(HaveOccurred()) Expect(uint64(25)).To(Equal(total)) err = r.Close() @@ -294,7 +294,7 @@ var _ = Describe("Http reader", func() { defer ts.Close() ep, err := url.Parse(ts.URL) Expect(err).ToNot(HaveOccurred()) - r, total, brokenForQemuImg, err := createHTTPReader(context.Background(), ep, "", "", "", nil, nil) + r, total, brokenForQemuImg, err := createHTTPReader(context.Background(), ep, "", "", "", nil, nil, cdiv1.DataVolumeKubeVirt) Expect(brokenForQemuImg).To(BeFalse()) Expect(err).ToNot(HaveOccurred()) Expect(uint64(25)).To(Equal(total)) @@ -315,13 +315,49 @@ var _ = Describe("Http reader", func() { defer ts.Close() ep, err := url.Parse(ts.URL) Expect(err).ToNot(HaveOccurred()) - r, total, _, err := createHTTPReader(context.Background(), ep, "", "", "", nil, nil) + r, total, _, err := createHTTPReader(context.Background(), ep, "", "", "", nil, nil, cdiv1.DataVolumeKubeVirt) Expect(err).ToNot(HaveOccurred()) Expect(uint64(0)).To(Equal(total)) err = r.Close() Expect(err).ToNot(HaveOccurred()) }) + It("should error if we expect kubevirt disk img but Content-Type is text/html", func() { + redirTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer w.WriteHeader(http.StatusOK) + w.Header().Add("Content-Type", "text/html") + })) + defer redirTs.Close() + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, redirTs.URL, http.StatusFound) + })) + defer ts.Close() + ep, err := url.Parse(ts.URL) + Expect(err).ToNot(HaveOccurred()) + _, total, _, err := createHTTPReader(context.Background(), ep, "", "", "", nil, nil, cdiv1.DataVolumeKubeVirt) + Expect(err).To(HaveOccurred()) + Expect(uint64(0)).To(Equal(total)) + Expect("unexpected content type text/html. Aborting import").To(Equal(err.Error())) + }) + + It("should not care about Content-Type if we expect archive", func() { + redirTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer w.WriteHeader(http.StatusOK) + w.Header().Add("Content-Type", "text/html") + w.Header().Add("Content-Length", "25") + })) + defer redirTs.Close() + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, redirTs.URL, http.StatusFound) + })) + defer ts.Close() + ep, err := url.Parse(ts.URL) + Expect(err).ToNot(HaveOccurred()) + _, total, _, err := createHTTPReader(context.Background(), ep, "", "", "", nil, nil, cdiv1.DataVolumeArchive) + Expect(err).ToNot(HaveOccurred()) + Expect(uint64(25)).To(Equal(total)) + }) + It("should continue even if HEAD is rejected, but mark broken for qemu-img", func() { redirTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method == "HEAD" { @@ -339,7 +375,7 @@ var _ = Describe("Http reader", func() { defer ts.Close() ep, err := url.Parse(ts.URL) Expect(err).ToNot(HaveOccurred()) - r, total, brokenForQemuImg, err := createHTTPReader(context.Background(), ep, "", "", "", nil, nil) + r, total, brokenForQemuImg, err := createHTTPReader(context.Background(), ep, "", "", "", nil, nil, cdiv1.DataVolumeKubeVirt) Expect(brokenForQemuImg).To(BeTrue()) Expect(err).ToNot(HaveOccurred()) Expect(uint64(25)).To(Equal(total)) @@ -359,7 +395,7 @@ var _ = Describe("Http reader", func() { defer ts.Close() ep, err := url.Parse(ts.URL) Expect(err).ToNot(HaveOccurred()) - r, total, brokenForQemuImg, err := createHTTPReader(context.Background(), ep, "", "", "", nil, nil) + r, total, brokenForQemuImg, err := createHTTPReader(context.Background(), ep, "", "", "", nil, nil, cdiv1.DataVolumeKubeVirt) Expect(brokenForQemuImg).To(BeTrue()) Expect(err).ToNot(HaveOccurred()) Expect(uint64(25)).To(Equal(total)) @@ -374,7 +410,7 @@ var _ = Describe("Http reader", func() { defer ts.Close() ep, err := url.Parse(ts.URL) Expect(err).ToNot(HaveOccurred()) - _, total, _, err := createHTTPReader(context.Background(), ep, "", "", "", nil, nil) + _, total, _, err := createHTTPReader(context.Background(), ep, "", "", "", nil, nil, cdiv1.DataVolumeKubeVirt) Expect(err).To(HaveOccurred()) Expect(uint64(0)).To(Equal(total)) Expect("expected status code 200, got 500. Status: 500 Internal Server Error").To(Equal(err.Error())) @@ -391,7 +427,7 @@ var _ = Describe("Http reader", func() { defer ts.Close() ep, err := url.Parse(ts.URL) Expect(err).ToNot(HaveOccurred()) - r, total, _, err := createHTTPReader(context.Background(), ep, "", "", "", []string{"Extra-Header: 123"}, nil) + r, total, _, err := createHTTPReader(context.Background(), ep, "", "", "", []string{"Extra-Header: 123"}, nil, cdiv1.DataVolumeKubeVirt) Expect(err).ToNot(HaveOccurred()) Expect(uint64(0)).To(Equal(total)) err = r.Close()