diff --git a/CHANGELOG.asciidoc b/CHANGELOG.asciidoc index af16eb5792b..a95aa471365 100644 --- a/CHANGELOG.asciidoc +++ b/CHANGELOG.asciidoc @@ -44,6 +44,7 @@ https://github.com/elastic/beats/compare/v5.0.0-alpha3...master[Check the HEAD d *Packetbeat* - Add missing nil-check to memcached GapInStream handler. {issue}1162[1162] - Fix NFSv4 Operation returning the first found first-class operation available in compound requests. {pull}1821[1821] +- Fix TCP overlapping segments not being handled correctly. {pull}1898[1898] *Topbeat* diff --git a/packetbeat/protos/tcp/tcp.go b/packetbeat/protos/tcp/tcp.go index 8e7bbf76ab2..7f912173358 100644 --- a/packetbeat/protos/tcp/tcp.go +++ b/packetbeat/protos/tcp/tcp.go @@ -31,6 +31,14 @@ type Processor interface { Process(flow *flows.FlowID, hdr *layers.TCP, pkt *protos.Packet) } +type seqCompare int + +const ( + seqLT seqCompare = -1 + seqEq seqCompare = 0 + seqGT seqCompare = 1 +) + var ( debugf = logp.MakeDebug("tcp") isDebug = false @@ -119,28 +127,36 @@ func (tcp *Tcp) Process(id *flows.FlowID, tcphdr *layers.TCP, pkt *protos.Packet // protocol modules. defer logp.Recover("Process tcp exception") - debugf("tcp flow id: %p", id) - stream, created := tcp.getStream(pkt) if stream.conn == nil { return } + conn := stream.conn if id != nil { - id.AddConnectionID(uint64(stream.conn.id)) + id.AddConnectionID(uint64(conn.id)) + } + + if isDebug { + debugf("tcp flow id: %p", id) + } + + if len(pkt.Payload) == 0 && !tcphdr.FIN { + // return early if packet is not interesting. Still need to find/create + // stream first in order to update the TCP stream timer + return } - conn := stream.conn - tcp_start_seq := tcphdr.Seq - tcp_seq := tcp_start_seq + uint32(len(pkt.Payload)) + tcpStartSeq := tcphdr.Seq + tcpSeq := tcpStartSeq + uint32(len(pkt.Payload)) lastSeq := conn.lastSeq[stream.dir] if isDebug { debugf("pkt.start_seq=%v pkt.last_seq=%v stream.last_seq=%v (len=%d)", - tcp_start_seq, tcp_seq, lastSeq, len(pkt.Payload)) + tcpStartSeq, tcpSeq, lastSeq, len(pkt.Payload)) } if len(pkt.Payload) > 0 && lastSeq != 0 { - if tcpSeqBeforeEq(tcp_seq, lastSeq) { + if tcpSeqBeforeEq(tcpSeq, lastSeq) { if isDebug { debugf("Ignoring retransmitted segment. pkt.seq=%v len=%v stream.seq=%v", tcphdr.Seq, len(pkt.Payload), lastSeq) @@ -148,26 +164,41 @@ func (tcp *Tcp) Process(id *flows.FlowID, tcphdr *layers.TCP, pkt *protos.Packet return } - if tcpSeqBefore(lastSeq, tcp_start_seq) { - if !created { - gap := int(tcp_start_seq - lastSeq) - logp.Warn("Gap in tcp stream. last_seq: %d, seq: %d, gap: %d", lastSeq, tcp_start_seq, gap) - drop := stream.gapInStream(gap) - if drop { - if isDebug { - debugf("Dropping connection state because of gap") - } - - // drop application layer connection state and - // update stream_id for app layer analysers using stream_id for lookups - conn.id = tcp.getId() - conn.data = nil + switch tcpSeqCompare(lastSeq, tcpStartSeq) { + case seqLT: // lastSeq < tcpStartSeq => Gap in tcp stream detected + if created { + break + } + + gap := int(tcpStartSeq - lastSeq) + logp.Warn("Gap in tcp stream. last_seq: %d, seq: %d, gap: %d", lastSeq, tcpStartSeq, gap) + drop := stream.gapInStream(gap) + if drop { + if isDebug { + debugf("Dropping connection state because of gap") } + + // drop application layer connection state and + // update stream_id for app layer analysers using stream_id for lookups + conn.id = tcp.getId() + conn.data = nil + } + + case seqGT: + // lastSeq > tcpStartSeq => overlapping TCP segment detected. shrink packet + delta := lastSeq - tcpStartSeq + + if isDebug { + debugf("Overlapping tcp segment. last_seq %d, seq: %d, delta: %d", + lastSeq, tcpStartSeq, delta) } + + pkt.Payload = pkt.Payload[delta:] + tcphdr.Seq += delta } } - conn.lastSeq[stream.dir] = tcp_seq + conn.lastSeq[stream.dir] = tcpSeq stream.addPacket(pkt, tcphdr) } @@ -209,6 +240,18 @@ func (tcp *Tcp) getStream(pkt *protos.Packet) (stream TcpStream, created bool) { return TcpStream{conn: conn, dir: TcpDirectionOriginal}, true } +func tcpSeqCompare(seq1, seq2 uint32) seqCompare { + i := int32(seq1 - seq2) + switch { + case i == 0: + return seqEq + case i < 0: + return seqLT + default: + return seqGT + } +} + func tcpSeqBefore(seq1 uint32, seq2 uint32) bool { return int32(seq1-seq2) < 0 } diff --git a/packetbeat/protos/tcp/tcp_test.go b/packetbeat/protos/tcp/tcp_test.go index 3a663832548..7bd0960dbb8 100644 --- a/packetbeat/protos/tcp/tcp_test.go +++ b/packetbeat/protos/tcp/tcp_test.go @@ -3,7 +3,6 @@ package tcp import ( - "fmt" "math/rand" "net" "testing" @@ -186,45 +185,128 @@ func (p protocols) GetAllTcp() map[protos.Protocol]protos.TcpPlugin { retur func (p protocols) GetAllUdp() map[protos.Protocol]protos.UdpPlugin { return nil } func (p protocols) Register(proto protos.Protocol, plugin protos.Plugin) { return } -func TestGapInStreamShouldDropState(t *testing.T) { - gap := 0 - var state []byte - - data1 := []byte{1, 2, 3, 4} - data2 := []byte{5, 6, 7, 8} - - tp := &TestProtocol{Ports: []int{ServerPort}} - tp.gap = func(t *common.TcpTuple, d uint8, n int, p protos.ProtocolData) (protos.ProtocolData, bool) { - fmt.Printf("lost: %v\n", n) - gap += n - return p, true // drop state - } - tp.parse = func(p *protos.Packet, t *common.TcpTuple, d uint8, priv protos.ProtocolData) protos.ProtocolData { - if priv == nil { - state = nil - } - state = append(state, p.Payload...) - return state +func TestTCSeqPayload(t *testing.T) { + type segment struct { + seq uint32 + payload []byte } - p := protocols{} - p.tcp = map[protos.Protocol]protos.TcpPlugin{ - httpProtocol: tp, + tests := []struct { + name string + segments []segment + expectedGaps int + expectedState []byte + }{ + {"No overlap", + []segment{ + {1, []byte{1, 2, 3, 4, 5}}, + {6, []byte{6, 7, 8, 9, 10}}, + }, + 0, + []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + }, + {"Gap drop state", + []segment{ + {1, []byte{1, 2, 3, 4}}, + {15, []byte{5, 6, 7, 8}}, + }, + 10, + []byte{5, 6, 7, 8}, + }, + {"ACK same sequence number", + []segment{ + {1, []byte{1, 2}}, + {3, nil}, + {3, []byte{3, 4}}, + {5, []byte{5, 6}}, + }, + 0, + []byte{1, 2, 3, 4, 5, 6}, + }, + {"ACK same sequence number 2", + []segment{ + {1, nil}, + {2, nil}, + {2, []byte{1, 2}}, + {4, nil}, + {4, []byte{3, 4}}, + {6, []byte{5, 6}}, + {8, []byte{7, 8}}, + {10, nil}, + }, + 0, + []byte{1, 2, 3, 4, 5, 6, 7, 8}, + }, + {"Overlap, first segment bigger", + []segment{ + {1, []byte{1, 2}}, + {3, []byte{3, 4}}, + {3, []byte{3}}, + {5, []byte{5, 6}}, + }, + 0, + []byte{1, 2, 3, 4, 5, 6}, + }, + {"Overlap, second segment bigger", + []segment{ + {1, []byte{1, 2}}, + {3, []byte{3}}, + {3, []byte{3, 4}}, + {5, []byte{5, 6}}, + }, + 0, + []byte{1, 2, 3, 4, 5, 6}, + }, + {"Overlap, covered", + []segment{ + {1, []byte{1, 2, 3, 4}}, + {2, []byte{2, 3}}, + {5, []byte{5, 6}}, + }, + 0, + []byte{1, 2, 3, 4, 5, 6}, + }, } - tcp, _ := NewTcp(p) - addr := common.NewIpPortTuple(4, - net.ParseIP(ServerIp), ServerPort, - net.ParseIP(ClientIp), uint16(rand.Intn(65535))) + for i, test := range tests { + t.Logf("Test (%v): %v", i, test.name) + + gap := 0 + var state []byte + tcp, err := NewTcp(protocols{ + tcp: map[protos.Protocol]protos.TcpPlugin{ + httpProtocol: &TestProtocol{ + Ports: []int{ServerPort}, + gap: makeCountGaps(nil, &gap), + parse: makeCollectPayload(&state, true), + }, + }, + }) + if err != nil { + t.Fatal(err) + } - hdr := &layers.TCP{} - tcp.Process(nil, hdr, &protos.Packet{Ts: time.Now(), Tuple: addr, Payload: data1}) - hdr.Seq += uint32(len(data1) + 10) - tcp.Process(nil, hdr, &protos.Packet{Ts: time.Now(), Tuple: addr, Payload: data2}) + addr := common.NewIpPortTuple(4, + net.ParseIP(ServerIp), ServerPort, + net.ParseIP(ClientIp), uint16(rand.Intn(65535))) + + for _, segment := range test.segments { + hdr := &layers.TCP{Seq: segment.seq} + pkt := &protos.Packet{ + Ts: time.Now(), + Tuple: addr, + Payload: segment.payload, + } + tcp.Process(nil, hdr, pkt) + } - // validate - assert.Equal(t, 10, gap) - assert.Equal(t, data2, state) + assert.Equal(t, test.expectedGaps, gap) + if len(test.expectedState) != len(state) { + assert.Equal(t, len(test.expectedState), len(state)) + continue + } + assert.Equal(t, test.expectedState, state) + } } // Benchmark that runs with parallelism to help find concurrency related @@ -251,3 +333,42 @@ func BenchmarkParallelProcess(b *testing.B) { } }) } + +func makeCountGaps( + counter *int, + bytes *int, +) func(*common.TcpTuple, uint8, int, protos.ProtocolData) (protos.ProtocolData, bool) { + return func( + t *common.TcpTuple, + d uint8, + n int, + p protos.ProtocolData, + ) (protos.ProtocolData, bool) { + if counter != nil { + (*counter)++ + } + if bytes != nil { + *bytes += n + } + + return p, true // drop state + } +} + +func makeCollectPayload( + state *[]byte, + resetOnNil bool, +) func(*protos.Packet, *common.TcpTuple, uint8, protos.ProtocolData) protos.ProtocolData { + return func( + p *protos.Packet, + t *common.TcpTuple, + d uint8, + priv protos.ProtocolData, + ) protos.ProtocolData { + if resetOnNil && priv == nil { + (*state) = nil + } + *state = append(*state, p.Payload...) + return *state + } +}