From 405726fae23ace72b22c410a77b7bd825608f2c8 Mon Sep 17 00:00:00 2001 From: Dave Cunningham Date: Wed, 7 Mar 2018 11:23:32 -0500 Subject: [PATCH] Clone ASTs to avoid aliasing and double-unescaping (#210) * Clone ASTs to avoid aliasing and double-unescaping --- ast/clone.go | 302 +++++++++++++++++++++++++++++++++++++++++++++++++++ desugarer.go | 23 ++-- 2 files changed, 317 insertions(+), 8 deletions(-) create mode 100644 ast/clone.go diff --git a/ast/clone.go b/ast/clone.go new file mode 100644 index 0000000..faa76f0 --- /dev/null +++ b/ast/clone.go @@ -0,0 +1,302 @@ +/* +Copyright 2018 Google Inc. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package ast + +import ( + "fmt" + "reflect" +) + +// Updates fields of specPtr to point to deep clones. +func cloneForSpec(specPtr *ForSpec) { + clone(&specPtr.Expr) + oldOuter := specPtr.Outer + specPtr.Outer = new(ForSpec) + *specPtr.Outer = *oldOuter + cloneForSpec(specPtr.Outer) + for i := range specPtr.Conditions { + clone(&specPtr.Conditions[i].Expr) + } +} + +// Updates fields of params to point to deep clones. +func cloneParameters(params *Parameters) { + if params == nil { + return + } + params.Optional = append(make([]NamedParameter, 0), params.Optional...) + for i := range params.Optional { + clone(¶ms.Optional[i].DefaultArg) + } +} + +// Updates fields of field to point to deep clones. +func cloneField(field *ObjectField) { + if field.Method != nil { + field.Method = Clone(field.Method).(*Function) + } + + oldParams := field.Params + if oldParams != nil { + field.Params = new(Parameters) + *field.Params = *oldParams + } + cloneParameters(field.Params) + + clone(&field.Expr1) + clone(&field.Expr2) + clone(&field.Expr3) +} + +// Updates fields of field to point to deep clones. +func cloneDesugaredField(field *DesugaredObjectField) { + clone(&field.Name) + clone(&field.Body) +} + +// Updates the NodeBase fields of astPtr to point to deep clones. +func cloneNodeBase(astPtr Node) { + if astPtr.Context() != nil { + newContext := new(string) + *newContext = *astPtr.Context() + astPtr.SetContext(newContext) + } + astPtr.SetFreeVariables(append(make(Identifiers, 0), astPtr.FreeVariables()...)) +} + +// Updates *astPtr to point to a deep clone of what it originally pointed at. +func clone(astPtr *Node) { + node := *astPtr + if node == nil { + return + } + + switch node := node.(type) { + case *Apply: + r := new(Apply) + *astPtr = r + *r = *node + clone(&r.Target) + r.Arguments.Positional = append(make(Nodes, 0), r.Arguments.Positional...) + for i := range r.Arguments.Positional { + clone(&r.Arguments.Positional[i]) + } + r.Arguments.Named = append(make([]NamedArgument, 0), r.Arguments.Named...) + for i := range r.Arguments.Named { + clone(&r.Arguments.Named[i].Arg) + } + + case *ApplyBrace: + r := new(ApplyBrace) + *astPtr = r + *r = *node + clone(&r.Left) + clone(&r.Right) + + case *Array: + r := new(Array) + *astPtr = r + *r = *node + r.Elements = append(make(Nodes, 0), r.Elements...) + for i := range r.Elements { + clone(&r.Elements[i]) + } + + case *ArrayComp: + r := new(ArrayComp) + *astPtr = r + *r = *node + clone(&r.Body) + cloneForSpec(&r.Spec) + + case *Assert: + r := new(Assert) + *astPtr = r + *r = *node + clone(&r.Cond) + clone(&r.Message) + clone(&r.Rest) + + case *Binary: + r := new(Binary) + *astPtr = r + *r = *node + clone(&r.Left) + clone(&r.Right) + + case *Conditional: + r := new(Conditional) + *astPtr = r + *r = *node + clone(&r.Cond) + clone(&r.BranchTrue) + clone(&r.BranchFalse) + + case *Dollar: + r := new(Dollar) + *astPtr = r + *r = *node + + case *Error: + r := new(Error) + *astPtr = r + *r = *node + clone(&r.Expr) + + case *Function: + r := new(Function) + *astPtr = r + *r = *node + cloneParameters(&r.Parameters) + clone(&r.Body) + + case *Import: + r := new(Import) + *astPtr = r + *r = *node + r.File = new(LiteralString) + *r.File = *node.File + + case *ImportStr: + r := new(ImportStr) + *astPtr = r + *r = *node + r.File = new(LiteralString) + *r.File = *node.File + + case *Index: + r := new(Index) + *astPtr = r + *r = *node + clone(&r.Target) + clone(&r.Index) + + case *Slice: + r := new(Slice) + *astPtr = r + *r = *node + clone(&r.Target) + clone(&r.BeginIndex) + clone(&r.EndIndex) + clone(&r.Step) + + case *Local: + r := new(Local) + *astPtr = r + *r = *node + r.Binds = append(make(LocalBinds, 0), r.Binds...) + for i := range r.Binds { + if r.Binds[i].Fun != nil { + r.Binds[i].Fun = Clone(r.Binds[i].Fun).(*Function) + } + clone(&r.Binds[i].Body) + } + clone(&r.Body) + + case *LiteralBoolean: + r := new(LiteralBoolean) + *astPtr = r + *r = *node + + case *LiteralNull: + r := new(LiteralNull) + *astPtr = r + *r = *node + + case *LiteralNumber: + r := new(LiteralNumber) + *astPtr = r + *r = *node + + case *LiteralString: + r := new(LiteralString) + *astPtr = r + *r = *node + + case *Object: + r := new(Object) + *astPtr = r + *r = *node + r.Fields = append(make(ObjectFields, 0), r.Fields...) + for i := range r.Fields { + cloneField(&r.Fields[i]) + } + + case *DesugaredObject: + r := new(DesugaredObject) + *astPtr = r + *r = *node + r.Fields = append(make(DesugaredObjectFields, 0), r.Fields...) + for i := range r.Fields { + cloneDesugaredField(&r.Fields[i]) + } + + case *ObjectComp: + r := new(ObjectComp) + *astPtr = r + *r = *node + r.Fields = append(make(ObjectFields, 0), r.Fields...) + for i := range r.Fields { + cloneField(&r.Fields[i]) + } + cloneForSpec(&r.Spec) + + case *Parens: + r := new(Parens) + *astPtr = r + *r = *node + clone(&r.Inner) + + case *Self: + r := new(Self) + *astPtr = r + *r = *node + + case *SuperIndex: + r := new(SuperIndex) + *astPtr = r + *r = *node + clone(&r.Index) + + case *InSuper: + r := new(InSuper) + *astPtr = r + *r = *node + clone(&r.Index) + + case *Unary: + r := new(Unary) + *astPtr = r + *r = *node + clone(&r.Expr) + + case *Var: + r := new(Var) + *astPtr = r + *r = *node + + default: + panic(fmt.Sprintf("ast.Clone() does not recognize ast: %s", reflect.TypeOf(node))) + } + + cloneNodeBase(*astPtr) +} + +func Clone(astPtr Node) Node { + clone(&astPtr) + return astPtr +} diff --git a/desugarer.go b/desugarer.go index 5bb4a6c..e803c0b 100644 --- a/desugarer.go +++ b/desugarer.go @@ -123,17 +123,17 @@ func desugarFields(location ast.LocationRange, fields *ast.ObjectFields, objLeve // Remove object-level locals newFields := []ast.ObjectField{} - var binds ast.LocalBinds - for _, local := range *fields { - if local.Kind != ast.ObjectLocal { - continue - } - binds = append(binds, ast.LocalBind{Variable: *local.Id, Body: local.Expr2}) - } for _, field := range *fields { if field.Kind == ast.ObjectLocal { continue } + var binds ast.LocalBinds + for _, local := range *fields { + if local.Kind != ast.ObjectLocal { + continue + } + binds = append(binds, ast.LocalBind{Variable: *local.Id, Body: ast.Clone(local.Expr2)}) + } if len(binds) > 0 { field.Expr2 = &ast.Local{ NodeBase: ast.NewNodeBaseLoc(*field.Expr2.Loc()), @@ -294,7 +294,10 @@ func buildDesugaredObject(nodeBase ast.NodeBase, fields ast.ObjectFields) *ast.D // Desugar Jsonnet expressions to reduce the number of constructs the rest of the implementation // needs to understand. - +// +// Note that despite the name, desugar() is not idempotent. String literals have their escape +// codes translated to low-level characters during desugaring. +// // Desugaring should happen immediately after parsing, i.e. before static analysis and execution. // Temporary variables introduced here should be prefixed with $ to ensure they do not clash with // variables used in user code. @@ -443,6 +446,9 @@ func desugar(astPtr *ast.Node, objLevel int) (err error) { } case *ast.Import: + // desugar() is allowed to update the pointer to point to something else, but will never do + // this for a LiteralString. We cannot simply do &node.File because the type is + // **ast.LiteralString which is not compatible with *ast.Node. var file ast.Node = node.File err = desugar(&file, objLevel) if err != nil { @@ -450,6 +456,7 @@ func desugar(astPtr *ast.Node, objLevel int) (err error) { } case *ast.ImportStr: + // See comment in ast.Import. var file ast.Node = node.File err = desugar(&file, objLevel) if err != nil {