# -*- Mode: Python -*-

from pdb import set_trace as trace
from pprint import pprint as pp

import nodes
#import solver
import itypes

is_a = isinstance

class register_rib:
    def __init__ (self, formals, regs):
        self.formals = formals
        self.regs = regs
        assert (len (formals) == len (regs))
    def lookup (self, name):
        lf = len (self.formals)
        for i in range (lf):
            if name == self.formals[i].name:
                return self.formals[i], self.regs[i]
        return None
    def __repr__ (self):
        return '<reg: %r %r>' % (self.formals, self.regs)

class fatbar_rib:
    def __init__ (self, name):
        self.name = name

class IncompleteMatch (Exception):
    pass

class compiler:

    def __init__ (self, context, verbose=False):
        self.context = context
        self.verbose = verbose
        self.constants = {}
        self.regalloc = register_allocator()
        self.current_function = None

    def lexical_address (self, lenv, name):
        x = 0
        while lenv:
            rib, lenv = lenv
            if is_a (rib, register_rib):
                probe = rib.lookup (name)
                if probe is not None:
                    var, reg = probe
                    return var, (None, reg), False
            elif is_a (rib, fatbar_rib):
                # ignore these for normal variable lookup
                pass
            else:
                for y in range (len (rib)):
                    if rib[y].name == name:
                        return rib[y], (x, y), self.use_top and lenv == None
                # only real 'ribs' increase lexical depth
                x += 1
        else:
            raise ValueError, "unbound variable: %r" % (name,)

    # This 'compiler' converts <exp> to CPS, with each continuation representing
    #  a target 'register' for the result of that expression.

    def compile_exp (self, tail_pos, exp, lenv, k):

        #import sys; W = sys.stdout.write
        #W ('compile_exp: [%3d] %r\n' % (exp.serial, exp,))
        if tail_pos:
            k = self.cont (k[1], self.gen_return)
        
        if exp.is_a ('varref'):
            return self.compile_varref (tail_pos, exp, lenv, k)
        elif exp.is_a ('varset'):
            return self.compile_varset (tail_pos, exp, lenv, k)
        elif exp.is_a ('literal'):
            return self.compile_literal (tail_pos, exp, lenv, k)            
        elif exp.is_a ('constructed'):
            return self.gen_constructed (self.scan_constructed (exp.value), k)
        elif exp.is_a ('sequence'):
            return self.compile_sequence (tail_pos, exp.subs, lenv, k)
        elif exp.is_a ('conditional'):
            return self.compile_conditional (tail_pos, exp, lenv, k)
        elif exp.is_a ('cexp'):
            return self.compile_primargs (exp.args, ('%cexp', exp.form, exp.type_sig), lenv, k)
        elif exp.is_a ('function'):
            return self.compile_function (tail_pos, exp, lenv, k)
        elif exp.is_a ('application'):
            return self.compile_application (tail_pos, exp, lenv, k)
        elif exp.is_a ('fix'):
            return self.compile_let_splat (tail_pos, exp, lenv, k)
        elif exp.is_a ('let_splat'):
            if self.safe_for_let_reg (tail_pos, exp, lenv, k):
                return self.compile_let_reg (tail_pos, exp, lenv, k)
            else:
                return self.compile_let_splat (tail_pos, exp, lenv, k)
        elif exp.is_a ('primapp'):
            return self.compile_primapp (tail_pos, exp, lenv, k)
        elif exp.is_a ('pvcase'):
            return self.compile_pvcase (tail_pos, exp, lenv, k)
        elif exp.is_a ('nvcase'):
            return self.compile_nvcase (tail_pos, exp, lenv, k)
        else:
            raise NotImplementedError

    def scan_constructed (self, exp):
        # add this literal to the global list
        cc = self.context.constructed

        def add (ob):
            index = len (cc)
            cc.append (ob)
            ob.index = index
            return index

        # search inside a constructed literal for other constructed literals,
        #  so we can emit them (in the correct order).
        def scan (exp):
            if exp.is_a ('primapp'):
                if exp.name == '%dtcon/symbol/t':
                    string = exp.args[0]
                    probe = self.context.symbols.get (string.value, None)
                    if probe is not None:
                        string.index, exp.index = probe
                        return exp.index
                    else:
                        index0 = add (string)
                        index1 = add (exp)
                        self.context.symbols[string.value] = (index0, index1)
                        return index1
                else:
                    for x in exp.args:
                        scan (x)
                    return None
            elif exp.is_a ('literal'):
                if exp.ltype == 'string':
                    return add (exp)
                elif exp.ltype in ('int', 'char', 'undefined'):
                    pass
                else:
                    raise ValueError ("unexpected object in constructed literal")
            else:
                raise ValueError ("huh?")

        index = scan (exp)
        if index is None:
            index = add (exp)
        return index

    # XXX a possible improvement: if we know that the body of the let
    #  makes only tail calls, then it should be safe as well.  Need to
    #  find an easy way to detect that case...

    # XXX this could be done *much* smarter.  Here's how: record the
    #   set of registers used by each and every function (transitively),
    #   which will let us know exactly which registers we can bind in
    #   a let around a call to that function.
    #
    #   For example: let's say we're about to call function X, which
    #     calls function Y.  If X uses only r0-r3, and Y only uses r0-r2,
    #     then we can safely bind to r4+.  The effect will be to let leaf-like
    #     functions use registers for binding.

    def safe_for_let_reg (self, tail_pos, exp, lenv, k):
        # we only want to use registers for bindings when
        #  1) we're in a leaf position (to avoid consuming registers
        #     too high on the stack - which means fewer registers to save
        #     around each funcall).
        #  2) there's not too many bindings (again, avoid consuming regs)
        #  3) none of the variables escape (storing a binding in a reg
        #     defeats the idea of a closure)
        if exp.leaf and len(exp.names) <= 4:
            for name in exp.names:
                if name.escapes:
                    return False
            else:
                return True
        else:
            return False

    # this optimization will mean less once we start passing arguments in registers.
    def safe_for_tr_call (self, app):
        if app.rator.is_a ('varref') and app.recursive and app.function:
            # we can only use the trcall hack when we know exactly what
            #   the stack looks like above us.  escaping funs do not provide
            #   that guarantee.
            if self.current_function.escapes:
                return False
            # XXX variables only escape if their containing function escapes,
            #   so I think this second test is redundant.
            for vardef in app.function.formals:
                if vardef.escapes:
                    return False
            return True
        else:
            return False

    def compile_application (self, tail_pos, exp, lenv, k):
        if tail_pos:
            gen_invoke = self.gen_invoke_tail
        else:
            gen_invoke = self.gen_invoke
        if tail_pos and self.safe_for_tr_call (exp):
            # special-case tail recursion to avoid consing environments
            var, addr, is_top = self.lexical_address (lenv, exp.rator.name)
            # <tr_call> needs to know how many levels of lenv to pop
            exp.depth, index = addr
            return self.compile_tr_call (exp.rands, exp, lenv, k)
        else:
            def make_application (args_reg):
                return self.compile_exp (
                    False, exp.rator, lenv, self.cont (
                        [args_reg] + k[1],
                        lambda closure_reg: gen_invoke (exp.function, closure_reg, args_reg, k)
                        )
                    )
            if len(exp.rands):
                return self.compile_rands (exp.rands, lenv, self.cont (k[1], make_application))
            else:
                return make_application (None)

    def compile_literal (self, tail_pos, exp, lenv, k):
        if exp.ltype == 'string':
            return self.gen_constructed (self.scan_constructed (exp), k)
        else:
            # immediates
            return self.gen_lit (exp, k)

    def compile_varref (self, tail_pos, exp, lenv, k):
        var, addr, is_top = self.lexical_address (lenv, exp.name)
        if addr[0] is None:
            # register variable
            return self.gen_move (addr[1], None, var.name, k)
        else:
            return self.gen_varref (addr, is_top, var, k)

    def compile_varset (self, tail_pos, exp, lenv, k):
        var, addr, is_top = self.lexical_address (lenv, exp.name)
        assert (var.name == exp.name)
        if addr[0] is None:
            # register variable
            fun = lambda reg: self.gen_move (addr[1], reg, var.name,k)
        else:
            fun = lambda reg: self.gen_assign (addr, is_top, var, reg, k)
        return self.compile_exp (False, exp.value, lenv, self.cont (k[1], fun))

    # collect_primargs is used by primops, simple_conditional, and tr_call.
    #   in order to avoid the needless consumption of registers, we re-arrange
    #   the eval order of these args - by placing the complex args first.

    def collect_primargs (self, args, regs, lenv, k, ck, reorder=True):
        args = [(args[i], i) for i in range (len (args))]
        if reorder:
            # sort args by size/complexity
            args.sort (lambda x,y: cmp (y[0].size, x[0].size))
        perm = [x[1] for x in args]
        args = [x[0] for x in args]
        #print 'collect_primargs, len(args)=', len(args)
        return self._collect_primargs (args, regs, perm, lenv, k, ck)

    def _collect_primargs (self, args, regs, perm, lenv, k, ck):
        # collect a set of arguments into registers, pass that into compiler-continuation <ck>
        if len(args) == 0:
            # undo the permutation of the args
            perm_regs = [regs[perm.index (i)] for i in range (len (perm))]
            return ck (perm_regs)
        else:
            return self.compile_exp (
                False, args[0], lenv, self.cont (
                    regs + k[1],
                    lambda reg: self._collect_primargs (args[1:], regs + [reg], perm, lenv, k, ck)
                    )
                )

    def compile_tr_call (self, args, node, lenv, k):
        return self.collect_primargs (args, [], lenv, k, lambda regs: self.gen_tr_call (node, regs))

    def compile_primargs (self, args, op, lenv, k):
        return self.collect_primargs (args, [], lenv, k, lambda regs: self.gen_primop (op, regs, k))

    def compile_primapp (self, tail_pos, exp, lenv, k):
        if exp.name.startswith ('%raccess/') or exp.name.startswith ('%rset/'):
            prim, field = exp.name.split ('/')
            # try to get constant-time field access...
            sig = itypes.get_record_sig (exp.args[0].type)
            if prim == '%raccess':
                if sig is None:
                    trace()
                return self.compile_primargs (exp.args, ('%record-get', field, sig), lenv, k)
            else:
                return self.compile_primargs (exp.args, ('%record-set', field, sig), lenv, k)                
        elif exp.name.startswith ('%rextend/'):
            return self.compile_record_literal (exp, lenv, k)
        elif exp.name.startswith ('%vector-literal/'):
            if len (exp.args) < 5:
                return self.compile_primargs (exp.args, ('%make-tuple', exp.type, 'TC_VECTOR'), lenv, k)
            else:
                return self.compile_vector_literal (exp.args, lenv, k)
        elif exp.name.startswith ('%make-vector'):
            return self.compile_primargs (exp.args, ('%make-vector',), lenv, k)
        elif exp.name.startswith ('%make-vec16'):
            return self.compile_primargs (exp.args, ('%make-vec16',), lenv, k)
        elif exp.name in ('%%array-ref', '%%product-ref'):
            # XXX need two different insns, to handle constant index
            # XXX could support strings as character arrays by passing down a hint?
            if is_a (exp.type, itypes.t_int16):
                return self.compile_primargs (exp.args, ('%vec16-ref',), lenv, k)
            else:
                return self.compile_primargs (exp.args, ('%array-ref',), lenv, k)
        elif exp.name == '%%array-set':
            if is_a (exp.args[0].type.args[0], itypes.t_int16):
                return self.compile_primargs (exp.args, ('%vec16-set',), lenv, k)
            else:
                return self.compile_primargs (exp.args, ('%array-set',), lenv, k)
        elif exp.name == '%vec16-set':
            return self.compile_primargs (exp.args, ('%vec16-set',), lenv, k)
        elif exp.name.startswith ('%vcon/'):
            ignore, label, arity = exp.name.split ('/')
            tag = self.context.variant_labels[label]
            return self.compile_primargs (exp.args, ('%make-tuple', label, tag), lenv, k)
        elif exp.name == ('&vget'):
            label, arity, index = exp.name_params
            return self.compile_primargs (exp.args, ('%vget', index), lenv, k)
        elif exp.name.startswith ('%nvget/'):
            ignore, dtype, label, index = exp.name.split ('/')
            dt = self.context.datatypes[dtype]
            if dt.uimm.has_key (label):
                return self.compile_exp (tail_pos, exp.args[0], lenv, k)
            else:
                return self.compile_primargs (exp.args, ('%vget', index), lenv, k)
        elif exp.name.startswith ('%dtcon/'):
            ignore, dtname, label = exp.name.split ('/')
            dt = self.context.datatypes[dtname]
            tag = dt.tags[label]
            if dtname == 'symbol' and exp.args[0].is_a ('literal'):
                # special case: only triggered when symbols are present in data structures
                #   that cannot be built at compile-time.
                return self.gen_constructed (self.scan_constructed (exp), k)
            elif dt.uimm.has_key (label):
                return self.compile_exp (tail_pos, exp.args[0], lenv, k)
            else:
                return self.compile_primargs (exp.args, ('%make-tuple', label, tag), lenv, k)
        elif exp.name == '%%match-error':
            return self.gen_primop (('%%match-error',), [], k)
        elif exp.name == '%%fatbar':
            # urgh, not really a primop, but rather a control feature.  I guess it should be a new node type?
            return self.compile_fatbar (tail_pos, exp.args, lenv, k)
        elif exp.name == '%%fail':
            return self.compile_fail (tail_pos, lenv, k)
        else:
            raise ValueError ("Unknown primop: %r" % (exp.name,))

    def compile_sequence (self, tail_pos, exps, lenv, k):
        if len(exps) == 0:
            raise ValueError ("illegal sequence")
        elif len(exps) == 1:
            # last expression may be in tail position
            return self.compile_exp (tail_pos, exps[0], lenv, k)
        else:
            # more than one expression
            return self.compile_exp (
                False, exps[0], lenv,
                self.dead_cont (k[1], self.compile_sequence (tail_pos, exps[1:], lenv, k))
                )

    def compile_conditional (self, tail_pos, exp, lenv, k):
        if exp.test_exp.is_a ('cexp'):
            return self.compile_simple_conditional (tail_pos, exp, lenv, k)
        else:
            return self.compile_exp (
                False, exp.test_exp, lenv, self.cont (
                    k[1],
                    lambda test_reg: self.gen_test (
                        test_reg, 
                        self.compile_exp (tail_pos, exp.then_exp, lenv, self.cont (k[1], lambda reg: self.gen_jump (reg, k))),
                        self.compile_exp (tail_pos, exp.else_exp, lenv, self.cont (k[1], lambda reg: self.gen_jump (reg, k))),
                        k
                        )
                    )
                )

    def compile_simple_conditional (self, tail_pos, exp, lenv, k):
        def finish (regs):
            return self.gen_simple_test (
                exp.test_exp.params,
                regs,
                self.compile_exp (tail_pos, exp.then_exp, lenv, self.cont (k[1], lambda reg: self.gen_jump (reg, k))),
                self.compile_exp (tail_pos, exp.else_exp, lenv, self.cont (k[1], lambda reg: self.gen_jump (reg, k))),
                k
                )
        return self.collect_primargs (exp.test_exp.args, [], lenv, k, finish)

    def compile_pvcase (self, tail_pos, exp, lenv, k):
        def finish (test_reg):
            jump_k = self.cont (k[1], lambda reg: self.gen_jump (reg, k))
            alts = [self.compile_exp (tail_pos, alt, lenv, jump_k) for alt in exp.alts]
            return self.gen_pvcase (test_reg, exp.alt_formals, alts, k)
        return self.compile_exp (False, exp.value, lenv, self.cont (k[1], finish))

    def compile_nvcase (self, tail_pos, exp, lenv, k):
        dt = self.context.datatypes[exp.vtype]
        def finish (test_reg):
            jump_k = self.cont (k[1], lambda reg: self.gen_jump (reg, k))
            alts = [self.compile_exp (tail_pos, alt, lenv, jump_k) for alt in exp.alts]
            ealt = self.compile_exp (tail_pos, exp.else_clause, lenv, jump_k)
            if len(dt.alts) != len(alts) and ealt.name == 'primop' and ealt.params[0] == '%%match-error':
                raise IncompleteMatch (exp)
            return self.gen_nvcase (test_reg, exp.vtype, exp.tags, alts, ealt, k)
        return self.compile_exp (False, exp.value, lenv, self.cont (k[1], finish))

    fatbar_counter = 0
    def compile_fatbar (self, tail_pos, (e1, e2), lenv, k):
        label = 'fatbar_%d' % (self.fatbar_counter,)
        lenv0 = (fatbar_rib (label), lenv)
        self.fatbar_counter += 1
        return self.gen_fatbar (
            label,
            self.compile_exp (tail_pos, e1, lenv0, self.cont (k[1], lambda reg: self.gen_jump (reg, k))),
            self.compile_exp (tail_pos, e2, lenv,  self.cont (k[1], lambda reg: self.gen_jump (reg, k))),
            k
            )

    def compile_fail (self, tail_pos, lenv, k):
        # lookup the closest surrounding fatbar label
        search = lenv
        # lexical depth to pop off
        d = 0
        while search:
            rib, search = search
            if is_a (rib, fatbar_rib):
                return self.gen_fail (d, rib.name, k)
            elif is_a (rib, register_rib):
                # ignore
                pass
            else:
                d += 1
        else:
            raise ValueError ("%%fail without fatbar??")

    def compile_function (self, tail_pos, exp, lenv, k):
        self.current_function = exp
        if len(exp.formals):
            # don't extend the environment if there are no args
            lenv = (exp.formals, lenv)
        return self.gen_closure (
            exp,
            self.compile_exp (True, exp.body, lenv, self.cont ([], self.gen_return)),
            k
            )

    def compile_let_splat (self, tail_pos, exp, lenv, k):
        if len (exp.inits) == 0:
            # no bindings, just compile the body
            return self.compile_exp (tail_pos, exp.body, lenv, k)
        # becomes this sequence:
        #   (new_env, push_env, store_env0, ..., <body>, pop_env)
        k_body = self.dead_cont (k[1], self.compile_exp (tail_pos, exp.body, (exp.names, lenv),
                                                         self.cont (k[1], lambda reg: self.gen_pop_env (reg, k))))
        return self.gen_new_env (
            len (exp.names),
            self.cont (
                k[1],
                lambda tuple_reg: self.gen_push_env (
                    tuple_reg,
                    self.dead_cont (
                        k[1],
                        self.compile_store_rands (
                            0, 1, exp.inits, tuple_reg,
                            [tuple_reg] + k[1],
                            (exp.names, lenv),
                            k_body)
                        )
                    )
                )
            )

    def compile_let_reg (self, tail_pos, exp, lenv, k):

        # since this is a let-*splat*, we're forced to compile this one variable at a time,
        #   which makes the register 'rib' look a little silly.  XXX redo it as 'register_var'.

        def loop (names, inits, lenv, regs):
            if len(inits) == 0:
                return self.compile_exp (tail_pos, exp.body, lenv, (k[0], k[1] + regs, k[2]))
            else:
                lenv0 = (register_rib ([names[0]], [inits[0]]), lenv)
                return self.compile_exp (
                    False, inits[0], lenv, self.cont (
                        regs + k[1],
                        lambda reg: loop (
                            names[1:],
                            inits[1:],
                            (register_rib ([names[0]], [reg]), lenv),
                            regs + [reg]
                            )
                        )
                    )

        return loop (exp.names, exp.inits, lenv, [])

    opt_collect_args_in_regs = False

    if opt_collect_args_in_regs:
        # simply collect the args into registers, then use a <build_env> insn to populate the rib.
        # Note that collect_primargs will re-order the args...
        def compile_rands (self, rands, lenv, k):
            return self.collect_primargs (rands, [], lenv, k, lambda regs: self.gen_build_env (regs, k))
    else:
        # allocate the env rib, then place each arg in turn.
        # NOTE:
        #   to change the order of evaluation to right-to-left, you need to:
        #   1) pass i+1 to compile_tuple_rands
        #   2) make "i>0" the test, and
        #   3) i-1 the iter
        # then beware of callers expecting the other behavior (like let*)
        def compile_rands (self, rands, lenv, k):
            if not rands:
                return self.gen_new_env (0, k)
            else:
                return self.gen_new_env (
                    len (rands),
                    self.cont (k[1], lambda tuple_reg: self.compile_store_rands (0, 1, rands, tuple_reg, [tuple_reg] + k[1], lenv, k))
                    )

    # if we use collect_primargs() to populate literal vectors and records, the code
    #   emitted consumes one register for each arg before finally storing all the registers
    #   in one pass.  As the literals become larger, the register usage becomes very wasteful.
    # instead, this function accumulates the args one at a time, and stores them individually
    #   into the tuple.
        
    def compile_store_rands (self, i, offset, rands, tuple_reg, free_regs, lenv, k):
        # offset is an additional offset from the beginning of the tuple - used only
        #  when storing into environment ribs (because of the <next> pointer immediately
        #  after the tag).
        return self.compile_exp (
            False, rands[i], lenv, self.cont (
                free_regs,
                lambda arg_reg: self.gen_store_tuple (
                    offset, arg_reg, tuple_reg, i, len(rands),
                    (self.dead_cont (free_regs, self.compile_store_rands (i+1, offset, rands, tuple_reg, free_regs, lenv, k)) if i+1 < len(rands) else k)
                    )
                )
            )

    def compile_vector_literal (self, rands, lenv, k):
        return self.gen_new_tuple (
            'TC_VECTOR', len (rands),
            self.cont (k[1], lambda vec_reg: self.compile_store_rands (0, 0, rands, vec_reg, [vec_reg] + k[1], lenv, k))
            )

    def get_record_tag (self, sig):
        #print 'get record tag', sig
        c = self.context
        if not c.records2.has_key (sig):
            c.records2[sig] = len (c.records2)
            for label in sig:
                if not c.labels2.has_key (label):
                    c.labels2[label] = len (c.labels2)
        return c.records2[sig]

    def compile_record_literal (self, exp, lenv, k):
        # unwind row primops into a record literal
        # (%rextend/field0 (%rextend/field1 (%rmake) ...)) => {field0=x field1=y}
        fields = []
        while 1:
            if exp.is_a ('primapp') and exp.name == '%rmake':
                # we're done...
                break
            elif exp.is_a ('primapp') and exp.name.startswith ('%rextend/'):
                ignore, field = exp.name.split ('/')
                fields.append ((field, exp.args[1]))
                exp = exp.args[0]
            else:
                return self.compile_record_extension (fields, exp, lenv, k)
        # put the names into canonical order (sorted by label)
        fields.sort (lambda a,b: cmp (a[0],b[0]))
        # lookup the runtime tag for this record
        sig = tuple ([x[0] for x in fields])
        tag = 'TC_USEROBJ+%d' % (self.get_record_tag (sig) << 2)
        # now compile the expression as a %make-tuple
        args = [x[1] for x in fields]
        return self.gen_new_tuple (
            tag, len (args),
            self.cont (k[1], lambda rec_reg: self.compile_store_rands (0, 0, args, rec_reg, [rec_reg] + k[1], lenv, k))
            )

    def compile_record_extension (self, fields, exp, lenv, k):
        # ok, we have a source record {a,b} to which we want to add
        #   one or more fields {c,d}.  We'll need to compile a
        #   'make-tuple' with args fetched from the source record
        #   mixed in with new args, all in the correct order.
        sig = itypes.get_record_sig (exp.type)
        if '...' in sig:
            raise ValueError ("can't extend record - only a partial type available")
        labels = [x[0] for x in fields]
        labels.sort()
        args = [x[1] for x in fields]
        new_sig = list(set(sig).union (set(labels)))
        new_sig.sort()
        new_sig = tuple (new_sig)
        if sig == new_sig:
            # identical, it's actually an update
            # XXX should consider doing copy+update instead, for functional cred.
            # XXX another option: consider it an error.
            # the last sounds best: principle of least surprise.
            assert (len(fields) == 1)
            return self.compile_primargs ([exp, args[0]], ('%record-set', fields[0][0], sig), lenv, k)
        else:
            new_tag = self.get_record_tag (new_sig)
            return self.compile_primargs ([exp] + args, ('%extend-tuple', labels, sig, new_tag), lenv, k)

    # --- continuations ---

    def cont (self, free_regs, generator):
        # allocate a register for this continuation, then generate the
        #   code that will create the value to go into it.
        reg = self.regalloc.allocate (free_regs)
        return (reg, free_regs, generator (reg))

    def dead_cont (self, free_regs, k):
        # a 'dead' continuation - only for a side-effect.  Doesn't need a register allocated.
        return ('dead', free_regs, k)

class register_allocator:

    def __init__ (self):
        self.max_reg = -1

    def allocate (self, free_regs):
        i = 0
        while 1:
            if i not in free_regs:
                self.max_reg = max (self.max_reg, i)
                return i
            else:
                i += 1
        
def box (n):
    return (n<<1)|1

class INSN:

    allocates = 0

    def __init__ (self, name, regs, params, k):
        self.name = name
        self.regs = regs
        self.params = params
        self.k = k
        self.subs = ()

    def print_info (self):
        if self.name == 'test':
            return '%s %r %r' % (self.name, self.regs, self.params[0])
        elif self.name == 'close':
            return '%s %r %r' % (self.name, self.regs, self.params[0].name)
        elif self.name in ('pvcase', 'nvcase'):
            return '%s %r %r' % (self.name, self.params[0], self.regs)
        elif self.name == 'fatbar':
            return '%s %r %r' % (self.name, self.params[0], self.regs)
        else:
            return '%s %r %r' % (self.name, self.regs, self.params)

    def __repr__ (self):
        return '<INSN %s>' % (self.print_info())

class cps (compiler):

    """generates 'register' CPS"""


    def gen_lit (self, lit, k):
        # these smarts probably belong in the back end.
        if lit.ltype == 'int':
            return INSN ('lit', [], box (lit.value), k)
        elif lit.ltype == 'bool':
            if lit.value == 'true':
                n = 0x106
            else:
                n = 0x6
            return INSN ('lit', [], n, k)
        elif lit.ltype == 'char':
            if lit.value == 'eof':
                # special case
                val = 257<<8|0x02
            else:
                val = ord(lit.value)<<8|0x02
            return INSN ('lit', [], val, k)
        elif lit.ltype == 'undefined':
            return INSN ('lit', [], 0x0e, k)
        elif lit.ltype == 'nil':
            return INSN ('lit', [], 0x0a, k)
        else:
            raise SyntaxError

    def gen_constructed (self, exp, k):
        return INSN ('constructed', [], exp, k)

    def gen_primop (self, primop, regs, k):
        return INSN ('primop', regs, primop, k)

    def gen_move (self, reg_var, reg_src, name, k):
        return INSN ('move', [reg_var, reg_src], name, k)

    def gen_jump (self, reg, k):
        # k[0] is the target for the whole conditional
        return INSN ('jump', [reg, k[0]], None, None)

    def gen_fatbar (self, label, e1, e2, k):
        return INSN ('fatbar', [], (label, e1, e2), k)

    def gen_fail (self, depth, label, k):
        return INSN ('fail', [], (label, depth), None)

    def gen_new_env (self, size, k):
        return INSN ('new_env', [], size, k)

    def gen_build_env (self, regs, k):
        return INSN ('build_env', regs, None, k)

    def gen_push_env (self, reg, k):
        return INSN ('push_env', [reg], None, k)

    def gen_pop_env (self, reg, k):
        return INSN ('pop_env', [reg], None, k)

    def gen_new_tuple (self, tag, size, k):
        return INSN ('new_tuple', [], (tag, size), k)

    def gen_store_tuple (self, offset, arg_reg, tuple_reg, i, n, k):
        return INSN ('store_tuple', [arg_reg, tuple_reg], (i, offset, n), k)

    def gen_varref (self, addr, is_top, var, k):
        return INSN ('varref', [], (addr, is_top, var), k)
    
    def gen_assign (self, addr, is_top, var, reg, k):
        return INSN ('varset', [reg], (addr, is_top, var), k)

    def gen_closure (self, fun, body, k):
        # track all functions for the back end
        self.context.functions.append (fun)
        return INSN ('close', [], (fun, body, k[1]), k)

    def gen_test (self, test_reg, then_code, else_code, k):
        return INSN ('test', [test_reg], (None, then_code, else_code), k)

    def gen_simple_test (self, cexp, regs, then_code, else_code, k):
        return INSN ('test', regs, (cexp, then_code, else_code), k)

    def gen_pvcase (self, test_reg, types, alts, k):
        return INSN ('pvcase', [test_reg], (types, alts), k)

    def gen_nvcase (self, test_reg, dtype, tags, alts, ealt, k):
        return INSN ('nvcase', [test_reg], (dtype, tags, alts, ealt), k)

    def gen_invoke_tail (self, fun, closure_reg, args_reg, k):
        return INSN ('invoke_tail', [closure_reg, args_reg], fun, None)

    def gen_invoke (self, fun, closure_reg, args_reg, k):
        return INSN ('invoke', [closure_reg, args_reg], (k[1], fun), k)

    def gen_tr_call (self, app_node, regs):
        return INSN ('tr_call', regs, (app_node.depth, app_node.function), None)

    def gen_return (self, val_reg):
        return INSN ('return', [val_reg], None, None)

    def go (self, exp):
        lenv = None
        # only enable the 'top lenv' hack if the top level is a fix
        self.use_top = exp.is_a ('fix')
        result = self.compile_exp (True, exp, lenv, self.cont ([], self.gen_return))
        result = flatten (result)
        #pretty_print (result)
        #remove_moves (result)
        find_allocation (result, self.verbose)
        return result

def flatten (exp):
    r = []
    while exp:
        #print exp
        if exp.k:
            target, free_regs, next = exp.k
        else:
            next = None
            target = None
            free_regs = []
        exp.k = None
        exp.target = target
        exp.free_regs = free_regs
        if exp.name == 'test':
            name, then_code, else_code = exp.params
            exp.params = name, flatten (then_code), flatten (else_code)
        elif exp.name == 'close':
            node, body, free = exp.params
            exp.params = node, flatten (body), free
        elif exp.name == 'pvcase':
            types, alts = exp.params
            exp.params = types, [flatten (x) for x in alts]
        elif exp.name == 'nvcase':
            types, tags, alts, ealt = exp.params
            exp.params = types, tags, [flatten (x) for x in alts], flatten (ealt)
        elif exp.name == 'fatbar':
            label, e1, e2 = exp.params
            exp.params = label, flatten (e1), flatten (e2)
        r.append (exp)
        exp = next
    return r

import sys
W = sys.stdout.write

def pretty_print (insns, depth=0): 
   for insn in insns:
        W ('%s' % ('    ' * depth))
        if insn.target == 'dead':
            W ('   -   ')
        elif insn.target is None:
            W ('       ')
        else:
            W ('%4d = ' % (insn.target,))
        W ('%s\n' % (insn.print_info(),))
        # special case prints
        if insn.name == 'test':
            name, then_code, else_code = insn.params
            pretty_print (then_code, depth+1)
            pretty_print (else_code, depth+1)
        elif insn.name == 'close':
            node, body, free = insn.params
            pretty_print (body, depth+1)
        elif insn.name == 'pvcase':
            types, alts = insn.params
            for alt in alts:
                pretty_print (alt, depth+1)
        elif insn.name == 'nvcase':
            types, tags, alts, ealt = insn.params
            for alt in alts + [ealt]:
                pretty_print (alt, depth+1)
        elif insn.name == 'fatbar':
            label, e1, e2 = insn.params
            pretty_print (e1, depth+1)
            pretty_print (e2, depth+1)

# when <let> expressions are in a leaf position, the bindings may be
#  be stored in registers rather than an environment tuple.  due to 
#  the way the CPS algorithm works, there are a lot of redundant move
#  insns generated that we can ignore by remapping the relevant registers.                

# Ok, this doesn't work correctly [yet].  The problem comes up when varset
#   causes regs to get remapped - tests/t_bad_inline.scm fails.

def remove_moves (insns):
    map = {}
    for insn in insns:
        name = insn.name
        if insn.name == 'move':
            # a new entry in map
            src = insn.regs[0]
            # note: <src> may already be in the map!
            while map.has_key (src):
                # follow the chain of references
                src = map[src]
            # src == target sometimes happens, don't go all infinite loop.
            if insn.target != 'dead' and insn.target != src:
                print 'map %d == %d' % (insn.target, src)
                map[insn.target] = src
        # rename any that we can
        insn.regs = [ map.get(x,x) for x in insn.regs ]
        # special case
        if insn.name == 'test':
            name, then_code, else_code = insn.params
            remove_moves (then_code)
            remove_moves (else_code)
        elif insn.name == 'close':
            node, body, free = insn.params
            remove_moves (body)
        elif insn.name in ('pvcase', 'nvcase'):
            types, alts = insn.params
            for alt in alts:
                remove_moves (alt)
        elif insn.name == 'fatbar':
            # XXX never tested, dead code
            label, e1, e2 = insn.params
            remove_moves (e1)
            remove_moves (e2)
        if insn.name != 'move' and map.has_key (insn.target):
            # remove any that are blown away
            del map[insn.target]

def walk (insns):
    "iterate the entire tree of insns"
    for insn in insns:
        yield (insn)
        if insn.name == 'test':
            name, then_code, else_code = insn.params
            for x in walk (then_code):
                yield x
            for x in walk (else_code):
                yield x
        elif insn.name == 'close':
            node, body, free = insn.params
            for x in walk (body):
                yield x
        elif insn.name == 'pvcase':
            types, alts = insn.params
            for alt in alts:
                for y in walk (alt):
                    yield y
        elif insn.name == 'nvcase':
            types, tags, alts, ealt = insn.params
            for alt in alts + [ealt]:
                for y in walk (alt):
                    yield y
        elif insn.name == 'fatbar':
            label, e1, e2 = insn.params
            for x in walk (e1):
                yield x
            for x in walk (e2):
                yield x

def walk_function (insns):
    "iterate only the insns in this function body"
    for insn in insns:
        yield (insn)
        if insn.name == 'test':
            name, then_code, else_code = insn.params
            for x in walk_function (then_code):
                yield x
            for x in walk_function (else_code):
                yield x
        elif insn.name == 'nvcase':
            types, tags, alts, ealt = insn.params
            for alt in alts + [ealt]:
                for x in walk_function (alt):
                    yield x
        elif insn.name == 'pvcase':
            types, alts = insn.params
            for alt in alts:
                for x in walk_function (alt):
                    yield x
        elif insn.name == 'fatbar':
            label, e1, e2 = insn.params
            for x in walk_function (e1):
                yield x
            for x in walk_function (e2):
                yield x

def find_allocation (insns, verbose):
    funs = [ x for x in walk (insns) if x.name == 'close' ]
    # examine each fun to see if it performs allocation
    for fun in funs:
        node, body, free = fun.params
        fun.allocates = 0
        for insn in walk_function (body):
            if insn.name == 'primop':
                if insn.params[0] == '%make-tuple' and len(insn.regs):
                    # we're looking for non-immediate constructors (i.e., list/cons but not list/nil)
                    fun.allocates += 1
                elif insn.params[0] in ('%make-vector', '%extend-tuple'):
                    fun.allocates += 1
            elif insn.name in ('new_env', 'build_env', 'new_tuple', 'invoke', 'close', 'make_string'):
                fun.allocates += 1
        if verbose:
            print 'allocates %d %s' % (fun.allocates, fun.params[0].name)