summaryrefslogtreecommitdiff
path: root/src/op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'src/op.cc')
-rw-r--r--src/op.cc138
1 files changed, 66 insertions, 72 deletions
diff --git a/src/op.cc b/src/op.cc
index 95ad5abd..86057f66 100644
--- a/src/op.cc
+++ b/src/op.cc
@@ -108,10 +108,15 @@ expr_t::ptr_op_t expr_t::op_t::compile(scope_t& scope, const int depth)
scope.define(symbol_t::FUNCTION, left()->as_ident(), right());
break;
case O_CALL:
- if (left()->left()->is_ident())
- scope.define(symbol_t::FUNCTION, left()->left()->as_ident(), this);
- else
+ if (left()->left()->is_ident()) {
+ ptr_op_t node(new op_t(op_t::O_LAMBDA));
+ node->set_left(left()->right());
+ node->set_right(right());
+
+ scope.define(symbol_t::FUNCTION, left()->left()->as_ident(), node);
+ } else {
throw_(compile_error, _("Invalid function definition"));
+ }
break;
default:
throw_(compile_error, _("Invalid function definition"));
@@ -151,15 +156,21 @@ value_t expr_t::op_t::calc(scope_t& scope, ptr_op_t * locus, const int depth)
break;
case IDENT: {
- if (! left())
+ ptr_op_t definition = left();
+ if (! definition) {
+ // If no definition was pre-compiled for this identifier, look it
+ // up in the current scope.
+ definition = scope.lookup(symbol_t::FUNCTION, as_ident());
+ }
+ if (! definition)
throw_(calc_error, _("Unknown identifier '%1'") << as_ident());
// Evaluating an identifier is the same as calling its definition
// directly, so we create an empty call_scope_t to reflect the scope for
// this implicit call.
call_scope_t call_args(scope, locus, depth);
- result = left()->compile(call_args, depth + 1)
- ->calc(call_args, locus, depth + 1);
+ result = definition->compile(call_args, depth + 1)
+ ->calc(call_args, locus, depth + 1);
check_type_context(scope, result);
break;
}
@@ -177,74 +188,68 @@ value_t expr_t::op_t::calc(scope_t& scope, ptr_op_t * locus, const int depth)
break;
}
- case O_DEFINE: {
+ case O_LAMBDA: {
call_scope_t& call_args(downcast<call_scope_t>(scope));
- std::size_t args_count = call_args.size();
- std::size_t args_index = 0;
-
- assert(left()->kind == O_CALL);
+ std::size_t args_count(call_args.size());
+ std::size_t args_index(0);
+ symbol_scope_t call_scope(call_args);
+ ptr_op_t sym(left());
- for (ptr_op_t sym = left()->right();
- sym;
- sym = sym->has_right() ? sym->right() : NULL) {
+ for (; sym; sym = sym->has_right() ? sym->right() : NULL) {
ptr_op_t varname = sym;
if (sym->kind == O_CONS)
varname = sym->left();
- if (! varname->is_ident())
+ if (! varname->is_ident()) {
throw_(calc_error, _("Invalid function definition"));
- else if (args_index == args_count)
- scope.define(symbol_t::FUNCTION, varname->as_ident(),
- wrap_value(false));
- else
- scope.define(symbol_t::FUNCTION, varname->as_ident(),
- wrap_value(call_args[args_index++]));
+ }
+ else if (args_index == args_count) {
+ call_scope.define(symbol_t::FUNCTION, varname->as_ident(),
+ wrap_value(NULL_VALUE));
+ }
+ else {
+ DEBUG("expr.compile",
+ "Defining function parameter " << varname->as_ident());
+ call_scope.define(symbol_t::FUNCTION, varname->as_ident(),
+ wrap_value(call_args[args_index++]));
+ }
}
if (args_index < args_count)
throw_(calc_error,
- _("Too many arguments in function call (saw %1)") << args_count);
+ _("Too few arguments in function call (saw %1)") << args_count);
- result = right()->calc(scope, locus, depth + 1);
+ result = right()->calc(call_scope, locus, depth + 1);
break;
}
case O_LOOKUP: {
context_scope_t context_scope(scope, value_t::SCOPE);
+ bool scope_error = true;
if (value_t obj = left()->calc(context_scope, locus, depth + 1)) {
- if (obj.is_scope()) {
- if (obj.as_scope() == NULL) {
- throw_(calc_error, _("Left operand of . operator is NULL"));
- } else {
- scope_t& objscope(*obj.as_scope());
- if (ptr_op_t member =
- objscope.lookup(symbol_t::FUNCTION, right()->as_ident())) {
- result = member->calc(objscope, NULL, depth + 1);
- break;
- }
- }
+ if (obj.is_scope() && obj.as_scope() != NULL) {
+ bind_scope_t bound_scope(scope, *obj.as_scope());
+ result = right()->calc(bound_scope, locus, depth + 1);
+ scope_error = false;
}
}
- if (right()->kind != IDENT)
- throw_(calc_error,
- _("Right operand of . operator must be an identifier"));
- else
- throw_(calc_error,
- _("Failed to lookup member '%1'") << right()->as_ident());
+ if (scope_error)
+ throw_(calc_error, _("Left operand does not evaluate to an object"));
break;
}
case O_CALL: {
call_scope_t call_args(scope, locus, depth);
if (has_right())
- call_args.set_args(split_cons_expr(right()->kind == O_SEQ ?
- right()->left() : right()));
+ call_args.set_args(split_cons_expr(right()));
ptr_op_t func = left();
const string& name(func->as_ident());
func = func->left();
if (! func)
+ func = scope.lookup(symbol_t::FUNCTION, name);
+ if (! func)
throw_(calc_error, _("Calling unknown function '%1'") << name);
if (func->is_function())
@@ -467,6 +472,9 @@ bool expr_t::op_t::print(std::ostream& out, const context_t& context) const
string symbol;
+ if (kind > TERMINALS && (kind != O_CALL && kind != O_DEFINE))
+ out << '(';
+
switch (kind) {
case VALUE:
as_value().dump(out, context.relaxed);
@@ -481,118 +489,94 @@ bool expr_t::op_t::print(std::ostream& out, const context_t& context) const
break;
case O_NOT:
- out << "!(";
+ out << "! ";
if (left() && left()->print(out, context))
found = true;
- out << ")";
break;
case O_NEG:
- out << "-(";
+ out << "- ";
if (left() && left()->print(out, context))
found = true;
- out << ")";
break;
case O_ADD:
- out << "(";
if (left() && left()->print(out, context))
found = true;
out << " + ";
if (has_right() && right()->print(out, context))
found = true;
- out << ")";
break;
case O_SUB:
- out << "(";
if (left() && left()->print(out, context))
found = true;
out << " - ";
if (has_right() && right()->print(out, context))
found = true;
- out << ")";
break;
case O_MUL:
- out << "(";
if (left() && left()->print(out, context))
found = true;
out << " * ";
if (has_right() && right()->print(out, context))
found = true;
- out << ")";
break;
case O_DIV:
- out << "(";
if (left() && left()->print(out, context))
found = true;
out << " / ";
if (has_right() && right()->print(out, context))
found = true;
- out << ")";
break;
case O_EQ:
- out << "(";
if (left() && left()->print(out, context))
found = true;
out << " == ";
if (has_right() && right()->print(out, context))
found = true;
- out << ")";
break;
case O_LT:
- out << "(";
if (left() && left()->print(out, context))
found = true;
out << " < ";
if (has_right() && right()->print(out, context))
found = true;
- out << ")";
break;
case O_LTE:
- out << "(";
if (left() && left()->print(out, context))
found = true;
out << " <= ";
if (has_right() && right()->print(out, context))
found = true;
- out << ")";
break;
case O_GT:
- out << "(";
if (left() && left()->print(out, context))
found = true;
out << " > ";
if (has_right() && right()->print(out, context))
found = true;
- out << ")";
break;
case O_GTE:
- out << "(";
if (left() && left()->print(out, context))
found = true;
out << " >= ";
if (has_right() && right()->print(out, context))
found = true;
- out << ")";
break;
case O_AND:
- out << "(";
if (left() && left()->print(out, context))
found = true;
out << " & ";
if (has_right() && right()->print(out, context))
found = true;
- out << ")";
break;
case O_OR:
- out << "(";
if (left() && left()->print(out, context))
found = true;
out << " | ";
if (has_right() && right()->print(out, context))
found = true;
- out << ")";
break;
case O_QUERY:
@@ -616,15 +600,13 @@ bool expr_t::op_t::print(std::ostream& out, const context_t& context) const
break;
case O_SEQ:
- out << "(";
found = print_seq(out, this, context);
- out << ")";
break;
case O_DEFINE:
if (left() && left()->print(out, context))
found = true;
- out << " := ";
+ out << " = ";
if (has_right() && right()->print(out, context))
found = true;
break;
@@ -637,11 +619,19 @@ bool expr_t::op_t::print(std::ostream& out, const context_t& context) const
found = true;
break;
+ case O_LAMBDA:
+ if (left() && left()->print(out, context))
+ found = true;
+ out << " -> ";
+ if (has_right() && right()->print(out, context))
+ found = true;
+ break;
+
case O_CALL:
if (left() && left()->print(out, context))
found = true;
if (has_right()) {
- if (right()->kind == O_SEQ) {
+ if (right()->kind == O_CONS) {
if (right()->print(out, context))
found = true;
} else {
@@ -669,6 +659,9 @@ bool expr_t::op_t::print(std::ostream& out, const context_t& context) const
break;
}
+ if (kind > TERMINALS && (kind != O_CALL && kind != O_DEFINE))
+ out << ')';
+
if (! symbol.empty()) {
if (commodity_pool_t::current_pool->find(symbol))
out << '@';
@@ -708,6 +701,7 @@ void expr_t::op_t::dump(std::ostream& out, const int depth) const
case O_DEFINE: out << "O_DEFINE"; break;
case O_LOOKUP: out << "O_LOOKUP"; break;
+ case O_LAMBDA: out << "O_LAMBDA"; break;
case O_CALL: out << "O_CALL"; break;
case O_MATCH: out << "O_MATCH"; break;