Skip to content

Commit

Permalink
network: use network context for DNS operations in readFromSRV (#5936)
Browse files Browse the repository at this point in the history
  • Loading branch information
algorandskiy authored Feb 13, 2024
1 parent caec33d commit d8c825d
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 15 deletions.
2 changes: 1 addition & 1 deletion cmd/catchpointdump/net.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ var netCmd = &cobra.Command{
if relayAddress != "" {
addrs = []string{relayAddress}
} else {
addrs, err = tools.ReadFromSRV("algobootstrap", "tcp", networkName, "", false)
addrs, err = tools.ReadFromSRV(context.Background(), "algobootstrap", "tcp", networkName, "", false)
if err != nil || len(addrs) == 0 {
reportErrorf("Unable to bootstrap records for '%s' : %v", networkName, err)
}
Expand Down
6 changes: 3 additions & 3 deletions network/wsNetwork.go
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ type WebsocketNetwork struct {
protocolVersion string

// resolveSRVRecords is a function that resolves SRV records for a given service, protocol and name
resolveSRVRecords func(service string, protocol string, name string, fallbackDNSResolverAddress string, secure bool) (addrs []string, err error)
resolveSRVRecords func(ctx context.Context, service string, protocol string, name string, fallbackDNSResolverAddress string, secure bool) (addrs []string, err error)
}

const (
Expand Down Expand Up @@ -1887,7 +1887,7 @@ func (wn *WebsocketNetwork) mergePrimarySecondaryRelayAddressSlices(network prot

func (wn *WebsocketNetwork) getDNSAddrs(dnsBootstrap string) (relaysAddresses []string, archiverAddresses []string) {
var err error
relaysAddresses, err = wn.resolveSRVRecords("algobootstrap", "tcp", dnsBootstrap, wn.config.FallbackDNSResolverAddress, wn.config.DNSSecuritySRVEnforced())
relaysAddresses, err = wn.resolveSRVRecords(wn.ctx, "algobootstrap", "tcp", dnsBootstrap, wn.config.FallbackDNSResolverAddress, wn.config.DNSSecuritySRVEnforced())
if err != nil {
// only log this warning on testnet or devnet
if wn.NetworkID == config.Devnet || wn.NetworkID == config.Testnet {
Expand All @@ -1896,7 +1896,7 @@ func (wn *WebsocketNetwork) getDNSAddrs(dnsBootstrap string) (relaysAddresses []
relaysAddresses = nil
}

archiverAddresses, err = wn.resolveSRVRecords("archive", "tcp", dnsBootstrap, wn.config.FallbackDNSResolverAddress, wn.config.DNSSecuritySRVEnforced())
archiverAddresses, err = wn.resolveSRVRecords(wn.ctx, "archive", "tcp", dnsBootstrap, wn.config.FallbackDNSResolverAddress, wn.config.DNSSecuritySRVEnforced())
if err != nil {
// only log this warning on testnet or devnet
if wn.NetworkID == config.Devnet || wn.NetworkID == config.Testnet {
Expand Down
2 changes: 1 addition & 1 deletion network/wsNetwork_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4156,7 +4156,7 @@ func TestRefreshRelayArchivePhonebookAddresses(t *testing.T) {
}

// Mock the SRV record lookup
netA.resolveSRVRecords = func(service string, protocol string, name string, fallbackDNSResolverAddress string,
netA.resolveSRVRecords = func(ctx context.Context, service string, protocol string, name string, fallbackDNSResolverAddress string,
secure bool) (addrs []string, err error) {
if service == "algobootstrap" && protocol == "tcp" && name == primarySRVBootstrap {
return primaryRelayResolvedRecords, nil
Expand Down
14 changes: 7 additions & 7 deletions tools/network/bootstrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import (
"github.com/algorand/go-algorand/logging"
)

func readFromSRV(service string, protocol string, name string, fallbackDNSResolverAddress string, secure bool) (records []*net.SRV, err error) {
func readFromSRV(ctx context.Context, service string, protocol string, name string, fallbackDNSResolverAddress string, secure bool) (records []*net.SRV, err error) {
log := logging.Base()
if name == "" {
log.Debug("no dns lookup due to empty name")
Expand All @@ -38,14 +38,14 @@ func readFromSRV(service string, protocol string, name string, fallbackDNSResolv
controller := NewResolveController(secure, fallbackDNSResolverAddress, log)

systemResolver := controller.SystemResolver()
_, records, sysLookupErr := systemResolver.LookupSRV(context.Background(), service, protocol, name)
_, records, sysLookupErr := systemResolver.LookupSRV(ctx, service, protocol, name)
if sysLookupErr != nil {
log.Infof("ReadFromBootstrap: DNS LookupSRV failed when using system resolver: %v", sysLookupErr)

var fallbackLookupErr error
if fallbackDNSResolverAddress != "" {
fallbackResolver := controller.FallbackResolver()
_, records, fallbackLookupErr = fallbackResolver.LookupSRV(context.Background(), service, protocol, name)
_, records, fallbackLookupErr = fallbackResolver.LookupSRV(ctx, service, protocol, name)
}
if fallbackLookupErr != nil {
log.Infof("ReadFromBootstrap: DNS LookupSRV failed when using fallback '%s' resolver: %v", fallbackDNSResolverAddress, fallbackLookupErr)
Expand All @@ -54,7 +54,7 @@ func readFromSRV(service string, protocol string, name string, fallbackDNSResolv
if fallbackLookupErr != nil || fallbackDNSResolverAddress == "" {
fallbackResolver := controller.DefaultResolver()
var defaultLookupErr error
_, records, defaultLookupErr = fallbackResolver.LookupSRV(context.Background(), service, protocol, name)
_, records, defaultLookupErr = fallbackResolver.LookupSRV(ctx, service, protocol, name)
if defaultLookupErr != nil {
err = fmt.Errorf("ReadFromBootstrap: DNS LookupSRV failed when using system resolver(%v), fallback resolver(%v), as well as using default resolver due to %v", sysLookupErr, fallbackLookupErr, defaultLookupErr)
return
Expand All @@ -65,8 +65,8 @@ func readFromSRV(service string, protocol string, name string, fallbackDNSResolv
}

// ReadFromSRV is a helper to collect SRV addresses for a given name
func ReadFromSRV(service string, protocol string, name string, fallbackDNSResolverAddress string, secure bool) (addrs []string, err error) {
records, err := readFromSRV(service, protocol, name, fallbackDNSResolverAddress, secure)
func ReadFromSRV(ctx context.Context, service string, protocol string, name string, fallbackDNSResolverAddress string, secure bool) (addrs []string, err error) {
records, err := readFromSRV(ctx, service, protocol, name, fallbackDNSResolverAddress, secure)
if err != nil {
return addrs, err
}
Expand All @@ -88,7 +88,7 @@ func ReadFromSRV(service string, protocol string, name string, fallbackDNSResolv

// ReadFromSRVPriority is a helper to collect SRV addresses with priorities for a given name
func ReadFromSRVPriority(service string, protocol string, name string, fallbackDNSResolverAddress string, secure bool) (prioAddrs map[uint16][]string, err error) {
records, err := readFromSRV(service, protocol, name, fallbackDNSResolverAddress, secure)
records, err := readFromSRV(context.Background(), service, protocol, name, fallbackDNSResolverAddress, secure)
if err != nil {
return prioAddrs, err
}
Expand Down
5 changes: 3 additions & 2 deletions tools/network/bootstrap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package network

import (
"context"
"testing"

"github.com/algorand/go-algorand/test/partitiontest"
Expand Down Expand Up @@ -55,10 +56,10 @@ func TestReadFromSRV(t *testing.T) {
fallback := ""
secure := true

addrs, err := ReadFromSRV("", protocol, name, fallback, secure)
addrs, err := ReadFromSRV(context.Background(), "", protocol, name, fallback, secure)
require.Error(t, err)

addrs, err = ReadFromSRV(service, protocol, name, fallback, secure)
addrs, err = ReadFromSRV(context.Background(), service, protocol, name, fallback, secure)
require.NoError(t, err)
require.GreaterOrEqual(t, len(addrs), 1)
addr := addrs[0]
Expand Down
3 changes: 2 additions & 1 deletion tools/network/telemetryURIUpdateService.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package network

import (
"context"
"net/url"
"strings"
"time"
Expand Down Expand Up @@ -132,5 +133,5 @@ func (t *telemetryURIUpdater) lookupTelemetryURL() (url *url.URL) {
}

func (t *telemetryURIUpdater) readFromSRV(protocol string, bootstrapID string) (addrs []string, err error) {
return ReadFromSRV("telemetry", protocol, bootstrapID, t.cfg.FallbackDNSResolverAddress, t.cfg.DNSSecuritySRVEnforced())
return ReadFromSRV(context.Background(), "telemetry", protocol, bootstrapID, t.cfg.FallbackDNSResolverAddress, t.cfg.DNSSecuritySRVEnforced())
}

0 comments on commit d8c825d

Please sign in to comment.