Skip to content

Commit

Permalink
Make packet proccessing multi threaded
Browse files Browse the repository at this point in the history
If packet processign is not fast enough, packet capture may miss packets
Now it use 10 goroutines, which distribute work based on the ephemeral port number
  • Loading branch information
buger committed Aug 6, 2021
1 parent c9274ac commit 11d61dc
Showing 1 changed file with 57 additions and 23 deletions.
80 changes: 57 additions & 23 deletions tcp/tcp_message.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"net"
"reflect"
"sort"
"sync"
"time"
"unsafe"
)
Expand Down Expand Up @@ -64,6 +65,7 @@ type Message struct {
packets []*Packet
parser *MessageParser
feedback interface{}
Idx uint16
Stats
}

Expand Down Expand Up @@ -198,7 +200,8 @@ type HintStart func(*Packet) (IsRequest, IsOutgoing bool)
// MessageParser holds data of all tcp messages in progress(still receiving/sending packets).
// message is identified by its source port and dst port, and last 4bytes of src IP.
type MessageParser struct {
m map[uint64]*Message
m []map[uint64]*Message
mL []sync.RWMutex

messageExpire time.Duration // the maximum time to wait for the final packet, minimum is 100ms
allowIncompete bool
Expand Down Expand Up @@ -226,18 +229,24 @@ func NewMessageParser(messages chan *Message, ports []uint16, ips []net.IP, mess
parser.packets = make(chan *PcapPacket, 10000)

if messages == nil {
messages = make(chan *Message, 1000)
messages = make(chan *Message, 100)
}
parser.messages = messages

parser.m = make(map[uint64]*Message)
parser.ticker = time.NewTicker(time.Millisecond * 100)
parser.close = make(chan struct{}, 1)

parser.ports = ports
parser.ips = ips

go parser.wait()
for i := 0; i < 10; i++ {
parser.m = append(parser.m, make(map[uint64]*Message))
parser.mL = append(parser.mL, sync.RWMutex{})
}

for i := 0; i < 10; i++ {
go parser.wait(i)
}

return parser
}

Expand All @@ -249,7 +258,7 @@ func (parser *MessageParser) PacketHandler(packet *PcapPacket) {
parser.packets <- packet
}

func (parser *MessageParser) wait() {
func (parser *MessageParser) wait(index int) {
var (
now time.Time
)
Expand All @@ -258,7 +267,7 @@ func (parser *MessageParser) wait() {
case pckt := <-parser.packets:
parser.processPacket(parser.parsePacket(pckt))
case now = <-parser.ticker.C:
parser.timer(now)
parser.timer(now, index)
case <-parser.close:
parser.ticker.Stop()
// parser.Close should wait for this function to return
Expand Down Expand Up @@ -298,19 +307,28 @@ func (parser *MessageParser) processPacket(pckt *Packet) {

// Trying to build unique hash, but there is small chance of collision
// No matter if it is request or response, all packets in the same message have same
m, ok := parser.m[pckt.MessageID()]
mID := pckt.MessageID()
mIDX := pckt.SrcPort % 10

parser.mL[mIDX].Lock()
m, ok := parser.m[mIDX][mID]
if !ok {
parser.mL[mIDX].Unlock()

mIDX = pckt.DstPort % 10
parser.mL[mIDX].Lock()
m, ok = parser.m[mIDX][mID]

if !ok {
parser.mL[mIDX].Unlock()
}
}

switch {
case ok:
if m.Direction == DirUnknown {
if in, out := parser.Start(pckt); in || out {
if in {
m.Direction = DirIncoming
} else {
m.Direction = DirOutcoming
}
}
}
parser.addPacket(m, pckt)

parser.mL[mIDX].Unlock()
return
case pckt.Direction == DirUnknown && parser.Start != nil:
if in, out := parser.Start(pckt); in || out {
Expand All @@ -322,12 +340,25 @@ func (parser *MessageParser) processPacket(pckt *Packet) {
}
}

if pckt.Direction == DirIncoming {
mIDX = pckt.SrcPort % 10
} else {
mIDX = pckt.DstPort % 10
}

parser.mL[mIDX].Lock()

m = new(Message)
m.Direction = pckt.Direction
parser.m[pckt.MessageID()] = m

parser.m[mIDX][mID] = m

m.Idx = mIDX
m.Start = pckt.Timestamp
m.parser = parser
parser.addPacket(m, pckt)

parser.mL[mIDX].Unlock()
}

func (parser *MessageParser) addPacket(m *Message, pckt *Packet) bool {
Expand All @@ -354,7 +385,7 @@ func (parser *MessageParser) Read() *Message {
func (parser *MessageParser) Emit(m *Message) {
stats.Add("message_count", 1)

delete(parser.m, m.packets[0].MessageID())
delete(parser.m[m.Idx], m.packets[0].MessageID())

parser.messages <- m
}
Expand All @@ -365,13 +396,14 @@ func GetUnexportedField(field reflect.Value) interface{} {

var failMsg int

func (parser *MessageParser) timer(now time.Time) {
func (parser *MessageParser) timer(now time.Time, index int) {
packetLen = 0
parser.mL[index].Lock()

packetQueueLen.Set(int64(len(parser.packets)))
messageQueueLen.Set(int64(len(parser.m)))
messageQueueLen.Set(int64(len(parser.m[index])))

for _, m := range parser.m {
for _, m := range parser.m[index] {
if now.Sub(m.End) > parser.messageExpire {
m.TimedOut = true
stats.Add("message_timeout_count", 1)
Expand All @@ -380,9 +412,11 @@ func (parser *MessageParser) timer(now time.Time) {
parser.Emit(m)
}

delete(parser.m, m.packets[0].MessageID())
delete(parser.m[index], m.packets[0].MessageID())
}
}

parser.mL[index].Unlock()
}

func (parser *MessageParser) Close() error {
Expand Down

0 comments on commit 11d61dc

Please sign in to comment.