Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support multiple allowIPs for a remote connection from one of KMS servers #707

Merged
merged 2 commits into from
Aug 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/ostracon/commands/show_validator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func TestShowValidatorWithKMS(t *testing.T) {
}
privval.WithMockKMS(t, dir, chainID, func(addr string, privKey crypto.PrivKey) {
config.PrivValidatorListenAddr = addr
config.PrivValidatorRemoteAddr = addr[:strings.Index(addr, ":")]
config.PrivValidatorRemoteAddresses = []string{addr[:strings.Index(addr, ":")]}
require.NoFileExists(t, config.PrivValidatorKeyFile())
output, err := captureStdout(func() {
err := showValidator(ShowValidatorCmd, nil, config)
Expand Down
10 changes: 6 additions & 4 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,10 +245,12 @@ type BaseConfig struct { //nolint: maligned
// example) tcp://0.0.0.0:26659
PrivValidatorListenAddr string `mapstructure:"priv_validator_laddr"`

// Validator's remote address(without port) to allow a connection
// ostracon only allow a connection from this address
// example) 10.0.0.7
PrivValidatorRemoteAddr string `mapstructure:"priv_validator_raddr"`
// Validator's remote addresses to allow a connection
// Comma separated list of addresses to allow
// ostracon only allows a connection from these addresses separated by a comma
// example) 127.0.0.1
// example) 127.0.0.1,192.168.1.2
PrivValidatorRemoteAddresses []string `mapstructure:"priv_validator_raddrs"`

// A JSON file containing the private key to use for p2p authenticated encryption
NodeKey string `mapstructure:"node_key_file"`
Expand Down
6 changes: 4 additions & 2 deletions config/toml.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,11 @@ priv_validator_state_file = "{{ js .BaseConfig.PrivValidatorState }}"
priv_validator_laddr = "{{ .BaseConfig.PrivValidatorListenAddr }}"

# Validator's remote address to allow a connection
# ostracon only allow a connection from this address
# Comma separated list of addresses to allow
# ostracon only allows a connection from these addresses separated by a comma
# example) 127.0.0.1
priv_validator_raddr = "127.0.0.1"
# example) 127.0.0.1,192.168.1.2
priv_validator_raddrs = "127.0.0.1"

# Path to the JSON file containing the private key to use for node authentication in the p2p protocol
node_key_file = "{{ js .BaseConfig.NodeKey }}"
Expand Down
2 changes: 1 addition & 1 deletion node/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -1524,7 +1524,7 @@ func saveGenesisDoc(db dbm.DB, genDoc *types.GenesisDoc) error {
}

func CreateAndStartPrivValidatorSocketClient(config *cfg.Config, chainID string, logger log.Logger) (types.PrivValidator, error) {
pve, err := privval.NewSignerListener(logger, config.PrivValidatorListenAddr, config.PrivValidatorRemoteAddr)
pve, err := privval.NewSignerListener(logger, config.PrivValidatorListenAddr, config.PrivValidatorRemoteAddresses)
if err != nil {
return nil, fmt.Errorf("failed to start private validator: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion node/node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ func TestNodeSetPrivValTCP(t *testing.T) {
if err != nil {
return
}
config.BaseConfig.PrivValidatorRemoteAddr = addrPart
config.BaseConfig.PrivValidatorRemoteAddresses = []string{addrPart}

dialer := privval.DialTCPFn(addr, 100*time.Millisecond, ed25519.GenPrivKey())
dialerEndpoint := privval.NewSignerDialerEndpoint(
Expand Down
18 changes: 12 additions & 6 deletions privval/internal/ip_filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,17 @@ import (
"fmt"
"github.com/Finschia/ostracon/libs/log"
"net"
"strings"
)

type IpFilter struct {
allowAddr string
allowList []string
torao marked this conversation as resolved.
Show resolved Hide resolved
log log.Logger
}

func NewIpFilter(addr string, l log.Logger) *IpFilter {
func NewIpFilter(allowAddresses []string, l log.Logger) *IpFilter {
return &IpFilter{
allowAddr: addr,
allowList: allowAddresses,
log: l,
}
}
Expand All @@ -26,11 +27,11 @@ func (f *IpFilter) Filter(addr net.Addr) net.Addr {
}

func (f *IpFilter) String() string {
return f.allowAddr
return strings.Join(f.allowList, ",")
}

func (f *IpFilter) isAllowedAddr(addr net.Addr) bool {
if len(f.allowAddr) == 0 {
if len(f.allowList) == 0 {
return false
}
hostAddr, _, err := net.SplitHostPort(addr.String())
Expand All @@ -40,5 +41,10 @@ func (f *IpFilter) isAllowedAddr(addr net.Addr) bool {
}
return false
}
return f.allowAddr == hostAddr
for _, address := range f.allowList {
if address == hostAddr {
return true
}
}
return false
}
54 changes: 49 additions & 5 deletions privval/internal/ip_filter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package internal
import (
"github.com/stretchr/testify/assert"
"net"
"strings"
"testing"
)

Expand Down Expand Up @@ -71,21 +72,64 @@ func TestFilterRemoteConnectionByIP(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cut := NewIpFilter(tt.fields.allowIP, nil)
cut := NewIpFilter([]string{tt.fields.allowIP}, nil)
assert.Equalf(t, tt.fields.expected, cut.Filter(tt.fields.remoteAddr), tt.name)
})
}
}

func TestFilterRemoteConnectionByIPWithMultipleAllowIPs(t *testing.T) {
type fields struct {
allowList []string
remoteAddr net.Addr
expected net.Addr
}
tests := []struct {
name string
fields fields
}{
{
"should allow the one in the allow list",
struct {
allowList []string
remoteAddr net.Addr
expected net.Addr
}{[]string{"127.0.0.1", "192.168.1.1"}, addrStub{"192.168.1.1:45678"}, addrStub{"192.168.1.1:45678"}},
},
{
"should not allow any ip which is not in the allow list",
struct {
allowList []string
remoteAddr net.Addr
expected net.Addr
}{[]string{"127.0.0.1", "192.168.1.1"}, addrStub{"10.0.0.2:45678"}, nil},
},
{
"should works for IPv6 with one of correct ip in the allow list",
struct {
allowList []string
remoteAddr net.Addr
expected net.Addr
}{[]string{"2001:db8::1", "2001:db8::2"}, addrStub{"[2001:db8::1]:80"}, addrStub{"[2001:db8::1]:80"}},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cut := NewIpFilter(tt.fields.allowList, nil)
assert.Equalf(t, tt.fields.expected, cut.Filter(tt.fields.remoteAddr), tt.name)
})
}
}

func TestIpFilterShouldSetAllowAddress(t *testing.T) {
expected := "192.168.0.1"
expected := []string{"192.168.0.1"}

cut := NewIpFilter(expected, nil)

assert.Equal(t, expected, cut.allowAddr)
assert.Equal(t, expected, cut.allowList)
}

func TestIpFilterStringShouldReturnsIP(t *testing.T) {
expected := "127.0.0.1"
assert.Equal(t, expected, NewIpFilter(expected, nil).String())
expected := []string{"127.0.0.1", "192.168.1.10"}
assert.Equal(t, strings.Join(expected, ","), NewIpFilter(expected, nil).String())
}
7 changes: 3 additions & 4 deletions privval/signer_listener_endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,11 @@ func SignerListenerEndpointTimeoutReadWrite(timeout time.Duration) SignerListene
}

// SignerListenerEndpointAllowAddress sets the address to allow
// connections from the only allowed address
//
func SignerListenerEndpointAllowAddress(protocol string, addr string) SignerListenerEndpointOption {
// connections from the only allowed addresses
func SignerListenerEndpointAllowAddress(protocol string, allowedAddresses []string) SignerListenerEndpointOption {
return func(sl *SignerListenerEndpoint) {
if protocol == "tcp" || len(protocol) == 0 {
sl.connFilter = internal.NewIpFilter(addr, sl.Logger)
sl.connFilter = internal.NewIpFilter(allowedAddresses, sl.Logger)
return
}
sl.connFilter = internal.NewNullObject()
Expand Down
4 changes: 2 additions & 2 deletions privval/signer_listener_endpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,13 +216,13 @@ func getMockEndpoints(
}

func TestSignerListenerEndpointAllowAddressSetIpFilterForTCP(t *testing.T) {
cut := NewSignerListenerEndpoint(nil, nil, SignerListenerEndpointAllowAddress("tcp", "127.0.0.1"))
cut := NewSignerListenerEndpoint(nil, nil, SignerListenerEndpointAllowAddress("tcp", []string{"127.0.0.1"}))
_, ok := cut.connFilter.(*internal.IpFilter)
assert.True(t, ok)
}

func TestSignerListenerEndpointAllowAddressSetNullObjectFilterForUDS(t *testing.T) {
cut := NewSignerListenerEndpoint(nil, nil, SignerListenerEndpointAllowAddress("unix", "/mnt/uds/sock01"))
cut := NewSignerListenerEndpoint(nil, nil, SignerListenerEndpointAllowAddress("unix", []string{"don't care"}))
_, ok := cut.connFilter.(*internal.NullObject)
assert.True(t, ok)
}
4 changes: 2 additions & 2 deletions privval/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func IsConnTimeout(err error) bool {
}

// NewSignerListener creates a new SignerListenerEndpoint using the corresponding listen address
func NewSignerListener(logger log.Logger, listenAddr, remoteAddr string) (*SignerListenerEndpoint, error) {
func NewSignerListener(logger log.Logger, listenAddr string, remoteAddresses []string) (*SignerListenerEndpoint, error) {
var listener net.Listener

protocol, address := tmnet.ProtocolAndAddress(listenAddr)
Expand All @@ -47,7 +47,7 @@ func NewSignerListener(logger log.Logger, listenAddr, remoteAddr string) (*Signe
)
}

pve := NewSignerListenerEndpoint(logger.With("module", "privval"), listener, SignerListenerEndpointAllowAddress(protocol, remoteAddr))
pve := NewSignerListenerEndpoint(logger.With("module", "privval"), listener, SignerListenerEndpointAllowAddress(protocol, remoteAddresses))

return pve, nil
}
Expand Down