diff --git a/decoders/netflowlegacy/netflow.go b/decoders/netflowlegacy/netflow.go index 9329f703..955f82ce 100644 --- a/decoders/netflowlegacy/netflow.go +++ b/decoders/netflowlegacy/netflow.go @@ -7,6 +7,10 @@ import ( "github.com/netsampler/goflow2/decoders/utils" ) +const ( + MAX_COUNT = 1536 +) + type ErrorVersion struct { version uint16 } @@ -44,6 +48,10 @@ func DecodeMessage(payload *bytes.Buffer) (interface{}, error) { packet.SamplingInterval = packet.SamplingInterval & 0x3FFF + if packet.Count > MAX_COUNT { + return nil, fmt.Errorf("Too many samples (%d > %d) in packet", packet.Count, MAX_COUNT) + } + packet.Records = make([]RecordsNetFlowV5, int(packet.Count)) for i := 0; i < int(packet.Count) && payload.Len() >= 48; i++ { record := RecordsNetFlowV5{} diff --git a/decoders/sflow/sflow.go b/decoders/sflow/sflow.go index dd096301..309e965a 100644 --- a/decoders/sflow/sflow.go +++ b/decoders/sflow/sflow.go @@ -16,6 +16,10 @@ const ( FORMAT_ETH = 2 FORMAT_IPV4 = 3 FORMAT_IPV6 = 4 + + MAX_SAMPLES = 512 + MAX_RECORDS = 8192 + MAX_ATTRS = 16383 ) type ErrorDecodingSFlow struct { @@ -251,6 +255,11 @@ func DecodeFlowRecord(header *RecordHeader, payload *bytes.Buffer) (FlowRecord, if int(extendedGateway.ASPathLength) > payload.Len()-4 { return flowRecord, errors.New(fmt.Sprintf("Invalid AS path length: %v.", extendedGateway.ASPathLength)) } + + if extendedGateway.ASPathLength > MAX_ATTRS { + return flowRecord, fmt.Errorf("AS path too large (%d > %d) in record", extendedGateway.ASPathLength, MAX_ATTRS) + } + asPath = make([]uint32, extendedGateway.ASPathLength) if len(asPath) > 0 { err = utils.BinaryDecoder(payload, asPath) @@ -265,6 +274,10 @@ func DecodeFlowRecord(header *RecordHeader, payload *bytes.Buffer) (FlowRecord, if err != nil { return flowRecord, err } + if extendedGateway.CommunitiesLength > MAX_ATTRS { + return flowRecord, fmt.Errorf("Communities list too large (%d > %d) in record", extendedGateway.CommunitiesLength, MAX_ATTRS) + } + if int(extendedGateway.CommunitiesLength) > payload.Len()-4 { return flowRecord, errors.New(fmt.Sprintf("Invalid Communities length: %v.", extendedGateway.ASPathLength)) } @@ -331,6 +344,11 @@ func DecodeSample(header *SampleHeader, payload *bytes.Buffer) (interface{}, err return sample, err } recordsCount = flowSample.FlowRecordsCount + + if recordsCount > MAX_RECORDS { + return flowSample, fmt.Errorf("Too many records (%d > %d) in packet", recordsCount, MAX_RECORDS) + } + flowSample.Records = make([]FlowRecord, recordsCount) sample = flowSample } else if format == FORMAT_ETH || format == FORMAT_IPV6 { @@ -342,6 +360,10 @@ func DecodeSample(header *SampleHeader, payload *bytes.Buffer) (interface{}, err Header: *header, CounterRecordsCount: recordsCount, } + + if recordsCount > MAX_RECORDS { + return flowSample, fmt.Errorf("Too many records (%d > %d) in packet", recordsCount, MAX_RECORDS) + } counterSample.Records = make([]CounterRecord, recordsCount) sample = counterSample } else if format == FORMAT_IPV4 { @@ -355,6 +377,10 @@ func DecodeSample(header *SampleHeader, payload *bytes.Buffer) (interface{}, err return sample, err } recordsCount = expandedFlowSample.FlowRecordsCount + + if recordsCount > MAX_RECORDS { + return flowSample, fmt.Errorf("Too many records (%d > %d) in packet", recordsCount, MAX_RECORDS) + } expandedFlowSample.Records = make([]FlowRecord, recordsCount) sample = expandedFlowSample } @@ -424,6 +450,11 @@ func DecodeMessage(payload *bytes.Buffer) (interface{}, error) { if err != nil { return packetV5, err } + + if packetV5.SamplesCount > MAX_SAMPLES { + return nil, fmt.Errorf("Too many samples (%d > %d) in packet", packetV5.SamplesCount, MAX_SAMPLES) + } + packetV5.Samples = make([]interface{}, int(packetV5.SamplesCount)) for i := 0; i < int(packetV5.SamplesCount) && payload.Len() >= 8; i++ { header := SampleHeader{}