Skip to content

Commit

Permalink
[ENT] fix trie output for package, source and vulnerability (#1863)
Browse files Browse the repository at this point in the history
* fix trie output for package, source and vulnerability

Signed-off-by: pxp928 <parth.psu@gmail.com>

* fix IDs for type and namespace. Fix neighbors query

Signed-off-by: pxp928 <parth.psu@gmail.com>

* use guac-split-@@ to concatenate type and namespace for ID

Signed-off-by: pxp928 <parth.psu@gmail.com>

---------

Signed-off-by: pxp928 <parth.psu@gmail.com>
  • Loading branch information
pxp928 authored Apr 24, 2024
1 parent eed71a5 commit 5ff8e90
Show file tree
Hide file tree
Showing 7 changed files with 192 additions and 56 deletions.
5 changes: 5 additions & 0 deletions pkg/assembler/backends/ent/backend/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ import (
"github.com/guacsec/guac/pkg/assembler/graphql/model"
)

const (
// guacIDSplit is used as a separator to concatenate the type and namespace to create an ID
guacIDSplit = "guac-split-@@"
)

type globalID struct {
nodeType string
id string
Expand Down
10 changes: 5 additions & 5 deletions pkg/assembler/backends/ent/backend/neighbors.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,12 +148,12 @@ func (b *EntBackend) Neighbors(ctx context.Context, nodeID string, usingOnly []m
return []model.Node{}, fmt.Errorf("failed to get pkgName neighbors with id: %s with error: %w", nodeID, err)
}
case pkgNamespaceString:
neighbors, err = b.packageNamespaceNeighbors(ctx, nodeID, processUsingOnly(usingOnly))
neighbors, err = b.packageNamespaceNeighbors(ctx, foundGlobalID.id, processUsingOnly(usingOnly))
if err != nil {
return []model.Node{}, fmt.Errorf("failed to get pkgNamespace neighbors with id: %s with error: %w", nodeID, err)
}
case pkgTypeString:
neighbors, err = b.packageTypeNeighbors(ctx, nodeID, processUsingOnly(usingOnly))
neighbors, err = b.packageTypeNeighbors(ctx, foundGlobalID.id, processUsingOnly(usingOnly))
if err != nil {
return []model.Node{}, fmt.Errorf("failed to get pkgType neighbors with id: %s with error: %w", nodeID, err)
}
Expand All @@ -163,12 +163,12 @@ func (b *EntBackend) Neighbors(ctx context.Context, nodeID string, usingOnly []m
return []model.Node{}, fmt.Errorf("failed to get source name neighbors with id: %s with error: %w", nodeID, err)
}
case srcNamespaceString:
neighbors, err = b.srcNamespaceNeighbors(ctx, nodeID, processUsingOnly(usingOnly))
neighbors, err = b.srcNamespaceNeighbors(ctx, foundGlobalID.id, processUsingOnly(usingOnly))
if err != nil {
return []model.Node{}, fmt.Errorf("failed to get source namespace neighbors with id: %s with error: %w", nodeID, err)
}
case srcTypeString:
neighbors, err = b.srcTypeNeighbors(ctx, nodeID, processUsingOnly(usingOnly))
neighbors, err = b.srcTypeNeighbors(ctx, foundGlobalID.id, processUsingOnly(usingOnly))
if err != nil {
return []model.Node{}, fmt.Errorf("failed to get source type neighbors with id: %s with error: %w", nodeID, err)
}
Expand All @@ -178,7 +178,7 @@ func (b *EntBackend) Neighbors(ctx context.Context, nodeID string, usingOnly []m
return []model.Node{}, fmt.Errorf("failed to get vulnID neighbors with id: %s with error: %w", nodeID, err)
}
case vulnTypeString:
neighbors, err = b.vulnTypeNeighbors(ctx, nodeID, processUsingOnly(usingOnly))
neighbors, err = b.vulnTypeNeighbors(ctx, foundGlobalID.id, processUsingOnly(usingOnly))
if err != nil {
return []model.Node{}, fmt.Errorf("failed to get vuln type neighbors with id: %s with error: %w", nodeID, err)
}
Expand Down
47 changes: 29 additions & 18 deletions pkg/assembler/backends/ent/backend/package.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
stdsql "database/sql"
"fmt"
"sort"
"strings"

"entgo.io/ent/dialect/sql"
"github.com/google/uuid"
Expand All @@ -38,8 +39,8 @@ import (
)

const (
pkgTypeString = "pkgType"
pkgNamespaceString = "pkgNamespace"
pkgTypeString = "package_types"
pkgNamespaceString = "package_namespaces"
)

func (b *EntBackend) Packages(ctx context.Context, pkgSpec *model.PkgSpec) ([]*model.Package, error) {
Expand All @@ -61,7 +62,7 @@ func (b *EntBackend) Packages(ctx context.Context, pkgSpec *model.PkgSpec) ([]*m
pkgNames = append(pkgNames, backReferencePackageVersion(collectedPkgVersion))
}

return collect(pkgNames, toModelPackage), nil
return toModelPackageTrie(pkgNames), nil
}

func packageQueryPredicates(pkgSpec *model.PkgSpec) predicate.PackageVersion {
Expand Down Expand Up @@ -138,6 +139,8 @@ func upsertBulkPackage(ctx context.Context, tx *ent.Tx, pkgInputs []*model.IDorP
batches := chunk(pkgInputs, MaxBatchSize)
pkgNameIDs := make([]string, 0)
pkgVersionIDs := make([]string, 0)
pkgTypes := map[string]string{}
pkgNamespaces := map[string]string{}

for _, pkgs := range batches {
pkgNameCreates := make([]*ent.PackageNameCreate, len(pkgs))
Expand All @@ -153,6 +156,8 @@ func upsertBulkPackage(ctx context.Context, tx *ent.Tx, pkgInputs []*model.IDorP
pkgVersionCreates[i] = generatePackageVersionCreate(tx, &pkgVersionID, &pkgNameID, pkgInput)

pkgNameIDs = append(pkgNameIDs, pkgNameID.String())
pkgTypes[pkgNameID.String()] = pkgInput.PackageInput.Type
pkgNamespaces[pkgNameID.String()] = strings.Join([]string{pkgInput.PackageInput.Type, stringOrEmpty(pkgInput.PackageInput.Namespace)}, guacIDSplit)
pkgVersionIDs = append(pkgVersionIDs, pkgVersionID.String())
}

Expand Down Expand Up @@ -182,8 +187,8 @@ func upsertBulkPackage(ctx context.Context, tx *ent.Tx, pkgInputs []*model.IDorP
var collectedPkgIDs []model.PackageIDs
for i := range pkgVersionIDs {
collectedPkgIDs = append(collectedPkgIDs, model.PackageIDs{
PackageTypeID: toGlobalID(pkgTypeString, pkgNameIDs[i]),
PackageNamespaceID: toGlobalID(pkgNamespaceString, pkgNameIDs[i]),
PackageTypeID: toGlobalID(pkgTypeString, pkgTypes[pkgNameIDs[i]]),
PackageNamespaceID: toGlobalID(pkgNamespaceString, pkgNamespaces[pkgNameIDs[i]]),
PackageNameID: toGlobalID(ent.TypePackageName, pkgNameIDs[i]),
PackageVersionID: toGlobalID(ent.TypePackageVersion, pkgVersionIDs[i])})
}
Expand Down Expand Up @@ -227,8 +232,8 @@ func upsertPackage(ctx context.Context, tx *ent.Tx, pkg model.IDorPkgInput) (*mo
}

return &model.PackageIDs{
PackageTypeID: toGlobalID(pkgTypeString, pkgNameID.String()),
PackageNamespaceID: toGlobalID(pkgNamespaceString, pkgNameID.String()),
PackageTypeID: toGlobalID(pkgTypeString, pkg.PackageInput.Type),
PackageNamespaceID: toGlobalID(pkgNamespaceString, strings.Join([]string{pkg.PackageInput.Type, stringOrEmpty(pkg.PackageInput.Namespace)}, guacIDSplit)),
PackageNameID: toGlobalID(packagename.Table, pkgNameID.String()),
PackageVersionID: toGlobalID(packageversion.Table, pkgVersionID.String())}, nil
}
Expand Down Expand Up @@ -401,7 +406,7 @@ func (b *EntBackend) packageTypeNeighbors(ctx context.Context, nodeID string, al
if allowedEdges[model.EdgePackageTypePackageNamespace] {
query := b.client.PackageName.Query().
Where([]predicate.PackageName{
optionalPredicate(&nodeID, IDEQ),
optionalPredicate(&nodeID, packagename.TypeEQ),
}...).
Limit(MaxPageSize)

Expand All @@ -412,11 +417,11 @@ func (b *EntBackend) packageTypeNeighbors(ctx context.Context, nodeID string, al

for _, foundPkgName := range pkgNames {
out = append(out, &model.Package{
ID: toGlobalID(pkgTypeString, foundPkgName.ID.String()),
ID: toGlobalID(pkgTypeString, foundPkgName.Type),
Type: foundPkgName.Type,
Namespaces: []*model.PackageNamespace{
{
ID: toGlobalID(pkgNamespaceString, foundPkgName.ID.String()),
ID: toGlobalID(pkgNamespaceString, strings.Join([]string{foundPkgName.Type, foundPkgName.Namespace}, guacIDSplit)),
Namespace: foundPkgName.Namespace,
Names: []*model.PackageName{},
},
Expand All @@ -430,9 +435,15 @@ func (b *EntBackend) packageTypeNeighbors(ctx context.Context, nodeID string, al
func (b *EntBackend) packageNamespaceNeighbors(ctx context.Context, nodeID string, allowedEdges edgeMap) ([]model.Node, error) {
var out []model.Node

// split to find the type and namespace value
splitQueryValue := strings.Split(nodeID, guacIDSplit)
if len(splitQueryValue) != 2 {
return out, fmt.Errorf("invalid query for packageNamespaceNeighbors with ID %s", nodeID)
}
query := b.client.PackageName.Query().
Where([]predicate.PackageName{
optionalPredicate(&nodeID, IDEQ),
optionalPredicate(&splitQueryValue[0], packagename.TypeEQ),
optionalPredicate(&splitQueryValue[1], packagename.NamespaceEQ),
}...).
Limit(MaxPageSize)

Expand All @@ -444,11 +455,11 @@ func (b *EntBackend) packageNamespaceNeighbors(ctx context.Context, nodeID strin
for _, foundPkgName := range pkgNames {
if allowedEdges[model.EdgePackageNamespacePackageName] {
out = append(out, &model.Package{
ID: toGlobalID(pkgTypeString, foundPkgName.ID.String()),
ID: toGlobalID(pkgTypeString, foundPkgName.Type),
Type: foundPkgName.Type,
Namespaces: []*model.PackageNamespace{
{
ID: toGlobalID(pkgNamespaceString, foundPkgName.ID.String()),
ID: toGlobalID(pkgNamespaceString, strings.Join([]string{foundPkgName.Type, foundPkgName.Namespace}, ":")),
Namespace: foundPkgName.Namespace,
Names: []*model.PackageName{{
ID: toGlobalID(packagename.Table, foundPkgName.ID.String()),
Expand All @@ -461,7 +472,7 @@ func (b *EntBackend) packageNamespaceNeighbors(ctx context.Context, nodeID strin
}
if allowedEdges[model.EdgePackageNamespacePackageType] {
out = append(out, &model.Package{
ID: toGlobalID(pkgTypeString, foundPkgName.ID.String()),
ID: toGlobalID(pkgTypeString, foundPkgName.Type),
Type: foundPkgName.Type,
Namespaces: []*model.PackageNamespace{},
})
Expand Down Expand Up @@ -552,11 +563,11 @@ func (b *EntBackend) packageNameNeighbors(ctx context.Context, nodeID string, al
for _, foundPkgName := range pkgNames {
if allowedEdges[model.EdgePackageNamePackageNamespace] {
out = append(out, &model.Package{
ID: toGlobalID(pkgTypeString, foundPkgName.ID.String()),
ID: toGlobalID(pkgTypeString, foundPkgName.Type),
Type: foundPkgName.Type,
Namespaces: []*model.PackageNamespace{
{
ID: toGlobalID(pkgNamespaceString, foundPkgName.ID.String()),
ID: toGlobalID(pkgNamespaceString, strings.Join([]string{foundPkgName.Type, foundPkgName.Namespace}, guacIDSplit)),
Namespace: foundPkgName.Namespace,
Names: []*model.PackageName{},
},
Expand Down Expand Up @@ -692,11 +703,11 @@ func (b *EntBackend) packageVersionNeighbors(ctx context.Context, nodeID string,
pkgNames = append(pkgNames, backReferencePackageVersion(foundPkgVersion))
for _, foundPkgName := range pkgNames {
out = append(out, &model.Package{
ID: toGlobalID(pkgTypeString, foundPkgName.ID.String()),
ID: toGlobalID(pkgTypeString, foundPkgName.Type),
Type: foundPkgName.Type,
Namespaces: []*model.PackageNamespace{
{
ID: toGlobalID(pkgNamespaceString, foundPkgName.ID.String()),
ID: toGlobalID(pkgNamespaceString, strings.Join([]string{foundPkgName.Type, foundPkgName.Namespace}, guacIDSplit)),
Namespace: foundPkgName.Namespace,
Names: []*model.PackageName{{
ID: toGlobalID(packagename.Table, foundPkgName.ID.String()),
Expand Down
2 changes: 1 addition & 1 deletion pkg/assembler/backends/ent/backend/search.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func (b *EntBackend) FindSoftware(ctx context.Context, searchText string) ([]mod
return nil, err
}
results = append(results, collect(sources, func(v *ent.SourceName) model.PackageSourceOrArtifact {
return toModelSourceName(v)
return toModelSource(v)
})...)

artifacts, err := b.client.Artifact.Query().Where(
Expand Down
Loading

0 comments on commit 5ff8e90

Please sign in to comment.