diff --git a/cmd/api/src/queries/graph.go b/cmd/api/src/queries/graph.go index 205c7d5c8..159d80ea9 100644 --- a/cmd/api/src/queries/graph.go +++ b/cmd/api/src/queries/graph.go @@ -24,6 +24,7 @@ import ( "fmt" "net/http" "net/url" + "slices" "sort" "strconv" "strings" @@ -244,7 +245,18 @@ func (s *GraphQuery) GetAssetGroupNodes(ctx context.Context, assetGroupTag strin return err } else { for _, node := range assetGroupNodes { - node.Properties.Set("type", analysis.GetNodeKindDisplayLabel(node)) + // We need to filter out nodes that do not contain an exact tag match + var ( + systemTags, _ = node.Properties.Get(common.SystemTags.String()).String() + userTags, _ = node.Properties.Get(common.UserTags.String()).String() + allTags = append(strings.Split(systemTags, " "), strings.Split(userTags, " ")...) + ) + + if !slices.Contains(allTags, assetGroupTag) { + assetGroupNodes.Remove(node.ID) + } else { + node.Properties.Set("type", analysis.GetNodeKindDisplayLabel(node)) + } } return nil } diff --git a/cmd/api/src/queries/graph_integration_test.go b/cmd/api/src/queries/graph_integration_test.go index adff67d31..02813468d 100644 --- a/cmd/api/src/queries/graph_integration_test.go +++ b/cmd/api/src/queries/graph_integration_test.go @@ -295,19 +295,24 @@ func TestGetAssetGroupNodes(t *testing.T) { }, func(harness integration.HarnessDetails, db graph.Database) { graphQuery := queries.NewGraphQuery(db, cache.Cache{}, config.Configuration{}) - tierZeroNodes, err := graphQuery.GetAssetGroupNodes(context.Background(), ad.AdminTierZero) + tierZeroNodes, err := graphQuery.GetAssetGroupNodes(context.Background(), harness.AssetGroupNodesHarness.TierZeroTag) require.Nil(t, err) - customGroupNodes, err := graphQuery.GetAssetGroupNodes(context.Background(), "custom_tag") + customGroup1Nodes, err := graphQuery.GetAssetGroupNodes(context.Background(), harness.AssetGroupNodesHarness.CustomTag1) + require.Nil(t, err) + + customGroup2Nodes, err := graphQuery.GetAssetGroupNodes(context.Background(), harness.AssetGroupNodesHarness.CustomTag2) require.Nil(t, err) require.True(t, tierZeroNodes.Contains(harness.AssetGroupNodesHarness.GroupB)) require.True(t, tierZeroNodes.Contains(harness.AssetGroupNodesHarness.GroupC)) require.Equal(t, 2, len(tierZeroNodes)) - require.True(t, customGroupNodes.Contains(harness.AssetGroupNodesHarness.GroupD)) - require.True(t, customGroupNodes.Contains(harness.AssetGroupNodesHarness.GroupE)) - require.Equal(t, 2, len(customGroupNodes)) + require.True(t, customGroup1Nodes.Contains(harness.AssetGroupNodesHarness.GroupD)) + require.Equal(t, 1, len(customGroup1Nodes)) + + require.True(t, customGroup2Nodes.Contains(harness.AssetGroupNodesHarness.GroupE)) + require.Equal(t, 1, len(customGroup2Nodes)) }) } diff --git a/cmd/api/src/test/integration/harnesses.go b/cmd/api/src/test/integration/harnesses.go index a21edc15d..f05974a07 100644 --- a/cmd/api/src/test/integration/harnesses.go +++ b/cmd/api/src/test/integration/harnesses.go @@ -331,26 +331,34 @@ func (s *AssetGroupComboNodeHarness) Setup(testCtx *GraphTestContext) { } type AssetGroupNodesHarness struct { - GroupA *graph.Node - GroupB *graph.Node - GroupC *graph.Node - GroupD *graph.Node - GroupE *graph.Node + GroupA *graph.Node + GroupB *graph.Node + GroupC *graph.Node + GroupD *graph.Node + GroupE *graph.Node + TierZeroTag string + CustomTag1 string + CustomTag2 string } func (s *AssetGroupNodesHarness) Setup(testCtx *GraphTestContext) { domainSID := RandomDomainSID() + // use one tag value that contains the other as a substring to test that we only match exactly + s.TierZeroTag = ad.AdminTierZero + s.CustomTag1 = "custom_tag" + s.CustomTag2 = "another_custom_tag" + s.GroupA = testCtx.NewActiveDirectoryGroup("GroupA", domainSID) s.GroupB = testCtx.NewActiveDirectoryGroup("GroupB", domainSID) s.GroupC = testCtx.NewActiveDirectoryGroup("GroupC", domainSID) s.GroupD = testCtx.NewActiveDirectoryGroup("GroupD", domainSID) - s.GroupE = testCtx.NewActiveDirectoryGroup("GroupD", domainSID) + s.GroupE = testCtx.NewActiveDirectoryGroup("GroupE", domainSID) - s.GroupB.Properties.Set(common.SystemTags.String(), ad.AdminTierZero) - s.GroupC.Properties.Set(common.SystemTags.String(), ad.AdminTierZero) - s.GroupD.Properties.Set(common.UserTags.String(), "custom_tag") - s.GroupE.Properties.Set(common.UserTags.String(), "custom_tag another_tag") + s.GroupB.Properties.Set(common.SystemTags.String(), s.TierZeroTag) + s.GroupC.Properties.Set(common.SystemTags.String(), s.TierZeroTag) + s.GroupD.Properties.Set(common.UserTags.String(), s.CustomTag1) + s.GroupE.Properties.Set(common.UserTags.String(), s.CustomTag2) testCtx.UpdateNode(s.GroupB) testCtx.UpdateNode(s.GroupC)