diff --git a/cmd/atelet/oci.go b/cmd/atelet/oci.go index 3936862e..3bf7503f 100644 --- a/cmd/atelet/oci.go +++ b/cmd/atelet/oci.go @@ -227,6 +227,19 @@ func untar(ctx context.Context, tarData io.Reader, rootPath string) error { switch hdr.Typeflag { case tar.TypeReg: // Regular file + // Same "later entry wins" handling: if any entry exists at the target path, + // remove it first. This ensures that: + // 1. If it's a symlink, we don't write through it (security vulnerability / incorrectness). + // 2. If it's a hardlink, we unlink it instead of truncating the shared inode. + // 3. If it's a directory, we recursively remove it so we can write the file. + if _, err := root.Lstat(name); err == nil { + if err := root.RemoveAll(name); err != nil { + return fmt.Errorf("while replacing existing path at %q before regular file: %w", name, err) + } + } else if !errors.Is(err, os.ErrNotExist) { + return fmt.Errorf("while checking existing path at %q before regular file: %w", name, err) + } + // Stream directly from tarReader to target file to avoid buffering in memory. outFile, err := root.OpenFile(name, os.O_CREATE|os.O_RDWR|os.O_TRUNC, mode) if err != nil { diff --git a/cmd/atelet/oci_test.go b/cmd/atelet/oci_test.go index 5e9f4ea8..4fbaf6a5 100644 --- a/cmd/atelet/oci_test.go +++ b/cmd/atelet/oci_test.go @@ -240,6 +240,96 @@ func TestUntar_LaterEntryWins(t *testing.T) { t.Errorf("symlink target = %q, want %q", got, "x") } }) + + t.Run("symlink overwritten by file", func(t *testing.T) { + entries := []tarEntry{ + {name: "etc/", typeflag: tar.TypeDir}, + {name: "etc/x", typeflag: tar.TypeReg, body: "original"}, + {name: "etc/link", typeflag: tar.TypeSymlink, linkname: "x"}, + {name: "etc/link", typeflag: tar.TypeReg, body: "replacement"}, + } + dir, err := runUntar(t, entries) + if err != nil { + t.Fatalf("untar: %v", err) + } + fi, err := os.Lstat(filepath.Join(dir, "etc/link")) + if err != nil { + t.Fatalf("lstat etc/link: %v", err) + } + if fi.Mode().IsRegular() { + got, err := os.ReadFile(filepath.Join(dir, "etc/link")) + if err != nil { + t.Fatalf("read etc/link: %v", err) + } + if string(got) != "replacement" { + t.Errorf("etc/link content = %q, want %q", got, "replacement") + } + } else { + t.Errorf("etc/link mode is not regular file: %v", fi.Mode()) + } + // Also verify etc/x was NOT overwritten + gotX, err := os.ReadFile(filepath.Join(dir, "etc/x")) + if err != nil { + t.Fatalf("read etc/x: %v", err) + } + if string(gotX) != "original" { + t.Errorf("etc/x content was overwritten to %q", gotX) + } + }) + + t.Run("file overwritten by symlink", func(t *testing.T) { + entries := []tarEntry{ + {name: "etc/", typeflag: tar.TypeDir}, + {name: "etc/link", typeflag: tar.TypeReg, body: "original-file"}, + {name: "etc/link", typeflag: tar.TypeSymlink, linkname: "target-doesnt-exist"}, + } + dir, err := runUntar(t, entries) + if err != nil { + t.Fatalf("untar: %v", err) + } + fi, err := os.Lstat(filepath.Join(dir, "etc/link")) + if err != nil { + t.Fatalf("lstat etc/link: %v", err) + } + if fi.Mode()&os.ModeSymlink == 0 { + t.Errorf("etc/link mode is not a symlink: %v", fi.Mode()) + } + got, err := os.Readlink(filepath.Join(dir, "etc/link")) + if err != nil { + t.Fatalf("readlink etc/link: %v", err) + } + if got != "target-doesnt-exist" { + t.Errorf("etc/link target = %q, want %q", got, "target-doesnt-exist") + } + }) + + t.Run("hardlink overwritten by file", func(t *testing.T) { + entries := []tarEntry{ + {name: "bin/", typeflag: tar.TypeDir}, + {name: "bin/sh", typeflag: tar.TypeReg, body: "sh-original"}, + {name: "bin/bash", typeflag: tar.TypeLink, linkname: "bin/sh"}, + {name: "bin/bash", typeflag: tar.TypeReg, body: "bash-new"}, + } + dir, err := runUntar(t, entries) + if err != nil { + t.Fatalf("untar: %v", err) + } + gotBash, err := os.ReadFile(filepath.Join(dir, "bin/bash")) + if err != nil { + t.Fatalf("read bin/bash: %v", err) + } + if string(gotBash) != "bash-new" { + t.Errorf("bin/bash content = %q, want %q", gotBash, "bash-new") + } + // Verify bin/sh was NOT modified! + gotSh, err := os.ReadFile(filepath.Join(dir, "bin/sh")) + if err != nil { + t.Fatalf("read bin/sh: %v", err) + } + if string(gotSh) != "sh-original" { + t.Errorf("bin/sh content was overwritten to %q (hardlink was not unlinked)", gotSh) + } + }) } func TestUntar_PathTraversal(t *testing.T) {