Skip to content

Commit

Permalink
Fix/overlapping tcp segments (#1898)
Browse files Browse the repository at this point in the history
* TCP table driven unit test

* Fix tcp segments overlaps

* Simplify overlap handling
  • Loading branch information
Steffen Siering authored and tsg committed Jun 27, 2016
1 parent f2c36ee commit 391e840
Show file tree
Hide file tree
Showing 3 changed files with 222 additions and 57 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ https://github.com/elastic/beats/compare/v5.0.0-alpha3...master[Check the HEAD d
*Packetbeat*
- Add missing nil-check to memcached GapInStream handler. {issue}1162[1162]
- Fix NFSv4 Operation returning the first found first-class operation available in compound requests. {pull}1821[1821]
- Fix TCP overlapping segments not being handled correctly. {pull}1898[1898]

*Topbeat*

Expand Down
89 changes: 66 additions & 23 deletions packetbeat/protos/tcp/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ type Processor interface {
Process(flow *flows.FlowID, hdr *layers.TCP, pkt *protos.Packet)
}

type seqCompare int

const (
seqLT seqCompare = -1
seqEq seqCompare = 0
seqGT seqCompare = 1
)

var (
debugf = logp.MakeDebug("tcp")
isDebug = false
Expand Down Expand Up @@ -119,55 +127,78 @@ func (tcp *Tcp) Process(id *flows.FlowID, tcphdr *layers.TCP, pkt *protos.Packet
// protocol modules.
defer logp.Recover("Process tcp exception")

debugf("tcp flow id: %p", id)

stream, created := tcp.getStream(pkt)
if stream.conn == nil {
return
}

conn := stream.conn
if id != nil {
id.AddConnectionID(uint64(stream.conn.id))
id.AddConnectionID(uint64(conn.id))
}

if isDebug {
debugf("tcp flow id: %p", id)
}

if len(pkt.Payload) == 0 && !tcphdr.FIN {
// return early if packet is not interesting. Still need to find/create
// stream first in order to update the TCP stream timer
return
}
conn := stream.conn

tcp_start_seq := tcphdr.Seq
tcp_seq := tcp_start_seq + uint32(len(pkt.Payload))
tcpStartSeq := tcphdr.Seq
tcpSeq := tcpStartSeq + uint32(len(pkt.Payload))
lastSeq := conn.lastSeq[stream.dir]
if isDebug {
debugf("pkt.start_seq=%v pkt.last_seq=%v stream.last_seq=%v (len=%d)",
tcp_start_seq, tcp_seq, lastSeq, len(pkt.Payload))
tcpStartSeq, tcpSeq, lastSeq, len(pkt.Payload))
}

if len(pkt.Payload) > 0 && lastSeq != 0 {
if tcpSeqBeforeEq(tcp_seq, lastSeq) {
if tcpSeqBeforeEq(tcpSeq, lastSeq) {
if isDebug {
debugf("Ignoring retransmitted segment. pkt.seq=%v len=%v stream.seq=%v",
tcphdr.Seq, len(pkt.Payload), lastSeq)
}
return
}

if tcpSeqBefore(lastSeq, tcp_start_seq) {
if !created {
gap := int(tcp_start_seq - lastSeq)
logp.Warn("Gap in tcp stream. last_seq: %d, seq: %d, gap: %d", lastSeq, tcp_start_seq, gap)
drop := stream.gapInStream(gap)
if drop {
if isDebug {
debugf("Dropping connection state because of gap")
}

// drop application layer connection state and
// update stream_id for app layer analysers using stream_id for lookups
conn.id = tcp.getId()
conn.data = nil
switch tcpSeqCompare(lastSeq, tcpStartSeq) {
case seqLT: // lastSeq < tcpStartSeq => Gap in tcp stream detected
if created {
break
}

gap := int(tcpStartSeq - lastSeq)
logp.Warn("Gap in tcp stream. last_seq: %d, seq: %d, gap: %d", lastSeq, tcpStartSeq, gap)
drop := stream.gapInStream(gap)
if drop {
if isDebug {
debugf("Dropping connection state because of gap")
}

// drop application layer connection state and
// update stream_id for app layer analysers using stream_id for lookups
conn.id = tcp.getId()
conn.data = nil
}

case seqGT:
// lastSeq > tcpStartSeq => overlapping TCP segment detected. shrink packet
delta := lastSeq - tcpStartSeq

if isDebug {
debugf("Overlapping tcp segment. last_seq %d, seq: %d, delta: %d",
lastSeq, tcpStartSeq, delta)
}

pkt.Payload = pkt.Payload[delta:]
tcphdr.Seq += delta
}
}

conn.lastSeq[stream.dir] = tcp_seq
conn.lastSeq[stream.dir] = tcpSeq
stream.addPacket(pkt, tcphdr)
}

Expand Down Expand Up @@ -209,6 +240,18 @@ func (tcp *Tcp) getStream(pkt *protos.Packet) (stream TcpStream, created bool) {
return TcpStream{conn: conn, dir: TcpDirectionOriginal}, true
}

func tcpSeqCompare(seq1, seq2 uint32) seqCompare {
i := int32(seq1 - seq2)
switch {
case i == 0:
return seqEq
case i < 0:
return seqLT
default:
return seqGT
}
}

func tcpSeqBefore(seq1 uint32, seq2 uint32) bool {
return int32(seq1-seq2) < 0
}
Expand Down
189 changes: 155 additions & 34 deletions packetbeat/protos/tcp/tcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
package tcp

import (
"fmt"
"math/rand"
"net"
"testing"
Expand Down Expand Up @@ -186,45 +185,128 @@ func (p protocols) GetAllTcp() map[protos.Protocol]protos.TcpPlugin { retur
func (p protocols) GetAllUdp() map[protos.Protocol]protos.UdpPlugin { return nil }
func (p protocols) Register(proto protos.Protocol, plugin protos.Plugin) { return }

func TestGapInStreamShouldDropState(t *testing.T) {
gap := 0
var state []byte

data1 := []byte{1, 2, 3, 4}
data2 := []byte{5, 6, 7, 8}

tp := &TestProtocol{Ports: []int{ServerPort}}
tp.gap = func(t *common.TcpTuple, d uint8, n int, p protos.ProtocolData) (protos.ProtocolData, bool) {
fmt.Printf("lost: %v\n", n)
gap += n
return p, true // drop state
}
tp.parse = func(p *protos.Packet, t *common.TcpTuple, d uint8, priv protos.ProtocolData) protos.ProtocolData {
if priv == nil {
state = nil
}
state = append(state, p.Payload...)
return state
func TestTCSeqPayload(t *testing.T) {
type segment struct {
seq uint32
payload []byte
}

p := protocols{}
p.tcp = map[protos.Protocol]protos.TcpPlugin{
httpProtocol: tp,
tests := []struct {
name string
segments []segment
expectedGaps int
expectedState []byte
}{
{"No overlap",
[]segment{
{1, []byte{1, 2, 3, 4, 5}},
{6, []byte{6, 7, 8, 9, 10}},
},
0,
[]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
},
{"Gap drop state",
[]segment{
{1, []byte{1, 2, 3, 4}},
{15, []byte{5, 6, 7, 8}},
},
10,
[]byte{5, 6, 7, 8},
},
{"ACK same sequence number",
[]segment{
{1, []byte{1, 2}},
{3, nil},
{3, []byte{3, 4}},
{5, []byte{5, 6}},
},
0,
[]byte{1, 2, 3, 4, 5, 6},
},
{"ACK same sequence number 2",
[]segment{
{1, nil},
{2, nil},
{2, []byte{1, 2}},
{4, nil},
{4, []byte{3, 4}},
{6, []byte{5, 6}},
{8, []byte{7, 8}},
{10, nil},
},
0,
[]byte{1, 2, 3, 4, 5, 6, 7, 8},
},
{"Overlap, first segment bigger",
[]segment{
{1, []byte{1, 2}},
{3, []byte{3, 4}},
{3, []byte{3}},
{5, []byte{5, 6}},
},
0,
[]byte{1, 2, 3, 4, 5, 6},
},
{"Overlap, second segment bigger",
[]segment{
{1, []byte{1, 2}},
{3, []byte{3}},
{3, []byte{3, 4}},
{5, []byte{5, 6}},
},
0,
[]byte{1, 2, 3, 4, 5, 6},
},
{"Overlap, covered",
[]segment{
{1, []byte{1, 2, 3, 4}},
{2, []byte{2, 3}},
{5, []byte{5, 6}},
},
0,
[]byte{1, 2, 3, 4, 5, 6},
},
}
tcp, _ := NewTcp(p)

addr := common.NewIpPortTuple(4,
net.ParseIP(ServerIp), ServerPort,
net.ParseIP(ClientIp), uint16(rand.Intn(65535)))
for i, test := range tests {
t.Logf("Test (%v): %v", i, test.name)

gap := 0
var state []byte
tcp, err := NewTcp(protocols{
tcp: map[protos.Protocol]protos.TcpPlugin{
httpProtocol: &TestProtocol{
Ports: []int{ServerPort},
gap: makeCountGaps(nil, &gap),
parse: makeCollectPayload(&state, true),
},
},
})
if err != nil {
t.Fatal(err)
}

hdr := &layers.TCP{}
tcp.Process(nil, hdr, &protos.Packet{Ts: time.Now(), Tuple: addr, Payload: data1})
hdr.Seq += uint32(len(data1) + 10)
tcp.Process(nil, hdr, &protos.Packet{Ts: time.Now(), Tuple: addr, Payload: data2})
addr := common.NewIpPortTuple(4,
net.ParseIP(ServerIp), ServerPort,
net.ParseIP(ClientIp), uint16(rand.Intn(65535)))

for _, segment := range test.segments {
hdr := &layers.TCP{Seq: segment.seq}
pkt := &protos.Packet{
Ts: time.Now(),
Tuple: addr,
Payload: segment.payload,
}
tcp.Process(nil, hdr, pkt)
}

// validate
assert.Equal(t, 10, gap)
assert.Equal(t, data2, state)
assert.Equal(t, test.expectedGaps, gap)
if len(test.expectedState) != len(state) {
assert.Equal(t, len(test.expectedState), len(state))
continue
}
assert.Equal(t, test.expectedState, state)
}
}

// Benchmark that runs with parallelism to help find concurrency related
Expand All @@ -251,3 +333,42 @@ func BenchmarkParallelProcess(b *testing.B) {
}
})
}

func makeCountGaps(
counter *int,
bytes *int,
) func(*common.TcpTuple, uint8, int, protos.ProtocolData) (protos.ProtocolData, bool) {
return func(
t *common.TcpTuple,
d uint8,
n int,
p protos.ProtocolData,
) (protos.ProtocolData, bool) {
if counter != nil {
(*counter)++
}
if bytes != nil {
*bytes += n
}

return p, true // drop state
}
}

func makeCollectPayload(
state *[]byte,
resetOnNil bool,
) func(*protos.Packet, *common.TcpTuple, uint8, protos.ProtocolData) protos.ProtocolData {
return func(
p *protos.Packet,
t *common.TcpTuple,
d uint8,
priv protos.ProtocolData,
) protos.ProtocolData {
if resetOnNil && priv == nil {
(*state) = nil
}
*state = append(*state, p.Payload...)
return *state
}
}

0 comments on commit 391e840

Please sign in to comment.