From 23a084832d724eca26f2520102ea8a25415e0efe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=AE=B5=E4=BB=AA?= Date: Sun, 17 Aug 2025 18:55:11 +0800 Subject: [PATCH] feat(go): support collect type parameters --- lang/golang/parser/file.go | 76 ++++++++++++++--------------- lang/golang/parser/utils.go | 52 +++++--------------- testdata/go/0_golang/pkg/generic.go | 39 +++++++++++++++ 3 files changed, 87 insertions(+), 80 deletions(-) create mode 100644 testdata/go/0_golang/pkg/generic.go diff --git a/lang/golang/parser/file.go b/lang/golang/parser/file.go index d5ec8701..f77f5841 100644 --- a/lang/golang/parser/file.go +++ b/lang/golang/parser/file.go @@ -120,7 +120,7 @@ func (p *GoParser) parseVar(ctx *fileContext, vspec *ast.ValueSpec, isConst bool v = p.newVar(ctx.module.Name, ctx.pkgPath, name.Name, isConst) v.FileLine = ctx.FileLine(vspec) - // always collect value's dependencies + // collect func value dependencies, in case of var a = func() {...} if val != nil && !isConst { collects := collectInfos{} ast.Inspect(*val, func(n ast.Node) bool { @@ -159,38 +159,13 @@ func (p *GoParser) parseVar(ctx *fileContext, vspec *ast.ValueSpec, isConst bool if isConst && v.Type == nil { v.Type = lastType } - var varType string - if v.Type != nil { - if v.Type.PkgPath == ctx.pkgPath { - varType = v.Type.Name - } else { - varType = v.Type.CallName() - } - if v.IsPointer { - varType = "*" + varType - } - } if !isConst { - v.Content = fmt.Sprintf("var %s %s", name.Name, varType) + v.Content = "var " + string(ctx.GetRawContent(vspec)) } else { - if varType != "" { - v.Content = fmt.Sprintf("const %s %s", name.Name, varType) - } else { - v.Content = fmt.Sprintf("const %s", name.Name) - } + v.Content = "const " + string(ctx.GetRawContent(vspec)) } - var comment string - if ctx.collectComment && doc != nil { - comment += string(ctx.GetRawContent(doc)) + "\n" - } - if ctx.collectComment && vspec.Doc != nil { - comment += string(ctx.GetRawContent(vspec.Doc)) + "\n" - v.FileLine.StartOffset = ctx.fset.Position(vspec.Pos()).Offset - } - v.Content = comment + v.Content - var finalVal string if val != nil { // refer codes @@ -229,11 +204,20 @@ func (p *GoParser) parseVar(ctx *fileContext, vspec *ast.ValueSpec, isConst bool lastValue = &tmp finalVal = strconv.FormatFloat(tmp, 'f', -1, 64) } - - if finalVal != "" { + if finalVal != "" && !strings.Contains(v.Content, " = ") { v.Content += " = " + finalVal } + var comment string + if ctx.collectComment && doc != nil { + comment += string(ctx.GetRawContent(doc)) + "\n" + } + if ctx.collectComment && vspec.Doc != nil { + comment += string(ctx.GetRawContent(vspec.Doc)) + "\n" + v.FileLine.StartOffset = ctx.fset.Position(vspec.Pos()).Offset + } + v.Content = comment + v.Content + typ = v.Type } return typ, v, lastValue @@ -441,22 +425,23 @@ func (p *GoParser) parseASTNode(ctx *fileContext, node ast.Node, collect *collec func (p *GoParser) parseFunc(ctx *fileContext, funcDecl *ast.FuncDecl) (*Function, bool) { // method receiver var receiver *Receiver - isMethod := funcDecl.Recv != nil - if strings.HasSuffix(ctx.filePath, "cmds/life_stat/main.go") && funcDecl.Name.Name == "init" { - - } + var tparams []Dependency + isMethod := funcDecl.Recv != nil && len(funcDecl.Recv.List) > 0 if isMethod { - // TODO: reserve the pointer message? - ti := ctx.GetTypeInfo(funcDecl.Recv.List[0].Type) - // name := "self" - // if len(funcDecl.Recv.List[0].Names) > 0 { - // name = funcDecl.Recv.List[0].Names[0].Name - // } + rt := funcDecl.Recv.List[0].Type + ti := ctx.GetTypeInfo(rt) receiver = &Receiver{ Type: ti.Id, IsPointer: ti.IsPointer, // Name: name, } + // collect receiver's type params + for _, d := range ti.Deps { + tparams = append(tparams, Dependency{ + Identity: d, + FileLine: ctx.FileLine(rt), // FIXME: location is not accurate, try parse Index AST to get it. + }) + } } fname := funcDecl.Name.Name @@ -474,6 +459,10 @@ func (p *GoParser) parseFunc(ctx *fileContext, funcDecl *ast.FuncDecl) (*Functio if funcDecl.Type.Results != nil { ctx.collectFields(funcDecl.Type.Results.List, &results) } + // collect type params + if funcDecl.Type.TypeParams != nil { + ctx.collectFields(funcDecl.Type.TypeParams.List, &tparams) + } // collect signature sig := ctx.GetRawContent(funcDecl.Type) @@ -510,6 +499,9 @@ set_func: f.Results = results f.GlobalVars = collects.globalVars f.Types = collects.tys + for _, t := range tparams { + f.Types = InsertDependency(f.Types, t) + } f.Signature = string(sig) return f, false } @@ -534,6 +526,10 @@ func (p *GoParser) parseType(ctx *fileContext, typDecl *ast.TypeSpec, doc *ast.C } } + if typDecl.TypeParams != nil { + ctx.collectFields(typDecl.TypeParams.List, &st.SubStruct) + } + st.FileLine = ctx.FileLine(typDecl) st.Content = string(ctx.GetRawContent(typDecl)) if ctx.collectComment && doc != nil { diff --git a/lang/golang/parser/utils.go b/lang/golang/parser/utils.go index 338c31f7..dedfbf7d 100644 --- a/lang/golang/parser/utils.go +++ b/lang/golang/parser/utils.go @@ -19,8 +19,6 @@ import ( "bytes" "fmt" "go/ast" - "go/parser" - "go/token" "go/types" "io" "os" @@ -51,35 +49,12 @@ func (c cache) Visited(val interface{}) bool { return ok } -func hasMain(file []byte) bool { - if !bytes.Contains(file, []byte("package main")) || !bytes.Contains(file, []byte("func main()")) { - return false - } - fset := token.NewFileSet() - f, err := parser.ParseFile(fset, "any.go", file, parser.SkipObjectResolution) - if err != nil { - return false - } - if f.Name.Name != "main" { - return false - } - for _, decl := range f.Decls { - if funcDecl, ok := decl.(*ast.FuncDecl); ok { - if funcDecl.Name.Name == "main" { - return true - } - } - } - return false -} - func isSysPkg(importPath string) bool { return !strings.Contains(strings.Split(importPath, "/")[0], ".") } var ( verReg = regexp.MustCompile(`/v\d+$`) - litReg = regexp.MustCompile(`[^a-zA-Z0-9_]`) ) func getPackageAlias(importPath string) string { @@ -98,14 +73,6 @@ func getPackageAlias(importPath string) string { return alias } -func splitVersion(module string) (string, string) { - if strings.Contains(module, "@") { - parts := strings.Split(module, "@") - return parts[0], parts[1] - } - return module, "" -} - func getModuleName(modFilePath string) (string, []byte, error) { file, err := os.Open(modFilePath) if err != nil { @@ -218,6 +185,18 @@ func getNamedTypes(typ types.Type, visited map[types.Type]bool) (tys []types.Obj case *types.Named: tys = append(tys, t.Obj()) isNamed = true + if targs := t.TypeArgs(); targs != nil { + for i := 0; i < targs.Len(); i++ { + typs, _, _ := getNamedTypes(targs.At(i), visited) + tys = append(tys, typs...) + } + } + if tparams := t.TypeParams(); tparams != nil { + for i := 0; i < tparams.Len(); i++ { + typs, _, _ := getNamedTypes(tparams.At(i), visited) + tys = append(tys, typs...) + } + } case *types.Struct: for i := 0; i < t.NumFields(); i++ { typs, _, _ := getNamedTypes(t.Field(i).Type(), visited) @@ -252,13 +231,6 @@ func getNamedTypes(typ types.Type, visited map[types.Type]bool) (tys []types.Obj return } -func extractName(typ string) string { - if strings.Contains(typ, ".") { - return strings.Split(typ, ".")[1] - } - return typ -} - func parseExpr(expr string) (interface{}, error) { // Create a map of parameters to pass to the expression evaluator. parameters := map[string]interface{}{ diff --git a/testdata/go/0_golang/pkg/generic.go b/testdata/go/0_golang/pkg/generic.go new file mode 100644 index 00000000..59c6357f --- /dev/null +++ b/testdata/go/0_golang/pkg/generic.go @@ -0,0 +1,39 @@ +/** + * Copyright 2025 ByteDance Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package pkg + +import ( + "fmt" + + "a.b/c/pkg/entity" +) + +type CaseGenericStruct[T entity.InterfaceB, U InterfaceA, V any] struct { + Prefix T + Subfix U + Data V +} + +func (s *CaseGenericStruct[_, _, _]) String() string { + return s.Prefix.String() + fmt.Sprintf("%v", s.Data) + s.Subfix.String() +} + +func CaseGenericFunc[U InterfaceA, T entity.InterfaceB, V any](a T, b U, c V) string { + return a.String() + fmt.Sprintf("%v", c) + b.String() +} + +var CaseGenericVar CaseGenericStruct[entity.InterfaceB, InterfaceA, int]