4c060a78cc
Co-authored-by: Mathieu Lonjaret <mathieu.lonjaret@gmail.com>
452 lines
11 KiB
Go
452 lines
11 KiB
Go
/*
|
|
Copyright 2015 The Kubernetes Authors.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
*/
|
|
|
|
package protobuf
|
|
|
|
import (
|
|
"bytes"
|
|
"errors"
|
|
"fmt"
|
|
"go/ast"
|
|
"go/format"
|
|
"go/parser"
|
|
"go/printer"
|
|
"go/token"
|
|
"io/ioutil"
|
|
"os"
|
|
"reflect"
|
|
"strings"
|
|
|
|
customreflect "k8s.io/code-generator/third_party/forked/golang/reflect"
|
|
)
|
|
|
|
func rewriteFile(name string, header []byte, rewriteFn func(*token.FileSet, *ast.File) error) error {
|
|
fset := token.NewFileSet()
|
|
src, err := ioutil.ReadFile(name)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
file, err := parser.ParseFile(fset, name, src, parser.DeclarationErrors|parser.ParseComments)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := rewriteFn(fset, file); err != nil {
|
|
return err
|
|
}
|
|
|
|
b := &bytes.Buffer{}
|
|
b.Write(header)
|
|
if err := printer.Fprint(b, fset, file); err != nil {
|
|
return err
|
|
}
|
|
|
|
body, err := format.Source(b.Bytes())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
f, err := os.OpenFile(name, os.O_WRONLY|os.O_TRUNC, 0644)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer f.Close()
|
|
if _, err := f.Write(body); err != nil {
|
|
return err
|
|
}
|
|
return f.Close()
|
|
}
|
|
|
|
// ExtractFunc extracts information from the provided TypeSpec and returns true if the type should be
|
|
// removed from the destination file.
|
|
type ExtractFunc func(*ast.TypeSpec) bool
|
|
|
|
// OptionalFunc returns true if the provided local name is a type that has protobuf.nullable=true
|
|
// and should have its marshal functions adjusted to remove the 'Items' accessor.
|
|
type OptionalFunc func(name string) bool
|
|
|
|
func RewriteGeneratedGogoProtobufFile(name string, extractFn ExtractFunc, optionalFn OptionalFunc, header []byte) error {
|
|
return rewriteFile(name, header, func(fset *token.FileSet, file *ast.File) error {
|
|
cmap := ast.NewCommentMap(fset, file, file.Comments)
|
|
|
|
// transform methods that point to optional maps or slices
|
|
for _, d := range file.Decls {
|
|
rewriteOptionalMethods(d, optionalFn)
|
|
}
|
|
|
|
// remove types that are already declared
|
|
decls := []ast.Decl{}
|
|
for _, d := range file.Decls {
|
|
if dropExistingTypeDeclarations(d, extractFn) {
|
|
continue
|
|
}
|
|
if dropEmptyImportDeclarations(d) {
|
|
continue
|
|
}
|
|
decls = append(decls, d)
|
|
}
|
|
file.Decls = decls
|
|
|
|
// remove unmapped comments
|
|
file.Comments = cmap.Filter(file).Comments()
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// rewriteOptionalMethods makes specific mutations to marshaller methods that belong to types identified
|
|
// as being "optional" (they may be nil on the wire). This allows protobuf to serialize a map or slice and
|
|
// properly discriminate between empty and nil (which is not possible in protobuf).
|
|
// TODO: move into upstream gogo-protobuf once https://github.com/gogo/protobuf/issues/181
|
|
// has agreement
|
|
func rewriteOptionalMethods(decl ast.Decl, isOptional OptionalFunc) {
|
|
switch t := decl.(type) {
|
|
case *ast.FuncDecl:
|
|
ident, ptr, ok := receiver(t)
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
// correct initialization of the form `m.Field = &OptionalType{}` to
|
|
// `m.Field = OptionalType{}`
|
|
if t.Name.Name == "Unmarshal" {
|
|
ast.Walk(optionalAssignmentVisitor{fn: isOptional}, t.Body)
|
|
}
|
|
|
|
if !isOptional(ident.Name) {
|
|
return
|
|
}
|
|
|
|
switch t.Name.Name {
|
|
case "Unmarshal":
|
|
ast.Walk(&optionalItemsVisitor{}, t.Body)
|
|
case "MarshalTo", "Size", "String":
|
|
ast.Walk(&optionalItemsVisitor{}, t.Body)
|
|
fallthrough
|
|
case "Marshal":
|
|
// if the method has a pointer receiver, set it back to a normal receiver
|
|
if ptr {
|
|
t.Recv.List[0].Type = ident
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
type optionalAssignmentVisitor struct {
|
|
fn OptionalFunc
|
|
}
|
|
|
|
// Visit walks the provided node, transforming field initializations of the form
|
|
// m.Field = &OptionalType{} -> m.Field = OptionalType{}
|
|
func (v optionalAssignmentVisitor) Visit(n ast.Node) ast.Visitor {
|
|
switch t := n.(type) {
|
|
case *ast.AssignStmt:
|
|
if len(t.Lhs) == 1 && len(t.Rhs) == 1 {
|
|
if !isFieldSelector(t.Lhs[0], "m", "") {
|
|
return nil
|
|
}
|
|
unary, ok := t.Rhs[0].(*ast.UnaryExpr)
|
|
if !ok || unary.Op != token.AND {
|
|
return nil
|
|
}
|
|
composite, ok := unary.X.(*ast.CompositeLit)
|
|
if !ok || composite.Type == nil || len(composite.Elts) != 0 {
|
|
return nil
|
|
}
|
|
if ident, ok := composite.Type.(*ast.Ident); ok && v.fn(ident.Name) {
|
|
t.Rhs[0] = composite
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
return v
|
|
}
|
|
|
|
type optionalItemsVisitor struct{}
|
|
|
|
// Visit walks the provided node, looking for specific patterns to transform that match
|
|
// the effective outcome of turning struct{ map[x]y || []x } into map[x]y or []x.
|
|
func (v *optionalItemsVisitor) Visit(n ast.Node) ast.Visitor {
|
|
switch t := n.(type) {
|
|
case *ast.RangeStmt:
|
|
if isFieldSelector(t.X, "m", "Items") {
|
|
t.X = &ast.Ident{Name: "m"}
|
|
}
|
|
case *ast.AssignStmt:
|
|
if len(t.Lhs) == 1 && len(t.Rhs) == 1 {
|
|
switch lhs := t.Lhs[0].(type) {
|
|
case *ast.IndexExpr:
|
|
if isFieldSelector(lhs.X, "m", "Items") {
|
|
lhs.X = &ast.StarExpr{X: &ast.Ident{Name: "m"}}
|
|
}
|
|
default:
|
|
if isFieldSelector(t.Lhs[0], "m", "Items") {
|
|
t.Lhs[0] = &ast.StarExpr{X: &ast.Ident{Name: "m"}}
|
|
}
|
|
}
|
|
switch rhs := t.Rhs[0].(type) {
|
|
case *ast.CallExpr:
|
|
if ident, ok := rhs.Fun.(*ast.Ident); ok && ident.Name == "append" {
|
|
ast.Walk(v, rhs)
|
|
if len(rhs.Args) > 0 {
|
|
switch arg := rhs.Args[0].(type) {
|
|
case *ast.Ident:
|
|
if arg.Name == "m" {
|
|
rhs.Args[0] = &ast.StarExpr{X: &ast.Ident{Name: "m"}}
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
case *ast.IfStmt:
|
|
switch cond := t.Cond.(type) {
|
|
case *ast.BinaryExpr:
|
|
if cond.Op == token.EQL {
|
|
if isFieldSelector(cond.X, "m", "Items") && isIdent(cond.Y, "nil") {
|
|
cond.X = &ast.StarExpr{X: &ast.Ident{Name: "m"}}
|
|
}
|
|
}
|
|
}
|
|
if t.Init != nil {
|
|
// Find form:
|
|
// if err := m[len(m.Items)-1].Unmarshal(data[iNdEx:postIndex]); err != nil {
|
|
// return err
|
|
// }
|
|
switch s := t.Init.(type) {
|
|
case *ast.AssignStmt:
|
|
if call, ok := s.Rhs[0].(*ast.CallExpr); ok {
|
|
if sel, ok := call.Fun.(*ast.SelectorExpr); ok {
|
|
if x, ok := sel.X.(*ast.IndexExpr); ok {
|
|
// m[] -> (*m)[]
|
|
if sel2, ok := x.X.(*ast.SelectorExpr); ok {
|
|
if ident, ok := sel2.X.(*ast.Ident); ok && ident.Name == "m" {
|
|
x.X = &ast.StarExpr{X: &ast.Ident{Name: "m"}}
|
|
}
|
|
}
|
|
// len(m.Items) -> len(*m)
|
|
if bin, ok := x.Index.(*ast.BinaryExpr); ok {
|
|
if call2, ok := bin.X.(*ast.CallExpr); ok && len(call2.Args) == 1 {
|
|
if isFieldSelector(call2.Args[0], "m", "Items") {
|
|
call2.Args[0] = &ast.StarExpr{X: &ast.Ident{Name: "m"}}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
case *ast.IndexExpr:
|
|
if isFieldSelector(t.X, "m", "Items") {
|
|
t.X = &ast.Ident{Name: "m"}
|
|
return nil
|
|
}
|
|
case *ast.CallExpr:
|
|
changed := false
|
|
for i := range t.Args {
|
|
if isFieldSelector(t.Args[i], "m", "Items") {
|
|
t.Args[i] = &ast.Ident{Name: "m"}
|
|
changed = true
|
|
}
|
|
}
|
|
if changed {
|
|
return nil
|
|
}
|
|
}
|
|
return v
|
|
}
|
|
|
|
func isFieldSelector(n ast.Expr, name, field string) bool {
|
|
s, ok := n.(*ast.SelectorExpr)
|
|
if !ok || s.Sel == nil || (field != "" && s.Sel.Name != field) {
|
|
return false
|
|
}
|
|
return isIdent(s.X, name)
|
|
}
|
|
|
|
func isIdent(n ast.Expr, value string) bool {
|
|
ident, ok := n.(*ast.Ident)
|
|
return ok && ident.Name == value
|
|
}
|
|
|
|
func receiver(f *ast.FuncDecl) (ident *ast.Ident, pointer bool, ok bool) {
|
|
if f.Recv == nil || len(f.Recv.List) != 1 {
|
|
return nil, false, false
|
|
}
|
|
switch t := f.Recv.List[0].Type.(type) {
|
|
case *ast.StarExpr:
|
|
identity, ok := t.X.(*ast.Ident)
|
|
if !ok {
|
|
return nil, false, false
|
|
}
|
|
return identity, true, true
|
|
case *ast.Ident:
|
|
return t, false, true
|
|
}
|
|
return nil, false, false
|
|
}
|
|
|
|
// dropExistingTypeDeclarations removes any type declaration for which extractFn returns true. The function
|
|
// returns true if the entire declaration should be dropped.
|
|
func dropExistingTypeDeclarations(decl ast.Decl, extractFn ExtractFunc) bool {
|
|
switch t := decl.(type) {
|
|
case *ast.GenDecl:
|
|
if t.Tok != token.TYPE {
|
|
return false
|
|
}
|
|
specs := []ast.Spec{}
|
|
for _, s := range t.Specs {
|
|
switch spec := s.(type) {
|
|
case *ast.TypeSpec:
|
|
if extractFn(spec) {
|
|
continue
|
|
}
|
|
specs = append(specs, spec)
|
|
}
|
|
}
|
|
if len(specs) == 0 {
|
|
return true
|
|
}
|
|
t.Specs = specs
|
|
}
|
|
return false
|
|
}
|
|
|
|
// dropEmptyImportDeclarations strips any generated but no-op imports from the generated code
|
|
// to prevent generation from being able to define side-effects. The function returns true
|
|
// if the entire declaration should be dropped.
|
|
func dropEmptyImportDeclarations(decl ast.Decl) bool {
|
|
switch t := decl.(type) {
|
|
case *ast.GenDecl:
|
|
if t.Tok != token.IMPORT {
|
|
return false
|
|
}
|
|
specs := []ast.Spec{}
|
|
for _, s := range t.Specs {
|
|
switch spec := s.(type) {
|
|
case *ast.ImportSpec:
|
|
if spec.Name != nil && spec.Name.Name == "_" {
|
|
continue
|
|
}
|
|
specs = append(specs, spec)
|
|
}
|
|
}
|
|
if len(specs) == 0 {
|
|
return true
|
|
}
|
|
t.Specs = specs
|
|
}
|
|
return false
|
|
}
|
|
|
|
func RewriteTypesWithProtobufStructTags(name string, structTags map[string]map[string]string) error {
|
|
return rewriteFile(name, []byte{}, func(fset *token.FileSet, file *ast.File) error {
|
|
allErrs := []error{}
|
|
|
|
// set any new struct tags
|
|
for _, d := range file.Decls {
|
|
if errs := updateStructTags(d, structTags, []string{"protobuf"}); len(errs) > 0 {
|
|
allErrs = append(allErrs, errs...)
|
|
}
|
|
}
|
|
|
|
if len(allErrs) > 0 {
|
|
var s string
|
|
for _, err := range allErrs {
|
|
s += err.Error() + "\n"
|
|
}
|
|
return errors.New(s)
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
func updateStructTags(decl ast.Decl, structTags map[string]map[string]string, toCopy []string) []error {
|
|
var errs []error
|
|
t, ok := decl.(*ast.GenDecl)
|
|
if !ok {
|
|
return nil
|
|
}
|
|
if t.Tok != token.TYPE {
|
|
return nil
|
|
}
|
|
|
|
for _, s := range t.Specs {
|
|
spec, ok := s.(*ast.TypeSpec)
|
|
if !ok {
|
|
continue
|
|
}
|
|
typeName := spec.Name.Name
|
|
fieldTags, ok := structTags[typeName]
|
|
if !ok {
|
|
continue
|
|
}
|
|
st, ok := spec.Type.(*ast.StructType)
|
|
if !ok {
|
|
continue
|
|
}
|
|
|
|
for i := range st.Fields.List {
|
|
f := st.Fields.List[i]
|
|
var name string
|
|
if len(f.Names) == 0 {
|
|
switch t := f.Type.(type) {
|
|
case *ast.Ident:
|
|
name = t.Name
|
|
case *ast.SelectorExpr:
|
|
name = t.Sel.Name
|
|
default:
|
|
errs = append(errs, fmt.Errorf("unable to get name for tag from struct %q, field %#v", spec.Name.Name, t))
|
|
continue
|
|
}
|
|
} else {
|
|
name = f.Names[0].Name
|
|
}
|
|
value, ok := fieldTags[name]
|
|
if !ok {
|
|
continue
|
|
}
|
|
var tags customreflect.StructTags
|
|
if f.Tag != nil {
|
|
oldTags, err := customreflect.ParseStructTags(strings.Trim(f.Tag.Value, "`"))
|
|
if err != nil {
|
|
errs = append(errs, fmt.Errorf("unable to read struct tag from struct %q, field %q: %v", spec.Name.Name, name, err))
|
|
continue
|
|
}
|
|
tags = oldTags
|
|
}
|
|
for _, name := range toCopy {
|
|
// don't overwrite existing tags
|
|
if tags.Has(name) {
|
|
continue
|
|
}
|
|
// append new tags
|
|
if v := reflect.StructTag(value).Get(name); len(v) > 0 {
|
|
tags = append(tags, customreflect.StructTag{Name: name, Value: v})
|
|
}
|
|
}
|
|
if len(tags) == 0 {
|
|
continue
|
|
}
|
|
if f.Tag == nil {
|
|
f.Tag = &ast.BasicLit{}
|
|
}
|
|
f.Tag.Value = tags.String()
|
|
}
|
|
}
|
|
return errs
|
|
}
|