Basic tailstrict support

This commit is contained in:
Stanisław Barzowski 2017-10-04 18:22:15 -04:00 committed by Dave Cunningham
parent 2db3d1c3cc
commit 0f049eaa38
20 changed files with 160 additions and 72 deletions

View File

@ -42,6 +42,13 @@ func (e *evaluator) evaluate(ph potentialValue) (value, error) {
return ph.getValue(e.i, e.trace) return ph.getValue(e.i, e.trace)
} }
func (e *evaluator) evaluateTailCall(ph potentialValue, tc tailCallStatus) (value, error) {
if tc == tailCall {
e.i.stack.tailCallTrimStack()
}
return ph.getValue(e.i, e.trace)
}
func (e *evaluator) Error(s string) error { func (e *evaluator) Error(s string) error {
err := makeRuntimeError(s, e.i.getCurrentStackTrace(e.trace)) err := makeRuntimeError(s, e.i.getCurrentStackTrace(e.trace))
return err return err
@ -161,12 +168,12 @@ func (e *evaluator) evaluateObject(pv potentialValue) (valueObject, error) {
return e.getObject(v) return e.getObject(v)
} }
func (e *evaluator) evalInCurrentContext(a ast.Node) (value, error) { func (e *evaluator) evalInCurrentContext(a ast.Node, tc tailCallStatus) (value, error) {
return e.i.evaluate(a) return e.i.evaluate(a, tc)
} }
func (e *evaluator) evalInCleanEnv(env *environment, ast ast.Node) (value, error) { func (e *evaluator) evalInCleanEnv(env *environment, ast ast.Node, trimmable bool) (value, error) {
return e.i.EvalInCleanEnv(e.trace, env, ast) return e.i.EvalInCleanEnv(e.trace, env, ast, trimmable)
} }
func (e *evaluator) lookUpVar(ident ast.Identifier) potentialValue { func (e *evaluator) lookUpVar(ident ast.Identifier) potentialValue {

View File

@ -74,17 +74,26 @@ type callFrame struct {
// Tracing information about the place where it was called from. // Tracing information about the place where it was called from.
trace *TraceElement trace *TraceElement
/** Reuse this stack frame for the purpose of tail call optimization. */ // Whether this frame can be removed from the stack when it doesn't affect
tailCall bool // TODO what is it? // the evaluation result, but in case of an error, it won't appear on the
// stack trace.
// It's used for tail call optimization.
trimmable bool
env environment env environment
} }
func dumpCallFrame(c *callFrame) string { func dumpCallFrame(c *callFrame) string {
return fmt.Sprintf("<callFrame isCall = %t location = %v tailCall = %t>", var loc ast.LocationRange
if c.trace == nil || c.trace.loc == nil {
loc = ast.MakeLocationRangeMessage("?")
} else {
loc = *c.trace.loc
}
return fmt.Sprintf("<callFrame isCall = %t location = %v trimmable = %t>",
c.isCall, c.isCall,
*c.trace.loc, loc,
c.tailCall, c.trimmable,
) )
} }
@ -105,24 +114,27 @@ func dumpCallStack(c *callStack) string {
} }
func (s *callStack) top() *callFrame { func (s *callStack) top() *callFrame {
return s.stack[len(s.stack)-1] r := s.stack[len(s.stack)-1]
return r
} }
func (s *callStack) pop() { // It might've been popped already by tail call optimization.
if s.top().isCall { // We check if it was trimmed by comparing the current stack size to the position
s.calls-- // of the frame we want to pop.
func (s *callStack) popIfExists(whichFrame int) {
if len(s.stack) == whichFrame {
if s.top().isCall {
s.calls--
}
s.stack = s.stack[:len(s.stack)-1]
} }
s.stack = s.stack[:len(s.stack)-1]
} }
// TODO(sbarzowski) I don't get this. When we have a tail call why can't we just /** If there is a trimmable frame followed by some locals, pop them all. */
// pop the last call from stack before pushing our new thing.
// https://github.com/google/go-jsonnet/pull/24#pullrequestreview-58524217
/** If there is a tailstrict annotated frame followed by some locals, pop them all. */
func (s *callStack) tailCallTrimStack() { func (s *callStack) tailCallTrimStack() {
for i := len(s.stack) - 1; i >= 0; i-- { for i := len(s.stack) - 1; i >= 0; i-- {
if s.stack[i].isCall { if s.stack[i].isCall {
if !s.stack[i].tailCall { // TODO(sbarzowski) we may need to check some more stuff if !s.stack[i].trimmable {
return return
} }
// Remove this stack frame and everything above it // Remove this stack frame and everything above it
@ -133,18 +145,23 @@ func (s *callStack) tailCallTrimStack() {
} }
} }
func (i *interpreter) newCall(trace *TraceElement, env environment) error { type tailCallStatus int
const (
nonTailCall tailCallStatus = iota
tailCall
)
func (i *interpreter) newCall(trace *TraceElement, env environment, trimmable bool) error {
s := &i.stack s := &i.stack
s.tailCallTrimStack()
if s.calls >= s.limit { if s.calls >= s.limit {
// TODO(sbarzowski) add tracing information
return makeRuntimeError("Max stack frames exceeded.", i.getCurrentStackTrace(trace)) return makeRuntimeError("Max stack frames exceeded.", i.getCurrentStackTrace(trace))
} }
s.stack = append(s.stack, &callFrame{ s.stack = append(s.stack, &callFrame{
isCall: true, isCall: true,
trace: trace, trace: trace,
env: env, env: env,
tailCall: false, trimmable: trimmable,
}) })
s.calls++ s.calls++
return nil return nil
@ -239,7 +256,7 @@ func (i *interpreter) getCurrentEnv(ast ast.Node) environment {
) )
} }
func (i *interpreter) evaluate(a ast.Node) (value, error) { func (i *interpreter) evaluate(a ast.Node, tc tailCallStatus) (value, error) {
e := &evaluator{ e := &evaluator{
trace: &TraceElement{ trace: &TraceElement{
loc: a.Loc(), loc: a.Loc(),
@ -290,7 +307,7 @@ func (i *interpreter) evaluate(a ast.Node) (value, error) {
return result, nil return result, nil
case *ast.Conditional: case *ast.Conditional:
cond, err := e.evalInCurrentContext(ast.Cond) cond, err := e.evalInCurrentContext(ast.Cond, nonTailCall)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -299,15 +316,15 @@ func (i *interpreter) evaluate(a ast.Node) (value, error) {
return nil, err return nil, err
} }
if condBool.value { if condBool.value {
return e.evalInCurrentContext(ast.BranchTrue) return e.evalInCurrentContext(ast.BranchTrue, tc)
} }
return e.evalInCurrentContext(ast.BranchFalse) return e.evalInCurrentContext(ast.BranchFalse, tc)
case *ast.DesugaredObject: case *ast.DesugaredObject:
// Evaluate all the field names. Check for null, dups, etc. // Evaluate all the field names. Check for null, dups, etc.
fields := make(simpleObjectFieldMap) fields := make(simpleObjectFieldMap)
for _, field := range ast.Fields { for _, field := range ast.Fields {
fieldNameValue, err := e.evalInCurrentContext(field.Name) fieldNameValue, err := e.evalInCurrentContext(field.Name, nonTailCall)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -339,7 +356,7 @@ func (i *interpreter) evaluate(a ast.Node) (value, error) {
return makeValueSimpleObject(upValues, fields, asserts), nil return makeValueSimpleObject(upValues, fields, asserts), nil
case *ast.Error: case *ast.Error:
msgVal, err := e.evalInCurrentContext(ast.Expr) msgVal, err := e.evalInCurrentContext(ast.Expr, nonTailCall)
if err != nil { if err != nil {
// error when evaluating error message // error when evaluating error message
return nil, err return nil, err
@ -351,11 +368,11 @@ func (i *interpreter) evaluate(a ast.Node) (value, error) {
return nil, e.Error(msg.getString()) return nil, e.Error(msg.getString())
case *ast.Index: case *ast.Index:
targetValue, err := e.evalInCurrentContext(ast.Target) targetValue, err := e.evalInCurrentContext(ast.Target, nonTailCall)
if err != nil { if err != nil {
return nil, err return nil, err
} }
index, err := e.evalInCurrentContext(ast.Index) index, err := e.evalInCurrentContext(ast.Index, nonTailCall)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -372,7 +389,8 @@ func (i *interpreter) evaluate(a ast.Node) (value, error) {
return nil, err return nil, err
} }
// TODO(https://github.com/google/jsonnet/issues/377): non-integer indexes should be an error // TODO(https://github.com/google/jsonnet/issues/377): non-integer indexes should be an error
return e.evaluate(target.elements[int(indexInt.value)]) return e.evaluateTailCall(target.elements[int(indexInt.value)], tc)
case *valueString: case *valueString:
indexInt, err := e.getNumber(index) indexInt, err := e.getNumber(index)
if err != nil { if err != nil {
@ -415,10 +433,12 @@ func (i *interpreter) evaluate(a ast.Node) (value, error) {
bindEnv.upValues[bind.Variable] = th bindEnv.upValues[bind.Variable] = th
} }
i.newLocal(vars) i.newLocal(vars)
sz := len(i.stack.stack)
// Add new stack frame, with new thunk for this variable // Add new stack frame, with new thunk for this variable
// execute body WRT stack frame. // execute body WRT stack frame.
v, err := e.evalInCurrentContext(ast.Body) v, err := e.evalInCurrentContext(ast.Body, tc)
i.stack.pop() i.stack.popIfExists(sz)
return v, err return v, err
case *ast.Self: case *ast.Self:
@ -426,10 +446,10 @@ func (i *interpreter) evaluate(a ast.Node) (value, error) {
return sb.self, nil return sb.self, nil
case *ast.Var: case *ast.Var:
return e.evaluate(e.lookUpVar(ast.Id)) return e.evaluateTailCall(e.lookUpVar(ast.Id), tc)
case *ast.SuperIndex: case *ast.SuperIndex:
index, err := e.evalInCurrentContext(ast.Index) index, err := e.evalInCurrentContext(ast.Index, nonTailCall)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -440,7 +460,7 @@ func (i *interpreter) evaluate(a ast.Node) (value, error) {
return objectIndex(e, i.stack.getSelfBinding().super(), indexStr.getString()) return objectIndex(e, i.stack.getSelfBinding().super(), indexStr.getString())
case *ast.InSuper: case *ast.InSuper:
index, err := e.evalInCurrentContext(ast.Index) index, err := e.evalInCurrentContext(ast.Index, nonTailCall)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -458,7 +478,7 @@ func (i *interpreter) evaluate(a ast.Node) (value, error) {
case *ast.Apply: case *ast.Apply:
// Eval target // Eval target
target, err := e.evalInCurrentContext(ast.Target) target, err := e.evalInCurrentContext(ast.Target, nonTailCall)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -469,10 +489,10 @@ func (i *interpreter) evaluate(a ast.Node) (value, error) {
// environment in which we can evaluate arguments // environment in which we can evaluate arguments
argEnv := i.getCurrentEnv(a) argEnv := i.getCurrentEnv(a)
arguments := callArguments{ arguments := callArguments{
positional: make([]potentialValue, len(ast.Arguments.Positional)), positional: make([]potentialValue, len(ast.Arguments.Positional)),
named: make([]namedCallArgument, len(ast.Arguments.Named)), named: make([]namedCallArgument, len(ast.Arguments.Named)),
tailstrict: ast.TailStrict,
} }
for i, arg := range ast.Arguments.Positional { for i, arg := range ast.Arguments.Positional {
arguments.positional[i] = makeThunk(argEnv, arg) arguments.positional[i] = makeThunk(argEnv, arg)
@ -481,8 +501,7 @@ func (i *interpreter) evaluate(a ast.Node) (value, error) {
for i, arg := range ast.Arguments.Named { for i, arg := range ast.Arguments.Named {
arguments.named[i] = namedCallArgument{name: arg.Name, pv: makeThunk(argEnv, arg.Arg)} arguments.named[i] = namedCallArgument{name: arg.Name, pv: makeThunk(argEnv, arg.Arg)}
} }
return e.evaluateTailCall(function.call(arguments), tc)
return e.evaluate(function.call(arguments))
default: default:
return nil, e.Error(fmt.Sprintf("Executing this AST type not implemented yet: %v", reflect.TypeOf(a))) return nil, e.Error(fmt.Sprintf("Executing this AST type not implemented yet: %v", reflect.TypeOf(a)))
@ -557,7 +576,7 @@ func (i *interpreter) manifestJSON(trace *TraceElement, v value) (interface{}, e
case *valueArray: case *valueArray:
result := make([]interface{}, 0, len(v.elements)) result := make([]interface{}, 0, len(v.elements))
for _, th := range v.elements { for _, th := range v.elements {
elVal, err := th.getValue(i, trace) // TODO(sbarzowski) perhaps manifestJSON should just take potentialValue elVal, err := e.evaluate(th)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -711,13 +730,17 @@ func (i *interpreter) manifestAndSerializeJSON(trace *TraceElement, v value, mul
return buf.String(), nil return buf.String(), nil
} }
func (i *interpreter) EvalInCleanEnv(fromWhere *TraceElement, env *environment, ast ast.Node) (value, error) { func (i *interpreter) EvalInCleanEnv(fromWhere *TraceElement, env *environment, ast ast.Node, trimmable bool) (value, error) {
err := i.newCall(fromWhere, *env) err := i.newCall(fromWhere, *env, trimmable)
if err != nil { if err != nil {
return nil, err return nil, err
} }
val, err := i.evaluate(ast) stackSize := len(i.stack.stack)
i.stack.pop()
val, err := i.evaluate(ast, tailCall)
i.stack.popIfExists(stackSize)
return val, err return val, err
} }
@ -750,7 +773,7 @@ func evaluateStd(i *interpreter) (value, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return i.EvalInCleanEnv(evalTrace, &beforeStdEnv, node) return i.EvalInCleanEnv(evalTrace, &beforeStdEnv, node, false)
} }
func prepareExtVars(i *interpreter, ext vmExtMap, kind string) map[ast.Identifier]potentialValue { func prepareExtVars(i *interpreter, ext vmExtMap, kind string) map[ast.Identifier]potentialValue {
@ -821,7 +844,7 @@ func evaluate(node ast.Node, ext vmExtMap, tla vmExtMap, maxStack int, importer
loc: &evalLoc, loc: &evalLoc,
} }
env := makeInitialEnv(node.Loc().FileName, i.baseStd) env := makeInitialEnv(node.Loc().FileName, i.baseStd)
result, err := i.EvalInCleanEnv(evalTrace, &env, node) result, err := i.EvalInCleanEnv(evalTrace, &env, node, false)
if err != nil { if err != nil {
return "", err return "", err
} }

View File

@ -9,16 +9,6 @@ RUNTIME ERROR: Too many values to format: 2, expected 1
format_codes_arr(codes, arr, i + 1, j, v + code) tailstrict format_codes_arr(codes, arr, i + 1, j, v + code) tailstrict
-------------------------------------------------
<std>:568:21-74 function <format_codes_arr>
format_codes_arr(codes, arr, i + 1, j3, v + s_padded) tailstrict;
-------------------------------------------------
<std>:525:21-69 function <format_codes_arr>
format_codes_arr(codes, arr, i + 1, j, v + code) tailstrict
------------------------------------------------- -------------------------------------------------
<std>:612:13-52 function <anonymous> <std>:612:13-52 function <anonymous>

View File

@ -18,11 +18,11 @@ RUNTIME ERROR: Not enough values to format, got 1
<builtin> builtin function <toString> <builtin> builtin function <toString>
------------------------------------------------- -------------------------------------------------
... (skipped 21 frames) ... (skipped 14 frames)
------------------------------------------------- -------------------------------------------------
<std>:525:21-69 function <format_codes_arr> <std>:568:21-74 function <format_codes_arr>
format_codes_arr(codes, arr, i + 1, j, v + code) tailstrict format_codes_arr(codes, arr, i + 1, j3, v + s_padded) tailstrict;
------------------------------------------------- -------------------------------------------------
<std>:612:13-52 function <anonymous> <std>:612:13-52 function <anonymous>

View File

@ -18,11 +18,11 @@ RUNTIME ERROR: Not enough values to format, got 1
<builtin> builtin function <toString> <builtin> builtin function <toString>
------------------------------------------------- -------------------------------------------------
... (skipped 21 frames) ... (skipped 14 frames)
------------------------------------------------- -------------------------------------------------
<std>:525:21-69 function <format_codes_arr> <std>:568:21-74 function <format_codes_arr>
format_codes_arr(codes, arr, i + 1, j, v + code) tailstrict format_codes_arr(codes, arr, i + 1, j3, v + s_padded) tailstrict;
------------------------------------------------- -------------------------------------------------
<std>:616:13-54 function <anonymous> <std>:616:13-54 function <anonymous>

View File

@ -21,11 +21,11 @@ RUNTIME ERROR: Format required number at 0, got string
padding(w - std.length(str), s) + str; padding(w - std.length(str), s) + str;
------------------------------------------------- -------------------------------------------------
... (skipped 16 frames) ... (skipped 11 frames)
------------------------------------------------- -------------------------------------------------
<std>:525:21-69 function <format_codes_arr> <std>:568:21-74 function <format_codes_arr>
format_codes_arr(codes, arr, i + 1, j, v + code) tailstrict format_codes_arr(codes, arr, i + 1, j3, v + s_padded) tailstrict;
------------------------------------------------- -------------------------------------------------
<std>:612:13-52 function <anonymous> <std>:612:13-52 function <anonymous>

View File

@ -0,0 +1 @@
true

View File

@ -0,0 +1 @@
local xxx=0; xxx==0

1
testdata/tailstrict.golden vendored Normal file
View File

@ -0,0 +1 @@
642

2
testdata/tailstrict.jsonnet vendored Normal file
View File

@ -0,0 +1,2 @@
local arr = [function(x) x] + std.makeArray(600, function(i) (function(x) arr[i](x + 1) tailstrict));
arr[600](42)

15
testdata/tailstrict2.golden vendored Normal file
View File

@ -0,0 +1,15 @@
RUNTIME ERROR: xxx
-------------------------------------------------
testdata/tailstrict2:1:13-20 function <e>
local e(x)=(error x);
-------------------------------------------------
testdata/tailstrict2:2:14-18 function <anonymous>
(function(x) e(x))("xxx") tailstrict
-------------------------------------------------
During evaluation

2
testdata/tailstrict2.jsonnet vendored Normal file
View File

@ -0,0 +1,2 @@
local e(x)=(error x);
(function(x) e(x))("xxx") tailstrict

15
testdata/tailstrict3.golden vendored Normal file
View File

@ -0,0 +1,15 @@
RUNTIME ERROR: xxx
-------------------------------------------------
testdata/tailstrict3:1:16-26 function <foo>
local foo(x, y=error "xxx")=x;
-------------------------------------------------
testdata/tailstrict3:2:1-8 $
foo(42) tailstrict
-------------------------------------------------
During evaluation

2
testdata/tailstrict3.jsonnet vendored Normal file
View File

@ -0,0 +1,2 @@
local foo(x, y=error "xxx")=x;
foo(42) tailstrict

1
testdata/tailstrict4.golden vendored Normal file
View File

@ -0,0 +1 @@
42

2
testdata/tailstrict4.jsonnet vendored Normal file
View File

@ -0,0 +1,2 @@
local foo(x, y=error "xxx")=x;
foo(42, y=5) tailstrict

1
testdata/tailstrict5.golden vendored Normal file
View File

@ -0,0 +1 @@
500500

7
testdata/tailstrict5.jsonnet vendored Normal file
View File

@ -0,0 +1,7 @@
local sum(x, v) =
if x <= 0 then
v
else
sum(x - 1, x + v) tailstrict;
sum(1000, 0)

View File

@ -70,7 +70,7 @@ func makeThunk(env environment, body ast.Node) *cachedThunk {
} }
func (t *thunk) getValue(i *interpreter, trace *TraceElement) (value, error) { func (t *thunk) getValue(i *interpreter, trace *TraceElement) (value, error) {
return i.EvalInCleanEnv(trace, &t.env, t.body) return i.EvalInCleanEnv(trace, &t.env, t.body, false)
} }
// callThunk represents a concrete, but not yet evaluated call to a function // callThunk represents a concrete, but not yet evaluated call to a function
@ -205,6 +205,16 @@ type closure struct {
params Parameters params Parameters
} }
func forceThunks(e *evaluator, args bindingFrame) error {
for _, arg := range args {
_, err := e.evaluate(arg)
if err != nil {
return err
}
}
return nil
}
func (closure *closure) EvalCall(arguments callArguments, e *evaluator) (value, error) { func (closure *closure) EvalCall(arguments callArguments, e *evaluator) (value, error) {
argThunks := make(bindingFrame) argThunks := make(bindingFrame)
parameters := closure.Parameters() parameters := closure.Parameters()
@ -234,11 +244,18 @@ func (closure *closure) EvalCall(arguments callArguments, e *evaluator) (value,
} }
} }
if arguments.tailstrict {
err := forceThunks(e, argThunks)
if err != nil {
return nil, err
}
}
calledEnvironment = makeEnvironment( calledEnvironment = makeEnvironment(
addBindings(closure.env.upValues, argThunks), addBindings(closure.env.upValues, argThunks),
closure.env.sb, closure.env.sb,
) )
return e.evalInCleanEnv(&calledEnvironment, closure.function.Body) return e.evalInCleanEnv(&calledEnvironment, closure.function.Body, arguments.tailstrict)
} }
func (closure *closure) Parameters() Parameters { func (closure *closure) Parameters() Parameters {

View File

@ -328,6 +328,7 @@ type potentialValueInEnv interface {
type callArguments struct { type callArguments struct {
positional []potentialValue positional []potentialValue
named []namedCallArgument named []namedCallArgument
tailstrict bool
} }
type namedCallArgument struct { type namedCallArgument struct {