Skip to content

Commit

Permalink
fix: fill uuid argument correctly in the config download URL
Browse files Browse the repository at this point in the history
It was broken, because `?uuid=` URL parses to `{"uuid": []string{""}}`.

Signed-off-by: Andrey Smirnov <smirnov.andrey@gmail.com>
(cherry picked from commit 3d77266)
  • Loading branch information
smira committed Jul 1, 2021
1 parent d6c5e50 commit 8aed6c2
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 33 deletions.
58 changes: 31 additions & 27 deletions internal/app/machined/pkg/runtime/v1alpha1/platform/metal/metal.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"net"
"net/url"
"path/filepath"
"strings"

"github.com/google/uuid"
"github.com/talos-systems/go-blockdevice/blockdevice/filesystem"
Expand Down Expand Up @@ -51,45 +52,48 @@ func (m *Metal) Configuration(ctx context.Context) ([]byte, error) {

log.Printf("fetching machine config from: %q", *option)

u, err := url.Parse(*option)
downloadURL, err := PopulateURLParameters(*option, getSystemUUID)
if err != nil {
return nil, fmt.Errorf("failed to parse %s: %w", constants.KernelParamConfig, err)
return nil, err
}

values := u.Query()
switch downloadURL {
case constants.MetalConfigISOLabel:
return readConfigFromISO()
default:
return download.Download(ctx, downloadURL)
}
}

if len(values) > 0 {
for key, qValues := range values {
switch key {
case "uuid":
if len(qValues) != 1 {
uid, err := getSystemUUID()
if err != nil {
return nil, err
}
// PopulateURLParameters fills in empty parameters in the download URL.
func PopulateURLParameters(downloadURL string, getSystemUUID func() (uuid.UUID, error)) (string, error) {
u, err := url.Parse(downloadURL)
if err != nil {
return "", fmt.Errorf("failed to parse %s: %w", constants.KernelParamConfig, err)
}

values.Set("uuid", uid.String())
values := u.Query()

break
for key, qValues := range values {
switch key {
case "uuid":
// don't touch uuid field if it already has some value
if !(len(qValues) == 1 && len(strings.TrimSpace(qValues[0])) > 0) {
uid, err := getSystemUUID()
if err != nil {
return "", err
}

values.Set("uuid", qValues[0])
default:
log.Printf("unsupported query parameter: %q", key)
values.Set("uuid", uid.String())
}
default:
log.Printf("unsupported query parameter: %q", key)
}

u.RawQuery = values.Encode()

*option = u.String()
}

switch *option {
case constants.MetalConfigISOLabel:
return readConfigFromISO()
default:
return download.Download(ctx, *option)
}
u.RawQuery = values.Encode()

return u.String(), nil
}

func getSystemUUID() (uuid.UUID, error) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,63 @@

package metal_test

import "testing"
import (
"fmt"
"testing"

func TestEmpty(t *testing.T) {
// added for accurate coverage estimation
//
// please remove it once any unit-test is added
// for this package
"github.com/google/uuid"
"github.com/stretchr/testify/assert"

"github.com/talos-systems/talos/internal/app/machined/pkg/runtime/v1alpha1/platform/metal"
)

func TestPopulateURLParameters(t *testing.T) {
mockUUID := uuid.New()

for _, tt := range []struct {
name string
url string
expectedURL string
expectedError string
}{
{
name: "no uuid",
url: "http://example.com/metadata",
expectedURL: "http://example.com/metadata",
},
{
name: "empty uuid",
url: "http://example.com/metadata?uuid=",
expectedURL: fmt.Sprintf("http://example.com/metadata?uuid=%s", mockUUID.String()),
},
{
name: "uuid present",
url: "http://example.com/metadata?uuid=xyz",
expectedURL: "http://example.com/metadata?uuid=xyz",
},
{
name: "other parameters",
url: "http://example.com/metadata?foo=a",
expectedURL: "http://example.com/metadata?foo=a",
},
{
name: "multiple uuids",
url: "http://example.com/metadata?uuid=xyz&uuid=foo",
expectedURL: fmt.Sprintf("http://example.com/metadata?uuid=%s", mockUUID.String()),
},
} {
tt := tt

t.Run(tt.name, func(t *testing.T) {
output, err := metal.PopulateURLParameters(tt.url, func() (uuid.UUID, error) {
return mockUUID, nil
})

if tt.expectedError != "" {
assert.EqualError(t, err, tt.expectedError)
} else {
assert.Equal(t, output, tt.expectedURL)
}
})
}
}

0 comments on commit 8aed6c2

Please sign in to comment.