Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into upd-dnsproxy
Browse files Browse the repository at this point in the history
  • Loading branch information
Mizzick committed Jul 10, 2024
2 parents e2de50b + e269260 commit 3de6224
Show file tree
Hide file tree
Showing 14 changed files with 641 additions and 229 deletions.
10 changes: 9 additions & 1 deletion internal/dhcpsvc/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package dhcpsvc
import (
"fmt"
"log/slog"
"os"
"time"

"github.com/AdguardTeam/golibs/errors"
Expand All @@ -23,7 +24,8 @@ type Config struct {
// clients' hostnames.
LocalDomainName string

// TODO(e.burkov): Add DB path.
// DBFilePath is the path to the database file containing the DHCP leases.
DBFilePath string

// ICMPTimeout is the timeout for checking another DHCP server's presence.
ICMPTimeout time.Duration
Expand Down Expand Up @@ -64,6 +66,12 @@ func (conf *Config) Validate() (err error) {
errs = append(errs, err)
}

// This is a best-effort check for the file accessibility. The file will be
// checked again when it is opened later.
if _, err = os.Stat(conf.DBFilePath); err != nil && !errors.Is(err, os.ErrNotExist) {
errs = append(errs, fmt.Errorf("db file path %q: %w", conf.DBFilePath, err))
}

if len(conf.Interfaces) == 0 {
errs = append(errs, errNoInterfaces)

Expand Down
9 changes: 9 additions & 0 deletions internal/dhcpsvc/config_test.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
package dhcpsvc_test

import (
"path/filepath"
"testing"

"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
"github.com/AdguardTeam/golibs/testutil"
)

func TestConfig_Validate(t *testing.T) {
leasesPath := filepath.Join(t.TempDir(), "leases.json")

testCases := []struct {
name string
conf *dhcpsvc.Config
Expand All @@ -25,13 +28,15 @@ func TestConfig_Validate(t *testing.T) {
conf: &dhcpsvc.Config{
Enabled: true,
Interfaces: testInterfaceConf,
DBFilePath: leasesPath,
},
wantErrMsg: `bad domain name "": domain name is empty`,
}, {
conf: &dhcpsvc.Config{
Enabled: true,
LocalDomainName: testLocalTLD,
Interfaces: nil,
DBFilePath: leasesPath,
},
name: "no_interfaces",
wantErrMsg: "no interfaces specified",
Expand All @@ -40,6 +45,7 @@ func TestConfig_Validate(t *testing.T) {
Enabled: true,
LocalDomainName: testLocalTLD,
Interfaces: nil,
DBFilePath: leasesPath,
},
name: "no_interfaces",
wantErrMsg: "no interfaces specified",
Expand All @@ -50,6 +56,7 @@ func TestConfig_Validate(t *testing.T) {
Interfaces: map[string]*dhcpsvc.InterfaceConfig{
"eth0": nil,
},
DBFilePath: leasesPath,
},
name: "nil_interface",
wantErrMsg: `interface "eth0": config is nil`,
Expand All @@ -63,6 +70,7 @@ func TestConfig_Validate(t *testing.T) {
IPv6: &dhcpsvc.IPv6Config{Enabled: false},
},
},
DBFilePath: leasesPath,
},
name: "nil_ipv4",
wantErrMsg: `interface "eth0": ipv4: config is nil`,
Expand All @@ -76,6 +84,7 @@ func TestConfig_Validate(t *testing.T) {
IPv6: nil,
},
},
DBFilePath: leasesPath,
},
name: "nil_ipv6",
wantErrMsg: `interface "eth0": ipv6: config is nil`,
Expand Down
192 changes: 192 additions & 0 deletions internal/dhcpsvc/db.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
package dhcpsvc

import (
"context"
"encoding/json"
"fmt"
"io/fs"
"net"
"net/netip"
"os"
"slices"
"strings"
"time"

"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/google/renameio/v2/maybe"
)

// dataVersion is the current version of the stored DHCP leases structure.
const dataVersion = 1

// databasePerm is the permissions for the database file.
const databasePerm fs.FileMode = 0o640

// dataLeases is the structure of the stored DHCP leases.
type dataLeases struct {
// Leases is the list containing stored DHCP leases.
Leases []*dbLease `json:"leases"`

// Version is the current version of the structure.
Version int `json:"version"`
}

// dbLease is the structure of stored lease.
type dbLease struct {
Expiry string `json:"expires"`
IP netip.Addr `json:"ip"`
Hostname string `json:"hostname"`
HWAddr string `json:"mac"`
IsStatic bool `json:"static"`
}

// compareNames returns the result of comparing the hostnames of dl and other
// lexicographically.
func (dl *dbLease) compareNames(other *dbLease) (res int) {
return strings.Compare(dl.Hostname, other.Hostname)
}

// toDBLease converts *Lease to *dbLease.
func toDBLease(l *Lease) (dl *dbLease) {
var expiryStr string
if !l.IsStatic {
// The front-end is waiting for RFC 3999 format of the time value. It
// also shouldn't got an Expiry field for static leases.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/2692.
expiryStr = l.Expiry.Format(time.RFC3339)
}

return &dbLease{
Expiry: expiryStr,
Hostname: l.Hostname,
HWAddr: l.HWAddr.String(),
IP: l.IP,
IsStatic: l.IsStatic,
}
}

// toInternal converts dl to *Lease.
func (dl *dbLease) toInternal() (l *Lease, err error) {
mac, err := net.ParseMAC(dl.HWAddr)
if err != nil {
return nil, fmt.Errorf("parsing hardware address: %w", err)
}

expiry := time.Time{}
if !dl.IsStatic {
expiry, err = time.Parse(time.RFC3339, dl.Expiry)
if err != nil {
return nil, fmt.Errorf("parsing expiry time: %w", err)
}
}

return &Lease{
Expiry: expiry,
IP: dl.IP,
Hostname: dl.Hostname,
HWAddr: mac,
IsStatic: dl.IsStatic,
}, nil
}

// dbLoad loads stored leases. It must only be called before the service has
// been started.
func (srv *DHCPServer) dbLoad(ctx context.Context) (err error) {
defer func() { err = errors.Annotate(err, "loading db: %w") }()

file, err := os.Open(srv.dbFilePath)
if err != nil {
if !errors.Is(err, os.ErrNotExist) {
return fmt.Errorf("reading db: %w", err)
}

srv.logger.DebugContext(ctx, "no db file found")

return nil
}

dl := &dataLeases{}
err = json.NewDecoder(file).Decode(dl)
if err != nil {
return fmt.Errorf("decoding db: %w", err)
}

srv.resetLeases()
srv.addDBLeases(ctx, dl.Leases)

return nil
}

// addDBLeases adds leases to the server.
func (srv *DHCPServer) addDBLeases(ctx context.Context, leases []*dbLease) {
var v4, v6 uint
for i, l := range leases {
lease, err := l.toInternal()
if err != nil {
srv.logger.WarnContext(ctx, "converting lease", "idx", i, slogutil.KeyError, err)

continue
}

iface, err := srv.ifaceForAddr(l.IP)
if err != nil {
srv.logger.WarnContext(ctx, "searching lease iface", "idx", i, slogutil.KeyError, err)

continue
}

err = srv.leases.add(lease, iface)
if err != nil {
srv.logger.WarnContext(ctx, "adding lease", "idx", i, slogutil.KeyError, err)

continue
}

if l.IP.Is4() {
v4++
} else {
v6++
}
}

// TODO(e.burkov): Group by interface.
srv.logger.InfoContext(ctx, "loaded leases", "v4", v4, "v6", v6, "total", len(leases))
}

// writeDB writes leases to the database file. It expects the
// [DHCPServer.leasesMu] to be locked.
func (srv *DHCPServer) dbStore(ctx context.Context) (err error) {
defer func() { err = errors.Annotate(err, "writing db: %w") }()

dl := &dataLeases{
// Avoid writing null into the database file if there are no leases.
Leases: make([]*dbLease, 0, srv.leases.len()),
Version: dataVersion,
}

srv.leases.rangeLeases(func(l *Lease) (cont bool) {
lease := toDBLease(l)
i, _ := slices.BinarySearchFunc(dl.Leases, lease, (*dbLease).compareNames)
dl.Leases = slices.Insert(dl.Leases, i, lease)

return true
})

buf, err := json.Marshal(dl)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return err
}

err = maybe.WriteFile(srv.dbFilePath, buf, databasePerm)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return err
}

srv.logger.InfoContext(ctx, "stored leases", "num", len(dl.Leases), "file", srv.dbFilePath)

return nil
}
4 changes: 4 additions & 0 deletions internal/dhcpsvc/db_internal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
package dhcpsvc

// DatabasePerm is the permissions for the test database file.
const DatabasePerm = databasePerm
66 changes: 66 additions & 0 deletions internal/dhcpsvc/dhcpsvc_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package dhcpsvc_test

import (
"net"
"net/netip"
"time"

"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/stretchr/testify/require"
)

// testLocalTLD is a common local TLD for tests.
const testLocalTLD = "local"

// testTimeout is a common timeout for tests and contexts.
const testTimeout time.Duration = 10 * time.Second

// discardLog is a logger to discard test output.
var discardLog = slogutil.NewDiscardLogger()

// testInterfaceConf is a common set of interface configurations for tests.
var testInterfaceConf = map[string]*dhcpsvc.InterfaceConfig{
"eth0": {
IPv4: &dhcpsvc.IPv4Config{
Enabled: true,
GatewayIP: netip.MustParseAddr("192.168.0.1"),
SubnetMask: netip.MustParseAddr("255.255.255.0"),
RangeStart: netip.MustParseAddr("192.168.0.2"),
RangeEnd: netip.MustParseAddr("192.168.0.254"),
LeaseDuration: 1 * time.Hour,
},
IPv6: &dhcpsvc.IPv6Config{
Enabled: true,
RangeStart: netip.MustParseAddr("2001:db8::1"),
LeaseDuration: 1 * time.Hour,
RAAllowSLAAC: true,
RASLAACOnly: true,
},
},
"eth1": {
IPv4: &dhcpsvc.IPv4Config{
Enabled: true,
GatewayIP: netip.MustParseAddr("172.16.0.1"),
SubnetMask: netip.MustParseAddr("255.255.255.0"),
RangeStart: netip.MustParseAddr("172.16.0.2"),
RangeEnd: netip.MustParseAddr("172.16.0.255"),
LeaseDuration: 1 * time.Hour,
},
IPv6: &dhcpsvc.IPv6Config{
Enabled: true,
RangeStart: netip.MustParseAddr("2001:db9::1"),
LeaseDuration: 1 * time.Hour,
RAAllowSLAAC: true,
RASLAACOnly: true,
},
},
}

// mustParseMAC parses a hardware address from s and requires no errors.
func mustParseMAC(t require.TestingT, s string) (mac net.HardwareAddr) {
mac, err := net.ParseMAC(s)
require.NoError(t, err)

return mac
}
15 changes: 15 additions & 0 deletions internal/dhcpsvc/leaseindex.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,18 @@ func (idx *leaseIndex) update(l *Lease, iface *netInterface) (err error) {

return nil
}

// rangeLeases calls f for each lease in idx in an unspecified order until f
// returns false.
func (idx *leaseIndex) rangeLeases(f func(l *Lease) (cont bool)) {
for _, l := range idx.byName {
if !f(l) {
break
}
}
}

// len returns the number of leases in idx.
func (idx *leaseIndex) len() (l uint) {
return uint(len(idx.byAddr))
}
Loading

0 comments on commit 3de6224

Please sign in to comment.