From 6490cb1973f6c79626b032a79e9525963512fee4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Barzowski?= Date: Fri, 8 Mar 2019 12:06:31 +0100 Subject: [PATCH] Builtin implementation for std.sort Sort is something that is highly optimized in most languages and users can expect it to be fast. We can piggyback on the Go implementation. This change results in 100x speedup on bench.06.jsonnet. --- builtins.go | 215 +++++++++++++++++++++++++++++-------- testdata/std.sort.golden | 5 + testdata/std.sort.jsonnet | 1 + testdata/std.sort2.golden | 12 +++ testdata/std.sort2.jsonnet | 4 + testdata/std.sort3.golden | 15 +++ testdata/std.sort3.jsonnet | 1 + testdata/std.sort4.golden | 10 ++ testdata/std.sort4.jsonnet | 1 + thunks.go | 2 +- 10 files changed, 223 insertions(+), 43 deletions(-) create mode 100644 testdata/std.sort.golden create mode 100644 testdata/std.sort.jsonnet create mode 100644 testdata/std.sort2.golden create mode 100644 testdata/std.sort2.jsonnet create mode 100644 testdata/std.sort3.golden create mode 100644 testdata/std.sort3.jsonnet create mode 100644 testdata/std.sort4.golden create mode 100644 testdata/std.sort4.jsonnet diff --git a/builtins.go b/builtins.go index 0bd0ad5..4486957 100644 --- a/builtins.go +++ b/builtins.go @@ -31,7 +31,6 @@ import ( ) func builtinPlus(i *interpreter, trace TraceElement, x, y value) (value, error) { - // TODO(sbarzowski) more types, mixing types // TODO(sbarzowski) perhaps a more elegant way to dispatch switch right := y.(type) { case *valueString: @@ -128,25 +127,30 @@ func builtinModulo(i *interpreter, trace TraceElement, xv, yv value) (value, err return makeDoubleCheck(i, trace, math.Mod(x.value, y.value)) } -func builtinLess(i *interpreter, trace TraceElement, x, yv value) (value, error) { +func valueLess(i *interpreter, trace TraceElement, x, yv value) (bool, error) { switch left := x.(type) { case *valueNumber: right, err := i.getNumber(yv, trace) if err != nil { - return nil, err + return false, err } - return makeValueBoolean(left.value < right.value), nil + return left.value < right.value, nil case *valueString: right, err := i.getString(yv, trace) if err != nil { - return nil, err + return false, err } - return makeValueBoolean(stringLessThan(left, right)), nil + return stringLessThan(left, right), nil default: - return nil, i.typeErrorGeneral(x, trace) + return false, i.typeErrorGeneral(x, trace) } } +func builtinLess(i *interpreter, trace TraceElement, x, yv value) (value, error) { + b, err := valueLess(i, trace, x, yv) + return makeValueBoolean(b), err +} + func builtinGreater(i *interpreter, trace TraceElement, x, y value) (value, error) { return builtinLess(i, trace, y, x) } @@ -371,6 +375,82 @@ func builtinFilter(i *interpreter, trace TraceElement, funcv, arrv value) (value return makeValueArray(elems), nil } +type sortData struct { + i *interpreter + trace TraceElement + thunks []*cachedThunk + keys []value + err error +} + +func (d *sortData) Len() int { + return len(d.thunks) +} + +func (d *sortData) Less(i, j int) bool { + b, err := valueLess(d.i, d.trace, d.keys[i], d.keys[j]) + if err != nil { + d.err = err + panic("Error while comparing elements") + } + return b +} + +func (d *sortData) Swap(i, j int) { + d.thunks[i], d.thunks[j] = d.thunks[j], d.thunks[i] + d.keys[i], d.keys[j] = d.keys[j], d.keys[i] +} + +func (d *sortData) Sort() (err error) { + defer func() { + if d.err != nil { + if r := recover(); r != nil { + err = d.err + } + } + }() + sort.Stable(d) + return +} + +func arrayFromThunks(vs []value) *valueArray { + thunks := make([]*cachedThunk, len(vs)) + for i := range vs { + thunks[i] = readyThunk(vs[i]) + } + return makeValueArray(thunks) +} + +func builtinSort(i *interpreter, trace TraceElement, arguments []value) (value, error) { + arrv := arguments[0] + keyFv := arguments[1] + + arr, err := i.getArray(arrv, trace) + if err != nil { + return nil, err + } + keyF, err := i.getFunction(keyFv, trace) + if err != nil { + return nil, err + } + num := arr.length() + + data := sortData{i: i, trace: trace, thunks: make([]*cachedThunk, num), keys: make([]value, num)} + + for counter := 0; counter < num; counter++ { + var err error + data.thunks[counter] = arr.elements[counter] + data.keys[counter], err = keyF.call(i, trace, args(arr.elements[counter])) + if err != nil { + return nil, err + } + } + + data.Sort() + + return makeValueArray(data.thunks), nil +} + func builtinRange(i *interpreter, trace TraceElement, fromv, tov value) (value, error) { from, err := i.getInt(fromv, trace) if err != nil { @@ -909,9 +989,38 @@ func builtinNative(i *interpreter, trace TraceElement, name value) (value, error return &valueNull{}, nil } +// Utils for builtins - TODO(sbarzowski) move to a separate file in another commit + +type builtin interface { + evalCallable + Name() ast.Identifier +} + +func flattenArgs(args callArguments, params Parameters, defaults []value) []*cachedThunk { + positions := make(map[ast.Identifier]int) + for i := 0; i < len(params.required); i++ { + positions[params.required[i]] = i + } + for i := 0; i < len(params.optional); i++ { + positions[params.optional[i].name] = i + len(params.required) + } + + flatArgs := make([]*cachedThunk, len(params.required)+len(params.optional)) + + copy(flatArgs, args.positional) + for _, arg := range args.named { + flatArgs[positions[arg.name]] = arg.pv + } + for i := 0; i < len(params.optional); i++ { + pos := len(params.required) + i + if flatArgs[pos] == nil { + flatArgs[pos] = readyThunk(defaults[i]) + } + } + return flatArgs +} + type unaryBuiltinFunc func(*interpreter, TraceElement, value) (value, error) -type binaryBuiltinFunc func(*interpreter, TraceElement, value, value) (value, error) -type ternaryBuiltinFunc func(*interpreter, TraceElement, value, value, value) (value, error) type unaryBuiltin struct { name ast.Identifier @@ -925,7 +1034,7 @@ func getBuiltinTrace(trace TraceElement, name ast.Identifier) TraceElement { } func (b *unaryBuiltin) evalCall(args callArguments, i *interpreter, trace TraceElement) (value, error) { - flatArgs := flattenArgs(args, b.Parameters()) + flatArgs := flattenArgs(args, b.Parameters(), []value{}) builtinTrace := getBuiltinTrace(trace, b.name) x, err := flatArgs[0].getValue(i, trace) if err != nil { @@ -942,39 +1051,16 @@ func (b *unaryBuiltin) Name() ast.Identifier { return b.name } +type binaryBuiltinFunc func(*interpreter, TraceElement, value, value) (value, error) + type binaryBuiltin struct { name ast.Identifier function binaryBuiltinFunc parameters ast.Identifiers } -// flattenArgs transforms all arguments to a simple array of positional arguments. -// It's needed, because it's possible to use named arguments for required parameters. -// For example both `toString("x")` and `toString(a="x")` are allowed. -// It assumes that we have already checked for duplicates. -func flattenArgs(args callArguments, params Parameters) []*cachedThunk { - if len(args.named) == 0 { - return args.positional - } - if len(params.optional) != 0 { - panic("Can't normalize arguments if optional parameters are present") - } - needed := make(map[ast.Identifier]int) - - for i := len(args.positional); i < len(params.required); i++ { - needed[params.required[i]] = i - } - - flatArgs := make([]*cachedThunk, len(params.required)) - copy(flatArgs, args.positional) - for _, arg := range args.named { - flatArgs[needed[arg.name]] = arg.pv - } - return flatArgs -} - func (b *binaryBuiltin) evalCall(args callArguments, i *interpreter, trace TraceElement) (value, error) { - flatArgs := flattenArgs(args, b.Parameters()) + flatArgs := flattenArgs(args, b.Parameters(), []value{}) builtinTrace := getBuiltinTrace(trace, b.name) x, err := flatArgs[0].getValue(i, trace) if err != nil { @@ -995,6 +1081,8 @@ func (b *binaryBuiltin) Name() ast.Identifier { return b.name } +type ternaryBuiltinFunc func(*interpreter, TraceElement, value, value, value) (value, error) + type ternaryBuiltin struct { name ast.Identifier function ternaryBuiltinFunc @@ -1002,7 +1090,7 @@ type ternaryBuiltin struct { } func (b *ternaryBuiltin) evalCall(args callArguments, i *interpreter, trace TraceElement) (value, error) { - flatArgs := flattenArgs(args, b.Parameters()) + flatArgs := flattenArgs(args, b.Parameters(), []value{}) builtinTrace := getBuiltinTrace(trace, b.name) x, err := flatArgs[0].getValue(i, trace) if err != nil { @@ -1027,6 +1115,52 @@ func (b *ternaryBuiltin) Name() ast.Identifier { return b.name } +type generalBuiltinFunc func(*interpreter, TraceElement, []value) (value, error) + +// generalBuiltin covers cases that other builtin structures do not, +// in particular it can have any number of parameters. It can also +// have optional parameters. +type generalBuiltin struct { + name ast.Identifier + required ast.Identifiers + optional ast.Identifiers + // Note that the defaults are passed as values rather than AST nodes like in Parameters. + // This spares us unnecessary evaluation. + defaultValues []value + function generalBuiltinFunc +} + +func (b *generalBuiltin) Parameters() Parameters { + optional := make([]namedParameter, len(b.optional)) + for i := range optional { + optional[i] = namedParameter{name: b.optional[i]} + } + return Parameters{required: b.required, optional: optional} +} + +func (b *generalBuiltin) Name() ast.Identifier { + return b.name +} + +func (b *generalBuiltin) evalCall(args callArguments, i *interpreter, trace TraceElement) (value, error) { + flatArgs := flattenArgs(args, b.Parameters(), b.defaultValues) + builtinTrace := getBuiltinTrace(trace, b.name) + values := make([]value, len(flatArgs)) + for j := 0; j < len(values); j++ { + var err error + values[j], err = flatArgs[j].getValue(i, trace) + if err != nil { + return nil, err + } + } + return b.function(i, builtinTrace, values) +} + +// End of builtin utils + +var builtinID = &unaryBuiltin{name: "id", function: builtinIdentity, parameters: ast.Identifiers{"x"}} +var functionID = &valueFunction{ec: builtinID} + var bopBuiltins = []*binaryBuiltin{ // Note that % and `in` are desugared instead of being handled here ast.BopMult: &binaryBuiltin{name: "operator*", function: builtinMult, parameters: ast.Identifiers{"x", "y"}}, @@ -1058,11 +1192,6 @@ var uopBuiltins = []*unaryBuiltin{ ast.UopMinus: &unaryBuiltin{name: "operator- (unary)", function: builtinUnaryMinus, parameters: ast.Identifiers{"x"}}, } -type builtin interface { - evalCallable - Name() ast.Identifier -} - func buildBuiltinMap(builtins []builtin) map[string]evalCallable { result := make(map[string]evalCallable) for _, b := range builtins { @@ -1072,6 +1201,7 @@ func buildBuiltinMap(builtins []builtin) map[string]evalCallable { } var funcBuiltins = buildBuiltinMap([]builtin{ + builtinID, &unaryBuiltin{name: "extVar", function: builtinExtVar, parameters: ast.Identifiers{"x"}}, &unaryBuiltin{name: "length", function: builtinLength, parameters: ast.Identifiers{"x"}}, &unaryBuiltin{name: "toString", function: builtinToString, parameters: ast.Identifiers{"a"}}, @@ -1109,6 +1239,7 @@ var funcBuiltins = buildBuiltinMap([]builtin{ &unaryBuiltin{name: "parseJson", function: builtinParseJSON, parameters: ast.Identifiers{"str"}}, &unaryBuiltin{name: "encodeUTF8", function: builtinEncodeUTF8, parameters: ast.Identifiers{"str"}}, &unaryBuiltin{name: "decodeUTF8", function: builtinDecodeUTF8, parameters: ast.Identifiers{"arr"}}, + &generalBuiltin{name: "sort", function: builtinSort, required: ast.Identifiers{"arr"}, optional: ast.Identifiers{"keyF"}, defaultValues: []value{functionID}}, &unaryBuiltin{name: "native", function: builtinNative, parameters: ast.Identifiers{"x"}}, // internal diff --git a/testdata/std.sort.golden b/testdata/std.sort.golden new file mode 100644 index 0000000..a238abc --- /dev/null +++ b/testdata/std.sort.golden @@ -0,0 +1,5 @@ +[ + 1, + 2, + 3 +] diff --git a/testdata/std.sort.jsonnet b/testdata/std.sort.jsonnet new file mode 100644 index 0000000..8c911ed --- /dev/null +++ b/testdata/std.sort.jsonnet @@ -0,0 +1 @@ +std.sort([1,2,3]) diff --git a/testdata/std.sort2.golden b/testdata/std.sort2.golden new file mode 100644 index 0000000..3ab8a63 --- /dev/null +++ b/testdata/std.sort2.golden @@ -0,0 +1,12 @@ +[ + [ + 3, + 2, + 1 + ], + [ + 3, + 2, + 1 + ] +] diff --git a/testdata/std.sort2.jsonnet b/testdata/std.sort2.jsonnet new file mode 100644 index 0000000..57f5083 --- /dev/null +++ b/testdata/std.sort2.jsonnet @@ -0,0 +1,4 @@ +[ + std.sort([1,2,3], keyF=(function(x) -x)), + std.sort([1,2,3], function(x) -x), +] diff --git a/testdata/std.sort3.golden b/testdata/std.sort3.golden new file mode 100644 index 0000000..3e2ce41 --- /dev/null +++ b/testdata/std.sort3.golden @@ -0,0 +1,15 @@ +RUNTIME ERROR: foo +------------------------------------------------- + testdata/std.sort3:1:16-27 thunk from > + +std.sort([1,2, error "foo"]) + +------------------------------------------------- + testdata/std.sort3:1:1-29 builtin function + +std.sort([1,2, error "foo"]) + +------------------------------------------------- + During evaluation + + diff --git a/testdata/std.sort3.jsonnet b/testdata/std.sort3.jsonnet new file mode 100644 index 0000000..a9c5909 --- /dev/null +++ b/testdata/std.sort3.jsonnet @@ -0,0 +1 @@ +std.sort([1,2, error "foo"]) diff --git a/testdata/std.sort4.golden b/testdata/std.sort4.golden new file mode 100644 index 0000000..19205ba --- /dev/null +++ b/testdata/std.sort4.golden @@ -0,0 +1,10 @@ +RUNTIME ERROR: foo +------------------------------------------------- + testdata/std.sort4:1:15-26 thunk from >> + +std.sort([1, [error "foo"]]) + +------------------------------------------------- + During manifestation + + diff --git a/testdata/std.sort4.jsonnet b/testdata/std.sort4.jsonnet new file mode 100644 index 0000000..caf98ad --- /dev/null +++ b/testdata/std.sort4.jsonnet @@ -0,0 +1 @@ +std.sort([1, [error "foo"]]) diff --git a/thunks.go b/thunks.go index 2d28f50..6c6c84e 100644 --- a/thunks.go +++ b/thunks.go @@ -240,7 +240,7 @@ type NativeFunction struct { // evalCall evaluates a call to a NativeFunction and returns the result. func (native *NativeFunction) evalCall(arguments callArguments, i *interpreter, trace TraceElement) (value, error) { - flatArgs := flattenArgs(arguments, native.Parameters()) + flatArgs := flattenArgs(arguments, native.Parameters(), []value{}) nativeArgs := make([]interface{}, 0, len(flatArgs)) for _, arg := range flatArgs { v, err := i.evaluatePV(arg, trace)