Skip to content

Commit

Permalink
use gorm serialiser instead of custom hooks (#2156)
Browse files Browse the repository at this point in the history
* add sqlite to debug/test image

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* test using gorm serialiser instead of custom hooks

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

---------

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
  • Loading branch information
kradalby authored Oct 2, 2024
1 parent 3964dec commit bc9e83b
Show file tree
Hide file tree
Showing 21 changed files with 243 additions and 354 deletions.
2 changes: 1 addition & 1 deletion Dockerfile.debug
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ ENV GOPATH /go
WORKDIR /go/src/headscale

RUN apt-get update \
&& apt-get install --no-install-recommends --yes less jq \
&& apt-get install --no-install-recommends --yes less jq sqlite3 \
&& rm -rf /var/lib/apt/lists/* \
&& apt-get clean
RUN mkdir -p /var/run/headscale
Expand Down
21 changes: 16 additions & 5 deletions hscontrol/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,14 @@ import (
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"gorm.io/gorm/schema"
"tailscale.com/util/set"
)

func init() {
schema.RegisterSerializer("text", TextSerialiser{})
}

var errDatabaseNotSupported = errors.New("database type not supported")

// KV is a key-value store in a psql table. For future use...
Expand All @@ -33,7 +38,8 @@ type KV struct {
}

type HSDatabase struct {
DB *gorm.DB
DB *gorm.DB
cfg *types.DatabaseConfig

baseDomain string
}
Expand Down Expand Up @@ -191,7 +197,7 @@ func NewHeadscaleDatabase(

type NodeAux struct {
ID uint64
EnabledRoutes types.IPPrefixes
EnabledRoutes []netip.Prefix `gorm:"serializer:json"`
}

nodesAux := []NodeAux{}
Expand All @@ -214,7 +220,7 @@ func NewHeadscaleDatabase(
}

err = tx.Preload("Node").
Where("node_id = ? AND prefix = ?", node.ID, types.IPPrefix(prefix)).
Where("node_id = ? AND prefix = ?", node.ID, prefix).
First(&types.Route{}).
Error
if err == nil {
Expand All @@ -229,7 +235,7 @@ func NewHeadscaleDatabase(
NodeID: node.ID,
Advertised: true,
Enabled: true,
Prefix: types.IPPrefix(prefix),
Prefix: prefix,
}
if err := tx.Create(&route).Error; err != nil {
log.Error().Err(err).Msg("Error creating route")
Expand Down Expand Up @@ -476,7 +482,8 @@ func NewHeadscaleDatabase(
}

db := HSDatabase{
DB: dbConn,
DB: dbConn,
cfg: &cfg,

baseDomain: baseDomain,
}
Expand Down Expand Up @@ -676,6 +683,10 @@ func (hsdb *HSDatabase) Close() error {
return err
}

if hsdb.cfg.Type == types.DatabaseSqlite && hsdb.cfg.Sqlite.WriteAheadLog {
db.Exec("VACUUM")
}

return db.Close()
}

Expand Down
36 changes: 28 additions & 8 deletions hscontrol/db/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/stretchr/testify/assert"
"gorm.io/gorm"
)

func TestMigrations(t *testing.T) {
ipp := func(p string) types.IPPrefix {
return types.IPPrefix(netip.MustParsePrefix(p))
ipp := func(p string) netip.Prefix {
return netip.MustParsePrefix(p)
}
r := func(id uint64, p string, a, e, i bool) types.Route {
return types.Route{
Expand Down Expand Up @@ -56,9 +57,7 @@ func TestMigrations(t *testing.T) {
r(31, "::/0", true, false, false),
r(32, "192.168.0.24/32", true, true, true),
}
if diff := cmp.Diff(want, routes, cmpopts.IgnoreFields(types.Route{}, "Model", "Node"), cmp.Comparer(func(x, y types.IPPrefix) bool {
return x == y
})); diff != "" {
if diff := cmp.Diff(want, routes, cmpopts.IgnoreFields(types.Route{}, "Model", "Node"), util.PrefixComparer); diff != "" {
t.Errorf("TestMigrations() mismatch (-want +got):\n%s", diff)
}
},
Expand Down Expand Up @@ -103,9 +102,7 @@ func TestMigrations(t *testing.T) {
r(13, "::/0", true, true, false),
r(13, "10.18.80.2/32", true, true, true),
}
if diff := cmp.Diff(want, routes, cmpopts.IgnoreFields(types.Route{}, "Model", "Node"), cmp.Comparer(func(x, y types.IPPrefix) bool {
return x == y
})); diff != "" {
if diff := cmp.Diff(want, routes, cmpopts.IgnoreFields(types.Route{}, "Model", "Node"), util.PrefixComparer); diff != "" {
t.Errorf("TestMigrations() mismatch (-want +got):\n%s", diff)
}
},
Expand Down Expand Up @@ -172,6 +169,29 @@ func TestMigrations(t *testing.T) {
}
},
},
{
dbPath: "testdata/0-23-0-to-0-24-0-no-more-special-types.sqlite",
wantFunc: func(t *testing.T, h *HSDatabase) {
nodes, err := Read(h.DB, func(rx *gorm.DB) (types.Nodes, error) {
return ListNodes(rx)
})
assert.NoError(t, err)

for _, node := range nodes {
assert.Falsef(t, node.MachineKey.IsZero(), "expected non zero machinekey")
assert.Contains(t, node.MachineKey.String(), "mkey:")
assert.Falsef(t, node.NodeKey.IsZero(), "expected non zero nodekey")
assert.Contains(t, node.NodeKey.String(), "nodekey:")
assert.Falsef(t, node.DiscoKey.IsZero(), "expected non zero discokey")
assert.Contains(t, node.DiscoKey.String(), "discokey:")
assert.NotNil(t, node.IPv4)
assert.NotNil(t, node.IPv4)
assert.Len(t, node.Endpoints, 1)
assert.NotNil(t, node.Hostinfo)
assert.NotNil(t, node.MachineKey)
}
},
},
}

for _, tt := range tests {
Expand Down
37 changes: 0 additions & 37 deletions hscontrol/db/ip_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package db

import (
"database/sql"
"fmt"
"net/netip"
"strings"
Expand Down Expand Up @@ -294,15 +293,7 @@ func TestBackfillIPAddresses(t *testing.T) {
v4 := fmt.Sprintf("100.64.0.%d", i)
v6 := fmt.Sprintf("fd7a:115c:a1e0::%d", i)
return &types.Node{
IPv4DatabaseField: sql.NullString{
Valid: true,
String: v4,
},
IPv4: nap(v4),
IPv6DatabaseField: sql.NullString{
Valid: true,
String: v6,
},
IPv6: nap(v6),
}
}
Expand Down Expand Up @@ -334,15 +325,7 @@ func TestBackfillIPAddresses(t *testing.T) {

want: types.Nodes{
&types.Node{
IPv4DatabaseField: sql.NullString{
Valid: true,
String: "100.64.0.1",
},
IPv4: nap("100.64.0.1"),
IPv6DatabaseField: sql.NullString{
Valid: true,
String: "fd7a:115c:a1e0::1",
},
IPv6: nap("fd7a:115c:a1e0::1"),
},
},
Expand All @@ -367,15 +350,7 @@ func TestBackfillIPAddresses(t *testing.T) {

want: types.Nodes{
&types.Node{
IPv4DatabaseField: sql.NullString{
Valid: true,
String: "100.64.0.1",
},
IPv4: nap("100.64.0.1"),
IPv6DatabaseField: sql.NullString{
Valid: true,
String: "fd7a:115c:a1e0::1",
},
IPv6: nap("fd7a:115c:a1e0::1"),
},
},
Expand All @@ -400,10 +375,6 @@ func TestBackfillIPAddresses(t *testing.T) {

want: types.Nodes{
&types.Node{
IPv4DatabaseField: sql.NullString{
Valid: true,
String: "100.64.0.1",
},
IPv4: nap("100.64.0.1"),
},
},
Expand All @@ -428,10 +399,6 @@ func TestBackfillIPAddresses(t *testing.T) {

want: types.Nodes{
&types.Node{
IPv6DatabaseField: sql.NullString{
Valid: true,
String: "fd7a:115c:a1e0::1",
},
IPv6: nap("fd7a:115c:a1e0::1"),
},
},
Expand Down Expand Up @@ -477,13 +444,9 @@ func TestBackfillIPAddresses(t *testing.T) {

comps := append(util.Comparers, cmpopts.IgnoreFields(types.Node{},
"ID",
"MachineKeyDatabaseField",
"NodeKeyDatabaseField",
"DiscoKeyDatabaseField",
"User",
"UserID",
"Endpoints",
"HostinfoDatabaseField",
"Hostinfo",
"Routes",
"CreatedAt",
Expand Down
14 changes: 10 additions & 4 deletions hscontrol/db/node.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package db

import (
"encoding/json"
"errors"
"fmt"
"net/netip"
Expand Down Expand Up @@ -207,21 +208,26 @@ func SetTags(
) error {
if len(tags) == 0 {
// if no tags are provided, we remove all forced tags
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", types.StringList{}).Error; err != nil {
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", "[]").Error; err != nil {
return fmt.Errorf("failed to remove tags for node in the database: %w", err)
}

return nil
}

var newTags types.StringList
var newTags []string
for _, tag := range tags {
if !slices.Contains(newTags, tag) {
newTags = append(newTags, tag)
}
}

if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", newTags).Error; err != nil {
b, err := json.Marshal(newTags)
if err != nil {
return err
}

if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", string(b)).Error; err != nil {
return fmt.Errorf("failed to update tags for node in the database: %w", err)
}

Expand Down Expand Up @@ -569,7 +575,7 @@ func enableRoutes(tx *gorm.DB,
for _, prefix := range newRoutes {
route := types.Route{}
err := tx.Preload("Node").
Where("node_id = ? AND prefix = ?", node.ID, types.IPPrefix(prefix)).
Where("node_id = ? AND prefix = ?", node.ID, prefix.String()).
First(&route).Error
if err == nil {
route.Enabled = true
Expand Down
16 changes: 10 additions & 6 deletions hscontrol/db/node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
nodeKey := key.NewNode()
machineKey := key.NewMachine()

v4 := netip.MustParseAddr(fmt.Sprintf("100.64.0.%v", strconv.Itoa(index+1)))
v4 := netip.MustParseAddr(fmt.Sprintf("100.64.0.%d", index+1))
node := types.Node{
ID: types.NodeID(index),
MachineKey: machineKey.Public(),
Expand Down Expand Up @@ -239,6 +239,8 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {

adminNode, err := db.GetNodeByID(1)
c.Logf("Node(%v), user: %v", adminNode.Hostname, adminNode.User)
c.Assert(adminNode.IPv4, check.NotNil)
c.Assert(adminNode.IPv6, check.IsNil)
c.Assert(err, check.IsNil)

testNode, err := db.GetNodeByID(2)
Expand All @@ -247,9 +249,11 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {

adminPeers, err := db.ListPeers(adminNode.ID)
c.Assert(err, check.IsNil)
c.Assert(len(adminPeers), check.Equals, 9)

testPeers, err := db.ListPeers(testNode.ID)
c.Assert(err, check.IsNil)
c.Assert(len(testPeers), check.Equals, 9)

adminRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, adminNode, adminPeers)
c.Assert(err, check.IsNil)
Expand All @@ -259,14 +263,14 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {

peersOfAdminNode := policy.FilterNodesByACL(adminNode, adminPeers, adminRules)
peersOfTestNode := policy.FilterNodesByACL(testNode, testPeers, testRules)

c.Log(peersOfAdminNode)
c.Log(peersOfTestNode)

c.Assert(len(peersOfTestNode), check.Equals, 9)
c.Assert(peersOfTestNode[0].Hostname, check.Equals, "testnode1")
c.Assert(peersOfTestNode[1].Hostname, check.Equals, "testnode3")
c.Assert(peersOfTestNode[3].Hostname, check.Equals, "testnode5")

c.Log(peersOfAdminNode)
c.Assert(len(peersOfAdminNode), check.Equals, 9)
c.Assert(peersOfAdminNode[0].Hostname, check.Equals, "testnode2")
c.Assert(peersOfAdminNode[2].Hostname, check.Equals, "testnode4")
Expand Down Expand Up @@ -346,7 +350,7 @@ func (s *Suite) TestSetTags(c *check.C) {
c.Assert(err, check.IsNil)
node, err = db.getNode("test", "testnode")
c.Assert(err, check.IsNil)
c.Assert(node.ForcedTags, check.DeepEquals, types.StringList(sTags))
c.Assert(node.ForcedTags, check.DeepEquals, sTags)

// assign duplicate tags, expect no errors but no doubles in DB
eTags := []string{"tag:bar", "tag:test", "tag:unknown", "tag:test"}
Expand All @@ -357,15 +361,15 @@ func (s *Suite) TestSetTags(c *check.C) {
c.Assert(
node.ForcedTags,
check.DeepEquals,
types.StringList([]string{"tag:bar", "tag:test", "tag:unknown"}),
[]string{"tag:bar", "tag:test", "tag:unknown"},
)

// test removing tags
err = db.SetTags(node.ID, []string{})
c.Assert(err, check.IsNil)
node, err = db.getNode("test", "testnode")
c.Assert(err, check.IsNil)
c.Assert(node.ForcedTags, check.DeepEquals, types.StringList([]string{}))
c.Assert(node.ForcedTags, check.DeepEquals, []string{})
}

func TestHeadscale_generateGivenName(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion hscontrol/db/preauth_keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func CreatePreAuthKey(
Ephemeral: ephemeral,
CreatedAt: &now,
Expiration: expiration,
Tags: types.StringList(aclTags),
Tags: aclTags,
}

if err := tx.Save(&key).Error; err != nil {
Expand Down
Loading

0 comments on commit bc9e83b

Please sign in to comment.