From c88030e856665af0a7b59f1798ebeac30af1d6a1 Mon Sep 17 00:00:00 2001 From: Ian Alexander Date: Mon, 11 Aug 2025 14:28:10 -0400 Subject: [PATCH] rf: accept receivers with multiple type parameters This commit adds support for type-parameterized receivers, in particular those with multiple type parameters. --- inject.go | 18 ++++++++++++++++-- refactor/pkgref.go | 6 +++++- refactor/snap.go | 2 +- testdata/inject3.txt | 44 ++++++++++++++++++++++++++++++++++++++++++++ testdata/inject4.txt | 44 ++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 110 insertions(+), 4 deletions(-) create mode 100644 testdata/inject3.txt create mode 100644 testdata/inject4.txt diff --git a/inject.go b/inject.go index 0baabbc..1037911 100644 --- a/inject.go +++ b/inject.go @@ -14,6 +14,20 @@ import ( "rsc.io/rf/refactor" ) +// origin returns the origin for the given [types.Object]. Normally, this is +// the object itself, but for instantiated methods or fields it is the +// corresponding object on the generic type. +func origin(obj types.Object) types.Object { + switch t := obj.(type) { + case *types.Func: + return t.Origin() + case *types.Var: + return t.Origin() + default: + return obj + } +} + func cmdInject(snap *refactor.Snapshot, args string) error { items, _ := snap.EvalList(args) if len(items) < 2 { @@ -101,7 +115,7 @@ func cmdInject(snap *refactor.Snapshot, args string) error { if !ok || len(stack) < 2 { return } - obj := pkg.TypesInfo.Uses[id] + obj := origin(pkg.TypesInfo.Uses[id]) if converting[obj] == "" { return } @@ -165,7 +179,7 @@ func cmdInject(snap *refactor.Snapshot, args string) error { if !ok || len(stack) < 2 { return } - obj := pkg.TypesInfo.Uses[id] + obj := origin(pkg.TypesInfo.Uses[id]) if converting[obj] == "" && obj != targetObj { return } diff --git a/refactor/pkgref.go b/refactor/pkgref.go index 92ca09c..213dcf8 100644 --- a/refactor/pkgref.go +++ b/refactor/pkgref.go @@ -225,7 +225,11 @@ func (s *Snapshot) addPkgDeps(g *DepsGraph, p *Package) { t = p.X } if i, ok := t.(*ast.IndexExpr); ok { - // Method with a type-parameterized receiver. + // Method with a receiver having a single type parameter. + t = i.X + } + if i, ok := t.(*ast.IndexListExpr); ok { + // Method with receiver having multiple type parameters. t = i.X } id, ok := t.(*ast.Ident) diff --git a/refactor/snap.go b/refactor/snap.go index 2c0e90a..2083ee5 100644 --- a/refactor/snap.go +++ b/refactor/snap.go @@ -182,7 +182,7 @@ func (r *Refactor) load1(config Config) ([]*Snapshot, error) { } mod := strings.TrimSpace(string(bmod)) if filepath.Base(mod) != "go.mod" { - return nil, fmt.Errorf("no module found for " + dir) + return nil, fmt.Errorf("no module found for: %s", dir) } r.modRoot = filepath.Dir(mod) diff --git a/testdata/inject3.txt b/testdata/inject3.txt new file mode 100644 index 0000000..81e333d --- /dev/null +++ b/testdata/inject3.txt @@ -0,0 +1,44 @@ +inject G f +-- x.go -- +package p + +var G int + +type C[U any] struct{} + +type T struct{} + +func f() { g() } +func g() { g2(1) } +func g2(g int) { h(g) } +func h(i int) { var c C[T]; c.j(i+1) } +func (c *C[U]) j(k int) { m(k+G) } +func m(l int) { println(2*G+l) } +func z() { h(2) } +func y() {} + +-- stdout -- +diff old/x.go new/x.go +--- old/x.go ++++ new/x.go +@@ -6,12 +6,11 @@ + + type T struct{} + +-func f() { g() } +-func g() { g2(1) } +-func g2(g int) { h(g) } +-func h(i int) { var c C[T]; c.j(i+1) } +-func (c *C[U]) j(k int) { m(k+G) } +-func m(l int) { println(2*G+l) } +-func z() { h(2) } +-func y() {} +- ++func f() { g(G) } ++func g(g int) { g2(g, 1) } ++func g2(g_ int, g int) { h(g_, g) } ++func h(g int, i int) { var c C[T]; c.j(g, i+1) } ++func (c *C[U]) j(g int, k int) { m(g, k+g) } ++func m(g int, l int) { println(2*g + l) } ++func z() { h(G, 2) } ++func y() {} diff --git a/testdata/inject4.txt b/testdata/inject4.txt new file mode 100644 index 0000000..83ca3b9 --- /dev/null +++ b/testdata/inject4.txt @@ -0,0 +1,44 @@ +inject G f +-- x.go -- +package p + +var G int + +type C[K, V any] struct{} + +type T struct{} + +func f() { g() } +func g() { g2(1) } +func g2(g int) { h(g) } +func h(i int) { var c C[T, T]; c.j(i+1) } +func (c *C[K, V]) j(k int) { m(k+G) } +func m(l int) { println(2*G+l) } +func z() { h(2) } +func y() {} + +-- stdout -- +diff old/x.go new/x.go +--- old/x.go ++++ new/x.go +@@ -6,12 +6,11 @@ + + type T struct{} + +-func f() { g() } +-func g() { g2(1) } +-func g2(g int) { h(g) } +-func h(i int) { var c C[T, T]; c.j(i+1) } +-func (c *C[K, V]) j(k int) { m(k+G) } +-func m(l int) { println(2*G+l) } +-func z() { h(2) } +-func y() {} +- ++func f() { g(G) } ++func g(g int) { g2(g, 1) } ++func g2(g_ int, g int) { h(g_, g) } ++func h(g int, i int) { var c C[T, T]; c.j(g, i+1) } ++func (c *C[K, V]) j(g int, k int) { m(g, k+g) } ++func m(g int, l int) { println(2*g + l) } ++func z() { h(G, 2) } ++func y() {}