Skip to content

Commit

Permalink
better concurrency support on Linux (#306)
Browse files Browse the repository at this point in the history
* commandrun: keep track of newly created files
* product: only attest for opened files when tracing is enabled
* file: do not attempt to record an artifact if it was not opened by the process
---------

Signed-off-by: Joshua Wang <josh@joshdabo.sh>
  • Loading branch information
joshdabosh authored Jul 31, 2024
1 parent a61ca03 commit 3f491a3
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 10 deletions.
4 changes: 4 additions & 0 deletions attestation/commandrun/commandrun.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,10 @@ func (rc *CommandRun) RunType() attestation.RunType {
return RunType
}

func (rc *CommandRun) TracingEnabled() bool {
return rc.enableTracing
}

func (r *CommandRun) runCmd(ctx *attestation.AttestationContext) error {
c := exec.Command(r.Cmd[0], r.Cmd[1:]...)
c.Dir = ctx.WorkingDir()
Expand Down
26 changes: 26 additions & 0 deletions attestation/commandrun/tracing_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ func (r *CommandRun) trace(c *exec.Cmd, actx *attestation.AttestationContext) ([
}

func (p *ptraceContext) runTrace() error {
defer p.retryOpenedFiles()

runtime.LockOSThread()
defer runtime.UnlockOSThread()
status := unix.WaitStatus(0)
Expand Down Expand Up @@ -121,6 +123,26 @@ func (p *ptraceContext) runTrace() error {
}
}

func (p *ptraceContext) retryOpenedFiles() {
// after tracing, look through opened files to try to resolve any newly created files
procInfo := p.getProcInfo(p.parentPid)

for file, digestSet := range procInfo.OpenedFiles {
if digestSet != nil {
continue
}

newDigest, err := cryptoutil.CalculateDigestSetFromFile(file, p.hash)

if err != nil {
delete(procInfo.OpenedFiles, file)
continue
}

procInfo.OpenedFiles[file] = newDigest
}
}

func (p *ptraceContext) nextSyscall(pid int) error {
regs := unix.PtraceRegs{}
if err := unix.PtraceGetRegs(pid, &regs); err != nil {
Expand Down Expand Up @@ -213,6 +235,10 @@ func (p *ptraceContext) handleSyscall(pid int, regs unix.PtraceRegs) error {
procInfo := p.getProcInfo(pid)
digestSet, err := cryptoutil.CalculateDigestSetFromFile(file, p.hash)
if err != nil {
if _, isPathErr := err.(*os.PathError); isPathErr {
procInfo.OpenedFiles[file] = nil
}

return err
}

Expand Down
14 changes: 9 additions & 5 deletions attestation/file/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import (
// recordArtifacts will walk basePath and record the digests of each file with each of the functions in hashes.
// If file already exists in baseArtifacts and the two artifacts are equal the artifact will not be in the
// returned map of artifacts.
func RecordArtifacts(basePath string, baseArtifacts map[string]cryptoutil.DigestSet, hashes []cryptoutil.DigestValue, visitedSymlinks map[string]struct{}) (map[string]cryptoutil.DigestSet, error) {
func RecordArtifacts(basePath string, baseArtifacts map[string]cryptoutil.DigestSet, hashes []cryptoutil.DigestValue, visitedSymlinks map[string]struct{}, processWasTraced bool, openedFiles map[string]bool) (map[string]cryptoutil.DigestSet, error) {
artifacts := make(map[string]cryptoutil.DigestSet)
err := filepath.Walk(basePath, func(path string, info fs.FileInfo, err error) error {
if err != nil {
Expand Down Expand Up @@ -57,15 +57,15 @@ func RecordArtifacts(basePath string, baseArtifacts map[string]cryptoutil.Digest
}

visitedSymlinks[linkedPath] = struct{}{}
symlinkedArtifacts, err := RecordArtifacts(linkedPath, baseArtifacts, hashes, visitedSymlinks)
symlinkedArtifacts, err := RecordArtifacts(linkedPath, baseArtifacts, hashes, visitedSymlinks, processWasTraced, openedFiles)
if err != nil {
return err
}

for artifactPath, artifact := range symlinkedArtifacts {
// all artifacts in the symlink should be recorded relative to our basepath
joinedPath := filepath.Join(relPath, artifactPath)
if shouldRecord(joinedPath, artifact, baseArtifacts) {
if shouldRecord(joinedPath, artifact, baseArtifacts, processWasTraced, openedFiles) {
artifacts[filepath.Join(relPath, artifactPath)] = artifact
}
}
Expand All @@ -78,7 +78,7 @@ func RecordArtifacts(basePath string, baseArtifacts map[string]cryptoutil.Digest
return err
}

if shouldRecord(relPath, artifact, baseArtifacts) {
if shouldRecord(relPath, artifact, baseArtifacts, processWasTraced, openedFiles) {
artifacts[relPath] = artifact
}

Expand All @@ -89,9 +89,13 @@ func RecordArtifacts(basePath string, baseArtifacts map[string]cryptoutil.Digest
}

// shouldRecord determines whether artifact should be recorded.
// if the process was traced and the artifact was not one of the opened files, return false
// if the artifact is already in baseArtifacts, check if it's changed
// if it is not equal to the existing artifact, return true, otherwise return false
func shouldRecord(path string, artifact cryptoutil.DigestSet, baseArtifacts map[string]cryptoutil.DigestSet) bool {
func shouldRecord(path string, artifact cryptoutil.DigestSet, baseArtifacts map[string]cryptoutil.DigestSet, processWasTraced bool, openedFiles map[string]bool) bool {
if _, ok := openedFiles[path]; !ok && processWasTraced {
return false
}
if previous, ok := baseArtifacts[path]; ok && artifact.Equal(previous) {
return false
}
Expand Down
6 changes: 3 additions & 3 deletions attestation/file/file_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,13 @@ func TestBrokenSymlink(t *testing.T) {
symTestDir := filepath.Join(dir, "symTestDir")
require.NoError(t, os.Symlink(testDir, symTestDir))

_, err := RecordArtifacts(dir, map[string]cryptoutil.DigestSet{}, []cryptoutil.DigestValue{{Hash: crypto.SHA256}}, map[string]struct{}{})
_, err := RecordArtifacts(dir, map[string]cryptoutil.DigestSet{}, []cryptoutil.DigestValue{{Hash: crypto.SHA256}}, map[string]struct{}{}, false, map[string]bool{})
require.NoError(t, err)

// remove the symlinks and make sure we don't get an error back
require.NoError(t, os.RemoveAll(testDir))
require.NoError(t, os.RemoveAll(testFile))
_, err = RecordArtifacts(dir, map[string]cryptoutil.DigestSet{}, []cryptoutil.DigestValue{{Hash: crypto.SHA256}}, map[string]struct{}{})
_, err = RecordArtifacts(dir, map[string]cryptoutil.DigestSet{}, []cryptoutil.DigestValue{{Hash: crypto.SHA256}}, map[string]struct{}{}, false, map[string]bool{})
require.NoError(t, err)
}

Expand All @@ -58,6 +58,6 @@ func TestSymlinkCycle(t *testing.T) {
require.NoError(t, os.Symlink(dir, symTestDir))

// if a symlink cycle weren't properly handled this would be an infinite loop
_, err := RecordArtifacts(dir, map[string]cryptoutil.DigestSet{}, []cryptoutil.DigestValue{{Hash: crypto.SHA256}}, map[string]struct{}{})
_, err := RecordArtifacts(dir, map[string]cryptoutil.DigestSet{}, []cryptoutil.DigestValue{{Hash: crypto.SHA256}}, map[string]struct{}{}, false, map[string]bool{})
require.NoError(t, err)
}
2 changes: 1 addition & 1 deletion attestation/material/material.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func (a *Attestor) Schema() *jsonschema.Schema {
}

func (a *Attestor) Attest(ctx *attestation.AttestationContext) error {
materials, err := file.RecordArtifacts(ctx.WorkingDir(), nil, ctx.Hashes(), map[string]struct{}{})
materials, err := file.RecordArtifacts(ctx.WorkingDir(), nil, ctx.Hashes(), map[string]struct{}{}, false, map[string]bool{})
if err != nil {
return err
}
Expand Down
20 changes: 19 additions & 1 deletion attestation/product/product.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/gabriel-vasile/mimetype"
"github.com/gobwas/glob"
"github.com/in-toto/go-witness/attestation"
"github.com/in-toto/go-witness/attestation/commandrun"
"github.com/in-toto/go-witness/attestation/file"
"github.com/in-toto/go-witness/cryptoutil"
"github.com/in-toto/go-witness/registry"
Expand Down Expand Up @@ -181,7 +182,24 @@ func (a *Attestor) Attest(ctx *attestation.AttestationContext) error {
a.compiledExcludeGlob = compiledExcludeGlob

a.baseArtifacts = ctx.Materials()
products, err := file.RecordArtifacts(ctx.WorkingDir(), a.baseArtifacts, ctx.Hashes(), map[string]struct{}{})

processWasTraced := false
openedFileSet := map[string]bool{}

for _, completedAttestor := range ctx.CompletedAttestors() {
attestor := completedAttestor.Attestor
if commandRunAttestor, ok := attestor.(*commandrun.CommandRun); ok && commandRunAttestor.TracingEnabled() {
processWasTraced = true

for _, process := range commandRunAttestor.Processes {
for fname := range process.OpenedFiles {
openedFileSet[fname] = true
}
}
}
}

products, err := file.RecordArtifacts(ctx.WorkingDir(), a.baseArtifacts, ctx.Hashes(), map[string]struct{}{}, processWasTraced, openedFileSet)
if err != nil {
return err
}
Expand Down

0 comments on commit 3f491a3

Please sign in to comment.