From 29c2da02ab5d3ab22f0745478e6c6d72fd80ab8e Mon Sep 17 00:00:00 2001 From: jcd Date: Mon, 25 Sep 2023 20:29:18 +1000 Subject: [PATCH] Find all methods that use unsafe pointers. Previously we went package-by-package and iterated through (ssa.Package).Members, but this only contains functions, not methods. Instead we iterate through the result of ssautil.AllFunctions. Add test cases for a package-level variable whose initialization expression converts an unsafe pointer to *int, and a method which does the same. --- analyzer/analyzer.go | 35 +++++++++++---------------------- testing/analyzepackages_test.go | 2 ++ testpkgs/useunsafe/useunsafe.go | 9 +++++++++ 3 files changed, 23 insertions(+), 23 deletions(-) diff --git a/analyzer/analyzer.go b/analyzer/analyzer.go index 4e9bc36..f494065 100644 --- a/analyzer/analyzer.go +++ b/analyzer/analyzer.go @@ -359,7 +359,7 @@ func getPackageNodesWithCapability(pkgs []*packages.Package, log.Fatal("Some packages had errors. Aborting analysis.") } graph, ssaProg, allFunctions := buildGraph(pkgs, true) - unsafePointerFunctions := findUnsafePointerConversions(pkgs, ssaProg) + unsafePointerFunctions := findUnsafePointerConversions(pkgs, ssaProg, allFunctions) ssaProg = nil // possibly save memory; we don't use ssaProg again safe, nodesByCapability = getNodeCapabilities(graph, classifier) extraNodesByCapability = make(nodesetPerCapability) @@ -437,7 +437,7 @@ func getPackageNodesWithCapability(pkgs []*packages.Package, // findUnsafePointerConversions uses analysis of the syntax tree to find // functions which convert unsafe.Pointer values to another type. -func findUnsafePointerConversions(pkgs []*packages.Package, ssaProg *ssa.Program) (unsafePointer map[*ssa.Function]struct{}) { +func findUnsafePointerConversions(pkgs []*packages.Package, ssaProg *ssa.Program, allFunctions map[*ssa.Function]bool) (unsafePointer map[*ssa.Function]struct{}) { // AST nodes corresponding to functions which convert unsafe.Pointer values. unsafeFunctionNodes := make(map[ast.Node]struct{}) // Packages which contain variables that are initialized using @@ -468,32 +468,21 @@ func findUnsafePointerConversions(pkgs []*packages.Package, ssaProg *ssa.Program // Find the *ssa.Function pointers corresponding to the syntax nodes found // above. unsafePointerFunctions := make(map[*ssa.Function]struct{}) - var processFunction func(f *ssa.Function) - processFunction = func(f *ssa.Function) { + for f := range allFunctions { if _, ok := unsafeFunctionNodes[f.Syntax()]; ok { unsafePointerFunctions[f] = struct{}{} } - // Process child functions, e.g. function literals contained inside f. - for _, fn := range f.AnonFuncs { - processFunction(fn) - } } for _, pkg := range ssaProg.AllPackages() { - _, initUsesUnsafePointer := packagesWithUnsafePointerUseInInitialization[pkg.Pkg] - // pkg.Members contains all "top-level" functions; other functions are - // reached recursively through those. - for _, m := range pkg.Members { - if f, ok := m.(*ssa.Function); ok { - if initUsesUnsafePointer && f.Name() == "init" { - // This package had an unsafe.Pointer conversion in the initialization - // expression for a package-scoped variable. f is the "init" function - // for the package, so we add it to unsafePointerFunctions. - // There will always be an init function for each package; if one - // didn't exist in the source, a synthetic one will have been - // created. - unsafePointerFunctions[f] = struct{}{} - } - processFunction(f) + if _, ok := packagesWithUnsafePointerUseInInitialization[pkg.Pkg]; ok { + // This package had an unsafe.Pointer conversion in the initialization + // expression for a package-scoped variable, so we add the package's + // "init" function to unsafePointerFunctions. + // There will always be an init function for each package; if one + // didn't exist in the source, a synthetic one will have been + // created. + if f := pkg.Func("init"); f != nil { + unsafePointerFunctions[f] = struct{}{} } } } diff --git a/testing/analyzepackages_test.go b/testing/analyzepackages_test.go index 8618924..0da7af9 100644 --- a/testing/analyzepackages_test.go +++ b/testing/analyzepackages_test.go @@ -160,6 +160,8 @@ func TestExpectedOutput(t *testing.T) { {Fn: []string{`useunsafe.Indirect2`, `useunsafe.init\$1`}}, {Fn: []string{`useunsafe.NestedFunctions\$1\$1\$1`}, Cap: `CAPABILITY_UNSAFE_POINTER`}, {Fn: []string{`useunsafe.ReturnFunction\$1`}, Cap: `CAPABILITY_UNSAFE_POINTER`}, + {Fn: []string{`useunsafe.T\).M`}, Cap: "CAPABILITY_UNSAFE_POINTER"}, + {Fn: []string{`useunsafe.init$`}, Cap: `CAPABILITY_UNSAFE_POINTER`}, {Fn: []string{`useunsafe.init\$1`}, Cap: `CAPABILITY_UNSAFE_POINTER`}, } for _, path := range expectedPaths { diff --git a/testpkgs/useunsafe/useunsafe.go b/testpkgs/useunsafe/useunsafe.go index 102b4d1..d160188 100644 --- a/testpkgs/useunsafe/useunsafe.go +++ b/testpkgs/useunsafe/useunsafe.go @@ -34,6 +34,7 @@ var ( I int U up = up(&I) IP *int + Y = *(*int)(U) Z = func() int { return *(*int)(U) } @@ -83,3 +84,11 @@ func Ok() uintptr { var p unsafe.Pointer return (uintptr)(p) } + +// T is a type with a method that uses an unsafe.Pointer. +type T struct{} + +// M uses an unsafe pointer. +func (t T) M() int { + return *(*int)(U) +}