'''
Implements the Straw Man (SM) language.  SM is intended to explore
a minimal scripting language.  This module is a prototype and has undergone
heavy tweaking, and has many known bugs.  However,
the test cases at the end do all work more or less as promised.

To get a feel for SM, run the modules and examine the output of each test
case.
'''
from copy import copy
import re

# define a token in the straw-main language.
tokenRE = re.compile(r'".*?"|\(|\)|\}|\{|;|[^(){}\s]+')
# used to detect proper symbols; everything else is assumed to be a literal.
symbolRE = re.compile(r'[a-z][a-z0-9_]*')

class Node(object):
    'An empty symbol node.'
    def __init__(self, symbol):
        self.sym = symbol
        self.args = []
        self.val = symbol
    def __str__(self):
        if not self.args:
            return str(self.val)
        else:
            arg_strings = map(str,self.args)
            return "%s(%s)" % (self.val, " ".join(arg_strings))


class DoNode(Node):
    'An empty "do" symbol node, which uses the {} syntactic sugar.'
    def __init__(self):
        self.sym = 'do'
        self.args = []
    def __str__(self):
        return "{%s}" % "; ".join(map(str, self.args))

class ValueNode(Node):
    'A value node bound to a particular value at creation.'
    def __init__(self, value):
        self.sym = 'value'
        self.val = value
        self.args = []
    def val(self):
        return self.val

class Stack(list):
    ' a thin wrapper around list() to make it more stack-like.'
    def peek(self):
        ' returns the top of the stack. '
        return self[-1]

    def push(self, top):
        ' adds an element to the top of the stack. '
        self.append(top)


def parse(string):
    'returns an AST for the SM grammer.'
    main = Node('do')
    # I love this trick; when working with any
    # kind of tree, we can often avoid treating
    # the root as a special case by simply
    # starting with a dummy node above the root.
    context = Stack()
    context.push(main)

    for match in tokenRE.finditer(string):
        token = match.group(0)
        if token == ";":
            if context.peek().sym != 'do':
                raise SyntaxError, "misplaced semicolon!"
        elif token == '(':
            if context.peek() == ast: raise SyntaxError, "'(' does not follow symbol!"
            context.push(ast)
        elif token == ')':
            context.pop()
        elif token == '{':
            ast = DoNode()
            context.peek().args.append(ast)
            context.push(ast)
        elif token == '}':
            if context.peek().sym != 'do': raise SyntaxError, "'}' terminating non-do block!"
            context.pop()
        elif symbolRE.match(token):
            ast = Node(token)
            context.peek().args.append(ast)
        else: # must be a literal
            ast = ValueNode(eval(token))
            context.peek().args.append(ast)
    return main

### eval functions for each primitive ###
def indexSafe(f):
    'decorator to handle a common exception'
    def newF(*x,**y):
        try:
            return f(*x,**y)
        except:
            return ValueNode('')
    return newF

@indexSafe
def evalIf(ast,env):
    if evaluate(ast.args[0], env):
        return evaluate(ast.args[1], env)
    else:
        return evaluate(ast.args[2], env)
@indexSafe
def evalSet(ast,env):
    name = ast.args[0].sym
    value = evaluate(ast.args[1], env)
    env[name] = value
    return value

@indexSafe
def evalGt(ast,env):
    left = int(evaluate(ast.args[0], env).val)
    right = int(evaluate(ast.args[1], env).val)
    return 'true' if left > right else ''

@indexSafe
def evalPrint(ast,env):
    ret = ''
    for arg in ast.args:
        ret = evaluate(arg, env)
        print ret,
    print
    return ret

@indexSafe
def evalDo(ast,env):
    ret = ''
    for arg in ast.args:
        if arg.sym == 'return':
            if len(arg.args) == 0: return ''
            return evaluate(arg.args[0], env)
        else:
            ret = evaluate(arg, env)
    return ret

@indexSafe
def evalAdd(ast,env):

    left  = int(evaluate(ast.args[0], env).val)
    right = int(evaluate(ast.args[1], env).val)
    return ValueNode(str(left + right))

@indexSafe
def evalEqual(ast,env):
    left = evaluate(ast.args[0], env)
    right = evaluate(ast.args[1], env)
    return True if left == right else False

@indexSafe
def evalQuote(ast, env):
    return copy(ast.args[0])

@indexSafe
def evalEval(ast, env):
        return  evaluate(evaluate(ast.args[0],env), env)

@indexSafe
def evalString(ast, env):
    return  ValueNode(str(ast))

@indexSafe
def evalParse(ast, env):
    string = str(evaluate(ast.args[0], env).val)
    return parse(string)

@indexSafe
def evalLambda(ast, env):
    parameters = ast.args[0].args  # not evaluated!
    expression = ast.args[1]
    l = Node('function')
    p = Node('parameters')
    p.args = parameters
    l.args = [ p, expression ]
    return l

def evalData(ast, env):
    return ast

@indexSafe
def evalProduct(ast, env):
    left  = int(evaluate(ast.args[0],env).val)
    right = int(evaluate(ast.args[1],env).val)
    return ValueNode(str(left * right))

@indexSafe
def evalVal(ast, env):
    return ValueNode(evaluate(ast.args[0], env).val)

@indexSafe
def evalFor(ast, env):
    'for(init,condition,increment,body)'
    evaluate(ast.args[0], env)
    ret = ValueNode('')
    while evaluate(ast.args[1], env):
        ret = evaluate(ast.args[3], env)
        evaluate(ast.args[2], env)
    return ret

@indexSafe
def evalForEach(ast, env):
    symbol = ast.args[0].sym
    collection = evaluate(ast.args[1],env)
    body = ast.args[2]
    ret = ValueNode('')
    for arg in collection.args:
        env[symbol] = arg
        ret = evaluate(body,env)
    return ret

keywords = {'do':evalDo,
           'if':evalIf,
           'set':evalSet,
           'gt':evalGt,
           'print':evalPrint,
           'add':evalAdd,
            'eq':evalEqual,
            'quote':evalQuote,
            'q':evalQuote,
            'eval':evalEval,
            'string':evalString,
            'parse':evalParse,
            'data':evalData,
            'd':evalData,
            'lambda':evalLambda,
            'product':evalProduct,
            'val':evalVal,
            'for':evalFor,
            'foreach':evalForEach,
           }
def evaluate(ast, env={}):
    ' evaluates an Node given an environment.'
    if ast.sym in keywords:
        return keywords[ast.sym](ast, env)
    elif ast.sym in env:
        value = env[ast.sym]
        if hasattr(value,'sym') and value.sym == 'function':
            return functionCall(value, ast, env)
        else:
            return value
    else:
        return ast

def functionCall(function, call, env):
    arguments = [ evaluate(arg, env) for arg in call.args ]
    parameters = [ param.sym for param in function.args[0].args]
    code = function.args[1]

    localEnv = dict(zip(parameters, arguments))
    localEnv.update(env)
    return evaluate(code, localEnv)

def run(codeString,env={}):
    'parses and evaluates an SM expression.'
    return evaluate(parse(codeString),env)

tests = []

tests.append( '''
set(x 1)
for(set(i 0) gt(8 i) set(i add(i 1)) {
  print("2 to the " i " = " x)
  set(x add(x x))
})
''')

tests.append( '''
set(x 17);
set(y 8);
print( "x=" x "y=" y);
print( "changing x...");
set(x 9);
print( "x=" x "y=" y);
if( gt(x y)
    { print("x is greater than y."); set(z x); }
    { print("x is not greater than y."); set(z y); }
   );
print( "greatest of x and y:" z );
print( "total of x and y:" add( x y) );
''')

tests.append( '''
set(string "Love");
if ( eq(string "Love")
    { set(string "Hate");}
   )
print("string = " string)
return(z);
print( "not reached!" );
''')

tests.append( '''
set(code quote(print ("hello world!")) )
print ("code = '" code "'")
print ("when I evaluate the code, I should see 'hello world' as a side effect:")
eval(code)''')

tests.append( '''
set(x 1)
set(y 2)
set( sum quote(add(x y))  )
print ( "sum = " sum )
print ( "add(x y) = " add(x y))
print ( "eval(sum) = " eval(sum) )
print ( "setting x to 10..." )
set(x 10)
print ( "eval(sum) = " eval(sum) )
''')

tests.append( '''
set(code_string "print(hello world!)")
set(code parse(code_string))
print (code_string)
print (code)
eval(code)
''')

tests.append( '''
set(f lambda ( quote(x) add(x 1) ))
print( f(2))
set(metric lambda ( quote(x y) add(product(x x) product(y y)) ) )
print( metric(5 10) )
print( metric(3 4) )
''')

tests.append( '''
for( set(i 0) gt(10 i) set(i add(i 1)) {
  print("i=" i)  
})
''')

tests.append( '''
set(list data(1 2 3))
print(list)
foreach(element list {
  print(element)
})
''')

tests.append( '''
set(tree data("Friends" ("John" "Jill") "Pets" ("Spot" "Tiger") ))
print(tree)
foreach(branch tree {
  print(val(branch) ":")
  foreach(leaf branch {
    print("  " leaf)
  })
})
''')

for t in tests:
    print "code:",t
    print "output:"
    run(t,{})
    print