Skip to content

Commit

Permalink
remove the use key stripping and store the proper keys (#1603)
Browse files Browse the repository at this point in the history
  • Loading branch information
kradalby authored Nov 16, 2023
1 parent 2af71c9 commit c0fd06e
Show file tree
Hide file tree
Showing 21 changed files with 99 additions and 198 deletions.
4 changes: 2 additions & 2 deletions cmd/headscale/cli/nodes.go
Original file line number Diff line number Diff line change
Expand Up @@ -529,15 +529,15 @@ func nodesToPtables(

var machineKey key.MachinePublic
err := machineKey.UnmarshalText(
[]byte(util.MachinePublicKeyEnsurePrefix(node.MachineKey)),
[]byte(node.MachineKey),
)
if err != nil {
machineKey = key.MachinePublic{}
}

var nodeKey key.NodePublic
err = nodeKey.UnmarshalText(
[]byte(util.NodePublicKeyEnsurePrefix(node.NodeKey)),
[]byte(node.NodeKey),
)
if err != nil {
return nil, err
Expand Down
3 changes: 1 addition & 2 deletions hscontrol/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -911,10 +911,9 @@ func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) {
}

trimmedPrivateKey := strings.TrimSpace(string(privateKey))
privateKeyEnsurePrefix := util.PrivateKeyEnsurePrefix(trimmedPrivateKey)

var machineKey key.MachinePrivate
if err = machineKey.UnmarshalText([]byte(privateKeyEnsurePrefix)); err != nil {
if err = machineKey.UnmarshalText([]byte(trimmedPrivateKey)); err != nil {
log.Info().
Str("path", path).
Msg("This might be due to a legacy (headscale pre-0.12) private key. " +
Expand Down
22 changes: 11 additions & 11 deletions hscontrol/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func (h *Headscale) handleRegister(
// is that the client will hammer headscale with requests until it gets a
// successful RegisterResponse.
if registerRequest.Followup != "" {
if _, ok := h.registrationCache.Get(util.NodePublicKeyStripPrefix(registerRequest.NodeKey)); ok {
if _, ok := h.registrationCache.Get(registerRequest.NodeKey.String()); ok {
log.Debug().
Caller().
Str("node", registerRequest.Hostinfo.Hostname).
Expand Down Expand Up @@ -97,10 +97,10 @@ func (h *Headscale) handleRegister(
// We create the node and then keep it around until a callback
// happens
newNode := types.Node{
MachineKey: util.MachinePublicKeyStripPrefix(machineKey),
MachineKey: machineKey.String(),
Hostname: registerRequest.Hostinfo.Hostname,
GivenName: givenName,
NodeKey: util.NodePublicKeyStripPrefix(registerRequest.NodeKey),
NodeKey: registerRequest.NodeKey.String(),
LastSeen: &now,
Expiry: &time.Time{},
}
Expand Down Expand Up @@ -136,7 +136,7 @@ func (h *Headscale) handleRegister(
// So if we have a not valid MachineKey (but we were able to fetch the node with the NodeKeys), we update it.
var storedMachineKey key.MachinePublic
err = storedMachineKey.UnmarshalText(
[]byte(util.MachinePublicKeyEnsurePrefix(node.MachineKey)),
[]byte(node.MachineKey),
)
if err != nil || storedMachineKey.IsZero() {
if err := h.db.NodeSetMachineKey(node, machineKey); err != nil {
Expand All @@ -156,7 +156,7 @@ func (h *Headscale) handleRegister(
// - Trying to log out (sending a expiry in the past)
// - A valid, registered node, looking for /map
// - Expired node wanting to reauthenticate
if node.NodeKey == util.NodePublicKeyStripPrefix(registerRequest.NodeKey) {
if node.NodeKey == registerRequest.NodeKey.String() {
// The client sends an Expiry in the past if the client is requesting to expire the key (aka logout)
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648
if !registerRequest.Expiry.IsZero() &&
Expand All @@ -176,7 +176,7 @@ func (h *Headscale) handleRegister(
}

// The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration
if node.NodeKey == util.NodePublicKeyStripPrefix(registerRequest.OldNodeKey) &&
if node.NodeKey == registerRequest.OldNodeKey.String() &&
!node.IsExpired() {
h.handleNodeKeyRefresh(
writer,
Expand Down Expand Up @@ -207,9 +207,9 @@ func (h *Headscale) handleRegister(
// we need to make sure the NodeKey matches the one in the request
// TODO(juan): What happens when using fast user switching between two
// headscale-managed tailnets?
node.NodeKey = util.NodePublicKeyStripPrefix(registerRequest.NodeKey)
node.NodeKey = registerRequest.NodeKey.String()
h.registrationCache.Set(
util.NodePublicKeyStripPrefix(registerRequest.NodeKey),
registerRequest.NodeKey.String(),
*node,
registerCacheExpiration,
)
Expand Down Expand Up @@ -294,7 +294,7 @@ func (h *Headscale) handleAuthKey(
Str("node", registerRequest.Hostinfo.Hostname).
Msg("Authentication key was valid, proceeding to acquire IP addresses")

nodeKey := util.NodePublicKeyStripPrefix(registerRequest.NodeKey)
nodeKey := registerRequest.NodeKey.String()

// retrieve node information if it exist
// The error is not important, because if it does not
Expand Down Expand Up @@ -342,7 +342,7 @@ func (h *Headscale) handleAuthKey(
} else {
now := time.Now().UTC()

givenName, err := h.db.GenerateGivenName(util.MachinePublicKeyStripPrefix(machineKey), registerRequest.Hostinfo.Hostname)
givenName, err := h.db.GenerateGivenName(machineKey.String(), registerRequest.Hostinfo.Hostname)
if err != nil {
log.Error().
Caller().
Expand All @@ -359,7 +359,7 @@ func (h *Headscale) handleAuthKey(
Hostname: registerRequest.Hostinfo.Hostname,
GivenName: givenName,
UserID: pak.User.ID,
MachineKey: util.MachinePublicKeyStripPrefix(machineKey),
MachineKey: machineKey.String(),
RegisterMethod: util.RegisterMethodAuthKey,
Expiry: &registerRequest.Expiry,
NodeKey: nodeKey,
Expand Down
2 changes: 1 addition & 1 deletion hscontrol/auth_legacy.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func (h *Headscale) RegistrationHandler(
body, _ := io.ReadAll(req.Body)

var machineKey key.MachinePublic
err := machineKey.UnmarshalText([]byte(util.MachinePublicKeyEnsurePrefix(machineKeyStr)))
err := machineKey.UnmarshalText([]byte("mkey:" + machineKeyStr))
if err != nil {
log.Error().
Caller().
Expand Down
9 changes: 0 additions & 9 deletions hscontrol/db/addresses_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,6 @@ func (s *Suite) TestGetUsedIps(c *check.C) {

node := types.Node{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "testnode",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
Expand Down Expand Up @@ -83,9 +80,6 @@ func (s *Suite) TestGetMultiIp(c *check.C) {

node := types.Node{
ID: uint64(index),
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "testnode",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
Expand Down Expand Up @@ -173,9 +167,6 @@ func (s *Suite) TestGetAvailableIpNodeWithoutIP(c *check.C) {

node := types.Node{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "testnode",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
Expand Down
22 changes: 22 additions & 0 deletions hscontrol/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"net/netip"
"strings"
"sync"
"time"

Expand Down Expand Up @@ -252,6 +253,27 @@ func NewHeadscaleDatabase(
return nil, err
}

// Ensure all keys have correct prefixes
// https://github.com/tailscale/tailscale/blob/main/types/key/node.go#L35
nodes := types.Nodes{}
if err := dbConn.Find(&nodes).Error; err != nil {
log.Error().Err(err).Msg("Error accessing db")
}

for _, node := range nodes {
if !strings.HasPrefix(node.DiscoKey, "discokey:") {
node.DiscoKey = "discokey:" + node.DiscoKey
}

if !strings.HasPrefix(node.NodeKey, "nodekey:") {
node.NodeKey = "nodekey:" + node.NodeKey
}

if !strings.HasPrefix(node.MachineKey, "mkey:") {
node.MachineKey = "mkey:" + node.MachineKey
}
}

// TODO(kradalby): is this needed?
err = db.setValue("db_version", dbVersion)

Expand Down
16 changes: 8 additions & 8 deletions hscontrol/db/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ func (hsdb *HSDatabase) GetNodeByMachineKey(
Preload("AuthKey.User").
Preload("User").
Preload("Routes").
First(&mach, "machine_key = ?", util.MachinePublicKeyStripPrefix(machineKey)); result.Error != nil {
First(&mach, "machine_key = ?", machineKey.String()); result.Error != nil {
return nil, result.Error
}

Expand All @@ -203,7 +203,7 @@ func (hsdb *HSDatabase) GetNodeByNodeKey(
Preload("User").
Preload("Routes").
First(&node, "node_key = ?",
util.NodePublicKeyStripPrefix(nodeKey)); result.Error != nil {
nodeKey.String()); result.Error != nil {
return nil, result.Error
}

Expand All @@ -224,9 +224,9 @@ func (hsdb *HSDatabase) GetNodeByAnyKey(
Preload("User").
Preload("Routes").
First(&node, "machine_key = ? OR node_key = ? OR node_key = ?",
util.MachinePublicKeyStripPrefix(machineKey),
util.NodePublicKeyStripPrefix(nodeKey),
util.NodePublicKeyStripPrefix(oldNodeKey)); result.Error != nil {
machineKey.String(),
nodeKey.String(),
oldNodeKey.String()); result.Error != nil {
return nil, result.Error
}

Expand Down Expand Up @@ -397,7 +397,7 @@ func (hsdb *HSDatabase) RegisterNodeFromAuthCallback(
Str("expiresAt", fmt.Sprintf("%v", nodeExpiry)).
Msg("Registering node from API/CLI or auth callback")

if nodeInterface, ok := cache.Get(util.NodePublicKeyStripPrefix(nodeKey)); ok {
if nodeInterface, ok := cache.Get(nodeKey.String()); ok {
if registrationNode, ok := nodeInterface.(types.Node); ok {
user, err := hsdb.getUser(userName)
if err != nil {
Expand Down Expand Up @@ -507,7 +507,7 @@ func (hsdb *HSDatabase) NodeSetNodeKey(node *types.Node, nodeKey key.NodePublic)
defer hsdb.mu.Unlock()

if err := hsdb.db.Model(node).Updates(types.Node{
NodeKey: util.NodePublicKeyStripPrefix(nodeKey),
NodeKey: nodeKey.String(),
}).Error; err != nil {
return err
}
Expand All @@ -524,7 +524,7 @@ func (hsdb *HSDatabase) NodeSetMachineKey(
defer hsdb.mu.Unlock()

if err := hsdb.db.Model(node).Updates(types.Node{
MachineKey: util.MachinePublicKeyStripPrefix(machineKey),
MachineKey: machineKey.String(),
}).Error; err != nil {
return err
}
Expand Down
10 changes: 5 additions & 5 deletions hscontrol/db/node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ func (s *Suite) TestGetNodeByNodeKey(c *check.C) {

node := types.Node{
ID: 0,
MachineKey: util.MachinePublicKeyStripPrefix(machineKey.Public()),
NodeKey: util.NodePublicKeyStripPrefix(nodeKey.Public()),
MachineKey: machineKey.Public().String(),
NodeKey: nodeKey.Public().String(),
DiscoKey: "faa",
Hostname: "testnode",
UserID: user.ID,
Expand Down Expand Up @@ -113,8 +113,8 @@ func (s *Suite) TestGetNodeByAnyNodeKey(c *check.C) {

node := types.Node{
ID: 0,
MachineKey: util.MachinePublicKeyStripPrefix(machineKey.Public()),
NodeKey: util.NodePublicKeyStripPrefix(nodeKey.Public()),
MachineKey: machineKey.Public().String(),
NodeKey: nodeKey.Public().String(),
DiscoKey: "faa",
Hostname: "testnode",
UserID: user.ID,
Expand Down Expand Up @@ -575,7 +575,7 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) {
node := types.Node{
ID: 0,
MachineKey: "foo",
NodeKey: util.NodePublicKeyStripPrefix(nodeKey.Public()),
NodeKey: nodeKey.Public().String(),
DiscoKey: "faa",
Hostname: "test",
UserID: user.ID,
Expand Down
9 changes: 0 additions & 9 deletions hscontrol/db/preauth_keys_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,6 @@ func (*Suite) TestAlreadyUsedKey(c *check.C) {

node := types.Node{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "testest",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
Expand All @@ -101,9 +98,6 @@ func (*Suite) TestReusableBeingUsedKey(c *check.C) {

node := types.Node{
ID: 1,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "testest",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
Expand Down Expand Up @@ -138,9 +132,6 @@ func (*Suite) TestEphemeralKey(c *check.C) {
now := time.Now().Add(-time.Second * 30)
node := types.Node{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "testest",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
Expand Down
21 changes: 0 additions & 21 deletions hscontrol/db/routes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,6 @@ func (s *Suite) TestGetRoutes(c *check.C) {

node := types.Node{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "test_get_route_node",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
Expand Down Expand Up @@ -80,9 +77,6 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) {

node := types.Node{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "test_enable_route_node",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
Expand Down Expand Up @@ -154,9 +148,6 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) {
}
node1 := types.Node{
ID: 1,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "test_enable_route_node",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
Expand All @@ -179,9 +170,6 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) {
}
node2 := types.Node{
ID: 2,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "test_enable_route_node",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
Expand Down Expand Up @@ -240,9 +228,6 @@ func (s *Suite) TestSubnetFailover(c *check.C) {
now := time.Now()
node1 := types.Node{
ID: 1,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "test_enable_route_node",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
Expand Down Expand Up @@ -277,9 +262,6 @@ func (s *Suite) TestSubnetFailover(c *check.C) {
}
node2 := types.Node{
ID: 2,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "test_enable_route_node",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
Expand Down Expand Up @@ -382,9 +364,6 @@ func (s *Suite) TestDeleteRoutes(c *check.C) {
now := time.Now()
node1 := types.Node{
ID: 1,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "test_enable_route_node",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
Expand Down
Loading

0 comments on commit c0fd06e

Please sign in to comment.