whatcanGOwrong

This commit is contained in:
2024-09-19 21:38:24 -04:00
commit d0ae4d841d
17908 changed files with 4096831 additions and 0 deletions
@@ -0,0 +1,168 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build go1.18
// +build go1.18
package vulncheck
import (
"context"
"fmt"
"sort"
"golang.org/x/tools/go/packages"
"golang.org/x/vuln/internal"
"golang.org/x/vuln/internal/buildinfo"
"golang.org/x/vuln/internal/client"
"golang.org/x/vuln/internal/govulncheck"
)
// Bin is an abstraction of Go binary containing
// minimal information needed by govulncheck.
type Bin struct {
Modules []*packages.Module `json:"modules,omitempty"`
PkgSymbols []buildinfo.Symbol `json:"pkgSymbols,omitempty"`
GoVersion string `json:"goVersion,omitempty"`
GOOS string `json:"goos,omitempty"`
GOARCH string `json:"goarch,omitempty"`
}
// Binary detects presence of vulnerable symbols in bin and
// emits findings to handler.
func Binary(ctx context.Context, handler govulncheck.Handler, bin *Bin, cfg *govulncheck.Config, client *client.Client) error {
vr, err := binary(ctx, handler, bin, cfg, client)
if err != nil {
return err
}
if cfg.ScanLevel.WantSymbols() {
return emitCallFindings(handler, binaryCallstacks(vr))
}
return nil
}
// binary detects presence of vulnerable symbols in bin.
// It does not compute call graphs so the corresponding
// info in Result will be empty.
func binary(ctx context.Context, handler govulncheck.Handler, bin *Bin, cfg *govulncheck.Config, client *client.Client) (*Result, error) {
graph := NewPackageGraph(bin.GoVersion)
graph.AddModules(bin.Modules...)
mods := append(bin.Modules, graph.GetModule(internal.GoStdModulePath))
mv, err := FetchVulnerabilities(ctx, client, mods)
if err != nil {
return nil, err
}
// Emit OSV entries immediately in their raw unfiltered form.
if err := emitOSVs(handler, mv); err != nil {
return nil, err
}
if bin.GOOS == "" || bin.GOARCH == "" {
fmt.Printf("warning: failed to extract build system specification GOOS: %s GOARCH: %s\n", bin.GOOS, bin.GOARCH)
}
affVulns := affectingVulnerabilities(mv, bin.GOOS, bin.GOARCH)
if err := emitModuleFindings(handler, affVulns); err != nil {
return nil, err
}
if !cfg.ScanLevel.WantPackages() || len(affVulns) == 0 {
return &Result{}, nil
}
// Group symbols per package to avoid querying affVulns all over again.
var pkgSymbols map[string][]string
if len(bin.PkgSymbols) == 0 {
// The binary exe is stripped. We currently cannot detect inlined
// symbols for stripped binaries (see #57764), so we report
// vulnerabilities at the go.mod-level precision.
pkgSymbols = allKnownVulnerableSymbols(affVulns)
} else {
pkgSymbols = make(map[string][]string)
for _, sym := range bin.PkgSymbols {
pkgSymbols[sym.Pkg] = append(pkgSymbols[sym.Pkg], sym.Name)
}
}
impVulns := binImportedVulnPackages(graph, pkgSymbols, affVulns)
// Emit information on imported vulnerable packages now to
// mimic behavior of source.
if err := emitPackageFindings(handler, impVulns); err != nil {
return nil, err
}
// Return result immediately if not in symbol mode to mimic the
// behavior of source.
if !cfg.ScanLevel.WantSymbols() || len(impVulns) == 0 {
return &Result{Vulns: impVulns}, nil
}
symVulns := binVulnSymbols(graph, pkgSymbols, affVulns)
return &Result{Vulns: symVulns}, nil
}
func binImportedVulnPackages(graph *PackageGraph, pkgSymbols map[string][]string, affVulns affectingVulns) []*Vuln {
var vulns []*Vuln
for pkg := range pkgSymbols {
for _, osv := range affVulns.ForPackage(pkg) {
vuln := &Vuln{
OSV: osv,
Package: graph.GetPackage(pkg),
}
vulns = append(vulns, vuln)
}
}
return vulns
}
func binVulnSymbols(graph *PackageGraph, pkgSymbols map[string][]string, affVulns affectingVulns) []*Vuln {
var vulns []*Vuln
for pkg, symbols := range pkgSymbols {
// sort symbols for deterministic results
sort.SliceStable(symbols, func(i, j int) bool { return symbols[i] < symbols[j] })
for _, symbol := range symbols {
for _, osv := range affVulns.ForSymbol(pkg, symbol) {
vuln := &Vuln{
OSV: osv,
Symbol: symbol,
Package: graph.GetPackage(pkg),
}
vulns = append(vulns, vuln)
}
}
}
return vulns
}
// allKnownVulnerableSymbols returns all known vulnerable symbols for packages in graph.
// If all symbols of a package are vulnerable, that is modeled as a wild car symbol "<pkg-path>/*".
func allKnownVulnerableSymbols(affVulns affectingVulns) map[string][]string {
pkgSymbols := make(map[string][]string)
for _, mv := range affVulns {
for _, osv := range mv.Vulns {
for _, affected := range osv.Affected {
for _, p := range affected.EcosystemSpecific.Packages {
syms := p.Symbols
if len(syms) == 0 {
// If every symbol of pkg is vulnerable, we would ideally
// compute every symbol mentioned in the pkg and then add
// Vuln entry for it, just as we do in Source. However,
// we don't have code of pkg here and we don't even have
// pkg symbols used in stripped binary, so we add a placeholder
// symbol.
//
// Note: this should not affect output of govulncheck since
// in binary mode no symbol/call stack information is
// communicated back to the user.
syms = []string{fmt.Sprintf("%s/*", p.Path)}
}
pkgSymbols[p.Path] = append(pkgSymbols[p.Path], syms...)
}
}
}
}
return pkgSymbols
}
@@ -0,0 +1,92 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build go1.18
// +build go1.18
package vulncheck
import (
"context"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"golang.org/x/tools/go/packages"
"golang.org/x/vuln/internal/buildinfo"
"golang.org/x/vuln/internal/govulncheck"
"golang.org/x/vuln/internal/test"
)
func TestBinary(t *testing.T) {
bin := &Bin{
Modules: []*packages.Module{
{Path: "golang.org/entry"},
{Path: "golang.org/cmod", Version: "v1.1.3"},
{Path: "golang.org/amod", Version: "v1.1.3"},
{Path: "golang.org/bmod", Version: "v0.5.0"},
},
GoVersion: "go1.20",
GOOS: "linux",
GOARCH: "amd64",
PkgSymbols: []buildinfo.Symbol{
{Pkg: "golang.org/entry", Name: "main"},
{Pkg: "golang.org/cmod/c", Name: "C"},
{Pkg: "golang.org/amod/avuln", Name: "VulnData.Vuln1"}, // assume linker skips VulnData.Vuln2
{Pkg: "golang.org/bmod/bvuln", Name: "NoVuln"}, // assume linker skips NoVuln
{Pkg: "archive/zip", Name: "OpenReader"},
},
}
c, err := newTestClient()
if err != nil {
t.Fatal(err)
}
// Test imports only mode
cfg := &govulncheck.Config{ScanLevel: "package"}
res, err := binary(context.Background(), test.NewMockHandler(), bin, cfg, c)
if err != nil {
t.Fatal(err)
}
// With package scan level, all vulnerable packages should be detected.
want := []*Vuln{
{Package: &packages.Package{PkgPath: "golang.org/bmod/bvuln"}},
{Package: &packages.Package{PkgPath: "golang.org/amod/avuln"}},
{Package: &packages.Package{PkgPath: "archive/zip"}},
}
less := func(v1, v2 *Vuln) bool {
return (v1.Package.PkgPath + "." + v1.Symbol) < (v2.Package.PkgPath + "." + v2.Symbol)
}
equal := func(v1, v2 *Vuln) bool {
if v1.Symbol != v2.Symbol {
return false
}
if v1.Package != nil && v2.Package != nil {
return v1.Package.PkgPath == v2.Package.PkgPath
}
return true // we don't care about these cases here
}
if diff := cmp.Diff(want, res.Vulns, cmpopts.SortSlices(less), cmp.Comparer(equal)); diff != "" {
t.Errorf("(-want, +got): %s", diff)
}
// Test the symbols.
cfg.ScanLevel = "symbol"
res, err = binary(context.Background(), test.NewMockHandler(), bin, cfg, c)
if err != nil {
t.Fatal(err)
}
want = []*Vuln{
{Symbol: "OpenReader", Package: &packages.Package{PkgPath: "archive/zip"}},
{Symbol: "VulnData.Vuln1", Package: &packages.Package{PkgPath: "golang.org/amod/avuln"}},
}
if diff := cmp.Diff(want, res.Vulns, cmpopts.SortSlices(less), cmp.Comparer(equal)); diff != "" {
t.Errorf("(-want, +got): %s", diff)
}
}
@@ -0,0 +1,55 @@
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Package vulncheck detects uses of known vulnerabilities
in Go programs.
Vulncheck identifies vulnerability uses in Go programs
at the level of call graph, package import graph, and module
requires graph. For instance, vulncheck identifies which
vulnerable functions and methods are transitively called
from the program entry points. vulncheck also detects
transitively imported packages and required modules that
contain known vulnerable functions and methods.
We recommend using the command line tool [govulncheck] to
detect vulnerabilities in your code.
# Usage
The two main APIs of vulncheck, [Source] and [Binary], allow vulnerability
detection in Go source code and binaries, respectively.
[Source] accepts a list of [Package] objects, which
are a trimmed version of [golang.org/x/tools/go/packages.Package] objects to
reduce memory consumption. [Binary] accepts a path to a Go binary file that
must have been compiled with Go 1.18 or greater.
Both [Source] and [Binary] require information about known
vulnerabilities in the form of a vulnerability database,
specifically a [golang.org/x/vuln/internal/client.Client].
The vulnerabilities
are modeled using the [golang.org/x/vuln/internal/osv] format.
# Results
The results of vulncheck are slices of the call graph, package imports graph,
and module requires graph leading to the use of an identified vulnerability.
The parts of these graphs not related to any vulnerabilities are omitted.
The [CallStacks] and [ImportChains] functions search the returned slices for
user-friendly representative call stacks and import chains. These call stacks
and import chains are provided as examples of vulnerability uses in the client
code.
# Limitations
There are some limitations with vulncheck. Please see the
[documented limitations] for more information.
[govulncheck]: https://pkg.go.dev/golang.org/x/vuln/cmd/govulncheck
[documented limitations]: https://go.dev/security/vulncheck#limitations.
*/
package vulncheck
@@ -0,0 +1,163 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package vulncheck
import (
"go/token"
"sort"
"golang.org/x/tools/go/packages"
"golang.org/x/vuln/internal"
"golang.org/x/vuln/internal/govulncheck"
"golang.org/x/vuln/internal/osv"
)
// emitOSVs emits all OSV vuln entries in modVulns to handler.
func emitOSVs(handler govulncheck.Handler, modVulns []*ModVulns) error {
for _, mv := range modVulns {
for _, v := range mv.Vulns {
if err := handler.OSV(v); err != nil {
return err
}
}
}
return nil
}
// emitModuleFindings emits module-level findings for vulnerabilities in modVulns.
func emitModuleFindings(handler govulncheck.Handler, affVulns affectingVulns) error {
for _, vuln := range affVulns {
for _, osv := range vuln.Vulns {
if err := handler.Finding(&govulncheck.Finding{
OSV: osv.ID,
FixedVersion: FixedVersion(modPath(vuln.Module), modVersion(vuln.Module), osv.Affected),
Trace: []*govulncheck.Frame{frameFromModule(vuln.Module, osv.Affected)},
}); err != nil {
return err
}
}
}
return nil
}
// emitPackageFinding emits package-level findings fod vulnerabilities in vulns.
func emitPackageFindings(handler govulncheck.Handler, vulns []*Vuln) error {
for _, v := range vulns {
if err := handler.Finding(&govulncheck.Finding{
OSV: v.OSV.ID,
FixedVersion: FixedVersion(modPath(v.Package.Module), modVersion(v.Package.Module), v.OSV.Affected),
Trace: []*govulncheck.Frame{frameFromPackage(v.Package)},
}); err != nil {
return err
}
}
return nil
}
// emitCallFindings emits call-level findings for vulnerabilities
// that have a call stack in callstacks.
func emitCallFindings(handler govulncheck.Handler, callstacks map[*Vuln]CallStack) error {
var vulns []*Vuln
for v := range callstacks {
vulns = append(vulns, v)
}
sort.SliceStable(vulns, func(i, j int) bool {
return vulns[i].Symbol < vulns[j].Symbol
})
for _, vuln := range vulns {
stack := callstacks[vuln]
if stack == nil {
continue
}
fixed := FixedVersion(modPath(vuln.Package.Module), modVersion(vuln.Package.Module), vuln.OSV.Affected)
if err := handler.Finding(&govulncheck.Finding{
OSV: vuln.OSV.ID,
FixedVersion: fixed,
Trace: traceFromEntries(stack),
}); err != nil {
return err
}
}
return nil
}
// traceFromEntries creates a sequence of
// frames from vcs. Position of a Frame is the
// call position of the corresponding stack entry.
func traceFromEntries(vcs CallStack) []*govulncheck.Frame {
var frames []*govulncheck.Frame
for i := len(vcs) - 1; i >= 0; i-- {
e := vcs[i]
fr := frameFromPackage(e.Function.Package)
fr.Function = e.Function.Name
fr.Receiver = e.Function.Receiver()
isSink := i == (len(vcs) - 1)
fr.Position = posFromStackEntry(e, isSink)
frames = append(frames, fr)
}
return frames
}
func posFromStackEntry(e StackEntry, sink bool) *govulncheck.Position {
var p *token.Position
if sink && e.Function != nil && e.Function.Pos != nil {
// For sinks, i.e., vulns we take the position
// of the symbol.
p = e.Function.Pos
} else if e.Call != nil && e.Call.Pos != nil {
// Otherwise, we take the position of
// the call statement.
p = e.Call.Pos
}
if p == nil {
return nil
}
return &govulncheck.Position{
Filename: p.Filename,
Offset: p.Offset,
Line: p.Line,
Column: p.Column,
}
}
func frameFromPackage(pkg *packages.Package) *govulncheck.Frame {
fr := &govulncheck.Frame{}
if pkg != nil {
fr.Module = pkg.Module.Path
fr.Version = pkg.Module.Version
fr.Package = pkg.PkgPath
}
if pkg.Module.Replace != nil {
fr.Module = pkg.Module.Replace.Path
fr.Version = pkg.Module.Replace.Version
}
return fr
}
func frameFromModule(mod *packages.Module, affected []osv.Affected) *govulncheck.Frame {
fr := &govulncheck.Frame{
Module: mod.Path,
Version: mod.Version,
}
if mod.Path == internal.GoStdModulePath {
for _, a := range affected {
if a.Module.Path != mod.Path {
continue
}
fr.Package = a.EcosystemSpecific.Packages[0].Path
}
}
if mod.Replace != nil {
fr.Module = mod.Replace.Path
fr.Version = mod.Replace.Version
}
return fr
}
@@ -0,0 +1,56 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package vulncheck
import (
"strings"
"golang.org/x/tools/go/ssa"
)
// entryPoints returns functions of topPackages considered entry
// points of govulncheck analysis: main, inits, and exported methods
// and functions.
//
// TODO(https://go.dev/issue/57221): currently, entry functions
// that are generics are not considered an entry point.
func entryPoints(topPackages []*ssa.Package) []*ssa.Function {
var entries []*ssa.Function
for _, pkg := range topPackages {
if pkg.Pkg.Name() == "main" {
// for "main" packages the only valid entry points are the "main"
// function and any "init#" functions, even if there are other
// exported functions or types. similarly to isEntry it should be
// safe to ignore the validity of the main or init# signatures,
// since the compiler will reject malformed definitions,
// and the init function is synthetic
entries = append(entries, memberFuncs(pkg.Members["main"], pkg.Prog)...)
for name, member := range pkg.Members {
if strings.HasPrefix(name, "init#") || name == "init" {
entries = append(entries, memberFuncs(member, pkg.Prog)...)
}
}
continue
}
for _, member := range pkg.Members {
for _, f := range memberFuncs(member, pkg.Prog) {
if isEntry(f) {
entries = append(entries, f)
}
}
}
}
return entries
}
func isEntry(f *ssa.Function) bool {
// it should be safe to ignore checking that the signature of the "init" function
// is valid, since it is synthetic
if f.Name() == "init" && f.Synthetic == "package initializer" {
return true
}
return f.Synthetic == "" && f.Object() != nil && f.Object().Exported()
}
@@ -0,0 +1,42 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package vulncheck
import (
"context"
"fmt"
"golang.org/x/tools/go/packages"
"golang.org/x/vuln/internal/client"
)
// FetchVulnerabilities fetches vulnerabilities that affect the supplied modules.
func FetchVulnerabilities(ctx context.Context, c *client.Client, modules []*packages.Module) ([]*ModVulns, error) {
mreqs := make([]*client.ModuleRequest, len(modules))
for i, mod := range modules {
modPath := mod.Path
if mod.Replace != nil {
modPath = mod.Replace.Path
}
mreqs[i] = &client.ModuleRequest{
Path: modPath,
}
}
resps, err := c.ByModules(ctx, mreqs)
if err != nil {
return nil, fmt.Errorf("fetching vulnerabilities: %v", err)
}
var mv []*ModVulns
for i, resp := range resps {
if len(resp.Entries) == 0 {
continue
}
mv = append(mv, &ModVulns{
Module: modules[i],
Vulns: resp.Entries,
})
}
return mv, nil
}
@@ -0,0 +1,56 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package vulncheck_test
import (
"context"
"testing"
"github.com/google/go-cmp/cmp"
"golang.org/x/tools/go/packages"
"golang.org/x/vuln/internal/client"
"golang.org/x/vuln/internal/osv"
"golang.org/x/vuln/internal/vulncheck"
)
func TestFetchVulnerabilities(t *testing.T) {
a := &osv.Entry{ID: "a", Affected: []osv.Affected{{Module: osv.Module{Path: "example.mod/a"}, Ranges: []osv.Range{{Type: osv.RangeTypeSemver, Events: []osv.RangeEvent{{Fixed: "2.0.0"}}}}}}}
b := &osv.Entry{ID: "b", Affected: []osv.Affected{{Module: osv.Module{Path: "example.mod/b"}, Ranges: []osv.Range{{Type: osv.RangeTypeSemver, Events: []osv.RangeEvent{{Fixed: "1.1.1"}}}}}}}
c := &osv.Entry{ID: "c", Affected: []osv.Affected{{Module: osv.Module{Path: "example.mod/d"}, Ranges: []osv.Range{{Type: osv.RangeTypeSemver, Events: []osv.RangeEvent{{Fixed: "2.0.0"}}}}}}}
d := &osv.Entry{ID: "e", Affected: []osv.Affected{{Module: osv.Module{Path: "example.mod/e"}, Ranges: []osv.Range{{Type: osv.RangeTypeSemver, Events: []osv.RangeEvent{{Fixed: "2.2.0"}}}}}}}
mc, err := client.NewInMemoryClient([]*osv.Entry{a, b, c, d})
if err != nil {
t.Fatal(err)
}
got, err := vulncheck.FetchVulnerabilities(context.Background(), mc, []*packages.Module{
{Path: "example.mod/a", Version: "v1.0.0"},
{Path: "example.mod/b", Version: "v1.0.4"},
{Path: "example.mod/c", Replace: &packages.Module{Path: "example.mod/d", Version: "v1.0.0"}, Version: "v2.0.0"},
{Path: "example.mod/e", Replace: &packages.Module{Path: "../local/example.mod/d", Version: "v1.0.1"}, Version: "v2.1.0"},
})
if err != nil {
t.Fatalf("FetchVulnerabilities failed: %s", err)
}
want := []*vulncheck.ModVulns{
{
Module: &packages.Module{Path: "example.mod/a", Version: "v1.0.0"},
Vulns: []*osv.Entry{a},
},
{
Module: &packages.Module{Path: "example.mod/b", Version: "v1.0.4"},
Vulns: []*osv.Entry{b},
},
{
Module: &packages.Module{Path: "example.mod/c", Replace: &packages.Module{Path: "example.mod/d", Version: "v1.0.0"}, Version: "v2.0.0"},
Vulns: []*osv.Entry{c},
},
}
if diff := cmp.Diff(got, want); diff != "" {
t.Fatalf("mismatch (-want, +got):\n%s", diff)
}
}
@@ -0,0 +1,102 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package vulncheck
import (
"runtime"
"sort"
"golang.org/x/vuln/internal/client"
"golang.org/x/vuln/internal/osv"
"golang.org/x/vuln/internal/semver"
)
// newTestClient returns a client that reads
// a database with the following vulnerable symbols:
//
// golang.org/amod/avuln.{VulnData.Vuln1, vulnData.Vuln2}
// golang.org/bmod/bvuln.Vuln
// archive/zip.OpenReader
func newTestClient() (*client.Client, error) {
return client.NewInMemoryClient(
[]*osv.Entry{
{
ID: "VA",
Affected: []osv.Affected{{
Module: osv.Module{Path: "golang.org/amod"},
Ranges: []osv.Range{{Type: osv.RangeTypeSemver, Events: []osv.RangeEvent{{Introduced: "1.0.0"}, {Fixed: "1.0.4"}, {Introduced: "1.1.2"}}}},
EcosystemSpecific: osv.EcosystemSpecific{Packages: []osv.Package{{
Path: "golang.org/amod/avuln",
Symbols: []string{"VulnData.Vuln1", "VulnData.Vuln2"}},
}},
}},
},
{
ID: "VB",
Affected: []osv.Affected{{
Module: osv.Module{Path: "golang.org/bmod"},
Ranges: []osv.Range{{Type: osv.RangeTypeSemver}},
EcosystemSpecific: osv.EcosystemSpecific{
Packages: []osv.Package{{
Path: "golang.org/bmod/bvuln",
Symbols: []string{"Vuln"},
}},
},
}},
},
{
ID: "STD",
Affected: []osv.Affected{{
Module: osv.Module{Path: osv.GoStdModulePath},
// Range is populated also using runtime info for testing binaries since
// setting fixed Go version for binaries is very difficult.
Ranges: []osv.Range{{Type: osv.RangeTypeSemver, Events: []osv.RangeEvent{{Introduced: "1.18"}, {Introduced: semver.GoTagToSemver(runtime.Version())}}}},
EcosystemSpecific: osv.EcosystemSpecific{
Packages: []osv.Package{{
Path: "archive/zip",
Symbols: []string{"OpenReader"},
}},
},
}},
}})
}
type edge struct {
// src and dest are ids of source and
// destination nodes in a callgraph edge.
src, dst string
}
func callGraphToStrMap(r *Result) map[string][]string {
// seen edges, to avoid repetitions
seen := make(map[edge]bool)
m := make(map[string][]string)
for _, v := range r.Vulns {
updateCallGraph(m, v.CallSink, seen)
}
sortStrMap(m)
return m
}
func updateCallGraph(callGraph map[string][]string, f *FuncNode, seen map[edge]bool) {
fName := f.String()
for _, callsite := range f.CallSites {
e := edge{src: callsite.Parent.Name, dst: f.Name}
if seen[e] {
continue
}
seen[e] = true
callerName := callsite.Parent.String()
callGraph[callerName] = append(callGraph[callerName], fName)
updateCallGraph(callGraph, callsite.Parent, seen)
}
}
// sortStrMap sorts the map string slice values to make them deterministic.
func sortStrMap(m map[string][]string) {
for _, strs := range m {
sort.Strings(strs)
}
}
@@ -0,0 +1,194 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package vulncheck
import (
"fmt"
"strings"
"golang.org/x/tools/go/packages"
"golang.org/x/vuln/internal"
"golang.org/x/vuln/internal/semver"
)
// PackageGraph holds a complete module and package graph.
// Its primary purpose is to allow fast access to the nodes by path.
type PackageGraph struct {
modules map[string]*packages.Module
packages map[string]*packages.Package
}
func NewPackageGraph(goVersion string) *PackageGraph {
graph := &PackageGraph{
modules: map[string]*packages.Module{},
packages: map[string]*packages.Package{},
}
graph.AddModules(&packages.Module{
Path: internal.GoStdModulePath,
Version: semver.GoTagToSemver(goVersion),
})
return graph
}
// AddModules adds the modules and any replace modules provided.
// It will ignore modules that have duplicate paths to ones the graph already holds.
func (g *PackageGraph) AddModules(mods ...*packages.Module) {
for _, mod := range mods {
if _, found := g.modules[mod.Path]; found {
//TODO: check duplicates are okay?
continue
}
g.modules[mod.Path] = mod
if mod.Replace != nil {
g.AddModules(mod.Replace)
}
}
}
// .
func (g *PackageGraph) GetModule(path string) *packages.Module {
if mod, ok := g.modules[path]; ok {
return mod
}
mod := &packages.Module{
Path: path,
Version: "",
}
g.AddModules(mod)
return mod
}
// AddPackages adds the packages and the full graph of imported packages.
// It will ignore packages that have duplicate paths to ones the graph already holds.
func (g *PackageGraph) AddPackages(pkgs ...*packages.Package) {
for _, pkg := range pkgs {
if _, found := g.packages[pkg.PkgPath]; found {
//TODO: check duplicates are okay?
continue
}
g.packages[pkg.PkgPath] = pkg
g.fixupPackage(pkg)
for _, child := range pkg.Imports {
g.AddPackages(child)
}
}
}
func (g *PackageGraph) fixupPackage(pkg *packages.Package) {
if pkg.Module != nil {
g.AddModules(pkg.Module)
return
}
pkg.Module = g.findModule(pkg.PkgPath)
}
// findModule finds a module for package.
// It does a longest prefix search amongst the existing modules, if that does
// not find anything, it returns the "unknown" module.
func (g *PackageGraph) findModule(pkgPath string) *packages.Module {
//TODO: better stdlib test
if !strings.Contains(pkgPath, ".") {
return g.GetModule(internal.GoStdModulePath)
}
for _, m := range g.modules {
//TODO: not first match, best match...
if pkgPath == m.Path || strings.HasPrefix(pkgPath, m.Path+"/") {
return m
}
}
return g.GetModule(internal.UnknownModulePath)
}
// GetPackage returns the package matching the path.
// If the graph does not already know about the package, a new one is added.
func (g *PackageGraph) GetPackage(path string) *packages.Package {
if pkg, ok := g.packages[path]; ok {
return pkg
}
pkg := &packages.Package{
PkgPath: path,
}
g.AddPackages(pkg)
return pkg
}
// LoadPackages loads the packages specified by the patterns into the graph.
// See golang.org/x/tools/go/packages.Load for details of how it works.
func (g *PackageGraph) LoadPackagesAndMods(cfg *packages.Config, tags []string, patterns []string) ([]*packages.Package, []*packages.Module, error) {
if len(tags) > 0 {
cfg.BuildFlags = []string{fmt.Sprintf("-tags=%s", strings.Join(tags, ","))}
}
cfg.Mode |=
packages.NeedDeps |
packages.NeedImports |
packages.NeedModule |
packages.NeedSyntax |
packages.NeedTypes |
packages.NeedTypesInfo |
packages.NeedName
pkgs, err := packages.Load(cfg, patterns...)
if err != nil {
return nil, nil, err
}
var perrs []packages.Error
packages.Visit(pkgs, nil, func(p *packages.Package) {
perrs = append(perrs, p.Errors...)
})
if len(perrs) > 0 {
err = &packageError{perrs}
}
g.AddPackages(pkgs...)
return pkgs, extractModules(pkgs), err
}
// extractModules collects modules in `pkgs` up to uniqueness of
// module path and version.
func extractModules(pkgs []*packages.Package) []*packages.Module {
modMap := map[string]*packages.Module{}
seen := map[*packages.Package]bool{}
var extract func(*packages.Package, map[string]*packages.Module)
extract = func(pkg *packages.Package, modMap map[string]*packages.Module) {
if pkg == nil || seen[pkg] {
return
}
if pkg.Module != nil {
if pkg.Module.Replace != nil {
modMap[pkg.Module.Replace.Path] = pkg.Module
} else {
modMap[pkg.Module.Path] = pkg.Module
}
}
seen[pkg] = true
for _, imp := range pkg.Imports {
extract(imp, modMap)
}
}
for _, pkg := range pkgs {
extract(pkg, modMap)
}
modules := []*packages.Module{}
for _, mod := range modMap {
modules = append(modules, mod)
}
return modules
}
// packageError contains errors from loading a set of packages.
type packageError struct {
Errors []packages.Error
}
func (e *packageError) Error() string {
var b strings.Builder
fmt.Fprintln(&b, "\nThere are errors with the provided package patterns:")
fmt.Fprintln(&b, "")
for _, e := range e.Errors {
fmt.Fprintln(&b, e)
}
fmt.Fprintln(&b, "\nFor details on package patterns, see https://pkg.go.dev/cmd/go#hdr-Package_lists_and_patterns.")
return b.String()
}
@@ -0,0 +1,55 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package vulncheck
import (
"golang.org/x/tools/go/callgraph"
"golang.org/x/tools/go/ssa"
)
// forwardSlice computes the transitive closure of functions forward reachable
// via calls in cg or referred to in an instruction starting from `sources`.
func forwardSlice(sources map[*ssa.Function]bool, cg *callgraph.Graph) map[*ssa.Function]bool {
seen := make(map[*ssa.Function]bool)
var visit func(f *ssa.Function)
visit = func(f *ssa.Function) {
if seen[f] {
return
}
seen[f] = true
if n := cg.Nodes[f]; n != nil {
for _, e := range n.Out {
if e.Site != nil {
visit(e.Callee.Func)
}
}
}
var buf [10]*ssa.Value // avoid alloc in common case
for _, b := range f.Blocks {
for _, instr := range b.Instrs {
for _, op := range instr.Operands(buf[:0]) {
if fn, ok := (*op).(*ssa.Function); ok {
visit(fn)
}
}
}
}
}
for source := range sources {
visit(source)
}
return seen
}
// pruneSet removes functions in `set` that are in `toPrune`.
func pruneSet(set, toPrune map[*ssa.Function]bool) {
for f := range set {
if !toPrune[f] {
delete(set, f)
}
}
}
@@ -0,0 +1,119 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package vulncheck
import (
"path"
"reflect"
"testing"
"golang.org/x/tools/go/callgraph/cha"
"golang.org/x/tools/go/packages/packagestest"
"golang.org/x/tools/go/ssa"
"golang.org/x/tools/go/ssa/ssautil"
)
// funcNames returns a set of function names for `funcs`.
func funcNames(funcs map[*ssa.Function]bool) map[string]bool {
fs := make(map[string]bool)
for f := range funcs {
fs[dbFuncName(f)] = true
}
return fs
}
func TestSlicing(t *testing.T) {
// test program
p := `
package slice
func X() {}
func Y() {}
// not reachable
func id(i int) int {
return i
}
// not reachable
func inc(i int) int {
return i + 1
}
func Apply(b bool, h func()) {
if b {
func() {
print("applied")
}()
return
}
h()
}
type I interface {
Foo()
}
type A struct{}
func (a A) Foo() {}
// not reachable
func (a A) Bar() {}
type B struct{}
func (b B) Foo() {}
func debug(s string) {
print(s)
}
func Do(i I, input string) {
debug(input)
i.Foo()
func(x string) {
func(l int) {
print(l)
}(len(x))
}(input)
}`
e := packagestest.Export(t, packagestest.Modules, []packagestest.Module{
{
Name: "some/module",
Files: map[string]interface{}{"slice/slice.go": p},
},
})
graph := NewPackageGraph("go1.18")
pkgs, _, err := graph.LoadPackagesAndMods(e.Config, nil, []string{path.Join(e.Temp(), "/module/slice")})
if err != nil {
t.Fatal(err)
}
prog, ssaPkgs := ssautil.AllPackages(pkgs, 0)
prog.Build()
pkg := ssaPkgs[0]
sources := map[*ssa.Function]bool{pkg.Func("Apply"): true, pkg.Func("Do"): true}
fs := funcNames(forwardSlice(sources, cha.CallGraph(prog)))
want := map[string]bool{
"Apply": true,
"Apply$1": true,
"X": true,
"Y": true,
"Do": true,
"Do$1": true,
"Do$1$1": true,
"debug": true,
"A.Foo": true,
"B.Foo": true,
}
if !reflect.DeepEqual(want, fs) {
t.Errorf("want %v; got %v", want, fs)
}
}
@@ -0,0 +1,308 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package vulncheck
import (
"context"
"sync"
"golang.org/x/tools/go/callgraph"
"golang.org/x/tools/go/packages"
"golang.org/x/tools/go/ssa"
"golang.org/x/vuln/internal/client"
"golang.org/x/vuln/internal/govulncheck"
"golang.org/x/vuln/internal/osv"
)
// Source detects vulnerabilities in pkgs and emits the findings to handler.
func Source(ctx context.Context, handler govulncheck.Handler, pkgs []*packages.Package, mods []*packages.Module, cfg *govulncheck.Config, client *client.Client, graph *PackageGraph) error {
vr, err := source(ctx, handler, pkgs, mods, cfg, client, graph)
if err != nil {
return err
}
if cfg.ScanLevel.WantSymbols() {
return emitCallFindings(handler, sourceCallstacks(vr))
}
return nil
}
// source detects vulnerabilities in packages. It emits findings to handler
// and produces a Result that contains info on detected vulnerabilities.
//
// Assumes that pkgs are non-empty and belong to the same program.
func source(ctx context.Context, handler govulncheck.Handler, pkgs []*packages.Package, mods []*packages.Module, cfg *govulncheck.Config, client *client.Client, graph *PackageGraph) (*Result, error) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
// If we are building the callgraph, build ssa and the callgraph in parallel
// with fetching vulnerabilities. If the vulns set is empty, return without
// waiting for SSA construction or callgraph to finish.
var (
wg sync.WaitGroup // guards entries, cg, and buildErr
entries []*ssa.Function
cg *callgraph.Graph
buildErr error
)
if cfg.ScanLevel.WantSymbols() {
fset := pkgs[0].Fset
wg.Add(1)
go func() {
defer wg.Done()
prog, ssaPkgs := buildSSA(pkgs, fset)
entries = entryPoints(ssaPkgs)
cg, buildErr = callGraph(ctx, prog, entries)
}()
}
mv, err := FetchVulnerabilities(ctx, client, mods)
if err != nil {
return nil, err
}
// Emit OSV entries immediately in their raw unfiltered form.
if err := emitOSVs(handler, mv); err != nil {
return nil, err
}
affVulns := affectingVulnerabilities(mv, "", "")
if err := emitModuleFindings(handler, affVulns); err != nil {
return nil, err
}
if !cfg.ScanLevel.WantPackages() || len(affVulns) == 0 {
return &Result{}, nil
}
impVulns := importedVulnPackages(pkgs, affVulns)
// Emit information on imported vulnerable packages now as
// call graph computation might take a while.
if err := emitPackageFindings(handler, impVulns); err != nil {
return nil, err
}
// Return result immediately if not in symbol mode or
// if there are no vulnerabilities imported.
if !cfg.ScanLevel.WantSymbols() || len(impVulns) == 0 {
return &Result{Vulns: impVulns}, nil
}
wg.Wait() // wait for build to finish
if buildErr != nil {
return nil, err
}
entryFuncs, callVulns := calledVulnSymbols(entries, affVulns, cg, graph)
return &Result{EntryFunctions: entryFuncs, Vulns: callVulns}, nil
}
// importedVulnPackages detects imported vulnerable packages.
func importedVulnPackages(pkgs []*packages.Package, affVulns affectingVulns) []*Vuln {
var vulns []*Vuln
analyzed := make(map[*packages.Package]bool) // skip analyzing the same package multiple times
var vulnImports func(pkg *packages.Package)
vulnImports = func(pkg *packages.Package) {
if analyzed[pkg] {
return
}
osvs := affVulns.ForPackage(pkg.PkgPath)
// Create Vuln entry for each OSV entry for pkg.
for _, osv := range osvs {
vuln := &Vuln{
OSV: osv,
Package: pkg,
}
vulns = append(vulns, vuln)
}
analyzed[pkg] = true
for _, imp := range pkg.Imports {
vulnImports(imp)
}
}
for _, pkg := range pkgs {
vulnImports(pkg)
}
return vulns
}
// calledVulnSymbols detects vuln symbols transitively reachable from sources
// via call graph cg.
//
// A slice of call graph is computed related to the reachable vulnerabilities. Each
// reachable Vuln has attached FuncNode that can be upward traversed to the entry points.
// Entry points that reach the vulnerable symbols are also returned.
func calledVulnSymbols(sources []*ssa.Function, affVulns affectingVulns, cg *callgraph.Graph, graph *PackageGraph) ([]*FuncNode, []*Vuln) {
sinksWithVulns := vulnFuncs(cg, affVulns)
// Compute call graph backwards reachable
// from vulnerable functions and methods.
var sinks []*callgraph.Node
for n := range sinksWithVulns {
sinks = append(sinks, n)
}
bcg := callGraphSlice(sinks, false)
// Interesect backwards call graph with forward
// reachable graph to remove redundant edges.
var filteredSources []*callgraph.Node
for _, e := range sources {
if n, ok := bcg.Nodes[e]; ok {
filteredSources = append(filteredSources, n)
}
}
fcg := callGraphSlice(filteredSources, true)
// Get the sinks that are in fact reachable from entry points.
filteredSinks := make(map[*callgraph.Node][]*osv.Entry)
for n, vs := range sinksWithVulns {
if fn, ok := fcg.Nodes[n.Func]; ok {
filteredSinks[fn] = vs
}
}
// Transform the resulting call graph slice into
// vulncheck representation.
return vulnCallGraph(filteredSources, filteredSinks, graph)
}
// callGraphSlice computes a slice of callgraph beginning at starts
// in the direction (forward/backward) controlled by forward flag.
func callGraphSlice(starts []*callgraph.Node, forward bool) *callgraph.Graph {
g := &callgraph.Graph{Nodes: make(map[*ssa.Function]*callgraph.Node)}
visited := make(map[*callgraph.Node]bool)
var visit func(*callgraph.Node)
visit = func(n *callgraph.Node) {
if visited[n] {
return
}
visited[n] = true
var edges []*callgraph.Edge
if forward {
edges = n.Out
} else {
edges = n.In
}
for _, edge := range edges {
nCallee := g.CreateNode(edge.Callee.Func)
nCaller := g.CreateNode(edge.Caller.Func)
callgraph.AddEdge(nCaller, edge.Site, nCallee)
if forward {
visit(edge.Callee)
} else {
visit(edge.Caller)
}
}
}
for _, s := range starts {
visit(s)
}
return g
}
// vulnCallGraph creates vulnerability call graph in terms of sources and sinks.
func vulnCallGraph(sources []*callgraph.Node, sinks map[*callgraph.Node][]*osv.Entry, graph *PackageGraph) ([]*FuncNode, []*Vuln) {
var entries []*FuncNode
var vulns []*Vuln
nodes := make(map[*ssa.Function]*FuncNode)
// First create entries and sinks and store relevant information.
for _, s := range sources {
fn := createNode(nodes, s.Func, graph)
entries = append(entries, fn)
}
for s, osvs := range sinks {
f := s.Func
funNode := createNode(nodes, s.Func, graph)
// Populate CallSink field for each detected vuln symbol.
for _, osv := range osvs {
vulns = append(vulns, calledVuln(funNode, osv, dbFuncName(f), funNode.Package))
}
}
visited := make(map[*callgraph.Node]bool)
var visit func(*callgraph.Node)
visit = func(n *callgraph.Node) {
if visited[n] {
return
}
visited[n] = true
for _, edge := range n.In {
nCallee := createNode(nodes, edge.Callee.Func, graph)
nCaller := createNode(nodes, edge.Caller.Func, graph)
call := edge.Site
cs := &CallSite{
Parent: nCaller,
Name: call.Common().Value.Name(),
RecvType: callRecvType(call),
Resolved: resolved(call),
Pos: instrPosition(call),
}
nCallee.CallSites = append(nCallee.CallSites, cs)
visit(edge.Caller)
}
}
for s := range sinks {
visit(s)
}
return entries, vulns
}
// vulnFuncs returns vulnerability information for vulnerable functions in cg.
func vulnFuncs(cg *callgraph.Graph, affVulns affectingVulns) map[*callgraph.Node][]*osv.Entry {
m := make(map[*callgraph.Node][]*osv.Entry)
for f, n := range cg.Nodes {
vulns := affVulns.ForSymbol(pkgPath(f), dbFuncName(f))
if len(vulns) > 0 {
m[n] = vulns
}
}
return m
}
// pkgPath returns the path of the f's enclosing package, if any.
// Otherwise, returns "".
func pkgPath(f *ssa.Function) string {
if f.Package() != nil && f.Package().Pkg != nil {
return f.Package().Pkg.Path()
}
return ""
}
func createNode(nodes map[*ssa.Function]*FuncNode, f *ssa.Function, graph *PackageGraph) *FuncNode {
if fn, ok := nodes[f]; ok {
return fn
}
fn := &FuncNode{
Name: f.Name(),
Package: graph.GetPackage(pkgPath(f)),
RecvType: funcRecvType(f),
Pos: funcPosition(f),
}
nodes[f] = fn
return fn
}
func calledVuln(call *FuncNode, osv *osv.Entry, symbol string, pkg *packages.Package) *Vuln {
return &Vuln{
Symbol: symbol,
Package: pkg,
OSV: osv,
CallSink: call,
}
}
@@ -0,0 +1,521 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package vulncheck
import (
"context"
"path"
"reflect"
"testing"
"golang.org/x/tools/go/packages/packagestest"
"golang.org/x/vuln/internal/client"
"golang.org/x/vuln/internal/govulncheck"
"golang.org/x/vuln/internal/osv"
"golang.org/x/vuln/internal/test"
)
// TestCalls checks for call graph vuln slicing correctness.
// The inlined test code has the following call graph
//
// x.X
// / | \
// / d.D1 avuln.VulnData.Vuln1
// / / |
// c.C1 d.internal.Vuln1
// |
// avuln.VulnData.Vuln2
//
// --------------------y.Y-------------------------------
// / / \ \ \ \
// / / \ \ \ \
// / / \ \ \ \
// c.C4 c.vulnWrap.V.Vuln1(=nil) c.C2 bvuln.Vuln c.C3 c.C3$1
// | | |
// y.benign e.E
//
// and this slice
//
// x.X
// / | \
// / d.D1 avuln.VulnData.Vuln1
// / /
// c.C1
// |
// avuln.VulnData.Vuln2
//
// y.Y
// |
// bvuln.Vuln
// | |
// e.E
//
// related to avuln.VulnData.{Vuln1, Vuln2} and bvuln.Vuln vulnerabilities.
func TestCalls(t *testing.T) {
e := packagestest.Export(t, packagestest.Modules, []packagestest.Module{
{
Name: "golang.org/entry",
Files: map[string]interface{}{
"x/x.go": `
package x
import (
"golang.org/cmod/c"
"golang.org/dmod/d"
)
func X(x bool) {
if x {
c.C1().Vuln1() // vuln use: Vuln1
} else {
d.D1() // no vuln use
}
}
`,
"y/y.go": `
package y
import (
"golang.org/cmod/c"
)
func Y(y bool) {
if y {
c.C2()() // vuln use: bvuln.Vuln
} else {
c.C3()()
w := c.C4(benign)
w.V.Vuln1() // no vuln use: Vuln1 does not belong to vulnerable type
}
}
func benign(i c.I) {}
`}},
{
Name: "golang.org/cmod@v1.1.3",
Files: map[string]interface{}{"c/c.go": `
package c
import (
"golang.org/amod/avuln"
"golang.org/bmod/bvuln"
)
type I interface {
Vuln1()
}
func C1() I {
v := avuln.VulnData{}
v.Vuln2() // vuln use
return v
}
func C2() func() {
return bvuln.Vuln
}
func C3() func() {
return func() {}
}
type vulnWrap struct {
V I
}
func C4(f func(i I)) vulnWrap {
f(avuln.VulnData{})
return vulnWrap{}
}
`},
},
{
Name: "golang.org/dmod@v0.5.0",
Files: map[string]interface{}{"d/d.go": `
package d
import (
"golang.org/cmod/c"
)
type internal struct{}
func (i internal) Vuln1() {}
func D1() {
c.C1() // transitive vuln use
var i c.I
i = internal{}
i.Vuln1() // no vuln use
}
`},
},
{
Name: "golang.org/amod@v1.1.3",
Files: map[string]interface{}{"avuln/avuln.go": `
package avuln
type VulnData struct {}
func (v VulnData) Vuln1() {}
func (v VulnData) Vuln2() {}
`},
},
{
Name: "golang.org/bmod@v0.5.0",
Files: map[string]interface{}{"bvuln/bvuln.go": `
package bvuln
import (
"golang.org/emod/e"
)
func Vuln() {
e.E(Vuln)
}
`},
},
{
Name: "golang.org/emod@v1.5.0",
Files: map[string]interface{}{"e/e.go": `
package e
func E(f func()) {
f()
}
`},
},
})
defer e.Cleanup()
// Load x and y as entry packages.
graph := NewPackageGraph("go1.18")
pkgs, mods, err := graph.LoadPackagesAndMods(e.Config, nil, []string{path.Join(e.Temp(), "entry/x"), path.Join(e.Temp(), "entry/y")})
if err != nil {
t.Fatal(err)
}
if len(pkgs) != 2 {
t.Fatal("failed to load x and y test packages")
}
c, err := newTestClient()
if err != nil {
t.Fatal(err)
}
cfg := &govulncheck.Config{ScanLevel: "symbol"}
result, err := source(context.Background(), test.NewMockHandler(), pkgs, mods, cfg, c, graph)
if err != nil {
t.Fatal(err)
}
// Check that we find the right number of vulnerabilities.
// There should be three entries as there are three vulnerable
// symbols in the two import-reachable OSVs.
if len(result.Vulns) != 3 {
t.Errorf("want 3 Vulns, got %d", len(result.Vulns))
}
// Check that call graph entry points are present.
if got := len(result.EntryFunctions); got != 2 {
t.Errorf("want 2 call graph entry points; got %v", got)
}
// Check that vulnerabilities are connected to the call graph.
// For the test example, all vulns should have a call sink.
for _, v := range result.Vulns {
if v.CallSink == nil {
t.Errorf("want CallSink !=0 for %v; got 0", v.Symbol)
}
}
wantCalls := map[string][]string{
"golang.org/entry/x.X": {"golang.org/amod/avuln.VulnData.Vuln1", "golang.org/cmod/c.C1", "golang.org/dmod/d.D1"},
"golang.org/cmod/c.C1": {"golang.org/amod/avuln.VulnData.Vuln2"},
"golang.org/dmod/d.D1": {"golang.org/cmod/c.C1"},
"golang.org/entry/y.Y": {"golang.org/bmod/bvuln.Vuln"},
"golang.org/bmod/bvuln.Vuln": {"golang.org/emod/e.E"},
"golang.org/emod/e.E": {"golang.org/bmod/bvuln.Vuln"},
}
if callStrMap := callGraphToStrMap(result); !reflect.DeepEqual(wantCalls, callStrMap) {
t.Errorf("want %v call graph; got %v", wantCalls, callStrMap)
}
}
func TestAllSymbolsVulnerable(t *testing.T) {
e := packagestest.Export(t, packagestest.Modules, []packagestest.Module{
{
Name: "golang.org/entry",
Files: map[string]interface{}{
"x/x.go": `
package x
import "golang.org/vmod/vuln"
func X() {
vuln.V1()
}`,
},
},
{
Name: "golang.org/vmod@v1.2.3",
Files: map[string]interface{}{"vuln/vuln.go": `
package vuln
func V1() {}
func V2() {}
func v() {}
type a struct{}
func (x a) foo() {}
func (x *a) bar() {}
`},
},
})
defer e.Cleanup()
client, err := client.NewInMemoryClient(
[]*osv.Entry{
{
ID: "V",
Affected: []osv.Affected{{
Module: osv.Module{Path: "golang.org/vmod"},
Ranges: []osv.Range{{Type: osv.RangeTypeSemver, Events: []osv.RangeEvent{{Introduced: "1.2.0"}}}},
EcosystemSpecific: osv.EcosystemSpecific{
Packages: []osv.Package{{
Path: "golang.org/vmod/vuln",
Symbols: []string{},
}},
},
}},
},
},
)
if err != nil {
t.Fatal(err)
}
// Load x as entry package.
graph := NewPackageGraph("go1.18")
pkgs, mods, err := graph.LoadPackagesAndMods(e.Config, nil, []string{path.Join(e.Temp(), "entry/x")})
if err != nil {
t.Fatal(err)
}
if len(pkgs) != 1 {
t.Fatal("failed to load x test package")
}
cfg := &govulncheck.Config{ScanLevel: "symbol"}
result, err := source(context.Background(), test.NewMockHandler(), pkgs, mods, cfg, client, graph)
if err != nil {
t.Fatal(err)
}
if len(result.Vulns) != 2 { // init and V1
t.Errorf("want 2 Vulns, got %d", len(result.Vulns))
}
for _, v := range result.Vulns {
if v.CallSink == nil {
t.Errorf("expected a call sink for %s; got none", v.Symbol)
}
}
}
// TestNoSyntheticNodes checks that removing synthetic wrappers from
// call graph still produces correct results.
func TestNoSyntheticNodes(t *testing.T) {
e := packagestest.Export(t, packagestest.Modules, []packagestest.Module{
{
Name: "golang.org/entry",
Files: map[string]interface{}{
"x/x.go": `
package x
import "golang.org/amod/avuln"
type i interface {
Vuln1()
}
func X() {
v := &avuln.VulnData{}
var x i = v // to force creatation of wrapper method *avuln.VulnData.Vuln1
x.Vuln1()
}`,
},
},
{
Name: "golang.org/amod@v1.1.3",
Files: map[string]interface{}{"avuln/avuln.go": `
package avuln
type VulnData struct {}
func (v VulnData) Vuln1() {}
func (v VulnData) Vuln2() {}
`},
},
})
defer e.Cleanup()
// Load x as entry package.
graph := NewPackageGraph("go1.18")
pkgs, mods, err := graph.LoadPackagesAndMods(e.Config, nil, []string{path.Join(e.Temp(), "entry/x")})
if err != nil {
t.Fatal(err)
}
if len(pkgs) != 1 {
t.Fatal("failed to load x test package")
}
c, err := newTestClient()
if err != nil {
t.Fatal(err)
}
cfg := &govulncheck.Config{ScanLevel: "symbol"}
result, err := source(context.Background(), test.NewMockHandler(), pkgs, mods, cfg, c, graph)
if err != nil {
t.Fatal(err)
}
if len(result.Vulns) != 1 {
t.Errorf("want 1 Vuln, got %d", len(result.Vulns))
}
vuln := result.Vulns[0]
if vuln.Symbol != "VulnData.Vuln1" {
t.Fatalf("expected VulnData.Vuln1 as called symbol; got %s", vuln.Symbol)
}
stack := sourceCallstacks(result)[vuln]
// We don't want the call stack X -> *VulnData.Vuln1 (wrapper) -> VulnData.Vuln1.
// We want X -> VulnData.Vuln1.
if len(stack) != 2 {
t.Errorf("want stack of length 2; got stack of length %v", len(stack))
}
}
func TestRecursion(t *testing.T) {
e := packagestest.Export(t, packagestest.Modules, []packagestest.Module{
{
Name: "golang.org/entry",
Files: map[string]interface{}{
"x/x.go": `
package x
import "golang.org/bmod/bvuln"
func X() {
y()
bvuln.Vuln()
z()
}
func y() {
X()
}
func z() {}
`,
},
},
{
Name: "golang.org/bmod@v0.5.0",
Files: map[string]interface{}{"bvuln/bvuln.go": `
package bvuln
func Vuln() {}
`},
},
})
defer e.Cleanup()
// Load x as entry package.
graph := NewPackageGraph("go1.18")
pkgs, mods, err := graph.LoadPackagesAndMods(e.Config, nil, []string{path.Join(e.Temp(), "entry/x")})
if err != nil {
t.Fatal(err)
}
if len(pkgs) != 1 {
t.Fatal("failed to load x test package")
}
c, err := newTestClient()
if err != nil {
t.Fatal(err)
}
cfg := &govulncheck.Config{ScanLevel: "symbol"}
result, err := source(context.Background(), test.NewMockHandler(), pkgs, mods, cfg, c, graph)
if err != nil {
t.Fatal(err)
}
wantCalls := map[string][]string{
"golang.org/entry/x.X": {"golang.org/bmod/bvuln.Vuln", "golang.org/entry/x.y"},
"golang.org/entry/x.y": {"golang.org/entry/x.X"},
}
if callStrMap := callGraphToStrMap(result); !reflect.DeepEqual(wantCalls, callStrMap) {
t.Errorf("want %v call graph; got %v", wantCalls, callStrMap)
}
}
func TestIssue57174(t *testing.T) {
e := packagestest.Export(t, packagestest.Modules, []packagestest.Module{
{
Name: "golang.org/entry",
Files: map[string]interface{}{
"x/x.go": `
package x
import "golang.org/bmod/bvuln"
func P(d [][3]int) {
p(d)
}
func p[E interface{ [3]int | [4]int }](d []E) {
c := d[0]
if c[0] > 0 {
bvuln.Vuln()
}
}
`,
},
},
{
Name: "golang.org/bmod@v0.5.0",
Files: map[string]interface{}{"bvuln/bvuln.go": `
package bvuln
func Vuln() {}
`},
},
})
defer e.Cleanup()
// Load x as entry package.
graph := NewPackageGraph("go1.18")
pkgs, mods, err := graph.LoadPackagesAndMods(e.Config, nil, []string{path.Join(e.Temp(), "entry/x")})
if err != nil {
t.Fatal(err)
}
if len(pkgs) != 1 {
t.Fatal("failed to load x test package")
}
c, err := newTestClient()
if err != nil {
t.Fatal(err)
}
cfg := &govulncheck.Config{ScanLevel: "symbol"}
_, err = source(context.Background(), test.NewMockHandler(), pkgs, mods, cfg, c, graph)
if err != nil {
t.Fatal(err)
}
}
@@ -0,0 +1,317 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package vulncheck
import (
"bytes"
"context"
"go/token"
"go/types"
"sort"
"strings"
"golang.org/x/tools/go/callgraph"
"golang.org/x/tools/go/callgraph/cha"
"golang.org/x/tools/go/callgraph/vta"
"golang.org/x/tools/go/packages"
"golang.org/x/tools/go/ssa/ssautil"
"golang.org/x/tools/go/types/typeutil"
"golang.org/x/vuln/internal/osv"
"golang.org/x/vuln/internal/semver"
"golang.org/x/tools/go/ssa"
)
// buildSSA creates an ssa representation for pkgs. Returns
// the ssa program encapsulating the packages and top level
// ssa packages corresponding to pkgs.
func buildSSA(pkgs []*packages.Package, fset *token.FileSet) (*ssa.Program, []*ssa.Package) {
prog := ssa.NewProgram(fset, ssa.InstantiateGenerics)
imports := make(map[*packages.Package]*ssa.Package)
var createImports func(map[string]*packages.Package)
createImports = func(pkgs map[string]*packages.Package) {
for _, p := range pkgs {
if _, ok := imports[p]; !ok {
i := prog.CreatePackage(p.Types, p.Syntax, p.TypesInfo, true)
imports[p] = i
createImports(p.Imports)
}
}
}
for _, tp := range pkgs {
createImports(tp.Imports)
}
var ssaPkgs []*ssa.Package
for _, tp := range pkgs {
if sp, ok := imports[tp]; ok {
ssaPkgs = append(ssaPkgs, sp)
} else {
sp := prog.CreatePackage(tp.Types, tp.Syntax, tp.TypesInfo, false)
ssaPkgs = append(ssaPkgs, sp)
}
}
prog.Build()
return prog, ssaPkgs
}
// callGraph builds a call graph of prog based on VTA analysis.
func callGraph(ctx context.Context, prog *ssa.Program, entries []*ssa.Function) (*callgraph.Graph, error) {
entrySlice := make(map[*ssa.Function]bool)
for _, e := range entries {
entrySlice[e] = true
}
if err := ctx.Err(); err != nil { // cancelled?
return nil, err
}
initial := cha.CallGraph(prog)
allFuncs := ssautil.AllFunctions(prog)
fslice := forwardSlice(entrySlice, initial)
// Keep only actually linked functions.
pruneSet(fslice, allFuncs)
if err := ctx.Err(); err != nil { // cancelled?
return nil, err
}
vtaCg := vta.CallGraph(fslice, initial)
// Repeat the process once more, this time using
// the produced VTA call graph as the base graph.
fslice = forwardSlice(entrySlice, vtaCg)
pruneSet(fslice, allFuncs)
if err := ctx.Err(); err != nil { // cancelled?
return nil, err
}
cg := vta.CallGraph(fslice, vtaCg)
cg.DeleteSyntheticNodes()
return cg, nil
}
// dbTypeFormat formats the name of t according how types
// are encoded in vulnerability database:
// - pointer designation * is skipped
// - full path prefix is skipped as well
func dbTypeFormat(t types.Type) string {
switch tt := t.(type) {
case *types.Pointer:
return dbTypeFormat(tt.Elem())
case *types.Named:
return tt.Obj().Name()
default:
return types.TypeString(t, func(p *types.Package) string { return "" })
}
}
// dbFuncName computes a function name consistent with the namings used in vulnerability
// databases. Effectively, a qualified name of a function local to its enclosing package.
// If a receiver is a pointer, this information is not encoded in the resulting name. If
// a function has type argument/parameter, this information is omitted. The name of
// anonymous functions is simply "". The function names are unique subject to the enclosing
// package, but not globally.
//
// Examples:
//
// func (a A) foo (...) {...} -> A.foo
// func foo(...) {...} -> foo
// func (b *B) bar (...) {...} -> B.bar
// func (c C[T]) do(...) {...} -> C.do
func dbFuncName(f *ssa.Function) string {
selectBound := func(f *ssa.Function) types.Type {
// If f is a "bound" function introduced by ssa for a given type, return the type.
// When "f" is a "bound" function, it will have 1 free variable of that type within
// the function. This is subject to change when ssa changes.
if len(f.FreeVars) == 1 && strings.HasPrefix(f.Synthetic, "bound ") {
return f.FreeVars[0].Type()
}
return nil
}
selectThunk := func(f *ssa.Function) types.Type {
// If f is a "thunk" function introduced by ssa for a given type, return the type.
// When "f" is a "thunk" function, the first parameter will have that type within
// the function. This is subject to change when ssa changes.
params := f.Signature.Params() // params.Len() == 1 then params != nil.
if strings.HasPrefix(f.Synthetic, "thunk ") && params.Len() >= 1 {
if first := params.At(0); first != nil {
return first.Type()
}
}
return nil
}
var qprefix string
if recv := f.Signature.Recv(); recv != nil {
qprefix = dbTypeFormat(recv.Type())
} else if btype := selectBound(f); btype != nil {
qprefix = dbTypeFormat(btype)
} else if ttype := selectThunk(f); ttype != nil {
qprefix = dbTypeFormat(ttype)
}
if qprefix == "" {
return funcName(f)
}
return qprefix + "." + funcName(f)
}
// funcName returns the name of the ssa function f.
// It is f.Name() without additional type argument
// information in case of generics.
func funcName(f *ssa.Function) string {
n, _, _ := strings.Cut(f.Name(), "[")
return n
}
// memberFuncs returns functions associated with the `member`:
// 1) `member` itself if `member` is a function
// 2) `member` methods if `member` is a type
// 3) empty list otherwise
func memberFuncs(member ssa.Member, prog *ssa.Program) []*ssa.Function {
switch t := member.(type) {
case *ssa.Type:
methods := typeutil.IntuitiveMethodSet(t.Type(), &prog.MethodSets)
var funcs []*ssa.Function
for _, m := range methods {
if f := prog.MethodValue(m); f != nil {
funcs = append(funcs, f)
}
}
return funcs
case *ssa.Function:
return []*ssa.Function{t}
default:
return nil
}
}
// funcPosition gives the position of `f`. Returns empty token.Position
// if no file information on `f` is available.
func funcPosition(f *ssa.Function) *token.Position {
pos := f.Prog.Fset.Position(f.Pos())
return &pos
}
// instrPosition gives the position of `instr`. Returns empty token.Position
// if no file information on `instr` is available.
func instrPosition(instr ssa.Instruction) *token.Position {
pos := instr.Parent().Prog.Fset.Position(instr.Pos())
return &pos
}
func resolved(call ssa.CallInstruction) bool {
if call == nil {
return true
}
return call.Common().StaticCallee() != nil
}
func callRecvType(call ssa.CallInstruction) string {
if !call.Common().IsInvoke() {
return ""
}
buf := new(bytes.Buffer)
types.WriteType(buf, call.Common().Value.Type(), nil)
return buf.String()
}
func funcRecvType(f *ssa.Function) string {
v := f.Signature.Recv()
if v == nil {
return ""
}
buf := new(bytes.Buffer)
types.WriteType(buf, v.Type(), nil)
return buf.String()
}
func FixedVersion(modulePath, version string, affected []osv.Affected) string {
fixed := earliestValidFix(modulePath, version, affected)
// Add "v" prefix if one does not exist. moduleVersionString
// will later on replace it with "go" if needed.
if fixed != "" && !strings.HasPrefix(fixed, "v") {
fixed = "v" + fixed
}
return fixed
}
// earliestValidFix returns the earliest fix for version of modulePath that
// itself is not vulnerable in affected.
//
// Suppose we have a version "v1.0.0" and we use {...} to denote different
// affected regions. Assume for simplicity that all affected apply to the
// same input modulePath.
//
// {[v0.1.0, v0.1.9), [v1.0.0, v2.0.0)} -> v2.0.0
// {[v1.0.0, v1.5.0), [v2.0.0, v2.1.0}, {[v1.4.0, v1.6.0)} -> v2.1.0
func earliestValidFix(modulePath, version string, affected []osv.Affected) string {
var moduleAffected []osv.Affected
for _, a := range affected {
if a.Module.Path == modulePath {
moduleAffected = append(moduleAffected, a)
}
}
vFixes := validFixes(version, moduleAffected)
for _, fix := range vFixes {
if !fixNegated(fix, moduleAffected) {
return fix
}
}
return ""
}
// validFixes computes all fixes for version in affected and
// returns them sorted increasingly. Assumes that all affected
// apply to the same module.
func validFixes(version string, affected []osv.Affected) []string {
var fixes []string
for _, a := range affected {
for _, r := range a.Ranges {
if r.Type != osv.RangeTypeSemver {
continue
}
for _, e := range r.Events {
fix := e.Fixed
if fix != "" && semver.Less(version, fix) {
fixes = append(fixes, fix)
}
}
}
}
sort.SliceStable(fixes, func(i, j int) bool { return semver.Less(fixes[i], fixes[j]) })
return fixes
}
// fixNegated checks if fix is negated to by a re-introduction
// of a vulnerability in affected. Assumes that all affected apply
// to the same module.
func fixNegated(fix string, affected []osv.Affected) bool {
for _, a := range affected {
for _, r := range a.Ranges {
if semver.ContainsSemver(r, fix) {
return true
}
}
}
return false
}
func modPath(mod *packages.Module) string {
if mod.Replace != nil {
return mod.Replace.Path
}
return mod.Path
}
func modVersion(mod *packages.Module) string {
if mod.Replace != nil {
return mod.Replace.Version
}
return mod.Version
}
@@ -0,0 +1,256 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package vulncheck
import (
"path"
"testing"
"github.com/google/go-cmp/cmp"
"golang.org/x/tools/go/packages/packagestest"
"golang.org/x/tools/go/ssa/ssautil"
"golang.org/x/vuln/internal/osv"
)
func TestFixedVersion(t *testing.T) {
for _, test := range []struct {
name string
module string
version string
in []osv.Affected
want string
}{
{
name: "empty",
want: "",
},
{
name: "no semver",
module: "example.com/module",
version: "v1.2.0",
in: []osv.Affected{
{
Module: osv.Module{
Path: "example.com/module",
},
Ranges: []osv.Range{
{
Type: osv.RangeType("unspecified"),
Events: []osv.RangeEvent{
{Introduced: "v1.0.0"}, {Fixed: "v1.2.3"},
},
}},
},
},
want: "",
},
{
name: "one",
module: "example.com/module",
version: "v1.0.1",
in: []osv.Affected{
{
Module: osv.Module{
Path: "example.com/module",
},
Ranges: []osv.Range{
{
Type: osv.RangeTypeSemver,
Events: []osv.RangeEvent{
{Introduced: "v1.0.0"}, {Fixed: "v1.2.3"},
},
}},
},
},
want: "v1.2.3",
},
{
name: "several",
module: "example.com/module",
version: "v1.2.0",
in: []osv.Affected{
{
Module: osv.Module{
Path: "example.com/module",
},
Ranges: []osv.Range{
{
Type: osv.RangeTypeSemver,
Events: []osv.RangeEvent{
{Introduced: "v1.0.0"}, {Fixed: "v1.2.3"},
{Introduced: "v1.5.0"}, {Fixed: "v1.5.6"},
},
}},
},
{
Module: osv.Module{
Path: "example.com/module",
},
Ranges: []osv.Range{
{
Type: osv.RangeTypeSemver,
Events: []osv.RangeEvent{
{Introduced: "v1.3.0"}, {Fixed: "v1.4.1"},
},
}},
},
{
// This should be ignored.
Module: osv.Module{
Path: "example.com/anothermodule",
},
Ranges: []osv.Range{
{
Type: osv.RangeTypeSemver,
Events: []osv.RangeEvent{
{Introduced: "0"}, {Fixed: "v1.6.0"},
},
}},
},
},
want: "v1.2.3",
},
{
name: "no v prefix",
version: "1.18.1",
module: "example.com/module",
in: []osv.Affected{
{
Module: osv.Module{
Path: "example.com/module",
},
Ranges: []osv.Range{
{
Type: osv.RangeTypeSemver,
Events: []osv.RangeEvent{
{Fixed: "1.17.2"},
},
}},
},
{
Module: osv.Module{
Path: "example.com/module",
},
Ranges: []osv.Range{
{
Type: osv.RangeTypeSemver,
Events: []osv.RangeEvent{
{Introduced: "1.18.0"}, {Fixed: "1.18.4"},
},
}},
},
},
want: "v1.18.4",
},
{
name: "overlapping",
module: "example.com/module",
in: []osv.Affected{
{
Module: osv.Module{
Path: "example.com/module",
},
Ranges: []osv.Range{
{
Type: osv.RangeTypeSemver,
Events: []osv.RangeEvent{
// v1.2.3 is nominally the earliest fix,
// but it is contained in vulnerable range
// for the next affected value.
{Introduced: "v1.0.0"}, {Fixed: "v1.2.3"},
{Introduced: "v1.5.0"},
},
}},
},
{
Module: osv.Module{
Path: "example.com/module",
},
Ranges: []osv.Range{
{
Type: osv.RangeTypeSemver,
Events: []osv.RangeEvent{
{Introduced: "v1.2.0"}, {Fixed: "v1.4.1"},
},
}},
},
},
want: "v1.4.1",
},
} {
t.Run(test.name, func(t *testing.T) {
got := FixedVersion(test.module, test.version, test.in)
if got != test.want {
t.Errorf("got %q, want %q", got, test.want)
}
})
}
}
func TestDbSymbolName(t *testing.T) {
e := packagestest.Export(t, packagestest.Modules, []packagestest.Module{
{
Name: "golang.org/package",
Files: map[string]interface{}{
"x/x.go": `
package x
func Foo() {
// needed for ssautil.Allfunctions
x := a{}
x.Do()
x.NotDo()
b := B[a]{}
b.P()
b.Q(x)
Z[a]()
}
func bar() {}
type a struct{}
func (x a) Do() {}
func (x *a) NotDo() {
}
type B[T any] struct{}
func (b *B[T]) P() {}
func (b B[T]) Q(t T) {}
func Z[T any]() {}
`},
},
})
defer e.Cleanup()
graph := NewPackageGraph("go1.18")
pkgs, _, err := graph.LoadPackagesAndMods(e.Config, nil, []string{path.Join(e.Temp(), "package/x")})
if err != nil {
t.Fatal(err)
}
want := map[string]bool{
"init": true,
"bar": true,
"B.P": true,
"B.Q": true,
"a.Do": true,
"a.NotDo": true,
"Foo": true,
"Z": true,
}
// test dbFuncName
prog, _ := buildSSA(pkgs, pkgs[0].Fset)
got := make(map[string]bool)
for f := range ssautil.AllFunctions(prog) {
got[dbFuncName(f)] = true
}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("(-want;got+): %s", diff)
}
}
@@ -0,0 +1,312 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package vulncheck
import (
"fmt"
"go/token"
"strings"
"time"
"golang.org/x/tools/go/packages"
"golang.org/x/vuln/internal"
"golang.org/x/vuln/internal/osv"
"golang.org/x/vuln/internal/semver"
)
// Result contains information on detected vulnerabilities.
// For call graph analysis, it provides information on reachability
// of vulnerable symbols through entry points of the program.
type Result struct {
// EntryFunctions are a subset of Functions representing vulncheck entry points.
EntryFunctions []*FuncNode
// Vulns contains information on detected vulnerabilities.
Vulns []*Vuln
}
// Vuln provides information on a detected vulnerability. For call
// graph mode, Vuln will also contain the information on how the
// vulnerability is reachable in the user call graph.
type Vuln struct {
// OSV contains information on the detected vulnerability in the shared
// vulnerability format.
//
// OSV, Symbol, and Package identify a vulnerability.
//
// Note that *osv.Entry may describe multiple symbols from multiple
// packages.
OSV *osv.Entry
// Symbol is the name of the detected vulnerable function or method.
Symbol string
// CallSink is the FuncNode corresponding to Symbol.
//
// When analyzing binaries, Symbol is not reachable, or cfg.ScanLevel
// is symbol, CallSink will be unavailable and set to nil.
CallSink *FuncNode
// Package of Symbol.
//
// When the package of symbol is not imported, Package will be
// unavailable and set to nil.
Package *packages.Package
}
// A FuncNode describes a function in the call graph.
type FuncNode struct {
// Name is the name of the function.
Name string
// RecvType is the receiver object type of this function, if any.
RecvType string
// Package is the package the function is part of.
Package *packages.Package
// Position describes the position of the function in the file.
Pos *token.Position
// CallSites is a set of call sites where this function is called.
CallSites []*CallSite
}
func (fn *FuncNode) String() string {
if fn.RecvType == "" {
return fmt.Sprintf("%s.%s", fn.Package.PkgPath, fn.Name)
}
return fmt.Sprintf("%s.%s", fn.RecvType, fn.Name)
}
// Receiver returns the FuncNode's receiver, with package path removed.
// Pointers are preserved if present.
func (fn *FuncNode) Receiver() string {
return strings.Replace(fn.RecvType, fmt.Sprintf("%s.", fn.Package.PkgPath), "", 1)
}
// A CallSite describes a function call.
type CallSite struct {
// Parent is the enclosing function where the call is made.
Parent *FuncNode
// Name stands for the name of the function (variable) being called.
Name string
// RecvType is the full path of the receiver object type, if any.
RecvType string
// Position describes the position of the function in the file.
Pos *token.Position
// Resolved indicates if the called function can be statically resolved.
Resolved bool
}
// affectingVulns is an internal structure for querying
// vulnerabilities that apply to the current program
// and platform under consideration.
type affectingVulns []*ModVulns
// ModVulns groups vulnerabilities per module.
type ModVulns struct {
Module *packages.Module
Vulns []*osv.Entry
}
func affectingVulnerabilities(vulns []*ModVulns, os, arch string) affectingVulns {
now := time.Now()
var filtered affectingVulns
for _, mod := range vulns {
module := mod.Module
modVersion := module.Version
if module.Replace != nil {
modVersion = module.Replace.Version
}
// TODO(https://golang.org/issues/49264): if modVersion == "", try vcs?
var filteredVulns []*osv.Entry
for _, v := range mod.Vulns {
// Ignore vulnerabilities that have been withdrawn
if v.Withdrawn != nil && v.Withdrawn.Before(now) {
continue
}
var filteredAffected []osv.Affected
for _, a := range v.Affected {
// Vulnerabilities from some databases might contain
// information on related but different modules that
// were, say, reported in the same CVE. We filter such
// information out as it might lead to incorrect results:
// Computing a latest fix could consider versions of these
// different packages.
if a.Module.Path != module.Path {
continue
}
// A module version is affected if
// - it is included in one of the affected version ranges
// - and module version is not ""
if modVersion == "" {
// Module version of "" means the module version is not available,
// and so we don't want to spam users with potential false alarms.
continue
}
if !semver.Affects(a.Ranges, modVersion) {
continue
}
var filteredImports []osv.Package
for _, p := range a.EcosystemSpecific.Packages {
if matchesPlatform(os, arch, p) {
filteredImports = append(filteredImports, p)
}
}
// If we pruned all existing Packages, then the affected is
// empty and we can filter it out. Note that Packages can
// be empty for vulnerabilities that have no package or
// symbol information available.
if len(a.EcosystemSpecific.Packages) != 0 && len(filteredImports) == 0 {
continue
}
a.EcosystemSpecific.Packages = filteredImports
filteredAffected = append(filteredAffected, a)
}
if len(filteredAffected) == 0 {
continue
}
// save the non-empty vulnerability with only
// affected symbols.
newV := *v
newV.Affected = filteredAffected
filteredVulns = append(filteredVulns, &newV)
}
filtered = append(filtered, &ModVulns{
Module: module,
Vulns: filteredVulns,
})
}
return filtered
}
func matchesPlatform(os, arch string, e osv.Package) bool {
return matchesPlatformComponent(os, e.GOOS) &&
matchesPlatformComponent(arch, e.GOARCH)
}
// matchesPlatformComponent reports whether a GOOS (or GOARCH)
// matches a list of GOOS (or GOARCH) values from an osv.EcosystemSpecificImport.
func matchesPlatformComponent(s string, ps []string) bool {
// An empty input or an empty GOOS or GOARCH list means "matches everything."
if s == "" || len(ps) == 0 {
return true
}
for _, p := range ps {
if s == p {
return true
}
}
return false
}
// ForPackage returns the vulnerabilities for the module which is the most
// specific prefix of importPath, or nil if there is no matching module with
// vulnerabilities.
func (aff affectingVulns) ForPackage(importPath string) []*osv.Entry {
isStd := IsStdPackage(importPath)
var mostSpecificMod *ModVulns
for _, mod := range aff {
md := mod
if isStd && mod.Module.Path == internal.GoStdModulePath {
// standard library packages do not have an associated module,
// so we relate them to the artificial stdlib module.
mostSpecificMod = md
} else if strings.HasPrefix(importPath, md.Module.Path) {
if mostSpecificMod == nil || len(mostSpecificMod.Module.Path) < len(md.Module.Path) {
mostSpecificMod = md
}
}
}
if mostSpecificMod == nil {
return nil
}
if mostSpecificMod.Module.Replace != nil {
// standard libraries do not have a module nor replace module
importPath = fmt.Sprintf("%s%s", mostSpecificMod.Module.Replace.Path, strings.TrimPrefix(importPath, mostSpecificMod.Module.Path))
}
vulns := mostSpecificMod.Vulns
packageVulns := []*osv.Entry{}
Vuln:
for _, v := range vulns {
for _, a := range v.Affected {
if len(a.EcosystemSpecific.Packages) == 0 {
// no packages means all packages are vulnerable
packageVulns = append(packageVulns, v)
continue Vuln
}
for _, p := range a.EcosystemSpecific.Packages {
if p.Path == importPath {
packageVulns = append(packageVulns, v)
continue Vuln
}
}
}
}
return packageVulns
}
// ForSymbol returns vulnerabilities for symbol in aff.ForPackage(importPath).
func (aff affectingVulns) ForSymbol(importPath, symbol string) []*osv.Entry {
vulns := aff.ForPackage(importPath)
if vulns == nil {
return nil
}
symbolVulns := []*osv.Entry{}
vulnLoop:
for _, v := range vulns {
for _, a := range v.Affected {
if len(a.EcosystemSpecific.Packages) == 0 {
// no packages means all symbols of all packages are vulnerable
symbolVulns = append(symbolVulns, v)
continue vulnLoop
}
for _, p := range a.EcosystemSpecific.Packages {
if p.Path != importPath {
continue
}
if len(p.Symbols) > 0 && !contains(p.Symbols, symbol) {
continue
}
symbolVulns = append(symbolVulns, v)
continue vulnLoop
}
}
}
return symbolVulns
}
func contains(symbols []string, target string) bool {
for _, s := range symbols {
if s == target {
return true
}
}
return false
}
func IsStdPackage(pkg string) bool {
if pkg == "" {
return false
}
// std packages do not have a "." in their path. For instance, see
// Contains in pkgsite/+/refs/heads/master/internal/stdlbib/stdlib.go.
if i := strings.IndexByte(pkg, '/'); i != -1 {
pkg = pkg[:i]
}
return !strings.Contains(pkg, ".")
}
@@ -0,0 +1,451 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package vulncheck
import (
"reflect"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"golang.org/x/tools/go/packages"
"golang.org/x/vuln/internal/osv"
)
func TestFilterVulns(t *testing.T) {
past := time.Now().Add(-3 * time.Hour)
mv := []*ModVulns{
{
Module: &packages.Module{
Path: "example.mod/a",
Version: "v1.0.0",
},
Vulns: []*osv.Entry{
{ID: "a", Affected: []osv.Affected{
{Module: osv.Module{Path: "example.mod/a"}, Ranges: []osv.Range{{Type: osv.RangeTypeSemver, Events: []osv.RangeEvent{{Introduced: "1.0.0"}, {Fixed: "2.0.0"}}}}},
{Module: osv.Module{Path: "a.example.mod/a"}, Ranges: []osv.Range{{Type: osv.RangeTypeSemver, Events: []osv.RangeEvent{{Introduced: "1.0.0"}, {Fixed: "2.0.0"}}}}}, // should be filtered out
{Module: osv.Module{Path: "example.mod/a"}, Ranges: []osv.Range{{Type: osv.RangeTypeSemver, Events: []osv.RangeEvent{{Introduced: "0"}, {Fixed: "0.9.0"}}}}}, // should be filtered out
}},
{ID: "b", Affected: []osv.Affected{{Module: osv.Module{Path: "example.mod/a"}, Ranges: []osv.Range{{Type: osv.RangeTypeSemver, Events: []osv.RangeEvent{{Introduced: "1.0.1"}}}},
EcosystemSpecific: osv.EcosystemSpecific{Packages: []osv.Package{{
GOOS: []string{"windows", "linux"},
}},
}}}},
{ID: "c", Affected: []osv.Affected{{Module: osv.Module{Path: "example.mod/a"}, Ranges: []osv.Range{{Type: osv.RangeTypeSemver, Events: []osv.RangeEvent{{Introduced: "1.0.0"}, {Fixed: "1.0.1"}}}},
EcosystemSpecific: osv.EcosystemSpecific{Packages: []osv.Package{{
GOARCH: []string{"arm64", "amd64"},
}},
}}}},
{ID: "d", Affected: []osv.Affected{{Module: osv.Module{Path: "example.mod/a"}, EcosystemSpecific: osv.EcosystemSpecific{
Packages: []osv.Package{{
GOOS: []string{"windows"},
}},
}}}},
},
},
{
Module: &packages.Module{
Path: "example.mod/b",
Version: "v1.0.0",
},
Vulns: []*osv.Entry{
{ID: "e", Affected: []osv.Affected{{Module: osv.Module{Path: "example.mod/b"}, EcosystemSpecific: osv.EcosystemSpecific{
Packages: []osv.Package{{
GOARCH: []string{"arm64"},
}},
}}}},
{ID: "f", Affected: []osv.Affected{{Module: osv.Module{Path: "example.mod/b"}, EcosystemSpecific: osv.EcosystemSpecific{
Packages: []osv.Package{{
GOOS: []string{"linux"},
}},
}}}},
{ID: "g", Affected: []osv.Affected{{Module: osv.Module{Path: "example.mod/b"}, EcosystemSpecific: osv.EcosystemSpecific{
Packages: []osv.Package{{
GOARCH: []string{"amd64"},
}},
}, Ranges: []osv.Range{{Type: osv.RangeTypeSemver, Events: []osv.RangeEvent{{Introduced: "0.0.1"}, {Fixed: "2.0.1"}}}}}}},
{ID: "h", Affected: []osv.Affected{{Module: osv.Module{Path: "example.mod/b"}, EcosystemSpecific: osv.EcosystemSpecific{
Packages: []osv.Package{{
GOOS: []string{"windows"}, GOARCH: []string{"amd64"},
}},
}}}},
},
},
{
Module: &packages.Module{
Path: "example.mod/c",
},
Vulns: []*osv.Entry{
{ID: "i", Affected: []osv.Affected{{Module: osv.Module{Path: "example.mod/c"}, EcosystemSpecific: osv.EcosystemSpecific{
Packages: []osv.Package{{
GOARCH: []string{"amd64"},
}},
}, Ranges: []osv.Range{{Type: osv.RangeTypeSemver, Events: []osv.RangeEvent{{Introduced: "0.0.0"}}}}}}},
{ID: "j", Affected: []osv.Affected{{Module: osv.Module{Path: "example.mod/c"}, EcosystemSpecific: osv.EcosystemSpecific{
Packages: []osv.Package{{
GOARCH: []string{"amd64"},
}},
}, Ranges: []osv.Range{{Type: osv.RangeTypeSemver, Events: []osv.RangeEvent{{Fixed: "3.0.0"}}}}}}},
{ID: "k"},
},
},
{
Module: &packages.Module{
Path: "example.mod/d",
Version: "v1.2.0",
},
Vulns: []*osv.Entry{
{ID: "l", Affected: []osv.Affected{
{Module: osv.Module{Path: "example.mod/d"}, EcosystemSpecific: osv.EcosystemSpecific{
Packages: []osv.Package{{
GOOS: []string{"windows"}, // should be filtered out
}},
}},
{Module: osv.Module{Path: "example.mod/d"}, EcosystemSpecific: osv.EcosystemSpecific{
Packages: []osv.Package{{
GOOS: []string{"linux"},
}},
}},
}},
},
},
{
Module: &packages.Module{
Path: "example.mod/w",
Version: "v1.3.0",
},
Vulns: []*osv.Entry{
{ID: "m", Withdrawn: &past, Affected: []osv.Affected{ // should be filtered out
{Module: osv.Module{Path: "example.mod/w"}, EcosystemSpecific: osv.EcosystemSpecific{
Packages: []osv.Package{{
GOOS: []string{"linux"},
}},
}},
}},
{ID: "n", Affected: []osv.Affected{
{Module: osv.Module{Path: "example.mod/w"}, EcosystemSpecific: osv.EcosystemSpecific{
Packages: []osv.Package{{
GOOS: []string{"linux"},
}},
}},
}},
},
},
}
want := affectingVulns{
{
Module: &packages.Module{
Path: "example.mod/a",
Version: "v1.0.0",
},
Vulns: []*osv.Entry{
{ID: "a", Affected: []osv.Affected{{Module: osv.Module{Path: "example.mod/a"}, Ranges: []osv.Range{{Type: osv.RangeTypeSemver, Events: []osv.RangeEvent{{Introduced: "1.0.0"}, {Fixed: "2.0.0"}}}}}}},
{ID: "c", Affected: []osv.Affected{{Module: osv.Module{Path: "example.mod/a"}, EcosystemSpecific: osv.EcosystemSpecific{
Packages: []osv.Package{{
GOARCH: []string{"arm64", "amd64"},
}},
}, Ranges: []osv.Range{{Type: osv.RangeTypeSemver, Events: []osv.RangeEvent{{Introduced: "1.0.0"}, {Fixed: "1.0.1"}}}}}}},
},
},
{
Module: &packages.Module{
Path: "example.mod/b",
Version: "v1.0.0",
},
Vulns: []*osv.Entry{
{ID: "f", Affected: []osv.Affected{{Module: osv.Module{Path: "example.mod/b"}, EcosystemSpecific: osv.EcosystemSpecific{
Packages: []osv.Package{{
GOOS: []string{"linux"},
}},
}}}},
{ID: "g", Affected: []osv.Affected{{Module: osv.Module{Path: "example.mod/b"}, EcosystemSpecific: osv.EcosystemSpecific{
Packages: []osv.Package{{
GOARCH: []string{"amd64"},
}},
}, Ranges: []osv.Range{{Type: osv.RangeTypeSemver, Events: []osv.RangeEvent{{Introduced: "0.0.1"}, {Fixed: "2.0.1"}}}}}}},
},
},
{
Module: &packages.Module{
Path: "example.mod/c",
},
},
{
Module: &packages.Module{
Path: "example.mod/d",
Version: "v1.2.0",
},
Vulns: []*osv.Entry{
{ID: "l", Affected: []osv.Affected{{Module: osv.Module{Path: "example.mod/d"}, EcosystemSpecific: osv.EcosystemSpecific{
Packages: []osv.Package{{
GOOS: []string{"linux"},
}},
}}}},
},
},
{
Module: &packages.Module{
Path: "example.mod/w",
Version: "v1.3.0",
},
Vulns: []*osv.Entry{
{ID: "n", Affected: []osv.Affected{
{Module: osv.Module{Path: "example.mod/w"}, EcosystemSpecific: osv.EcosystemSpecific{
Packages: []osv.Package{{
GOOS: []string{"linux"},
}},
}},
}},
},
},
}
got := affectingVulnerabilities(mv, "linux", "amd64")
if diff := cmp.Diff(want, got, cmp.Exporter(func(t reflect.Type) bool {
return reflect.TypeOf(affectingVulns{}) == t || reflect.TypeOf(ModVulns{}) == t
})); diff != "" {
t.Errorf("(-want,+got):\n%s", diff)
}
}
func TestVulnsForPackage(t *testing.T) {
aff := affectingVulns{
{
Module: &packages.Module{
Path: "example.mod/a",
Version: "v1.0.0",
},
Vulns: []*osv.Entry{
{ID: "a", Affected: []osv.Affected{{
Module: osv.Module{Path: "example.mod/a"},
EcosystemSpecific: osv.EcosystemSpecific{
Packages: []osv.Package{{
Path: "example.mod/a/b/c",
}},
},
}}},
},
},
{
Module: &packages.Module{
Path: "example.mod/a/b",
Version: "v1.0.0",
},
Vulns: []*osv.Entry{
{ID: "b", Affected: []osv.Affected{{
Module: osv.Module{Path: "example.mod/a/b"},
EcosystemSpecific: osv.EcosystemSpecific{
Packages: []osv.Package{{
Path: "example.mod/a/b/c",
}},
},
}}},
},
},
{
Module: &packages.Module{
Path: "example.mod/d",
Version: "v0.0.1",
},
Vulns: []*osv.Entry{
{ID: "d", Affected: []osv.Affected{{
Module: osv.Module{Path: "example.mod/d"},
EcosystemSpecific: osv.EcosystemSpecific{
Packages: []osv.Package{{
Path: "example.mod/d",
}},
},
}}},
},
},
}
got := aff.ForPackage("example.mod/a/b/c")
want := []*osv.Entry{
{ID: "b", Affected: []osv.Affected{{
Module: osv.Module{Path: "example.mod/a/b"},
EcosystemSpecific: osv.EcosystemSpecific{
Packages: []osv.Package{{
Path: "example.mod/a/b/c",
}},
},
}}},
}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("(-want,+got):\n%s", diff)
}
}
func TestVulnsForPackageReplaced(t *testing.T) {
aff := affectingVulns{
{
Module: &packages.Module{
Path: "example.mod/a",
Version: "v1.0.0",
},
Vulns: []*osv.Entry{
{ID: "a", Affected: []osv.Affected{{
Module: osv.Module{Path: "example.mod/a"},
EcosystemSpecific: osv.EcosystemSpecific{
Packages: []osv.Package{{
Path: "example.mod/a/b/c",
}},
},
}}},
},
},
{
Module: &packages.Module{
Path: "example.mod/a/b",
Replace: &packages.Module{
Path: "example.mod/b",
},
Version: "v1.0.0",
},
Vulns: []*osv.Entry{
{ID: "c", Affected: []osv.Affected{{
Module: osv.Module{Path: "example.mod/b"},
EcosystemSpecific: osv.EcosystemSpecific{
Packages: []osv.Package{{
Path: "example.mod/b/c",
}},
},
}}},
},
},
}
got := aff.ForPackage("example.mod/a/b/c")
want := []*osv.Entry{
{ID: "c", Affected: []osv.Affected{{
Module: osv.Module{Path: "example.mod/b"},
EcosystemSpecific: osv.EcosystemSpecific{
Packages: []osv.Package{{
Path: "example.mod/b/c",
}},
},
}}},
}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("(-want,+got):\n%s", diff)
}
}
func TestVulnsForSymbol(t *testing.T) {
aff := affectingVulns{
{
Module: &packages.Module{
Path: "example.mod/a",
Version: "v1.0.0",
},
Vulns: []*osv.Entry{
{ID: "a", Affected: []osv.Affected{{
Module: osv.Module{Path: "example.mod/a"},
EcosystemSpecific: osv.EcosystemSpecific{
Packages: []osv.Package{{
Path: "example.mod/a/b/c",
}},
},
}}},
},
},
{
Module: &packages.Module{
Path: "example.mod/a/b",
Version: "v1.0.0",
},
Vulns: []*osv.Entry{
{ID: "b", Affected: []osv.Affected{{
Module: osv.Module{Path: "example.mod/a/b"},
EcosystemSpecific: osv.EcosystemSpecific{
Packages: []osv.Package{{
Path: "example.mod/a/b/c",
Symbols: []string{"a"},
}},
},
}}},
{ID: "c", Affected: []osv.Affected{{
Module: osv.Module{Path: "example.mod/a/b"},
EcosystemSpecific: osv.EcosystemSpecific{
Packages: []osv.Package{{
Path: "example.mod/a/b/c",
Symbols: []string{"b"},
}},
},
}}},
},
},
}
got := aff.ForSymbol("example.mod/a/b/c", "a")
want := []*osv.Entry{
{ID: "b", Affected: []osv.Affected{{
Module: osv.Module{Path: "example.mod/a/b"},
EcosystemSpecific: osv.EcosystemSpecific{
Packages: []osv.Package{{
Path: "example.mod/a/b/c",
Symbols: []string{"a"},
}},
},
}}},
}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("(-want,+got):\n%s", diff)
}
}
func TestReceiver(t *testing.T) {
tcs := []struct {
name string
fn *FuncNode
want string
}{
{
name: "empty",
fn: &FuncNode{
RecvType: "",
Package: &packages.Package{PkgPath: "example.com/a/pkg"},
},
want: "",
},
{
name: "pointer",
fn: &FuncNode{
RecvType: "*example.com/a/pkg.Atype",
Package: &packages.Package{PkgPath: "example.com/a/pkg"},
},
want: "*Atype",
},
{
name: "not pointer",
fn: &FuncNode{
RecvType: "example.com/a/pkg.Atype",
Package: &packages.Package{PkgPath: "example.com/a/pkg"},
},
want: "Atype",
},
{
name: "no prefix",
fn: &FuncNode{
RecvType: "Atype",
Package: &packages.Package{PkgPath: "example.com/a/pkg"},
},
want: "Atype",
},
}
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
got := tc.fn.Receiver()
if got != tc.want {
t.Errorf("want %s; got %s", tc.want, got)
}
})
}
}
@@ -0,0 +1,447 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package vulncheck
import (
"container/list"
"fmt"
"go/ast"
"go/token"
"sort"
"strconv"
"strings"
"sync"
"unicode"
"golang.org/x/tools/go/packages"
)
// CallStack is a call stack starting with a client
// function or method and ending with a call to a
// vulnerable symbol.
type CallStack []StackEntry
// StackEntry is an element of a call stack.
type StackEntry struct {
// Function whose frame is on the stack.
Function *FuncNode
// Call is the call site inducing the next stack frame.
// nil when the frame represents the last frame in the stack.
Call *CallSite
}
// sourceCallstacks returns representative call stacks for each
// vulnerability in res. The returned call stacks are heuristically
// ordered by how seemingly easy is to understand them: shorter
// call stacks with less dynamic call sites appear earlier in the
// returned slices.
//
// sourceCallstacks performs a breadth-first search of res.CallGraph
// starting at the vulnerable symbol and going up until reaching an entry
// function or method in res.CallGraph.Entries. During this search,
// each function is visited at most once to avoid potential
// exponential explosion. Hence, not all call stacks are analyzed.
func sourceCallstacks(res *Result) map[*Vuln]CallStack {
var (
wg sync.WaitGroup
mu sync.Mutex
)
stackPerVuln := make(map[*Vuln]CallStack)
for _, vuln := range res.Vulns {
vuln := vuln
wg.Add(1)
go func() {
cs := sourceCallstack(vuln, res)
mu.Lock()
stackPerVuln[vuln] = cs
mu.Unlock()
wg.Done()
}()
}
wg.Wait()
updateInitPositions(stackPerVuln)
return stackPerVuln
}
// sourceCallstack finds a representative call stack for vuln.
// This is a shortest unique call stack with the least
// number of dynamic call sites.
func sourceCallstack(vuln *Vuln, res *Result) CallStack {
vulnSink := vuln.CallSink
if vulnSink == nil {
return nil
}
entries := make(map[*FuncNode]bool)
for _, e := range res.EntryFunctions {
entries[e] = true
}
seen := make(map[*FuncNode]bool)
// Do a BFS from the vuln sink to the entry points
// and find the representative call stack. This is
// the shortest call stack that goes through the
// least number of dynamic call sites. We first
// collect all candidate call stacks of the shortest
// length and then pick the best one accordingly.
var candidates []CallStack
candDepth := 0
queue := list.New()
queue.PushBack(&callChain{f: vulnSink})
// We want to avoid call stacks that go through
// other vulnerable symbols of the same package
// for the same vulnerability. In other words,
// we want unique call stacks.
skipSymbols := make(map[*FuncNode]bool)
for _, v := range res.Vulns {
if v.CallSink != nil && v != vuln &&
v.OSV == vuln.OSV && v.Package == vuln.Package {
skipSymbols[v.CallSink] = true
}
}
for queue.Len() > 0 {
front := queue.Front()
c := front.Value.(*callChain)
queue.Remove(front)
f := c.f
if seen[f] {
continue
}
seen[f] = true
// Pick a single call site for each function in determinstic order.
// A single call site is sufficient as we visit a function only once.
for _, cs := range callsites(f.CallSites, seen) {
nStack := &callChain{f: cs.Parent, call: cs, child: c}
if !skipSymbols[cs.Parent] {
queue.PushBack(nStack)
}
if entries[cs.Parent] {
ns := nStack.CallStack()
if len(candidates) == 0 || len(ns) == candDepth {
// The case where we either have not identified
// any call stacks or just found one of the same
// length as the previous ones.
candidates = append(candidates, ns)
candDepth = len(ns)
} else {
// We just found a candidate call stack whose
// length is greater than what we previously
// found. We can thus safely disregard this
// call stack and stop searching since we won't
// be able to find any better candidates.
queue.Init() // clear the list, effectively exiting the outer loop
}
}
}
}
// Sort candidate call stacks by their number of dynamic call
// sites and return the first one.
sort.SliceStable(candidates, func(i int, j int) bool {
s1, s2 := candidates[i], candidates[j]
if w1, w2 := weight(s1), weight(s2); w1 != w2 {
return w1 < w2
}
// At this point, the stableness/determinism of
// sorting is guaranteed by the determinism of
// the underlying call graph and the call stack
// search algorithm.
return true
})
if len(candidates) == 0 {
return nil
}
return candidates[0]
}
// callsites picks a call site from sites for each non-visited function.
// For each such function, the smallest (posLess) call site is chosen. The
// returned slice is sorted by caller functions (funcLess). Assumes callee
// of each call site is the same.
func callsites(sites []*CallSite, visited map[*FuncNode]bool) []*CallSite {
minCs := make(map[*FuncNode]*CallSite)
for _, cs := range sites {
if visited[cs.Parent] {
continue
}
if csLess(cs, minCs[cs.Parent]) {
minCs[cs.Parent] = cs
}
}
var fs []*FuncNode
for _, cs := range minCs {
fs = append(fs, cs.Parent)
}
sort.SliceStable(fs, func(i, j int) bool { return funcLess(fs[i], fs[j]) })
var css []*CallSite
for _, f := range fs {
css = append(css, minCs[f])
}
return css
}
// callChain models a chain of function calls.
type callChain struct {
call *CallSite // nil for entry points
f *FuncNode
child *callChain
}
// CallStack converts callChain to CallStack type.
func (c *callChain) CallStack() CallStack {
if c == nil {
return nil
}
return append(CallStack{StackEntry{Function: c.f, Call: c.call}}, c.child.CallStack()...)
}
// weight computes an approximate measure of how easy is to understand the call
// stack when presented to the client as a witness. The smaller the value, the more
// understandable the stack is. Currently defined as the number of unresolved
// call sites in the stack.
func weight(stack CallStack) int {
w := 0
for _, e := range stack {
if e.Call != nil && !e.Call.Resolved {
w += 1
}
}
return w
}
// csLess compares two call sites by their locations and, if needed,
// their string representation.
func csLess(cs1, cs2 *CallSite) bool {
if cs2 == nil {
return true
}
// fast code path
if p1, p2 := cs1.Pos, cs2.Pos; p1 != nil && p2 != nil {
if posLess(*p1, *p2) {
return true
}
if posLess(*p2, *p1) {
return false
}
// for sanity, should not occur in practice
return fmt.Sprintf("%v.%v", cs1.RecvType, cs2.Name) < fmt.Sprintf("%v.%v", cs2.RecvType, cs2.Name)
}
// code path rarely exercised
if cs2.Pos == nil {
return true
}
if cs1.Pos == nil {
return false
}
// should very rarely occur in practice
return fmt.Sprintf("%v.%v", cs1.RecvType, cs2.Name) < fmt.Sprintf("%v.%v", cs2.RecvType, cs2.Name)
}
// posLess compares two positions by their line and column number,
// and filename if needed.
func posLess(p1, p2 token.Position) bool {
if p1.Line < p2.Line {
return true
}
if p2.Line < p1.Line {
return false
}
if p1.Column < p2.Column {
return true
}
if p2.Column < p1.Column {
return false
}
return strings.Compare(p1.Filename, p2.Filename) == -1
}
// funcLess compares two function nodes by locations of
// corresponding functions and, if needed, their string representation.
func funcLess(f1, f2 *FuncNode) bool {
if p1, p2 := f1.Pos, f2.Pos; p1 != nil && p2 != nil {
if posLess(*p1, *p2) {
return true
}
if posLess(*p2, *p1) {
return false
}
// for sanity, should not occur in practice
return f1.String() < f2.String()
}
if f2.Pos == nil {
return true
}
if f1.Pos == nil {
return false
}
// should happen only for inits
return f1.String() < f2.String()
}
// updateInitPositions populates non-existing positions of init functions
// and their respective calls in callStacks (see #51575).
func updateInitPositions(callStacks map[*Vuln]CallStack) {
for _, cs := range callStacks {
for i := range cs {
updateInitPosition(&cs[i])
if i != len(cs)-1 {
updateInitCallPosition(&cs[i], cs[i+1])
}
}
}
}
// updateInitCallPosition updates the position of a call to init in a stack frame, if
// one already does not exist:
//
// P1.init -> P2.init: position of call to P2.init is the position of "import P2"
// statement in P1
//
// P.init -> P.init#d: P.init is an implicit init. We say it calls the explicit
// P.init#d at the place of "package P" statement.
func updateInitCallPosition(curr *StackEntry, next StackEntry) {
call := curr.Call
if !isInit(next.Function) || (call.Pos != nil && call.Pos.IsValid()) {
// Skip non-init functions and inits whose call site position is available.
return
}
var pos token.Position
if curr.Function.Name == "init" && curr.Function.Package == next.Function.Package {
// We have implicit P.init calling P.init#d. Set the call position to
// be at "package P" statement position.
pos = packageStatementPos(curr.Function.Package)
} else {
// Choose the beginning of the import statement as the position.
pos = importStatementPos(curr.Function.Package, next.Function.Package.PkgPath)
}
call.Pos = &pos
}
func importStatementPos(pkg *packages.Package, importPath string) token.Position {
var importSpec *ast.ImportSpec
spec:
for _, f := range pkg.Syntax {
for _, impSpec := range f.Imports {
// Import spec paths have quotation marks.
impSpecPath, err := strconv.Unquote(impSpec.Path.Value)
if err != nil {
panic(fmt.Sprintf("import specification: package path has no quotation marks: %v", err))
}
if impSpecPath == importPath {
importSpec = impSpec
break spec
}
}
}
if importSpec == nil {
// for sanity, in case of a wild call graph imprecision
return token.Position{}
}
// Choose the beginning of the import statement as the position.
return pkg.Fset.Position(importSpec.Pos())
}
func packageStatementPos(pkg *packages.Package) token.Position {
if len(pkg.Syntax) == 0 {
return token.Position{}
}
// Choose beginning of the package statement as the position. Pick
// the first file since it is as good as any.
return pkg.Fset.Position(pkg.Syntax[0].Package)
}
// updateInitPosition updates the position of P.init function in a stack frame if one
// is not available. The new position is the position of the "package P" statement.
func updateInitPosition(se *StackEntry) {
fun := se.Function
if !isInit(fun) || (fun.Pos != nil && fun.Pos.IsValid()) {
// Skip non-init functions and inits whose position is available.
return
}
pos := packageStatementPos(fun.Package)
fun.Pos = &pos
}
func isInit(f *FuncNode) bool {
// A source init function, or anonymous functions used in inits, will
// be named "init#x" by vulncheck (more precisely, ssa), where x is a
// positive integer. Implicit inits are named simply "init".
return f.Name == "init" || strings.HasPrefix(f.Name, "init#")
}
// binaryCallstacks computes representative call stacks for binary results.
func binaryCallstacks(vr *Result) map[*Vuln]CallStack {
callstacks := map[*Vuln]CallStack{}
for _, vv := range uniqueVulns(vr.Vulns) {
f := &FuncNode{Package: vv.Package, Name: vv.Symbol}
parts := strings.Split(vv.Symbol, ".")
if len(parts) != 1 {
f.RecvType = parts[0]
f.Name = parts[1]
}
callstacks[vv] = CallStack{StackEntry{Function: f}}
}
return callstacks
}
// uniqueVulns does for binary mode what sourceCallstacks does for source mode.
// It tries not to report redundant symbols. Since there are no call stacks in
// binary mode, the following approximate approach is used. Do not report unexported
// symbols for a <vulnID, pkg, module> triple if there are some exported symbols.
// Otherwise, report all unexported symbols to avoid not reporting anything.
func uniqueVulns(vulns []*Vuln) []*Vuln {
type key struct {
id string
pkg string
mod string
}
hasExported := make(map[key]bool)
for _, v := range vulns {
if isExported(v.Symbol) {
k := key{id: v.OSV.ID, pkg: v.Package.PkgPath, mod: v.Package.Module.Path}
hasExported[k] = true
}
}
var uniques []*Vuln
for _, v := range vulns {
k := key{id: v.OSV.ID, pkg: v.Package.PkgPath, mod: v.Package.Module.Path}
if isExported(v.Symbol) || !hasExported[k] {
uniques = append(uniques, v)
}
}
return uniques
}
// isExported checks if the symbol is exported. Assumes that the
// symbol is of the form "identifier" or "identifier1.identifier2".
func isExported(symbol string) bool {
parts := strings.Split(symbol, ".")
if len(parts) == 1 {
return unicode.IsUpper(rune(symbol[0]))
}
return unicode.IsUpper(rune(parts[1][0]))
}
@@ -0,0 +1,266 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package vulncheck
import (
"context"
"fmt"
"path"
"path/filepath"
"reflect"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"golang.org/x/tools/go/packages"
"golang.org/x/tools/go/packages/packagestest"
"golang.org/x/vuln/internal/client"
"golang.org/x/vuln/internal/govulncheck"
"golang.org/x/vuln/internal/osv"
"golang.org/x/vuln/internal/test"
)
// stacksToString converts map *Vuln:stack to Vuln.Symbol:"f1->...->fN"
// string representation.
func stacksToString(stacks map[*Vuln]CallStack) map[string]string {
m := make(map[string]string)
for v, st := range stacks {
var stStr []string
for _, call := range st {
stStr = append(stStr, call.Function.Name)
}
m[v.Symbol] = strings.Join(stStr, "->")
}
return m
}
func TestSourceCallstacks(t *testing.T) {
// Call graph structure for the test program
// entry1 entry2
// | |
// interm1 |
// | \ /
// | interm2(interface)
// | / |
// vuln1 vuln2
o := &osv.Entry{ID: "o"}
e1 := &FuncNode{Name: "entry1"}
e2 := &FuncNode{Name: "entry2"}
i1 := &FuncNode{Name: "interm1", CallSites: []*CallSite{{Parent: e1, Resolved: true}}}
i2 := &FuncNode{Name: "interm2", CallSites: []*CallSite{{Parent: e2, Resolved: true}, {Parent: i1, Resolved: true}}}
v1 := &FuncNode{Name: "vuln1", CallSites: []*CallSite{{Parent: i1, Resolved: true}, {Parent: i2, Resolved: false}}}
v2 := &FuncNode{Name: "vuln2", CallSites: []*CallSite{{Parent: i2, Resolved: false}}}
vp := &packages.Package{PkgPath: "v1", Module: &packages.Module{Path: "m1"}}
vuln1 := &Vuln{CallSink: v1, Package: vp, OSV: o, Symbol: "vuln1"}
vuln2 := &Vuln{CallSink: v2, Package: vp, OSV: o, Symbol: "vuln2"}
res := &Result{
EntryFunctions: []*FuncNode{e1, e2},
Vulns: []*Vuln{vuln1, vuln2},
}
want := map[string]string{
"vuln1": "entry1->interm1->vuln1",
"vuln2": "entry2->interm2->vuln2",
}
stacks := sourceCallstacks(res)
if got := stacksToString(stacks); !reflect.DeepEqual(want, got) {
t.Errorf("want %v; got %v", want, got)
}
}
func TestSourceUniqueCallStack(t *testing.T) {
// Call graph structure for the test program
// entry1 entry2
// | |
// vuln1 interm1
// | |
// | interm2
// | /
// vuln2
o := &osv.Entry{ID: "o"}
e1 := &FuncNode{Name: "entry1"}
e2 := &FuncNode{Name: "entry2"}
i1 := &FuncNode{Name: "interm1", CallSites: []*CallSite{{Parent: e2}}}
i2 := &FuncNode{Name: "interm2", CallSites: []*CallSite{{Parent: i1}}}
v1 := &FuncNode{Name: "vuln1", CallSites: []*CallSite{{Parent: e1}}}
v2 := &FuncNode{Name: "vuln2", CallSites: []*CallSite{{Parent: v1}, {Parent: i2}}}
vp := &packages.Package{PkgPath: "v1", Module: &packages.Module{Path: "m1"}}
vuln1 := &Vuln{CallSink: v1, Package: vp, OSV: o, Symbol: "vuln1"}
vuln2 := &Vuln{CallSink: v2, Package: vp, OSV: o, Symbol: "vuln2"}
res := &Result{
EntryFunctions: []*FuncNode{e1, e2},
Vulns: []*Vuln{vuln1, vuln2},
}
want := map[string]string{
"vuln1": "entry1->vuln1",
"vuln2": "entry2->interm1->interm2->vuln2",
}
stacks := sourceCallstacks(res)
if got := stacksToString(stacks); !reflect.DeepEqual(want, got) {
t.Errorf("want %v; got %v", want, got)
}
}
// TestInits checks for correct positions of init functions
// and their respective calls (see #51575).
func TestInits(t *testing.T) {
testClient, err := client.NewInMemoryClient(
[]*osv.Entry{
{
ID: "A", Affected: []osv.Affected{{Module: osv.Module{Path: "golang.org/amod"}, Ranges: []osv.Range{{Type: osv.RangeTypeSemver}},
EcosystemSpecific: osv.EcosystemSpecific{Packages: []osv.Package{{
Path: "golang.org/amod/avuln", Symbols: []string{"A"}},
}},
}},
},
{
ID: "C", Affected: []osv.Affected{{Module: osv.Module{Path: "golang.org/cmod"}, Ranges: []osv.Range{{Type: osv.RangeTypeSemver}},
EcosystemSpecific: osv.EcosystemSpecific{Packages: []osv.Package{{
Path: "golang.org/cmod/cvuln", Symbols: []string{"C"}},
}},
}},
},
})
if err != nil {
t.Fatal(err)
}
e := packagestest.Export(t, packagestest.Modules, []packagestest.Module{
{
Name: "golang.org/entry",
Files: map[string]interface{}{
"x/x.go": `
package x
import (
_ "golang.org/amod/avuln"
_ "golang.org/bmod/b"
)
`,
},
},
{
Name: "golang.org/amod@v0.5.0",
Files: map[string]interface{}{"avuln/avuln.go": `
package avuln
func init() {
A()
}
func A() {}
`},
},
{
Name: "golang.org/bmod@v0.5.0",
Files: map[string]interface{}{"b/b.go": `
package b
import _ "golang.org/cmod/cvuln"
`},
},
{
Name: "golang.org/cmod@v0.5.0",
Files: map[string]interface{}{"cvuln/cvuln.go": `
package cvuln
var x int = C()
func C() int {
return 0
}
`},
},
})
defer e.Cleanup()
// Load x as entry package.
graph := NewPackageGraph("go1.18")
pkgs, mods, err := graph.LoadPackagesAndMods(e.Config, nil, []string{path.Join(e.Temp(), "entry/x")})
if err != nil {
t.Fatal(err)
}
if len(pkgs) != 1 {
t.Fatal("failed to load x test package")
}
cfg := &govulncheck.Config{ScanLevel: "symbol"}
result, err := source(context.Background(), test.NewMockHandler(), pkgs, mods, cfg, testClient, graph)
if err != nil {
t.Fatal(err)
}
cs := sourceCallstacks(result)
want := map[string][]string{
"A": {
// Entry init's position is the package statement.
// It calls avuln.init at avuln import statement.
"N:golang.org/entry/x.init F:x.go:2:4 C:x.go:5:5",
// implicit avuln.init is calls explicit init at the avuln
// package statement.
"N:golang.org/amod/avuln.init F:avuln.go:2:4 C:avuln.go:2:4",
"N:golang.org/amod/avuln.init#1 F:avuln.go:4:9 C:avuln.go:5:6",
"N:golang.org/amod/avuln.A F:avuln.go:8:9 C:",
},
"C": {
"N:golang.org/entry/x.init F:x.go:2:4 C:x.go:6:5",
"N:golang.org/bmod/b.init F:b.go:2:4 C:b.go:4:11",
"N:golang.org/cmod/cvuln.init F:cvuln.go:2:4 C:cvuln.go:4:17",
"N:golang.org/cmod/cvuln.C F:cvuln.go:6:9 C:",
},
}
if diff := cmp.Diff(want, fullStacksToString(cs)); diff != "" {
t.Errorf("modules mismatch (-want, +got):\n%s", diff)
}
}
// fullStacksToString is like stacksToString but the stack stringification
// is a slice of strings, each containing detailed information on each on
// the corresponding frame.
func fullStacksToString(callStacks map[*Vuln]CallStack) map[string][]string {
m := make(map[string][]string)
for v, cs := range callStacks {
var scs []string
for _, se := range cs {
fPos := se.Function.Pos
fp := fmt.Sprintf("%s:%d:%d", filepath.Base(fPos.Filename), fPos.Line, fPos.Column)
var cp string
if se.Call != nil && se.Call.Pos.IsValid() {
cPos := se.Call.Pos
cp = fmt.Sprintf("%s:%d:%d", filepath.Base(cPos.Filename), cPos.Line, cPos.Column)
}
sse := fmt.Sprintf("N:%s.%s\tF:%v\tC:%v", se.Function.Package.PkgPath, se.Function.Name, fp, cp)
scs = append(scs, sse)
}
m[v.OSV.ID] = scs
}
return m
}
func TestIsExported(t *testing.T) {
for _, tc := range []struct {
symbol string
want bool
}{
{"foo", false},
{"Foo", true},
{"x.foo", false},
{"X.foo", false},
{"x.Foo", true},
{"X.Foo", true},
} {
tc := tc
t.Run(tc.symbol, func(t *testing.T) {
if got := isExported(tc.symbol); tc.want != got {
t.Errorf("want %t; got %t", tc.want, got)
}
})
}
}