Refactor to allow interleaving optional and positional params

This commit is contained in:
Dave Cunningham 2020-03-03 04:24:48 +00:00 committed by Stanisław Barzowski
parent 7cad41e894
commit 0e67cc3c68
11 changed files with 5418 additions and 5289 deletions

View File

@ -371,15 +371,17 @@ type Error struct {
type Function struct { type Function struct {
NodeBase NodeBase
ParenLeftFodder Fodder ParenLeftFodder Fodder
Parameters Parameters Parameters []Parameter
// Always false if there were no parameters. // Always false if there were no parameters.
TrailingComma bool TrailingComma bool
ParenRightFodder Fodder ParenRightFodder Fodder
Body Node Body Node
} }
// NamedParameter represents an optional named parameter of a function. // Parameter represents a parameter of function.
type NamedParameter struct { // If DefaultArg is set, it's an optional named parameter.
// Otherwise, it's a positional parameter and EqFodder is not used.
type Parameter struct {
NameFodder Fodder NameFodder Fodder
Name Identifier Name Identifier
EqFodder Fodder EqFodder Fodder
@ -395,13 +397,6 @@ type CommaSeparatedID struct {
CommaFodder Fodder CommaFodder Fodder
} }
// Parameters represents the required positional parameters and optional named
// parameters to a function definition.
type Parameters struct {
Required []CommaSeparatedID
Optional []NamedParameter
}
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// Import represents import "file". // Import represents import "file".

View File

@ -35,17 +35,6 @@ func cloneForSpec(specPtr *ForSpec) {
} }
} }
// Updates fields of params to point to deep clones.
func cloneParameters(params *Parameters) {
if params == nil {
return
}
params.Optional = append(make([]NamedParameter, 0), params.Optional...)
for i := range params.Optional {
clone(&params.Optional[i].DefaultArg)
}
}
// Updates fields of field to point to deep clones. // Updates fields of field to point to deep clones.
func cloneField(field *ObjectField) { func cloneField(field *ObjectField) {
if field.Method != nil { if field.Method != nil {
@ -158,7 +147,12 @@ func clone(astPtr *Node) {
r := new(Function) r := new(Function)
*astPtr = r *astPtr = r
*r = *node *r = *node
cloneParameters(&r.Parameters) if r.Parameters != nil {
r.Parameters = append(make([]Parameter, 0), r.Parameters...)
for i := range r.Parameters {
clone(&r.Parameters[i].DefaultArg)
}
}
clone(&r.Body) clone(&r.Body)
case *Import: case *Import:

File diff suppressed because it is too large Load Diff

View File

@ -182,7 +182,11 @@ func builtinLength(i *interpreter, trace traceElement, x value) (value, error) {
case valueString: case valueString:
num = x.length() num = x.length()
case *valueFunction: case *valueFunction:
num = len(x.Parameters().required) for _, param := range x.Parameters() {
if param.defaultArg == nil {
num++
}
}
default: default:
return nil, i.typeErrorGeneral(x, trace) return nil, i.typeErrorGeneral(x, trace)
} }
@ -353,14 +357,14 @@ func builtinReverse(i *interpreter, trace traceElement, arrv value) (value, erro
} }
lenArr := len(arr.elements) // lenx holds the original array length lenArr := len(arr.elements) // lenx holds the original array length
reversed_array := make([]*cachedThunk, lenArr) // creates a slice that refer to a new array of length lenx reversedArray := make([]*cachedThunk, lenArr) // creates a slice that refer to a new array of length lenx
for i := 0; i < lenArr; i++ { for i := 0; i < lenArr; i++ {
j := lenArr - (i + 1) // j initially holds (lenx - 1) and decreases to 0 while i initially holds 0 and increase to (lenx - 1) j := lenArr - (i + 1) // j initially holds (lenx - 1) and decreases to 0 while i initially holds 0 and increase to (lenx - 1)
reversed_array[i] = arr.elements[j] reversedArray[i] = arr.elements[j]
} }
return makeValueArray(reversed_array), nil return makeValueArray(reversedArray), nil
} }
func builtinFilter(i *interpreter, trace traceElement, funcv, arrv value) (value, error) { func builtinFilter(i *interpreter, trace traceElement, funcv, arrv value) (value, error) {
@ -1197,25 +1201,24 @@ type builtin interface {
Name() ast.Identifier Name() ast.Identifier
} }
func flattenArgs(args callArguments, params parameters, defaults []value) []*cachedThunk { func flattenArgs(args callArguments, params []namedParameter, defaults []value) []*cachedThunk {
positions := make(map[ast.Identifier]int) positions := make(map[ast.Identifier]int)
for i := 0; i < len(params.required); i++ { for i, param := range params {
positions[params.required[i]] = i positions[param.name] = i
}
for i := 0; i < len(params.optional); i++ {
positions[params.optional[i].name] = i + len(params.required)
} }
flatArgs := make([]*cachedThunk, len(params.required)+len(params.optional)) flatArgs := make([]*cachedThunk, len(params))
// Bind positional arguments
copy(flatArgs, args.positional) copy(flatArgs, args.positional)
// Bind named arguments
for _, arg := range args.named { for _, arg := range args.named {
flatArgs[positions[arg.name]] = arg.pv flatArgs[positions[arg.name]] = arg.pv
} }
for i := 0; i < len(params.optional); i++ { // Bind defaults for unsatisfied named parameters
pos := len(params.required) + i for i := range params {
if flatArgs[pos] == nil { if flatArgs[i] == nil {
flatArgs[pos] = readyThunk(defaults[i]) flatArgs[i] = readyThunk(defaults[i])
} }
} }
return flatArgs return flatArgs
@ -1244,8 +1247,12 @@ func (b *unaryBuiltin) evalCall(args callArguments, i *interpreter, trace traceE
return b.function(i, builtinTrace, x) return b.function(i, builtinTrace, x)
} }
func (b *unaryBuiltin) Parameters() parameters { func (b *unaryBuiltin) Parameters() []namedParameter {
return parameters{required: b.parameters} ret := make([]namedParameter, len(b.parameters))
for i := range ret {
ret[i].name = b.parameters[i]
}
return ret
} }
func (b *unaryBuiltin) Name() ast.Identifier { func (b *unaryBuiltin) Name() ast.Identifier {
@ -1274,8 +1281,12 @@ func (b *binaryBuiltin) evalCall(args callArguments, i *interpreter, trace trace
return b.function(i, builtinTrace, x, y) return b.function(i, builtinTrace, x, y)
} }
func (b *binaryBuiltin) Parameters() parameters { func (b *binaryBuiltin) Parameters() []namedParameter {
return parameters{required: b.parameters} ret := make([]namedParameter, len(b.parameters))
for i := range ret {
ret[i].name = b.parameters[i]
}
return ret
} }
func (b *binaryBuiltin) Name() ast.Identifier { func (b *binaryBuiltin) Name() ast.Identifier {
@ -1308,8 +1319,12 @@ func (b *ternaryBuiltin) evalCall(args callArguments, i *interpreter, trace trac
return b.function(i, builtinTrace, x, y, z) return b.function(i, builtinTrace, x, y, z)
} }
func (b *ternaryBuiltin) Parameters() parameters { func (b *ternaryBuiltin) Parameters() []namedParameter {
return parameters{required: b.parameters} ret := make([]namedParameter, len(b.parameters))
for i := range ret {
ret[i].name = b.parameters[i]
}
return ret
} }
func (b *ternaryBuiltin) Name() ast.Identifier { func (b *ternaryBuiltin) Name() ast.Identifier {
@ -1318,25 +1333,44 @@ func (b *ternaryBuiltin) Name() ast.Identifier {
type generalBuiltinFunc func(*interpreter, traceElement, []value) (value, error) type generalBuiltinFunc func(*interpreter, traceElement, []value) (value, error)
// generalBuiltin covers cases that other builtin structures do not, type generalBuiltinParameter struct {
// in particular it can have any number of parameters. It can also
// have optional parameters.
type generalBuiltin struct {
name ast.Identifier name ast.Identifier
required ast.Identifiers
optional ast.Identifiers
// Note that the defaults are passed as values rather than AST nodes like in Parameters. // Note that the defaults are passed as values rather than AST nodes like in Parameters.
// This spares us unnecessary evaluation. // This spares us unnecessary evaluation.
defaultValues []value defaultValue value
}
// generalBuiltin covers cases that other builtin structures do not,
// in particular it can have any number of parameters. It can also
// have optional parameters. The optional ones have non-nil defaultValues
// at the same index.
type generalBuiltin struct {
name ast.Identifier
parameters []generalBuiltinParameter
function generalBuiltinFunc function generalBuiltinFunc
} }
func (b *generalBuiltin) Parameters() parameters { func (b *generalBuiltin) Parameters() []namedParameter {
optional := make([]namedParameter, len(b.optional)) ret := make([]namedParameter, len(b.parameters))
for i := range optional { for i := range ret {
optional[i] = namedParameter{name: b.optional[i]} ret[i].name = b.parameters[i].name
if b.parameters[i].defaultValue != nil {
// This is not actually used because the defaultValue is used instead.
// The only reason we don't leave it nil is because the checkArguments
// function uses the non-nil status to indicate that the parameter
// is optional.
ret[i].defaultArg = &ast.LiteralNull{}
} }
return parameters{required: b.required, optional: optional} }
return ret
}
func (b *generalBuiltin) defaultValues() []value {
ret := make([]value, len(b.parameters))
for i := range ret {
ret[i] = b.parameters[i].defaultValue
}
return ret
} }
func (b *generalBuiltin) Name() ast.Identifier { func (b *generalBuiltin) Name() ast.Identifier {
@ -1344,7 +1378,7 @@ func (b *generalBuiltin) Name() ast.Identifier {
} }
func (b *generalBuiltin) evalCall(args callArguments, i *interpreter, trace traceElement) (value, error) { func (b *generalBuiltin) evalCall(args callArguments, i *interpreter, trace traceElement) (value, error) {
flatArgs := flattenArgs(args, b.Parameters(), b.defaultValues) flatArgs := flattenArgs(args, b.Parameters(), b.defaultValues())
builtinTrace := getBuiltinTrace(trace, b.name) builtinTrace := getBuiltinTrace(trace, b.name)
values := make([]value, len(flatArgs)) values := make([]value, len(flatArgs))
for j := 0; j < len(values); j++ { for j := 0; j < len(values); j++ {
@ -1445,7 +1479,7 @@ var funcBuiltins = buildBuiltinMap([]builtin{
&unaryBuiltin{name: "base64", function: builtinBase64, parameters: ast.Identifiers{"input"}}, &unaryBuiltin{name: "base64", function: builtinBase64, parameters: ast.Identifiers{"input"}},
&unaryBuiltin{name: "encodeUTF8", function: builtinEncodeUTF8, parameters: ast.Identifiers{"str"}}, &unaryBuiltin{name: "encodeUTF8", function: builtinEncodeUTF8, parameters: ast.Identifiers{"str"}},
&unaryBuiltin{name: "decodeUTF8", function: builtinDecodeUTF8, parameters: ast.Identifiers{"arr"}}, &unaryBuiltin{name: "decodeUTF8", function: builtinDecodeUTF8, parameters: ast.Identifiers{"arr"}},
&generalBuiltin{name: "sort", function: builtinSort, required: ast.Identifiers{"arr"}, optional: ast.Identifiers{"keyF"}, defaultValues: []value{functionID}}, &generalBuiltin{name: "sort", function: builtinSort, parameters: []generalBuiltinParameter{{name: "arr"}, {name: "keyF", defaultValue: functionID}}},
&unaryBuiltin{name: "native", function: builtinNative, parameters: ast.Identifiers{"x"}}, &unaryBuiltin{name: "native", function: builtinNative, parameters: ast.Identifiers{"x"}},
// internal // internal

View File

@ -298,9 +298,11 @@ func specialChildren(node ast.Node) []ast.Node {
return nil return nil
case *ast.Function: case *ast.Function:
children := []ast.Node{node.Body} children := []ast.Node{node.Body}
for _, child := range node.Parameters.Optional { for _, child := range node.Parameters {
if child.DefaultArg != nil {
children = append(children, child.DefaultArg) children = append(children, child.DefaultArg)
} }
}
return children return children
case *ast.Import: case *ast.Import:
return nil return nil
@ -389,9 +391,11 @@ func addContext(node ast.Node, context *string, bind string) {
case *ast.Function: case *ast.Function:
funContext := functionContext(bind) funContext := functionContext(bind)
addContext(node.Body, funContext, anonymous) addContext(node.Body, funContext, anonymous)
for i := range node.Parameters.Optional { for i := range node.Parameters {
if node.Parameters[i].DefaultArg != nil {
// Default arguments have the same context as the function body. // Default arguments have the same context as the function body.
addContext(node.Parameters.Optional[i].DefaultArg, funContext, anonymous) addContext(node.Parameters[i].DefaultArg, funContext, anonymous)
}
} }
case *ast.Object: case *ast.Object:
// TODO(sbarzowski) include fieldname, maybe even chains // TODO(sbarzowski) include fieldname, maybe even chains

View File

@ -210,25 +210,25 @@ func (p *parser) parseArguments(elementKind string) (*token, *ast.Arguments, boo
} }
// TODO(sbarzowski) - this returned bool is weird // TODO(sbarzowski) - this returned bool is weird
func (p *parser) parseParameters(elementKind string) (*token, *ast.Parameters, bool, error) { func (p *parser) parseParameters(elementKind string) (*token, []ast.Parameter, bool, error) {
parenR, args, trailingComma, err := p.parseArguments(elementKind) parenR, args, trailingComma, err := p.parseArguments(elementKind)
if err != nil { if err != nil {
return nil, nil, false, err return nil, nil, false, err
} }
var params ast.Parameters var params []ast.Parameter
for _, arg := range args.Positional { for _, arg := range args.Positional {
idFodder, id, ok := astVarToIdentifier(arg.Expr) idFodder, id, ok := astVarToIdentifier(arg.Expr)
if !ok { if !ok {
return nil, nil, false, errors.MakeStaticError(fmt.Sprintf("Expected simple identifier but got a complex expression."), *arg.Expr.Loc()) return nil, nil, false, errors.MakeStaticError(fmt.Sprintf("Expected simple identifier but got a complex expression."), *arg.Expr.Loc())
} }
params.Required = append(params.Required, ast.CommaSeparatedID{ params = append(params, ast.Parameter{
NameFodder: idFodder, NameFodder: idFodder,
Name: *id, Name: *id,
CommaFodder: arg.CommaFodder, CommaFodder: arg.CommaFodder,
}) })
} }
for _, arg := range args.Named { for _, arg := range args.Named {
params.Optional = append(params.Optional, ast.NamedParameter{ params = append(params, ast.Parameter{
NameFodder: arg.NameFodder, NameFodder: arg.NameFodder,
Name: arg.Name, Name: arg.Name,
EqFodder: arg.EqFodder, EqFodder: arg.EqFodder,
@ -236,7 +236,7 @@ func (p *parser) parseParameters(elementKind string) (*token, *ast.Parameters, b
CommaFodder: arg.CommaFodder, CommaFodder: arg.CommaFodder,
}) })
} }
return parenR, &params, trailingComma, nil return parenR, params, trailingComma, nil
} }
// TODO(sbarzowski) add location to all individual binds // TODO(sbarzowski) add location to all individual binds
@ -260,7 +260,7 @@ func (p *parser) parseBind(binds *ast.LocalBinds) (*token, error) {
} }
fun = &ast.Function{ fun = &ast.Function{
ParenLeftFodder: parenL.fodder, ParenLeftFodder: parenL.fodder,
Parameters: *params, Parameters: params,
TrailingComma: gotComma, TrailingComma: gotComma,
ParenRightFodder: parenR.fodder, ParenRightFodder: parenR.fodder,
// Body gets filled in later. // Body gets filled in later.
@ -423,7 +423,7 @@ func (p *parser) parseObjectRemainderField(literalFields *LiteralFieldSet, tok *
methComma := false methComma := false
var parenL *token var parenL *token
var parenR *token var parenR *token
var params *ast.Parameters var params []ast.Parameter
if p.peek().kind == tokenParenL { if p.peek().kind == tokenParenL {
parenL = p.pop() parenL = p.pop()
var err error var err error
@ -460,7 +460,7 @@ func (p *parser) parseObjectRemainderField(literalFields *LiteralFieldSet, tok *
if isMethod { if isMethod {
method = &ast.Function{ method = &ast.Function{
ParenLeftFodder: parenL.fodder, ParenLeftFodder: parenL.fodder,
Parameters: *params, Parameters: params,
TrailingComma: methComma, TrailingComma: methComma,
ParenRightFodder: parenR.fodder, ParenRightFodder: parenR.fodder,
Body: body, Body: body,
@ -505,7 +505,7 @@ func (p *parser) parseObjectRemainderLocal(binds *ast.IdentifierSet, tok *token,
funcComma := false funcComma := false
var parenL *token var parenL *token
var parenR *token var parenR *token
var params *ast.Parameters var params []ast.Parameter
if p.peek().kind == tokenParenL { if p.peek().kind == tokenParenL {
parenL = p.pop() parenL = p.pop()
isMethod = true isMethod = true
@ -528,7 +528,7 @@ func (p *parser) parseObjectRemainderLocal(binds *ast.IdentifierSet, tok *token,
if isMethod { if isMethod {
method = &ast.Function{ method = &ast.Function{
ParenLeftFodder: parenL.fodder, ParenLeftFodder: parenL.fodder,
Parameters: *params, Parameters: params,
ParenRightFodder: parenR.fodder, ParenRightFodder: parenR.fodder,
TrailingComma: funcComma, TrailingComma: funcComma,
Body: body, Body: body,
@ -1050,7 +1050,7 @@ func (p *parser) parse(prec precedence) (ast.Node, error) {
return &ast.Function{ return &ast.Function{
NodeBase: ast.NewNodeBaseLoc(locFromTokenAST(begin, body), begin.fodder), NodeBase: ast.NewNodeBaseLoc(locFromTokenAST(begin, body), begin.fodder),
ParenLeftFodder: next.fodder, ParenLeftFodder: next.fodder,
Parameters: *params, Parameters: params,
TrailingComma: gotComma, TrailingComma: gotComma,
ParenRightFodder: parenR.fodder, ParenRightFodder: parenR.fodder,
Body: body, Body: body,

View File

@ -193,9 +193,7 @@ func desugarFields(nodeBase ast.NodeBase, fields *ast.ObjectFields, objLevel int
func simpleLambda(body ast.Node, paramName ast.Identifier) ast.Node { func simpleLambda(body ast.Node, paramName ast.Identifier) ast.Node {
return &ast.Function{ return &ast.Function{
Body: body, Body: body,
Parameters: ast.Parameters{ Parameters: []ast.Parameter{{Name: paramName}},
Required: []ast.CommaSeparatedID{{Name: paramName}},
},
} }
} }
@ -431,13 +429,15 @@ func desugar(astPtr *ast.Node, objLevel int) (err error) {
} }
case *ast.Function: case *ast.Function:
for i := range node.Parameters.Optional { for i := range node.Parameters {
param := &node.Parameters.Optional[i] param := &node.Parameters[i]
if param.DefaultArg != nil {
err = desugar(&param.DefaultArg, objLevel) err = desugar(&param.DefaultArg, objLevel)
if err != nil { if err != nil {
return return
} }
} }
}
err = desugar(&node.Body, objLevel) err = desugar(&node.Body, objLevel)
if err != nil { if err != nil {
return return

View File

@ -76,21 +76,17 @@ func analyzeVisit(a ast.Node, inObject bool, vars ast.IdentifierSet) error {
visitNext(a.Expr, inObject, vars, s) visitNext(a.Expr, inObject, vars, s)
case *ast.Function: case *ast.Function:
newVars := vars.Clone() newVars := vars.Clone()
for _, param := range a.Parameters.Required { for _, param := range a.Parameters {
newVars.Add(param.Name) newVars.Add(param.Name)
} }
for _, param := range a.Parameters.Optional { for _, param := range a.Parameters {
newVars.Add(param.Name) if param.DefaultArg != nil {
}
for _, param := range a.Parameters.Optional {
visitNext(param.DefaultArg, inObject, newVars, s) visitNext(param.DefaultArg, inObject, newVars, s)
} }
}
visitNext(a.Body, inObject, newVars, s) visitNext(a.Body, inObject, newVars, s)
// Parameters are free inside the body, but not visible here or outside // Parameters are free inside the body, but not visible here or outside
for _, param := range a.Parameters.Required { for _, param := range a.Parameters {
s.freeVars.Remove(param.Name)
}
for _, param := range a.Parameters.Optional {
s.freeVars.Remove(param.Name) s.freeVars.Remove(param.Name)
} }
case *ast.Import: case *ast.Import:

View File

@ -27,15 +27,14 @@ func cloneScope(oldScope vScope) vScope {
} }
func findVariablesInFunc(node *ast.Function, info *LintingInfo, scope vScope) { func findVariablesInFunc(node *ast.Function, info *LintingInfo, scope vScope) {
for _, param := range node.Parameters.Required { for _, param := range node.Parameters {
addVar(param.Name, node, info, scope, true) addVar(param.Name, node, info, scope, true)
} }
for _, param := range node.Parameters.Optional { for _, param := range node.Parameters {
addVar(param.Name, node, info, scope, true) if param.DefaultArg != nil {
}
for _, param := range node.Parameters.Optional {
findVariables(param.DefaultArg, info, scope) findVariables(param.DefaultArg, info, scope)
} }
}
findVariables(node.Body, info, scope) findVariables(node.Body, info, scope)
} }

View File

@ -16,7 +16,9 @@ limitations under the License.
package jsonnet package jsonnet
import "github.com/google/go-jsonnet/ast" import (
"github.com/google/go-jsonnet/ast"
)
// readyValue // readyValue
// ------------------------------------- // -------------------------------------
@ -147,7 +149,7 @@ type closure struct {
// arguments should be added to it, before executing it // arguments should be added to it, before executing it
env environment env environment
function *ast.Function function *ast.Function
params parameters params []namedParameter
} }
func forceThunks(i *interpreter, trace traceElement, args *bindingFrame) error { func forceThunks(i *interpreter, trace traceElement, args *bindingFrame) error {
@ -164,13 +166,7 @@ func (closure *closure) evalCall(arguments callArguments, i *interpreter, trace
argThunks := make(bindingFrame) argThunks := make(bindingFrame)
parameters := closure.Parameters() parameters := closure.Parameters()
for i, arg := range arguments.positional { for i, arg := range arguments.positional {
var name ast.Identifier argThunks[parameters[i].name] = arg
if i < len(parameters.required) {
name = parameters.required[i]
} else {
name = parameters.optional[i-len(parameters.required)].name
}
argThunks[name] = arg
} }
for _, arg := range arguments.named { for _, arg := range arguments.named {
@ -179,8 +175,7 @@ func (closure *closure) evalCall(arguments callArguments, i *interpreter, trace
var calledEnvironment environment var calledEnvironment environment
for i := range parameters.optional { for _, param := range parameters {
param := &parameters.optional[i]
if _, exists := argThunks[param.name]; !exists { if _, exists := argThunks[param.name]; !exists {
argThunks[param.name] = &cachedThunk{ argThunks[param.name] = &cachedThunk{
// Default arguments are evaluated in the same environment as function body // Default arguments are evaluated in the same environment as function body
@ -204,27 +199,20 @@ func (closure *closure) evalCall(arguments callArguments, i *interpreter, trace
return i.EvalInCleanEnv(trace, &calledEnvironment, closure.function.Body, arguments.tailstrict) return i.EvalInCleanEnv(trace, &calledEnvironment, closure.function.Body, arguments.tailstrict)
} }
func (closure *closure) Parameters() parameters { func (closure *closure) Parameters() []namedParameter {
return closure.params return closure.params
} }
func prepareClosureParameters(params ast.Parameters, env environment) parameters { func prepareClosureParameters(params []ast.Parameter, env environment) []namedParameter {
optionalParameters := make([]namedParameter, 0, len(params.Optional)) preparedParams := make([]namedParameter, 0, len(params))
for _, named := range params.Optional { for _, named := range params {
optionalParameters = append(optionalParameters, namedParameter{ preparedParams = append(preparedParams, namedParameter{
name: named.Name, name: named.Name,
defaultArg: named.DefaultArg, defaultArg: named.DefaultArg,
}) })
} }
requiredParameters := make([]ast.Identifier, 0, len(params.Required)) return preparedParams
for _, required := range params.Required {
requiredParameters = append(requiredParameters, required.Name)
}
return parameters{
required: requiredParameters,
optional: optionalParameters,
}
} }
func makeClosure(env environment, function *ast.Function) *closure { func makeClosure(env environment, function *ast.Function) *closure {
@ -265,8 +253,12 @@ func (native *NativeFunction) evalCall(arguments callArguments, i *interpreter,
} }
// Parameters returns a NativeFunction's parameters. // Parameters returns a NativeFunction's parameters.
func (native *NativeFunction) Parameters() parameters { func (native *NativeFunction) Parameters() []namedParameter {
return parameters{required: native.Params} ret := make([]namedParameter, len(native.Params))
for i := range ret {
ret[i].name = native.Params[i]
}
return ret
} }
// ------------------------------------- // -------------------------------------

View File

@ -348,7 +348,7 @@ type valueFunction struct {
// TODO(sbarzowski) better name? // TODO(sbarzowski) better name?
type evalCallable interface { type evalCallable interface {
evalCall(args callArguments, i *interpreter, trace traceElement) (value, error) evalCall(args callArguments, i *interpreter, trace traceElement) (value, error)
Parameters() parameters Parameters() []namedParameter
} }
func (f *valueFunction) call(i *interpreter, trace traceElement, args callArguments) (value, error) { func (f *valueFunction) call(i *interpreter, trace traceElement, args callArguments) (value, error) {
@ -359,37 +359,30 @@ func (f *valueFunction) call(i *interpreter, trace traceElement, args callArgume
return f.ec.evalCall(args, i, trace) return f.ec.evalCall(args, i, trace)
} }
func (f *valueFunction) Parameters() parameters { func (f *valueFunction) Parameters() []namedParameter {
return f.ec.Parameters() return f.ec.Parameters()
} }
func checkArguments(i *interpreter, trace traceElement, args callArguments, params parameters) error { func checkArguments(i *interpreter, trace traceElement, args callArguments, params []namedParameter) error {
received := make(map[ast.Identifier]bool)
accepted := make(map[ast.Identifier]bool)
numPassed := len(args.positional) numPassed := len(args.positional)
numExpected := len(params.required) + len(params.optional) maxExpected := len(params)
if numPassed > numExpected { if numPassed > maxExpected {
return i.Error(fmt.Sprintf("function expected %v positional argument(s), but got %v", numExpected, numPassed), trace) return i.Error(fmt.Sprintf("function expected %v positional argument(s), but got %v", maxExpected, numPassed), trace)
} }
for _, param := range params.required { // Parameter names the function will accept.
accepted[param] = true accepted := make(map[ast.Identifier]bool)
} for _, param := range params {
for _, param := range params.optional {
accepted[param.name] = true accepted[param.name] = true
} }
// Parameter names the call will bind.
received := make(map[ast.Identifier]bool)
for i := range args.positional { for i := range args.positional {
if i < len(params.required) { received[params[i].name] = true
received[params.required[i]] = true
} else {
received[params.optional[i-len(params.required)].name] = true
} }
}
for _, arg := range args.named { for _, arg := range args.named {
if _, present := received[arg.name]; present { if _, present := received[arg.name]; present {
return i.Error(fmt.Sprintf("Argument %v already provided", arg.name), trace) return i.Error(fmt.Sprintf("Argument %v already provided", arg.name), trace)
@ -400,9 +393,9 @@ func checkArguments(i *interpreter, trace traceElement, args callArguments, para
received[arg.name] = true received[arg.name] = true
} }
for _, param := range params.required { for _, param := range params {
if _, present := received[param]; !present { if _, present := received[param.name]; !present && param.defaultArg == nil {
return i.Error(fmt.Sprintf("Missing argument: %v", param), trace) return i.Error(fmt.Sprintf("Missing argument: %v", param.name), trace)
} }
} }