From 3fd2609ae6a14e9c08ad1f8805a3097f79d0b55d Mon Sep 17 00:00:00 2001 From: Steffen Siering Date: Tue, 19 Jul 2016 15:17:40 +0200 Subject: [PATCH] Backport fix for overlapping tcp segments (#1898) (#1917) * TCP table driven unit test * Fix tcp segments overlaps --- CHANGELOG.asciidoc | 3 + packetbeat/protos/tcp/tcp.go | 80 ++++++++++--- packetbeat/protos/tcp/tcp_test.go | 189 ++++++++++++++++++++++++------ 3 files changed, 219 insertions(+), 53 deletions(-) diff --git a/CHANGELOG.asciidoc b/CHANGELOG.asciidoc index 7c28c5da9e9..8e7ac667d27 100644 --- a/CHANGELOG.asciidoc +++ b/CHANGELOG.asciidoc @@ -28,6 +28,9 @@ https://github.com/elastic/beats/compare/v1.3.0...1.3[Check the HEAD diff] *Affecting all Beats* *Packetbeat* +- Add missing nil-check to memcached GapInStream handler. {issue}1162[1162] +- Fix NFSv4 Operation returning the first found first-class operation available in compound requests. {pull}1821[1821] +- Fix TCP overlapping segments not being handled correctly. {pull}1917[1917] *Topbeat* diff --git a/packetbeat/protos/tcp/tcp.go b/packetbeat/protos/tcp/tcp.go index 2a34989b99c..5319d2636c8 100644 --- a/packetbeat/protos/tcp/tcp.go +++ b/packetbeat/protos/tcp/tcp.go @@ -30,6 +30,14 @@ type Processor interface { Process(tcphdr *layers.TCP, pkt *protos.Packet) } +type seqCompare int + +const ( + seqLT seqCompare = -1 + seqEq seqCompare = 0 + seqGT seqCompare = 1 +) + var ( debugf = logp.MakeDebug("tcp") isDebug = false @@ -122,18 +130,25 @@ func (tcp *Tcp) Process(tcphdr *layers.TCP, pkt *protos.Packet) { if stream.conn == nil { return } + conn := stream.conn - tcp_start_seq := tcphdr.Seq - tcp_seq := tcp_start_seq + uint32(len(pkt.Payload)) + if len(pkt.Payload) == 0 && !tcphdr.FIN { + // return early if packet is not interesting. Still need to find/create + // stream first in order to update the TCP stream timer + return + } + + tcpStartSeq := tcphdr.Seq + tcpSeq := tcpStartSeq + uint32(len(pkt.Payload)) lastSeq := conn.lastSeq[stream.dir] if isDebug { debugf("pkt.start_seq=%v pkt.last_seq=%v stream.last_seq=%v (len=%d)", - tcp_start_seq, tcp_seq, lastSeq, len(pkt.Payload)) + tcpStartSeq, tcpSeq, lastSeq, len(pkt.Payload)) } if len(pkt.Payload) > 0 && lastSeq != 0 { - if tcpSeqBeforeEq(tcp_seq, lastSeq) { + if tcpSeqBeforeEq(tcpSeq, lastSeq) { if isDebug { debugf("Ignoring retransmitted segment. pkt.seq=%v len=%v stream.seq=%v", tcphdr.Seq, len(pkt.Payload), lastSeq) @@ -141,26 +156,41 @@ func (tcp *Tcp) Process(tcphdr *layers.TCP, pkt *protos.Packet) { return } - if tcpSeqBefore(lastSeq, tcp_start_seq) { - if !created { - gap := int(tcp_start_seq - lastSeq) - logp.Warn("Gap in tcp stream. last_seq: %d, seq: %d, gap: %d", lastSeq, tcp_start_seq, gap) - drop := stream.gapInStream(gap) - if drop { - if isDebug { - debugf("Dropping connection state because of gap") - } - - // drop application layer connection state and - // update stream_id for app layer analysers using stream_id for lookups - conn.id = tcp.getId() - conn.data = nil + switch tcpSeqCompare(lastSeq, tcpStartSeq) { + case seqLT: // lastSeq < tcpStartSeq => Gap in tcp stream detected + if created { + break + } + + gap := int(tcpStartSeq - lastSeq) + logp.Warn("Gap in tcp stream. last_seq: %d, seq: %d, gap: %d", lastSeq, tcpStartSeq, gap) + drop := stream.gapInStream(gap) + if drop { + if isDebug { + debugf("Dropping connection state because of gap") } + + // drop application layer connection state and + // update stream_id for app layer analysers using stream_id for lookups + conn.id = tcp.getId() + conn.data = nil + } + + case seqGT: + // lastSeq > tcpStartSeq => overlapping TCP segment detected. shrink packet + delta := lastSeq - tcpStartSeq + + if isDebug { + debugf("Overlapping tcp segment. last_seq %d, seq: %d, delta: %d", + lastSeq, tcpStartSeq, delta) } + + pkt.Payload = pkt.Payload[delta:] + tcphdr.Seq += delta } } - conn.lastSeq[stream.dir] = tcp_seq + conn.lastSeq[stream.dir] = tcpSeq stream.addPacket(pkt, tcphdr) } @@ -202,6 +232,18 @@ func (tcp *Tcp) getStream(pkt *protos.Packet) (stream TcpStream, created bool) { return TcpStream{conn: conn, dir: TcpDirectionOriginal}, true } +func tcpSeqCompare(seq1, seq2 uint32) seqCompare { + i := int32(seq1 - seq2) + switch { + case i == 0: + return seqEq + case i < 0: + return seqLT + default: + return seqGT + } +} + func tcpSeqBefore(seq1 uint32, seq2 uint32) bool { return int32(seq1-seq2) < 0 } diff --git a/packetbeat/protos/tcp/tcp_test.go b/packetbeat/protos/tcp/tcp_test.go index 63f81e98864..8d99f0fb469 100644 --- a/packetbeat/protos/tcp/tcp_test.go +++ b/packetbeat/protos/tcp/tcp_test.go @@ -1,7 +1,6 @@ package tcp import ( - "fmt" "math/rand" "net" "testing" @@ -166,45 +165,128 @@ func (p protocols) GetAllTcp() map[protos.Protocol]protos.TcpProtocolPlugin func (p protocols) GetAllUdp() map[protos.Protocol]protos.UdpProtocolPlugin { return nil } func (p protocols) Register(proto protos.Protocol, plugin protos.ProtocolPlugin) { return } -func TestGapInStreamShouldDropState(t *testing.T) { - gap := 0 - var state []byte - - data1 := []byte{1, 2, 3, 4} - data2 := []byte{5, 6, 7, 8} - - tp := &TestProtocol{Ports: []int{ServerPort}} - tp.gap = func(t *common.TcpTuple, d uint8, n int, p protos.ProtocolData) (protos.ProtocolData, bool) { - fmt.Printf("lost: %v\n", n) - gap += n - return p, true // drop state - } - tp.parse = func(p *protos.Packet, t *common.TcpTuple, d uint8, priv protos.ProtocolData) protos.ProtocolData { - if priv == nil { - state = nil - } - state = append(state, p.Payload...) - return state +func TestTCSeqPayload(t *testing.T) { + type segment struct { + seq uint32 + payload []byte } - p := protocols{} - p.tcp = map[protos.Protocol]protos.TcpProtocolPlugin{ - protos.HttpProtocol: tp, + tests := []struct { + name string + segments []segment + expectedGaps int + expectedState []byte + }{ + {"No overlap", + []segment{ + {1, []byte{1, 2, 3, 4, 5}}, + {6, []byte{6, 7, 8, 9, 10}}, + }, + 0, + []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + }, + {"Gap drop state", + []segment{ + {1, []byte{1, 2, 3, 4}}, + {15, []byte{5, 6, 7, 8}}, + }, + 10, + []byte{5, 6, 7, 8}, + }, + {"ACK same sequence number", + []segment{ + {1, []byte{1, 2}}, + {3, nil}, + {3, []byte{3, 4}}, + {5, []byte{5, 6}}, + }, + 0, + []byte{1, 2, 3, 4, 5, 6}, + }, + {"ACK same sequence number 2", + []segment{ + {1, nil}, + {2, nil}, + {2, []byte{1, 2}}, + {4, nil}, + {4, []byte{3, 4}}, + {6, []byte{5, 6}}, + {8, []byte{7, 8}}, + {10, nil}, + }, + 0, + []byte{1, 2, 3, 4, 5, 6, 7, 8}, + }, + {"Overlap, first segment bigger", + []segment{ + {1, []byte{1, 2}}, + {3, []byte{3, 4}}, + {3, []byte{3}}, + {5, []byte{5, 6}}, + }, + 0, + []byte{1, 2, 3, 4, 5, 6}, + }, + {"Overlap, second segment bigger", + []segment{ + {1, []byte{1, 2}}, + {3, []byte{3}}, + {3, []byte{3, 4}}, + {5, []byte{5, 6}}, + }, + 0, + []byte{1, 2, 3, 4, 5, 6}, + }, + {"Overlap, covered", + []segment{ + {1, []byte{1, 2, 3, 4}}, + {2, []byte{2, 3}}, + {5, []byte{5, 6}}, + }, + 0, + []byte{1, 2, 3, 4, 5, 6}, + }, } - tcp, _ := NewTcp(p) - addr := common.NewIpPortTuple(4, - net.ParseIP(ServerIp), ServerPort, - net.ParseIP(ClientIp), uint16(rand.Intn(65535))) + for i, test := range tests { + t.Logf("Test (%v): %v", i, test.name) + + gap := 0 + var state []byte + tcp, err := NewTcp(protocols{ + tcp: map[protos.Protocol]protos.TcpProtocolPlugin{ + protos.HttpProtocol: &TestProtocol{ + Ports: []int{ServerPort}, + gap: makeCountGaps(nil, &gap), + parse: makeCollectPayload(&state, true), + }, + }, + }) + if err != nil { + t.Fatal(err) + } - hdr := &layers.TCP{} - tcp.Process(hdr, &protos.Packet{Ts: time.Now(), Tuple: addr, Payload: data1}) - hdr.Seq += uint32(len(data1) + 10) - tcp.Process(hdr, &protos.Packet{Ts: time.Now(), Tuple: addr, Payload: data2}) + addr := common.NewIpPortTuple(4, + net.ParseIP(ServerIp), ServerPort, + net.ParseIP(ClientIp), uint16(rand.Intn(65535))) + + for _, segment := range test.segments { + hdr := &layers.TCP{Seq: segment.seq} + pkt := &protos.Packet{ + Ts: time.Now(), + Tuple: addr, + Payload: segment.payload, + } + tcp.Process(hdr, pkt) + } - // validate - assert.Equal(t, 10, gap) - assert.Equal(t, data2, state) + assert.Equal(t, test.expectedGaps, gap) + if len(test.expectedState) != len(state) { + assert.Equal(t, len(test.expectedState), len(state)) + continue + } + assert.Equal(t, test.expectedState, state) + } } // Benchmark that runs with parallelism to help find concurrency related @@ -231,3 +313,42 @@ func BenchmarkParallelProcess(b *testing.B) { } }) } + +func makeCountGaps( + counter *int, + bytes *int, +) func(*common.TcpTuple, uint8, int, protos.ProtocolData) (protos.ProtocolData, bool) { + return func( + t *common.TcpTuple, + d uint8, + n int, + p protos.ProtocolData, + ) (protos.ProtocolData, bool) { + if counter != nil { + (*counter)++ + } + if bytes != nil { + *bytes += n + } + + return p, true // drop state + } +} + +func makeCollectPayload( + state *[]byte, + resetOnNil bool, +) func(*protos.Packet, *common.TcpTuple, uint8, protos.ProtocolData) protos.ProtocolData { + return func( + p *protos.Packet, + t *common.TcpTuple, + d uint8, + priv protos.ProtocolData, + ) protos.ProtocolData { + if resetOnNil && priv == nil { + (*state) = nil + } + *state = append(*state, p.Payload...) + return *state + } +}