diff options
Diffstat (limited to 'src/op.cc')
-rw-r--r-- | src/op.cc | 398 |
1 files changed, 270 insertions, 128 deletions
@@ -38,6 +38,16 @@ namespace ledger { +void intrusive_ptr_add_ref(const expr_t::op_t * op) +{ + op->acquire(); +} + +void intrusive_ptr_release(const expr_t::op_t * op) +{ + op->release(); +} + namespace { value_t split_cons_expr(expr_t::ptr_op_t op) { @@ -76,10 +86,12 @@ namespace { } } -expr_t::ptr_op_t expr_t::op_t::compile(scope_t& scope, const int depth) +expr_t::ptr_op_t expr_t::op_t::compile(scope_t& scope, const int depth, + scope_t * param_scope) { - scope_t * scope_ptr = &scope; - expr_t::ptr_op_t result; + scope_t * scope_ptr = &scope; + unique_ptr<scope_t> bound_scope; + expr_t::ptr_op_t result; #if defined(DEBUG_ON) if (SHOW_DEBUG("expr.compile")) { @@ -93,7 +105,12 @@ expr_t::ptr_op_t expr_t::op_t::compile(scope_t& scope, const int depth) if (is_ident()) { DEBUG("expr.compile", "Lookup: " << as_ident() << " in " << scope_ptr); - if (ptr_op_t def = scope_ptr->lookup(symbol_t::FUNCTION, as_ident())) { + ptr_op_t def; + if (param_scope) + def = param_scope->lookup(symbol_t::FUNCTION, as_ident()); + if (! def) + def = scope_ptr->lookup(symbol_t::FUNCTION, as_ident()); + if (def) { // Identifier references are first looked up at the point of // definition, and then at the point of every use if they could // not be found there. @@ -113,9 +130,10 @@ expr_t::ptr_op_t expr_t::op_t::compile(scope_t& scope, const int depth) } } else if (is_scope()) { - shared_ptr<scope_t> subscope(new symbol_scope_t(scope)); + shared_ptr<scope_t> subscope(new symbol_scope_t(*scope_t::empty_scope)); set_scope(subscope); - scope_ptr = subscope.get(); + bound_scope.reset(new bind_scope_t(*scope_ptr, *subscope.get())); + scope_ptr = bound_scope.get(); } else if (kind < TERMINALS) { result = this; @@ -123,7 +141,7 @@ expr_t::ptr_op_t expr_t::op_t::compile(scope_t& scope, const int depth) else if (kind == O_DEFINE) { switch (left()->kind) { case IDENT: { - ptr_op_t node(right()->compile(*scope_ptr, depth + 1)); + ptr_op_t node(right()->compile(*scope_ptr, depth + 1, param_scope)); DEBUG("expr.compile", "Defining " << left()->as_ident() << " in " << scope_ptr); @@ -136,12 +154,12 @@ expr_t::ptr_op_t expr_t::op_t::compile(scope_t& scope, const int depth) ptr_op_t node(new op_t(op_t::O_LAMBDA)); node->set_left(left()->right()); node->set_right(right()); - node = node->compile(*scope_ptr, depth + 1); + + node = node->compile(*scope_ptr, depth + 1, param_scope); DEBUG("expr.compile", "Defining " << left()->left()->as_ident() << " in " << scope_ptr); - scope_ptr->define(symbol_t::FUNCTION, left()->left()->as_ident(), - node); + scope_ptr->define(symbol_t::FUNCTION, left()->left()->as_ident(), node); break; } // fall through... @@ -151,12 +169,39 @@ expr_t::ptr_op_t expr_t::op_t::compile(scope_t& scope, const int depth) } result = wrap_value(NULL_VALUE); } + else if (kind == O_LAMBDA) { + symbol_scope_t params(param_scope ? *param_scope : *scope_t::empty_scope); + + for (ptr_op_t sym = left(); + sym; + sym = sym->has_right() ? sym->right() : NULL) { + ptr_op_t varname = sym->kind == O_CONS ? sym->left() : sym; + + if (! varname->is_ident()) { + std::ostringstream buf; + varname->dump(buf, 0); + throw_(calc_error, + _("Invalid function or lambda parameter: %1") << buf.str()); + } else { + DEBUG("expr.compile", + "Defining function parameter " << varname->as_ident()); + params.define(symbol_t::FUNCTION, varname->as_ident(), + new op_t(PLUG)); + } + } + + ptr_op_t rhs(right()->compile(*scope_ptr, depth + 1, ¶ms)); + if (rhs == right()) + result = this; + else + result = copy(left(), rhs); + } if (! result) { - ptr_op_t lhs(left()->compile(*scope_ptr, depth + 1)); + ptr_op_t lhs(left()->compile(*scope_ptr, depth + 1, param_scope)); ptr_op_t rhs(kind > UNARY_OPERATORS && has_right() ? (kind == O_LOOKUP ? right() : - right()->compile(*scope_ptr, depth + 1)) : NULL); + right()->compile(*scope_ptr, depth + 1, param_scope)) : NULL); if (lhs == left() && (! rhs || rhs == right())) { result = this; @@ -182,6 +227,23 @@ expr_t::ptr_op_t expr_t::op_t::compile(scope_t& scope, const int depth) return result; } +namespace { + expr_t::ptr_op_t lookup_ident(expr_t::ptr_op_t op, scope_t& scope) + { + expr_t::ptr_op_t def = op->left(); + + // If no definition was pre-compiled for this identifier, look it up + // in the current scope. + if (! def || def->kind == expr_t::op_t::PLUG) { + DEBUG("scope.symbols", "Looking for IDENT '" << op->as_ident() << "'"); + def = scope.lookup(symbol_t::FUNCTION, op->as_ident()); + } + if (! def) + throw_(calc_error, _("Unknown identifier '%1'") << op->as_ident()); + return def; + } +} + value_t expr_t::op_t::calc(scope_t& scope, ptr_op_t * locus, const int depth) { try { @@ -206,23 +268,14 @@ value_t expr_t::op_t::calc(scope_t& scope, ptr_op_t * locus, const int depth) result = NULL_VALUE; break; - case IDENT: { - ptr_op_t definition = left(); - // If no definition was pre-compiled for this identifier, look it up - // in the current scope. - if (! definition) { - DEBUG("scope.symbols", "Looking for IDENT '" << as_ident() << "'"); - definition = scope.lookup(symbol_t::FUNCTION, as_ident()); + case IDENT: + if (ptr_op_t definition = lookup_ident(this, scope)) { + // Evaluating an identifier is the same as calling its definition + // directly + result = definition->calc(scope, locus, depth + 1); + check_type_context(scope, result); } - if (! definition) - throw_(calc_error, _("Unknown identifier '%1'") << as_ident()); - - // Evaluating an identifier is the same as calling its definition - // directly - result = definition->calc(scope, locus, depth + 1); - check_type_context(scope, result); break; - } case FUNCTION: { // Evaluating a FUNCTION is the same as calling it directly; this @@ -260,68 +313,14 @@ value_t expr_t::op_t::calc(scope_t& scope, ptr_op_t * locus, const int depth) break; } - case O_CALL: { - ptr_op_t func = left(); - const string& name(func->as_ident()); - - func = func->left(); - if (! func) - func = scope.lookup(symbol_t::FUNCTION, name); - if (! func) - throw_(calc_error, _("Calling unknown function '%1'") << name); - - call_scope_t call_args(scope, locus, depth + 1); - if (has_right()) - call_args.set_args(split_cons_expr(right())); - - try { - if (func->is_function()) - result = func->as_function()(call_args); - else - result = func->calc(call_args, locus, depth + 1); - } - catch (const std::exception&) { - add_error_context(_("While calling function '%1':" << name)); - throw; - } - + case O_CALL: + result = calc_call(scope, locus, depth); check_type_context(scope, result); break; - } - - case O_LAMBDA: { - call_scope_t& call_args(find_scope<call_scope_t>(scope, true)); - std::size_t args_count(call_args.size()); - std::size_t args_index(0); - symbol_scope_t call_scope(call_args); - - for (ptr_op_t sym = left(); - sym; - sym = sym->has_right() ? sym->right() : NULL) { - ptr_op_t varname = sym->kind == O_CONS ? sym->left() : sym; - if (! varname->is_ident()) { - throw_(calc_error, _("Invalid function definition")); - } - else if (args_index == args_count) { - call_scope.define(symbol_t::FUNCTION, varname->as_ident(), - wrap_value(NULL_VALUE)); - } - else { - DEBUG("expr.compile", - "Defining function parameter " << varname->as_ident()); - call_scope.define(symbol_t::FUNCTION, varname->as_ident(), - wrap_value(call_args[args_index++])); - } - } - - if (args_index < args_count) - throw_(calc_error, - _("Too few arguments in function call (saw %1, wanted %2)") - << args_count << args_index); - result = right()->calc(call_scope, locus, depth + 1); + case O_LAMBDA: + result = expr_value(this); break; - } case O_MATCH: result = (right()->calc(scope, locus, depth + 1).as_mask() @@ -402,51 +401,12 @@ value_t expr_t::op_t::calc(scope_t& scope, ptr_op_t * locus, const int depth) break; case O_CONS: - result = left()->calc(scope, locus, depth + 1); - if (has_right()) { - value_t temp; - temp.push_back(result); - - ptr_op_t next = right(); - while (next) { - ptr_op_t value_op; - if (next->kind == O_CONS) { - value_op = next->left(); - next = next->has_right() ? next->right() : NULL; - } else { - value_op = next; - next = NULL; - } - temp.push_back(value_op->calc(scope, locus, depth + 1)); - } - result = temp; - } + result = calc_cons(scope, locus, depth); break; - case O_SEQ: { - // An O_SEQ is very similar to an O_CONS except that only the last - // result value in the series is kept. O_CONS builds up a list. - // - // Another feature of O_SEQ is that it pushes a new symbol scope - // onto the stack. We evaluate the left side here to catch any - // side-effects, such as definitions in the case of 'x = 1; x'. - result = left()->calc(scope, locus, depth + 1); - if (has_right()) { - ptr_op_t next = right(); - while (next) { - ptr_op_t value_op; - if (next->kind == O_SEQ) { - value_op = next->left(); - next = next->right(); - } else { - value_op = next; - next = NULL; - } - result = value_op->calc(scope, locus, depth + 1); - } - } + case O_SEQ: + result = calc_seq(scope, locus, depth); break; - } default: throw_(calc_error, _("Unexpected expr node '%1'") << op_context(this)); @@ -473,6 +433,184 @@ value_t expr_t::op_t::calc(scope_t& scope, ptr_op_t * locus, const int depth) } namespace { + expr_t::ptr_op_t find_definition(expr_t::ptr_op_t op, scope_t& scope, + expr_t::ptr_op_t * locus, const int depth, + int recursion_depth = 0) + { + // If the object we are apply call notation to is a FUNCTION value + // or a O_LAMBDA expression, then this is the object we want to + // call. + if (op->is_function() || op->kind == expr_t::op_t::O_LAMBDA) + return op; + + if (recursion_depth > 256) + throw_(value_error, _("Function recursion_depth too deep (> 256)")); + + // If it's an identifier, look up its definition and see if it's a + // function. + if (op->is_ident()) + return find_definition(lookup_ident(op, scope), scope, + locus, depth, recursion_depth + 1); + + // Value objects might be callable if they contain an expression. + if (op->is_value()) { + value_t def(op->as_value()); + if (is_expr(def)) + return find_definition(as_expr(def), scope, locus, depth, + recursion_depth + 1); + else + throw_(value_error, _("Cannot call %1 as a function") << def.label()); + } + + // Resolve ordinary expressions. + return find_definition(expr_t::op_t::wrap_value(op->calc(scope, locus, + depth + 1)), + scope, locus, depth + 1, recursion_depth + 1); + } + + value_t call_lambda(expr_t::ptr_op_t func, scope_t& scope, + call_scope_t& call_args, expr_t::ptr_op_t * locus, + const int depth) + { + std::size_t args_index(0); + std::size_t args_count(call_args.size()); + + symbol_scope_t args_scope(*scope_t::empty_scope); + + for (expr_t::ptr_op_t sym = func->left(); + sym; + sym = sym->has_right() ? sym->right() : NULL) { + expr_t::ptr_op_t varname = + sym->kind == expr_t::op_t::O_CONS ? sym->left() : sym; + if (! varname->is_ident()) { + throw_(calc_error, _("Invalid function definition")); + } + else if (args_index == args_count) { + DEBUG("expr.calc", "Defining function argument as null: " + << varname->as_ident()); + args_scope.define(symbol_t::FUNCTION, varname->as_ident(), + expr_t::op_t::wrap_value(NULL_VALUE)); + } + else { + DEBUG("expr.calc", "Defining function argument from call_args: " + << varname->as_ident()); + args_scope.define(symbol_t::FUNCTION, varname->as_ident(), + expr_t::op_t::wrap_value(call_args[args_index++])); + } + } + + if (args_index < args_count) + throw_(calc_error, + _("Too few arguments in function call (saw %1, wanted %2)") + << args_count << args_index); + + if (func->right()->is_scope()) { + bind_scope_t outer_scope(scope, *func->right()->as_scope()); + bind_scope_t bound_scope(outer_scope, args_scope); + + return func->right()->left()->calc(bound_scope, locus, depth + 1); + } else { + return func->right()->calc(args_scope, locus, depth + 1); + } + } +} + + +value_t expr_t::op_t::call(const value_t& args, scope_t& scope, + ptr_op_t * locus, const int depth) +{ + call_scope_t call_args(scope, locus, depth + 1); + call_args.set_args(args); + + if (is_function()) + return as_function()(call_args); + else if (kind == O_LAMBDA) + return call_lambda(this, scope, call_args, locus, depth); + else + return find_definition(this, scope, locus, depth) + ->calc(call_args, locus, depth); +} + +value_t expr_t::op_t::calc_call(scope_t& scope, ptr_op_t * locus, + const int depth) +{ + ptr_op_t func = left(); + string name = func->is_ident() ? func->as_ident() : "<value expr>"; + + func = find_definition(func, scope, locus, depth); + + call_scope_t call_args(scope, locus, depth + 1); + if (has_right()) + call_args.set_args(split_cons_expr(right())); + + try { + if (func->is_function()) { + return func->as_function()(call_args); + } else { + assert(func->kind == O_LAMBDA); + return call_lambda(func, scope, call_args, locus, depth); + } + } + catch (const std::exception&) { + add_error_context(_("While calling function '%1 %2':" << name + << call_args.args)); + throw; + } +} + +value_t expr_t::op_t::calc_cons(scope_t& scope, ptr_op_t * locus, + const int depth) +{ + value_t result = left()->calc(scope, locus, depth + 1); + if (has_right()) { + value_t temp; + temp.push_back(result); + + ptr_op_t next = right(); + while (next) { + ptr_op_t value_op; + if (next->kind == O_CONS) { + value_op = next->left(); + next = next->has_right() ? next->right() : NULL; + } else { + value_op = next; + next = NULL; + } + temp.push_back(value_op->calc(scope, locus, depth + 1)); + } + result = temp; + } + return result; +} + +value_t expr_t::op_t::calc_seq(scope_t& scope, ptr_op_t * locus, + const int depth) +{ + // An O_SEQ is very similar to an O_CONS except that only the last + // result value in the series is kept. O_CONS builds up a list. + // + // Another feature of O_SEQ is that it pushes a new symbol scope onto + // the stack. We evaluate the left side here to catch any + // side-effects, such as definitions in the case of 'x = 1; x'. + value_t result = left()->calc(scope, locus, depth + 1); + if (has_right()) { + ptr_op_t next = right(); + while (next) { + ptr_op_t value_op; + if (next->kind == O_SEQ) { + value_op = next->left(); + next = next->right(); + } else { + value_op = next; + next = NULL; + } + result = value_op->calc(scope, locus, depth + 1); + } + } + return result; +} + +namespace { bool print_cons(std::ostream& out, const expr_t::const_ptr_op_t op, const expr_t::op_t::context_t& context) { @@ -743,6 +881,10 @@ void expr_t::op_t::dump(std::ostream& out, const int depth) const out << " "; switch (kind) { + case PLUG: + out << "PLUG"; + break; + case VALUE: out << "VALUE: "; as_value().dump(out); @@ -803,7 +945,7 @@ void expr_t::op_t::dump(std::ostream& out, const int depth) const // An identifier is a special non-terminal, in that its left() can // hold the compiled definition of the identifier. - if (kind > TERMINALS || is_scope()) { + if (kind > TERMINALS || is_scope() || is_ident()) { if (left()) { left()->dump(out, depth + 1); if (kind > UNARY_OPERATORS && has_right()) |