diff --git a/controllers/ext_client.go b/controllers/ext_client.go index eb5308bba..1b1ec3dc7 100644 --- a/controllers/ext_client.go +++ b/controllers/ext_client.go @@ -386,6 +386,17 @@ func createExtClient(w http.ResponseWriter, r *http.Request) { logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } + + var gateway models.EgressGatewayRequest + gateway.NetID = params["network"] + gateway.Ranges = customExtClient.ExtraAllowedIPs + err := logic.ValidateEgressRange(gateway) + if err != nil { + logger.Log(0, r.Header.Get("user"), "error validating egress range: ", err.Error()) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) + return + } + node, err := logic.GetNodeByID(nodeid) if err != nil { logger.Log(0, r.Header.Get("user"), @@ -530,6 +541,17 @@ func updateExtClient(w http.ResponseWriter, r *http.Request) { return } } + + var gateway models.EgressGatewayRequest + gateway.NetID = params["network"] + gateway.Ranges = update.ExtraAllowedIPs + err = logic.ValidateEgressRange(gateway) + if err != nil { + logger.Log(0, r.Header.Get("user"), "error validating egress range: ", err.Error()) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) + return + } + var changedID = update.ClientID != oldExtClient.ClientID if !reflect.DeepEqual(update.DeniedACLs, oldExtClient.DeniedACLs) { diff --git a/controllers/node.go b/controllers/node.go index ed104b354..adc631057 100644 --- a/controllers/node.go +++ b/controllers/node.go @@ -414,7 +414,12 @@ func createEgressGateway(w http.ResponseWriter, r *http.Request) { return } gateway.NetID = params["network"] - gateway.NodeID = params["nodeid"] + err = logic.ValidateEgressRange(gateway) + if err != nil { + logger.Log(0, r.Header.Get("user"), "error validating egress range: ", err.Error()) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) + return + } node, err = logic.CreateEgressGateway(gateway) if err != nil { logger.Log(0, r.Header.Get("user"), diff --git a/go.mod b/go.mod index 83774b4b1..0dafb218f 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/lib/pq v1.10.9 github.com/mattn/go-sqlite3 v1.14.22 github.com/rqlite/gorqlite v0.0.0-20240122221808-a8a425b1a6aa + github.com/seancfoley/ipaddress-go v1.6.0 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e github.com/stretchr/testify v1.9.0 github.com/txn2/txeh v1.5.5 @@ -49,6 +50,7 @@ require ( github.com/gabriel-vasile/mimetype v1.4.3 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/rivo/uniseg v0.2.0 // indirect + github.com/seancfoley/bintree v1.3.1 // indirect github.com/spf13/pflag v1.0.5 // indirect ) diff --git a/go.sum b/go.sum index 1a869dd8c..cb78d0be7 100644 --- a/go.sum +++ b/go.sum @@ -70,6 +70,10 @@ github.com/rqlite/gorqlite v0.0.0-20240122221808-a8a425b1a6aa h1:hxMLFbj+F444JAS github.com/rqlite/gorqlite v0.0.0-20240122221808-a8a425b1a6aa/go.mod h1:xF/KoXmrRyahPfo5L7Szb5cAAUl53dMWBh9cMruGEZg= github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/seancfoley/bintree v1.3.1 h1:cqmmQK7Jm4aw8gna0bP+huu5leVOgHGSJBEpUx3EXGI= +github.com/seancfoley/bintree v1.3.1/go.mod h1:hIUabL8OFYyFVTQ6azeajbopogQc2l5C/hiXMcemWNU= +github.com/seancfoley/ipaddress-go v1.6.0 h1:9z7yGmOnV4P2ML/dlR/kCJiv5tp8iHOOetJvxJh/R5w= +github.com/seancfoley/ipaddress-go v1.6.0/go.mod h1:TQRZgv+9jdvzHmKoPGBMxyiaVmoI0rYpfEk8Q/sL/Iw= github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M= diff --git a/logic/nodes.go b/logic/nodes.go index 72f07836d..62f49557c 100644 --- a/logic/nodes.go +++ b/logic/nodes.go @@ -19,6 +19,7 @@ import ( "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/servercfg" "github.com/gravitl/netmaker/validation" + "github.com/seancfoley/ipaddress-go/ipaddr" "golang.org/x/exp/slog" ) @@ -626,6 +627,39 @@ func ValidateParams(nodeid, netid string) (models.Node, error) { return node, nil } +func ValidateEgressRange(gateway models.EgressGatewayRequest) error { + network, err := GetNetworkSettings(gateway.NetID) + if err != nil { + slog.Error("error getting network with netid", "error", gateway.NetID, err.Error) + return errors.New("error getting network with netid: " + gateway.NetID + " " + err.Error()) + } + ipv4Net := network.AddressRange + ipv6Net := network.AddressRange6 + + for _, v := range gateway.Ranges { + if ipv4Net != "" { + if ContainsCIDR(ipv4Net, v) { + slog.Error("egress range should not be the same as or contained in the netmaker network address", "error", v, ipv4Net) + return errors.New("egress range should not be the same as or contained in the netmaker network address" + v + " " + ipv4Net) + } + } + if ipv6Net != "" { + if ContainsCIDR(ipv6Net, v) { + slog.Error("egress range should not be the same as or contained in the netmaker network address", "error", v, ipv6Net) + return errors.New("egress range should not be the same as or contained in the netmaker network address" + v + " " + ipv6Net) + } + } + } + + return nil +} + +func ContainsCIDR(net1, net2 string) bool { + one, two := ipaddr.NewIPAddressString(net1), + ipaddr.NewIPAddressString(net2) + return one.Contains(two) || two.Contains(one) +} + // GetAllFailOvers - gets all the nodes that are failovers func GetAllFailOvers() ([]models.Node, error) { nodes, err := GetAllNodes() diff --git a/logic/nodes_test.go b/logic/nodes_test.go new file mode 100644 index 000000000..e3331a6fd --- /dev/null +++ b/logic/nodes_test.go @@ -0,0 +1,33 @@ +package logic + +import ( + "testing" +) + +func TestContainsCIDR(t *testing.T) { + + b := ContainsCIDR("10.1.1.2/32", "10.1.1.0/24") + if !b { + t.Errorf("expected true, returned %v", b) + } + + b = ContainsCIDR("10.1.1.2/32", "10.5.1.0/24") + if b { + t.Errorf("expected false, returned %v", b) + } + + b = ContainsCIDR("fd52:65f5:d685:d11d::1/64", "fd52:65f5:d685:d11d::/64") + if !b { + t.Errorf("expected true, returned %v", b) + } + + b1 := ContainsCIDR("fd10:10::/64", "fd10::/16") + if !b1 { + t.Errorf("expected true, returned %v", b1) + } + + b1 = ContainsCIDR("fd10:10::/64", "fd10::/64") + if b1 { + t.Errorf("expected false, returned %v", b1) + } +}