diff options
Diffstat (limited to 'src/op.cc')
-rw-r--r-- | src/op.cc | 138 |
1 files changed, 66 insertions, 72 deletions
@@ -108,10 +108,15 @@ expr_t::ptr_op_t expr_t::op_t::compile(scope_t& scope, const int depth) scope.define(symbol_t::FUNCTION, left()->as_ident(), right()); break; case O_CALL: - if (left()->left()->is_ident()) - scope.define(symbol_t::FUNCTION, left()->left()->as_ident(), this); - else + if (left()->left()->is_ident()) { + ptr_op_t node(new op_t(op_t::O_LAMBDA)); + node->set_left(left()->right()); + node->set_right(right()); + + scope.define(symbol_t::FUNCTION, left()->left()->as_ident(), node); + } else { throw_(compile_error, _("Invalid function definition")); + } break; default: throw_(compile_error, _("Invalid function definition")); @@ -151,15 +156,21 @@ value_t expr_t::op_t::calc(scope_t& scope, ptr_op_t * locus, const int depth) break; case IDENT: { - if (! left()) + ptr_op_t definition = left(); + if (! definition) { + // If no definition was pre-compiled for this identifier, look it + // up in the current scope. + definition = scope.lookup(symbol_t::FUNCTION, as_ident()); + } + if (! definition) throw_(calc_error, _("Unknown identifier '%1'") << as_ident()); // Evaluating an identifier is the same as calling its definition // directly, so we create an empty call_scope_t to reflect the scope for // this implicit call. call_scope_t call_args(scope, locus, depth); - result = left()->compile(call_args, depth + 1) - ->calc(call_args, locus, depth + 1); + result = definition->compile(call_args, depth + 1) + ->calc(call_args, locus, depth + 1); check_type_context(scope, result); break; } @@ -177,74 +188,68 @@ value_t expr_t::op_t::calc(scope_t& scope, ptr_op_t * locus, const int depth) break; } - case O_DEFINE: { + case O_LAMBDA: { call_scope_t& call_args(downcast<call_scope_t>(scope)); - std::size_t args_count = call_args.size(); - std::size_t args_index = 0; - - assert(left()->kind == O_CALL); + std::size_t args_count(call_args.size()); + std::size_t args_index(0); + symbol_scope_t call_scope(call_args); + ptr_op_t sym(left()); - for (ptr_op_t sym = left()->right(); - sym; - sym = sym->has_right() ? sym->right() : NULL) { + for (; sym; sym = sym->has_right() ? sym->right() : NULL) { ptr_op_t varname = sym; if (sym->kind == O_CONS) varname = sym->left(); - if (! varname->is_ident()) + if (! varname->is_ident()) { throw_(calc_error, _("Invalid function definition")); - else if (args_index == args_count) - scope.define(symbol_t::FUNCTION, varname->as_ident(), - wrap_value(false)); - else - scope.define(symbol_t::FUNCTION, varname->as_ident(), - wrap_value(call_args[args_index++])); + } + else if (args_index == args_count) { + call_scope.define(symbol_t::FUNCTION, varname->as_ident(), + wrap_value(NULL_VALUE)); + } + else { + DEBUG("expr.compile", + "Defining function parameter " << varname->as_ident()); + call_scope.define(symbol_t::FUNCTION, varname->as_ident(), + wrap_value(call_args[args_index++])); + } } if (args_index < args_count) throw_(calc_error, - _("Too many arguments in function call (saw %1)") << args_count); + _("Too few arguments in function call (saw %1)") << args_count); - result = right()->calc(scope, locus, depth + 1); + result = right()->calc(call_scope, locus, depth + 1); break; } case O_LOOKUP: { context_scope_t context_scope(scope, value_t::SCOPE); + bool scope_error = true; if (value_t obj = left()->calc(context_scope, locus, depth + 1)) { - if (obj.is_scope()) { - if (obj.as_scope() == NULL) { - throw_(calc_error, _("Left operand of . operator is NULL")); - } else { - scope_t& objscope(*obj.as_scope()); - if (ptr_op_t member = - objscope.lookup(symbol_t::FUNCTION, right()->as_ident())) { - result = member->calc(objscope, NULL, depth + 1); - break; - } - } + if (obj.is_scope() && obj.as_scope() != NULL) { + bind_scope_t bound_scope(scope, *obj.as_scope()); + result = right()->calc(bound_scope, locus, depth + 1); + scope_error = false; } } - if (right()->kind != IDENT) - throw_(calc_error, - _("Right operand of . operator must be an identifier")); - else - throw_(calc_error, - _("Failed to lookup member '%1'") << right()->as_ident()); + if (scope_error) + throw_(calc_error, _("Left operand does not evaluate to an object")); break; } case O_CALL: { call_scope_t call_args(scope, locus, depth); if (has_right()) - call_args.set_args(split_cons_expr(right()->kind == O_SEQ ? - right()->left() : right())); + call_args.set_args(split_cons_expr(right())); ptr_op_t func = left(); const string& name(func->as_ident()); func = func->left(); if (! func) + func = scope.lookup(symbol_t::FUNCTION, name); + if (! func) throw_(calc_error, _("Calling unknown function '%1'") << name); if (func->is_function()) @@ -467,6 +472,9 @@ bool expr_t::op_t::print(std::ostream& out, const context_t& context) const string symbol; + if (kind > TERMINALS && (kind != O_CALL && kind != O_DEFINE)) + out << '('; + switch (kind) { case VALUE: as_value().dump(out, context.relaxed); @@ -481,118 +489,94 @@ bool expr_t::op_t::print(std::ostream& out, const context_t& context) const break; case O_NOT: - out << "!("; + out << "! "; if (left() && left()->print(out, context)) found = true; - out << ")"; break; case O_NEG: - out << "-("; + out << "- "; if (left() && left()->print(out, context)) found = true; - out << ")"; break; case O_ADD: - out << "("; if (left() && left()->print(out, context)) found = true; out << " + "; if (has_right() && right()->print(out, context)) found = true; - out << ")"; break; case O_SUB: - out << "("; if (left() && left()->print(out, context)) found = true; out << " - "; if (has_right() && right()->print(out, context)) found = true; - out << ")"; break; case O_MUL: - out << "("; if (left() && left()->print(out, context)) found = true; out << " * "; if (has_right() && right()->print(out, context)) found = true; - out << ")"; break; case O_DIV: - out << "("; if (left() && left()->print(out, context)) found = true; out << " / "; if (has_right() && right()->print(out, context)) found = true; - out << ")"; break; case O_EQ: - out << "("; if (left() && left()->print(out, context)) found = true; out << " == "; if (has_right() && right()->print(out, context)) found = true; - out << ")"; break; case O_LT: - out << "("; if (left() && left()->print(out, context)) found = true; out << " < "; if (has_right() && right()->print(out, context)) found = true; - out << ")"; break; case O_LTE: - out << "("; if (left() && left()->print(out, context)) found = true; out << " <= "; if (has_right() && right()->print(out, context)) found = true; - out << ")"; break; case O_GT: - out << "("; if (left() && left()->print(out, context)) found = true; out << " > "; if (has_right() && right()->print(out, context)) found = true; - out << ")"; break; case O_GTE: - out << "("; if (left() && left()->print(out, context)) found = true; out << " >= "; if (has_right() && right()->print(out, context)) found = true; - out << ")"; break; case O_AND: - out << "("; if (left() && left()->print(out, context)) found = true; out << " & "; if (has_right() && right()->print(out, context)) found = true; - out << ")"; break; case O_OR: - out << "("; if (left() && left()->print(out, context)) found = true; out << " | "; if (has_right() && right()->print(out, context)) found = true; - out << ")"; break; case O_QUERY: @@ -616,15 +600,13 @@ bool expr_t::op_t::print(std::ostream& out, const context_t& context) const break; case O_SEQ: - out << "("; found = print_seq(out, this, context); - out << ")"; break; case O_DEFINE: if (left() && left()->print(out, context)) found = true; - out << " := "; + out << " = "; if (has_right() && right()->print(out, context)) found = true; break; @@ -637,11 +619,19 @@ bool expr_t::op_t::print(std::ostream& out, const context_t& context) const found = true; break; + case O_LAMBDA: + if (left() && left()->print(out, context)) + found = true; + out << " -> "; + if (has_right() && right()->print(out, context)) + found = true; + break; + case O_CALL: if (left() && left()->print(out, context)) found = true; if (has_right()) { - if (right()->kind == O_SEQ) { + if (right()->kind == O_CONS) { if (right()->print(out, context)) found = true; } else { @@ -669,6 +659,9 @@ bool expr_t::op_t::print(std::ostream& out, const context_t& context) const break; } + if (kind > TERMINALS && (kind != O_CALL && kind != O_DEFINE)) + out << ')'; + if (! symbol.empty()) { if (commodity_pool_t::current_pool->find(symbol)) out << '@'; @@ -708,6 +701,7 @@ void expr_t::op_t::dump(std::ostream& out, const int depth) const case O_DEFINE: out << "O_DEFINE"; break; case O_LOOKUP: out << "O_LOOKUP"; break; + case O_LAMBDA: out << "O_LAMBDA"; break; case O_CALL: out << "O_CALL"; break; case O_MATCH: out << "O_MATCH"; break; |