diff --git a/ipfix/decoder.go b/ipfix/decoder.go index 15e16bcb..600728be 100644 --- a/ipfix/decoder.go +++ b/ipfix/decoder.go @@ -164,17 +164,19 @@ func (d *Decoder) decodeSet(mem MemCache, msg *Message) error { // This check is somewhat redundant with the switch-clause below, but the retrieve() operation should not be executed inside the loop. if setHeader.SetID > 255 { var ok bool - if tr, ok = mem.retrieve(setHeader.SetID, d.raddr); !ok { + if tr, ok = mem.retrieve(setHeader.SetID, d.raddr, msg.Header.DomainID); !ok { select { case rpcChan <- RPCRequest{ ID: setHeader.SetID, IP: d.raddr, + SrcID: msg.Header.DomainID, }: default: } - err = nonfatalError{fmt.Errorf("%s unknown ipfix template id# %d", + err = nonfatalError{fmt.Errorf("%s unknown ipfix template id# %d with domain ID %d", d.raddr.String(), setHeader.SetID, + msg.Header.DomainID, )} } } @@ -196,7 +198,7 @@ func (d *Decoder) decodeSet(mem MemCache, msg *Message) error { err = tr.unmarshalOpts(d.reader) } if err == nil { - mem.insert(tr.TemplateID, d.raddr, tr) + mem.insert(tr.TemplateID, d.raddr, tr, msg.Header.DomainID) } } else if setID >= 4 && setID <= 255 { // Reserved set, do not read any records diff --git a/ipfix/decoder_test.go b/ipfix/decoder_test.go index 4e56e1ca..1682538d 100644 --- a/ipfix/decoder_test.go +++ b/ipfix/decoder_test.go @@ -31,7 +31,7 @@ import ( var tpl, optsTpl, multiMessage, unknownDatasetMessage []byte func init() { - // IPFIX packet including template SetID:2, 25 fields + // IPFIX packet including template SetID:2, 25 fields, Domain id 33792 tpl = []byte{ 0x0, 0xa, 0x0, 0x7c, 0x58, 0x90, 0xd6, 0x40, 0x28, 0xf7, 0xa0, 0x4a, 0x0, 0x0, 0x84, 0x0, 0x0, 0x2, 0x0, 0x6c, 0x1, @@ -205,8 +205,8 @@ func TestUnknownDatasetsMessage(t *testing.T) { t.Error("Did not expect any result datasets, but got", l) } expectedErrorStr := `Multiple errors: -- 127.0.0.1 unknown ipfix template id# 264 -- 127.0.0.1 unknown ipfix template id# 264` +- 127.0.0.1 unknown ipfix template id# 264 with domain ID 1 +- 127.0.0.1 unknown ipfix template id# 264 with domain ID 1` if err == nil || err.Error() != expectedErrorStr { t.Error("Received unexpected erorr:", err) } diff --git a/ipfix/memcache.go b/ipfix/memcache.go index 5d2ee961..85ed053d 100644 --- a/ipfix/memcache.go +++ b/ipfix/memcache.go @@ -79,27 +79,32 @@ func GetCache(cacheFile string) MemCache { return m } -func (m MemCache) getShard(id uint16, addr net.IP) (*TemplatesShard, uint32) { - b := make([]byte, 2) - binary.BigEndian.PutUint16(b, id) - key := append(addr, b...) - +func (m MemCache) getShard(templateId uint16, addr net.IP, domainId uint32) (*TemplatesShard, uint32) { + var key []byte hash := fnv.New32() + dId := make([]byte, 4) + tID := make([]byte, 2) + binary.LittleEndian.PutUint32(dId, domainId) + binary.BigEndian.PutUint16(tID, templateId) + key = append(key, addr...) + key = append(key, dId...) + key = append(key, tID...) + hash.Write(key) hSum32 := hash.Sum32() return m[uint(hSum32)%uint(shardNo)], hSum32 } -func (m MemCache) insert(id uint16, addr net.IP, tr TemplateRecord) { - shard, key := m.getShard(id, addr) +func (m MemCache) insert(id uint16, addr net.IP, tr TemplateRecord, domainID uint32) { + shard, key := m.getShard(id, addr, domainID) shard.Lock() defer shard.Unlock() shard.Templates[key] = Data{tr, time.Now().Unix()} } -func (m MemCache) retrieve(id uint16, addr net.IP) (TemplateRecord, bool) { - shard, key := m.getShard(id, addr) +func (m MemCache) retrieve(id uint16, addr net.IP, domainID uint32) (TemplateRecord, bool) { + shard, key := m.getShard(id, addr, domainID) shard.RLock() defer shard.RUnlock() v, ok := shard.Templates[key] diff --git a/ipfix/memcache_rpc.go b/ipfix/memcache_rpc.go index 26de0e4b..bcf3644a 100644 --- a/ipfix/memcache_rpc.go +++ b/ipfix/memcache_rpc.go @@ -58,6 +58,7 @@ type RPCConfig struct { type RPCRequest struct { ID uint16 IP net.IP + SrcID uint32 } type vFlowServer struct { @@ -91,7 +92,7 @@ func NewRPC(mCache MemCache) *IRPC { func (r *IRPC) Get(req RPCRequest, resp *TemplateRecord) error { var ok bool - *resp, ok = r.mCache.retrieve(req.ID, req.IP) + *resp, ok = r.mCache.retrieve(req.ID, req.IP, req.SrcID) if !ok { return errNotAvail } @@ -168,7 +169,7 @@ func RPC(m MemCache, config *RPCConfig) { continue } - m.insert(req.ID, req.IP, *tr) + m.insert(req.ID, req.IP, *tr, req.SrcID) break } diff --git a/ipfix/memcache_test.go b/ipfix/memcache_test.go index ac3327d9..16f36a26 100644 --- a/ipfix/memcache_test.go +++ b/ipfix/memcache_test.go @@ -33,7 +33,7 @@ func TestMemCacheRetrieve(t *testing.T) { mCache := GetCache("cache.file") d := NewDecoder(ip, tpl) d.Decode(mCache) - v, ok := mCache.retrieve(256, ip) + v, ok := mCache.retrieve(256, ip, 33792) if !ok { t.Error("expected mCache retrieve status true, got", ok) } @@ -48,9 +48,9 @@ func TestMemCacheInsert(t *testing.T) { mCache := GetCache("cache.file") tpl.TemplateID = 310 - mCache.insert(310, ip, tpl) + mCache.insert(310, ip, tpl, 513) - v, ok := mCache.retrieve(310, ip) + v, ok := mCache.retrieve(310, ip, 513) if !ok { t.Error("expected mCache retrieve status true, got", ok) } @@ -65,11 +65,11 @@ func TestMemCacheAllSetIds(t *testing.T) { mCache := GetCache("cache.file") tpl.TemplateID = 310 - mCache.insert(tpl.TemplateID, ip, tpl) + mCache.insert(tpl.TemplateID, ip, tpl, 513) tpl.TemplateID = 410 - mCache.insert(tpl.TemplateID, ip, tpl) + mCache.insert(tpl.TemplateID, ip, tpl, 513) tpl.TemplateID = 210 - mCache.insert(tpl.TemplateID, ip, tpl) + mCache.insert(tpl.TemplateID, ip, tpl, 513) expected := []int{210, 310, 410} actual := mCache.allSetIds() @@ -77,3 +77,33 @@ func TestMemCacheAllSetIds(t *testing.T) { t.Errorf("Expected set IDs %v, got %v", expected, actual) } } + +func TestMemCache_keyWithDifferentDomainIDs(t *testing.T) { + var tpl TemplateRecord + ip := net.ParseIP("127.0.0.1") + mCache := GetCache("cache.file") + + tpl.TemplateID = 310 + tpl.FieldCount = 19 + mCache.insert(tpl.TemplateID, ip, tpl, 513) + tpl.FieldCount = 21 + mCache.insert(tpl.TemplateID, ip, tpl, 514) + + v, ok := mCache.retrieve(tpl.TemplateID, ip, 513) + + if !ok { + t.Error("expected mCache retrieve status true, got", ok) + } + if v.FieldCount != 19 { + t.Error("expected template id#:310 with Field count#:19, got", v.TemplateID, v.FieldCount) + } + + v, ok = mCache.retrieve(tpl.TemplateID, ip, 514) + + if !ok { + t.Error("expected mCache retrieve status true, got", ok) + } + if v.FieldCount != 21 { + t.Error("expected template id#:310 with Field count#:21, got", v.TemplateID, v.FieldCount) + } +} \ No newline at end of file diff --git a/netflow/v9/decoder.go b/netflow/v9/decoder.go index 4023d8f3..55c4e393 100644 --- a/netflow/v9/decoder.go +++ b/netflow/v9/decoder.go @@ -426,11 +426,11 @@ func (d *Decoder) decodeSet(mem MemCache, msg *Message) error { // This check is somewhat redundant with the switch-clause below, but the retrieve() operation should not be executed inside the loop. if setHeader.FlowSetID > 255 { var ok bool - tr, ok = mem.retrieve(setHeader.FlowSetID, d.raddr) + tr, ok = mem.retrieve(setHeader.FlowSetID, d.raddr, msg.Header.SrcID) if !ok { - err = nonfatalError(fmt.Errorf("%s unknown netflow template id# %d", + err = nonfatalError(fmt.Errorf("%s unknown netflow template id# %d from sourceID %d", d.raddr.String(), - setHeader.FlowSetID, + setHeader.FlowSetID, msg.Header.SrcID, )) } } @@ -446,7 +446,7 @@ func (d *Decoder) decodeSet(mem MemCache, msg *Message) error { err = tr.unmarshalOpts(d.reader) } if err == nil { - mem.insert(tr.TemplateID, d.raddr, tr) + mem.insert(tr.TemplateID, d.raddr, tr, msg.Header.SrcID) } } else if setId >= 4 && setId <= 255 { // Reserved set, do not read any records diff --git a/netflow/v9/memcache.go b/netflow/v9/memcache.go index 2596ae46..69a52ece 100644 --- a/netflow/v9/memcache.go +++ b/netflow/v9/memcache.go @@ -78,27 +78,32 @@ func GetCache(cacheFile string) MemCache { return m } -func (m MemCache) getShard(id uint16, addr net.IP) (*TemplatesShard, uint32) { - b := make([]byte, 2) - binary.BigEndian.PutUint16(b, id) - key := append(addr, b...) - +func (m MemCache) getShard(templateId uint16, addr net.IP, srcId uint32) (*TemplatesShard, uint32) { + var key []byte hash := fnv.New32() + sId := make([]byte, 4) + tID := make([]byte, 2) + binary.LittleEndian.PutUint32(sId, srcId) + binary.BigEndian.PutUint16(tID, templateId) + key = append(key, addr...) + key = append(key, sId...) + key = append(key, tID...) + hash.Write(key) hSum32 := hash.Sum32() return m[uint(hSum32)%uint(shardNo)], hSum32 } -func (m *MemCache) insert(id uint16, addr net.IP, tr TemplateRecord) { - shard, key := m.getShard(id, addr) +func (m *MemCache) insert(id uint16, addr net.IP, tr TemplateRecord, srcID uint32) { + shard, key := m.getShard(id, addr, srcID) shard.Lock() defer shard.Unlock() shard.Templates[key] = Data{tr, time.Now().Unix()} } -func (m *MemCache) retrieve(id uint16, addr net.IP) (TemplateRecord, bool) { - shard, key := m.getShard(id, addr) +func (m *MemCache) retrieve(id uint16, addr net.IP, srcID uint32) (TemplateRecord, bool) { + shard, key := m.getShard(id, addr, srcID) shard.RLock() defer shard.RUnlock() v, ok := shard.Templates[key]