Skip to content

Commit

Permalink
Check content-type before importing to avoid importing unexpected con…
Browse files Browse the repository at this point in the history
…tents

This commit adds a check in the http-datasource format reader to avoid importing unexpected content-types. This avoid issues with certain servers with unexpected behavior.

Signed-off-by: Alvaro Romero <alromero@redhat.com>
  • Loading branch information
alromeros committed Sep 26, 2023
1 parent f7f95c5 commit 951735b
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 12 deletions.
27 changes: 24 additions & 3 deletions pkg/importer/http-datasource.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ const (
nbdkitPid = "/tmp/nbdkit.pid"
nbdkitSocket = "/tmp/nbdkit.sock"
defaultUserAgent = "cdi-golang-importer"
contentType = "Content-Type"
contentLenght = "Content-Length"
)

// HTTPDataSource is the data provider for http(s) endpoints.
Expand Down Expand Up @@ -96,7 +98,7 @@ func NewHTTPDataSource(endpoint, accessKey, secKey, certDir string, contentType
return nil, errors.Wrap(err, "Error getting extra headers for HTTP client")
}

httpReader, contentLength, brokenForQemuImg, err := createHTTPReader(ctx, ep, accessKey, secKey, certDir, extraHeaders, secretExtraHeaders)
httpReader, contentLength, brokenForQemuImg, err := createHTTPReader(ctx, ep, accessKey, secKey, certDir, extraHeaders, secretExtraHeaders, contentType)
if err != nil {
cancel()
return nil, err
Expand Down Expand Up @@ -294,7 +296,7 @@ func addExtraheaders(req *http.Request, extraHeaders []string) {
req.Header.Add("User-Agent", defaultUserAgent)
}

func createHTTPReader(ctx context.Context, ep *url.URL, accessKey, secKey, certDir string, extraHeaders, secretExtraHeaders []string) (io.ReadCloser, uint64, bool, error) {
func createHTTPReader(ctx context.Context, ep *url.URL, accessKey, secKey, certDir string, extraHeaders, secretExtraHeaders []string, contentType cdiv1.DataVolumeContentType) (io.ReadCloser, uint64, bool, error) {
var brokenForQemuImg bool
client, err := createHTTPClient(certDir)
if err != nil {
Expand Down Expand Up @@ -334,6 +336,14 @@ func createHTTPReader(ctx context.Context, ep *url.URL, accessKey, secKey, certD
return nil, uint64(0), true, errors.Errorf("expected status code 200, got %d. Status: %s", resp.StatusCode, resp.Status)
}

if contentType == cdiv1.DataVolumeKubeVirt {
// Don't proceed with the import if we are expecting a KubeVirt disk image
// but content-type is unexpected
if err := checkHTTPContentType(resp); err != nil {
return nil, uint64(0), true, err
}
}

acceptRanges, ok := resp.Header["Accept-Ranges"]
if !ok || acceptRanges[0] == "none" {
klog.V(2).Infof("Accept-Ranges isn't bytes, avoiding qemu-img")
Expand Down Expand Up @@ -416,7 +426,7 @@ func getContentLength(client *http.Client, ep *url.URL, accessKey, secKey string
func parseHTTPHeader(resp *http.Response) uint64 {
var err error
total := uint64(0)
if val, ok := resp.Header["Content-Length"]; ok {
if val, ok := resp.Header[contentLenght]; ok {
total, err = strconv.ParseUint(val[0], 10, 64)
if err != nil {
klog.Errorf("could not convert content length, got %v", err)
Expand All @@ -427,6 +437,17 @@ func parseHTTPHeader(resp *http.Response) uint64 {
return total
}

func checkHTTPContentType(resp *http.Response) error {
if val, ok := resp.Header[contentType]; ok {
// TODO: Improve this with a list of unwanted content types
if strings.HasPrefix(val[0], "text/") {
klog.Errorf("http: unexpected content type %s", val[0])
return errors.Errorf("unexpected content type %s. Aborting import", val[0])
}
}
return nil
}

// Check for any extra headers to pass along. Return secret headers separately so callers can suppress logging them.
func getExtraHeaders() ([]string, []string, error) {
extraHeaders := getExtraHeadersFromEnvironment()
Expand Down
54 changes: 45 additions & 9 deletions pkg/importer/http-datasource_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ var _ = Describe("Http client", func() {

var _ = Describe("Http reader", func() {
It("should fail when passed an invalid cert directory", func() {
_, total, _, err := createHTTPReader(context.Background(), nil, "", "", "/invalid", nil, nil)
_, total, _, err := createHTTPReader(context.Background(), nil, "", "", "/invalid", nil, nil, cdiv1.DataVolumeKubeVirt)
Expect(err).To(HaveOccurred())
Expect(uint64(0)).To(Equal(total))
})
Expand All @@ -249,7 +249,7 @@ var _ = Describe("Http reader", func() {
defer ts.Close()
ep, err := url.Parse(ts.URL)
Expect(err).ToNot(HaveOccurred())
r, total, _, err := createHTTPReader(context.Background(), ep, "user", "password", "", nil, nil)
r, total, _, err := createHTTPReader(context.Background(), ep, "user", "password", "", nil, nil, cdiv1.DataVolumeKubeVirt)
Expect(err).ToNot(HaveOccurred())
Expect(uint64(25)).To(Equal(total))
err = r.Close()
Expand All @@ -272,7 +272,7 @@ var _ = Describe("Http reader", func() {
defer ts.Close()
ep, err := url.Parse(ts.URL)
Expect(err).ToNot(HaveOccurred())
r, total, _, err := createHTTPReader(context.Background(), ep, "user", "password", "", nil, nil)
r, total, _, err := createHTTPReader(context.Background(), ep, "user", "password", "", nil, nil, cdiv1.DataVolumeKubeVirt)
Expect(err).ToNot(HaveOccurred())
Expect(uint64(25)).To(Equal(total))
err = r.Close()
Expand All @@ -294,7 +294,7 @@ var _ = Describe("Http reader", func() {
defer ts.Close()
ep, err := url.Parse(ts.URL)
Expect(err).ToNot(HaveOccurred())
r, total, brokenForQemuImg, err := createHTTPReader(context.Background(), ep, "", "", "", nil, nil)
r, total, brokenForQemuImg, err := createHTTPReader(context.Background(), ep, "", "", "", nil, nil, cdiv1.DataVolumeKubeVirt)
Expect(brokenForQemuImg).To(BeFalse())
Expect(err).ToNot(HaveOccurred())
Expect(uint64(25)).To(Equal(total))
Expand All @@ -315,13 +315,49 @@ var _ = Describe("Http reader", func() {
defer ts.Close()
ep, err := url.Parse(ts.URL)
Expect(err).ToNot(HaveOccurred())
r, total, _, err := createHTTPReader(context.Background(), ep, "", "", "", nil, nil)
r, total, _, err := createHTTPReader(context.Background(), ep, "", "", "", nil, nil, cdiv1.DataVolumeKubeVirt)
Expect(err).ToNot(HaveOccurred())
Expect(uint64(0)).To(Equal(total))
err = r.Close()
Expect(err).ToNot(HaveOccurred())
})

It("should error if we expect kubevirt disk img but Content-Type is text/html", func() {
redirTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer w.WriteHeader(http.StatusOK)
w.Header().Add("Content-Type", "text/html")
}))
defer redirTs.Close()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, redirTs.URL, http.StatusFound)
}))
defer ts.Close()
ep, err := url.Parse(ts.URL)
Expect(err).ToNot(HaveOccurred())
_, total, _, err := createHTTPReader(context.Background(), ep, "", "", "", nil, nil, cdiv1.DataVolumeKubeVirt)
Expect(err).To(HaveOccurred())
Expect(uint64(0)).To(Equal(total))
Expect("unexpected content type text/html. Aborting import").To(Equal(err.Error()))
})

It("should not care about Content-Type if we expect archive", func() {
redirTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer w.WriteHeader(http.StatusOK)
w.Header().Add("Content-Type", "text/html")
w.Header().Add("Content-Length", "25")
}))
defer redirTs.Close()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, redirTs.URL, http.StatusFound)
}))
defer ts.Close()
ep, err := url.Parse(ts.URL)
Expect(err).ToNot(HaveOccurred())
_, total, _, err := createHTTPReader(context.Background(), ep, "", "", "", nil, nil, cdiv1.DataVolumeArchive)
Expect(err).ToNot(HaveOccurred())
Expect(uint64(25)).To(Equal(total))
})

It("should continue even if HEAD is rejected, but mark broken for qemu-img", func() {
redirTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == "HEAD" {
Expand All @@ -339,7 +375,7 @@ var _ = Describe("Http reader", func() {
defer ts.Close()
ep, err := url.Parse(ts.URL)
Expect(err).ToNot(HaveOccurred())
r, total, brokenForQemuImg, err := createHTTPReader(context.Background(), ep, "", "", "", nil, nil)
r, total, brokenForQemuImg, err := createHTTPReader(context.Background(), ep, "", "", "", nil, nil, cdiv1.DataVolumeKubeVirt)
Expect(brokenForQemuImg).To(BeTrue())
Expect(err).ToNot(HaveOccurred())
Expect(uint64(25)).To(Equal(total))
Expand All @@ -359,7 +395,7 @@ var _ = Describe("Http reader", func() {
defer ts.Close()
ep, err := url.Parse(ts.URL)
Expect(err).ToNot(HaveOccurred())
r, total, brokenForQemuImg, err := createHTTPReader(context.Background(), ep, "", "", "", nil, nil)
r, total, brokenForQemuImg, err := createHTTPReader(context.Background(), ep, "", "", "", nil, nil, cdiv1.DataVolumeKubeVirt)
Expect(brokenForQemuImg).To(BeTrue())
Expect(err).ToNot(HaveOccurred())
Expect(uint64(25)).To(Equal(total))
Expand All @@ -374,7 +410,7 @@ var _ = Describe("Http reader", func() {
defer ts.Close()
ep, err := url.Parse(ts.URL)
Expect(err).ToNot(HaveOccurred())
_, total, _, err := createHTTPReader(context.Background(), ep, "", "", "", nil, nil)
_, total, _, err := createHTTPReader(context.Background(), ep, "", "", "", nil, nil, cdiv1.DataVolumeKubeVirt)
Expect(err).To(HaveOccurred())
Expect(uint64(0)).To(Equal(total))
Expect("expected status code 200, got 500. Status: 500 Internal Server Error").To(Equal(err.Error()))
Expand All @@ -391,7 +427,7 @@ var _ = Describe("Http reader", func() {
defer ts.Close()
ep, err := url.Parse(ts.URL)
Expect(err).ToNot(HaveOccurred())
r, total, _, err := createHTTPReader(context.Background(), ep, "", "", "", []string{"Extra-Header: 123"}, nil)
r, total, _, err := createHTTPReader(context.Background(), ep, "", "", "", []string{"Extra-Header: 123"}, nil, cdiv1.DataVolumeKubeVirt)
Expect(err).ToNot(HaveOccurred())
Expect(uint64(0)).To(Equal(total))
err = r.Close()
Expand Down

0 comments on commit 951735b

Please sign in to comment.