Skip to content

Commit

Permalink
fix #49, implement udp listeners
Browse files Browse the repository at this point in the history
  • Loading branch information
NicoBBCha committed Dec 25, 2023
1 parent 1308edd commit d0ae589
Show file tree
Hide file tree
Showing 7 changed files with 226 additions and 119 deletions.
159 changes: 107 additions & 52 deletions cmd/agent/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import (
)

var listenerConntrack map[int32]net.Conn
var listenerMap map[int32]net.Listener
var listenerMap map[int32]interface{}
var connTrackID int32
var listenerID int32

Expand Down Expand Up @@ -59,7 +59,7 @@ func main() {
var conn net.Conn

listenerConntrack = make(map[int32]net.Conn)
listenerMap = make(map[int32]net.Listener)
listenerMap = make(map[int32]interface{})

for {
var err error
Expand Down Expand Up @@ -146,6 +146,25 @@ func (s *Listener) Close() error {
return s.Listener.Close()
}

// UDPListener is the base class implementing UDP listeners for Ligolo
type UDPListener struct {
*net.UDPConn
}

// NewUDPListener register a new UDP listener
func NewUDPListener(network string, addr string) (UDPListener, error) {
udpaddr, err := net.ResolveUDPAddr(network, addr)
if err != nil {
return UDPListener{}, nil
}

udplis, err := net.ListenUDP(network, udpaddr)
if err != nil {
return UDPListener{}, err
}
return UDPListener{udplis}, err
}

func handleConn(conn net.Conn) {
decoder := protocol.NewDecoder(conn)
if err := decoder.Decode(); err != nil {
Expand Down Expand Up @@ -251,7 +270,12 @@ func handleConn(conn net.Conn) {

var err error
if lis, ok := listenerMap[closeRequest.ListenerID]; ok {
err = lis.Close()
if l, ok := lis.(net.Listener); ok {
l.Close()
}
if l, ok := lis.(*net.UDPConn); ok {
l.Close()
}
} else {
err = errors.New("invalid listener id")
}
Expand All @@ -276,72 +300,103 @@ func handleConn(conn net.Conn) {
connTrackChan := make(chan int32)
stopChan := make(chan error)

listener, err := NewListener(listenRequest.Network, listenRequest.Address)
if err != nil {
if listenRequest.Network == "tcp" {
listener, err := NewListener(listenRequest.Network, listenRequest.Address)
if err != nil {
listenerResponse := protocol.ListenerResponsePacket{
ListenerID: 0,
Err: true,
ErrString: err.Error(),
}
if err := encoder.Encode(protocol.Envelope{
Type: protocol.MessageListenerResponse,
Payload: listenerResponse,
}); err != nil {
logrus.Error(err)
}
return
}
listenerMap[listenerID] = listener.Listener
listenerResponse := protocol.ListenerResponsePacket{
ListenerID: 0,
Err: true,
ErrString: err.Error(),
ListenerID: listenerID,
Err: false,
ErrString: "",
}
if err := encoder.Encode(protocol.Envelope{
Type: protocol.MessageListenerResponse,
Payload: listenerResponse,
}); err != nil {
logrus.Error(err)
}
return
}

listenerResponse := protocol.ListenerResponsePacket{
ListenerID: listenerID,
Err: false,
ErrString: "",
}
listenerMap[listenerID] = listener.Listener
listenerID++

if err := encoder.Encode(protocol.Envelope{
Type: protocol.MessageListenerResponse,
Payload: listenerResponse,
}); err != nil {
logrus.Error(err)
}

go func() {
if err := listener.ListenAndServe(connTrackChan); err != nil {
stopChan <- err
}
}()
defer listener.Close()

for {
var bindResponse protocol.ListenerBindReponse
select {
case err := <-stopChan:
logrus.Error(err)
bindResponse = protocol.ListenerBindReponse{
SockID: 0,
Err: true,
ErrString: err.Error(),
go func() {
if err := listener.ListenAndServe(connTrackChan); err != nil {
stopChan <- err
}
}()
defer listener.Close()

} else if listenRequest.Network == "udp" {
udplistener, err := NewUDPListener(listenRequest.Network, listenRequest.Address)
if err != nil {
listenerResponse := protocol.ListenerResponsePacket{
ListenerID: 0,
Err: true,
ErrString: err.Error(),
}
case connTrackID := <-connTrackChan:
bindResponse = protocol.ListenerBindReponse{
SockID: connTrackID,
Err: false,
if err := encoder.Encode(protocol.Envelope{
Type: protocol.MessageListenerResponse,
Payload: listenerResponse,
}); err != nil {
logrus.Error(err)
}
return
}
listenerMap[listenerID] = udplistener.UDPConn
listenerResponse := protocol.ListenerResponsePacket{
ListenerID: listenerID,
Err: false,
ErrString: "",
}

if err := encoder.Encode(protocol.Envelope{
Type: protocol.MessageListenerBindResponse,
Payload: bindResponse,
Type: protocol.MessageListenerResponse,
Payload: listenerResponse,
}); err != nil {
logrus.Error(err)
}
go relay.StartRelay(conn, udplistener)
}

if bindResponse.Err {
break
}
listenerID++
if listenRequest.Network == "tcp" {
for {
var bindResponse protocol.ListenerBindReponse
select {
case err := <-stopChan:
logrus.Error(err)
bindResponse = protocol.ListenerBindReponse{
SockID: 0,
Err: true,
ErrString: err.Error(),
}
case connTrackID := <-connTrackChan:
bindResponse = protocol.ListenerBindReponse{
SockID: connTrackID,
Err: false,
}
}

if err := encoder.Encode(protocol.Envelope{
Type: protocol.MessageListenerBindResponse,
Payload: bindResponse,
}); err != nil {
logrus.Error(err)
}

if bindResponse.Err {
break
}

}
}
case protocol.MessageListenerSockRequest:
sockRequest := e.(protocol.ListenerSockRequestPacket)
Expand Down
139 changes: 84 additions & 55 deletions cmd/proxy/app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"strconv"
"strings"
"sync"
"time"
)

var AgentList map[int]proxy.LigoloAgent
Expand Down Expand Up @@ -228,11 +229,11 @@ func Run(stackSettings netstack.StackSettings) {
t := table.NewWriter()
t.SetStyle(table.StyleLight)
t.SetTitle("Active listeners")
t.AppendHeader(table.Row{"#", "Agent", "Agent listener address", "Proxy redirect address"})
t.AppendHeader(table.Row{"#", "Agent", "Network", "Agent listener address", "Proxy redirect address"})

ListenerListMutex.Lock()
for id, listener := range ListenerList {
t.AppendRow(table.Row{id, listener.Agent.Name, listener.ListenerAddr, listener.RedirectAddr})
t.AppendRow(table.Row{id, listener.Agent.String(), listener.Network, listener.ListenerAddr, listener.RedirectAddr})
}
ListenerListMutex.Unlock()
c.App.Println(t.Render())
Expand Down Expand Up @@ -379,79 +380,107 @@ func Run(stackSettings netstack.StackSettings) {
ListenerListMutex.Lock()
ListenerList[proxy.ListenerCounter] = listener
ListenerListMutex.Unlock()
currentListener := proxy.ListenerCounter
proxy.ListenerCounter++

go func() {
for {
// Wait for BindResponses
if err := protocolDecoder.Decode(); err != nil {
if err == io.EOF {
// Listener closed.
if netProto == "udp" {

// relay connections
go func() {
for {
// Check if deleted
if _, ok := ListenerList[currentListener]; !ok {
return
}
logrus.Error(err)
return
}

// We received a new BindResponse!
response := protocolDecoder.Envelope.Payload.(protocol.ListenerBindReponse)

if err := response.Err; err != false {
logrus.Error(response.ErrString)
return
}

logrus.Debugf("New socket opened : %d", response.SockID)

// relay connection
go func(sockID int32) {

forwarderSession, err := CurrentAgent.Session.Open()
// Dial the "to" target
lconn, err := net.Dial(netProto, c.Flags.String("to"))
if err != nil {
logrus.Error(err)
return
}

protocolEncoder := protocol.NewEncoder(forwarderSession)
protocolDecoder := protocol.NewDecoder(forwarderSession)

// Request socket access
socketRequestPacket := protocol.ListenerSockRequestPacket{SockID: sockID}
if err := protocolEncoder.Encode(protocol.Envelope{
Type: protocol.MessageListenerSockRequest,
Payload: socketRequestPacket,
}); err != nil {
logrus.Error(err)
return
// Relay conn
err = relay.StartPacketRelay(lconn, yamuxConnectionSession)
if err != nil {
logrus.WithFields(logrus.Fields{"listener": ListenerList[currentListener].String(), "error": err}).Error("Failed to relay UDP connection. Make sure that you are 'to' host is listening! Retrying...")
}
time.Sleep(2 * time.Second)
}
}()
}

if netProto == "tcp" {
go func() {
for {
// Wait for BindResponses
if err := protocolDecoder.Decode(); err != nil {
if err == io.EOF {
// Listener closed.
return
}
logrus.Error(err)
return
}

response := protocolDecoder.Envelope.Payload
if err := response.(protocol.ListenerSockResponsePacket).Err; err != false {
logrus.Error(response.(protocol.ListenerSockResponsePacket).ErrString)
return
}
// Got socket access!

logrus.Debug("Listener relay established!")
// We received a new BindResponse!
response := protocolDecoder.Envelope.Payload.(protocol.ListenerBindReponse)

// Dial the "to" target
lconn, err := net.Dial(netProto, c.Flags.String("to"))
if err != nil {
logrus.Error(err)
if err := response.Err; err != false {
logrus.Error(response.ErrString)
return
}

// relay connections
relay.StartRelay(lconn, forwarderSession)
}(response.SockID)
logrus.Debugf("New socket opened : %d", response.SockID)

// relay connection
go func(sockID int32) {

forwarderSession, err := CurrentAgent.Session.Open()
if err != nil {
logrus.Error(err)
return
}

protocolEncoder := protocol.NewEncoder(forwarderSession)
protocolDecoder := protocol.NewDecoder(forwarderSession)

// Request socket access
socketRequestPacket := protocol.ListenerSockRequestPacket{SockID: sockID}
if err := protocolEncoder.Encode(protocol.Envelope{
Type: protocol.MessageListenerSockRequest,
Payload: socketRequestPacket,
}); err != nil {
logrus.Error(err)
return
}
if err := protocolDecoder.Decode(); err != nil {
logrus.Error(err)
return
}

response := protocolDecoder.Envelope.Payload
if err := response.(protocol.ListenerSockResponsePacket).Err; err != false {
logrus.Error(response.(protocol.ListenerSockResponsePacket).ErrString)
return
}
// Got socket access!

logrus.Debug("Listener relay established!")

// Dial the "to" target
lconn, err := net.Dial(netProto, c.Flags.String("to"))
if err != nil {
logrus.Error(err)
return
}

// relay connections
relay.StartRelay(lconn, forwarderSession)
}(response.SockID)

}
}

}()
}()
}

return nil
},
Expand Down
Binary file modified doc/logo.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed doc/tnplogo.png
Binary file not shown.
Loading

0 comments on commit d0ae589

Please sign in to comment.