__version__ = "$Id$"

# (SMR) taken from the PyLR package, modernized a bit.

#
# Error checking.  These errors correspond the the same errors in
# Yacc/Bison
#

# (SMR) these are not used below - conflicts are printed to stdout
#       instead - look into raising these exceptions?

class ConflictError:

    def __init__(self, state_index, terminal_index, old, new):
	self.indices = (state_index, terminal_index)
	self.old = old
	self.new = new


class ShiftShiftError(ConflictError):
    pass

class ShiftReduceError(ConflictError):
    pass

class ReduceReduceError(ConflictError):
    pass


#
# Production -- a Grammar is really just a list of productions.
# The expected structure is a symbol for the LHS and a list of 
# symbols or symbols for the RHS.
#
class Production:

    def __init__(self, LHS, RHS, func=None, funcname=""):
	self.LHS = LHS
	self.RHS = RHS
	self.func = func
	self.funcname = funcname

    #
    # .setfunc(<callable>)  --used for the dynamic production 
    # of a parseengine directly from Grammar.mkengine(), instead of tables 
    # saved to a file.
    #
    def setfunc(self, func):
	self.func = func
	
    #
    # .setfuncname("") -- used by Grammar.writefile to produce
    # prodinfo table that.  .setfunc associates a function value
    # with the production for runtime, on the fly productions
    # of parsing engine from Grammar.
    #
    def setfuncname(self, name):
	self.funcname = name

    def __len__(self):
	return len(self.RHS)

    def __repr__(self):
	return self.LHS + " -> " + `self.RHS`

    def items(self):
	return range(len(self.RHS) + 1)

#
# Perhaps this should be called 'LR1Grammar' instead...
# Provides methods for producing the actiontable, the gototable, and the 
# prodinfo table.  Using these functions, it can produce a python source code file
# with these tables or a parsing engine.
#
class Grammar:

    EPS = "<EPS>"
    EOF = "<EOF>"
    DummyLA = -1

    def __init__(self, prods, toks=None):
	self.productions = prods[:]
	if toks:
	    self.tokens = toks
	else:
	    self.tokens = []
	self.terms = self.terminals()
	self.nonterms = self.nonterminals()

    def terminals(self):
	t = []
	for p in self.productions:
	    for s in p.RHS:
		if s not in t:
		    t.append(s)
	for nt in self.nonterminals():
	    if nt in t:
		t.remove(nt)
	t.sort()
	return t

    def nonterminals(self):
	nt = []
	for p in self.productions:
	    if p.LHS not in nt:
		nt.append(p.LHS)
	return nt


    def expansions(self, LHS):
	e = []
	for p in self.productions:
	    if p.LHS == LHS:
		e.append(p.RHS)
	return e

    def LHSproductions(self, lhs):
	res = []
	for p in self.productions:
	    if p.LHS == lhs:
		res.append(p)
	return res

    def showexp(self, LHS):
	s = LHS + " -> "
	w = len(s)
	exps = self.expansions(LHS)
	for e in exps[:-1]:
	    for ei in range(len(e)):
		if type(e[ei]) == type(0):
		    if self.tokens:
			e[ei] = self.tokens[e[ei]]
		    else:
			e[ei] = str(e[ei])
            s = s + " ".join(e) + " |\n" + " " * w
	e = exps[-1]
	for ei in range(len(e)):
	    if type(e[ei]) == type(0):
		if self.tokens:
		    e[ei] = self.tokens[e[ei]]
		else:
		    e[ei] = str(e[ei])	    
	s = s + " ".join(e) + ";"
	return s


    def __repr__(self):
	s = ""
	for nt in self.nonterminals():
	    s = s + self.showexp(nt) + "\n\n"
	return s


    def augment(self):
        lhss = [x.LHS for x in self.productions]
	newsym = self.productions[0].LHS
	while 1:
	    newsym = newsym + "'"
	    if newsym not in lhss:
		break
	self.productions.insert(0, Production(newsym, 
					      [self.productions[0].LHS]))

				

    def first(self, sym, banned=None):
	if not banned:
	    banned = []
	if sym in self.terms or sym == Grammar.DummyLA:
	    return [sym]
	symprods = self.LHSproductions(sym)
	res = []
	banned.append(sym)
	for sp in symprods:
	    if sp.RHS == [Grammar.EPS] and Grammar.EPS not in res:
		res.append(Grammar.EPS)
	    else:
		allhaveeps = 1
		for s in sp.RHS:
		    if s in banned:
			continue
		    subfirst = self.first(s, banned)
		    res = res + subfirst
		    if Grammar.EPS not in subfirst:
			allhaveeps = 0
			break
		if allhaveeps and Grammar.EPS not in res:
		    res.append(Grammar.EPS)
	return res


    def firstofstring(self, gs_list):
	res = []
	allhaveeps = 1
	for x in range(len(gs_list)):
	    tmp = self.first(gs_list[x])
	    if Grammar.EPS in tmp:
		tmp.remove(Grammar.EPS)
	    else:
		allhaveeps = 0
	    res = res + tmp
	if allhaveeps:
	    res.append(Grammar.EPS)
	return res


    def follow(self):
	eof = Grammar.EOF
	follow = {}
	startsym = self.productions[0].LHS
	follow[startsym] = [eof]
	nts = self.nonterms
	for p in self.productions:
	    cutoff = range(len(p.RHS))
	    cutoff.reverse()
	    for c in cutoff[:-1]:  # all but the first of the RHS elements
		f = self.first(p.RHS[c])
		if Grammar.EPS in f:
		    f.remove(Grammar.EPS)
		if follow.has_key(p.RHS[c - 1]):
		    if p.RHS[c -1] in nts:
			follow[p.RHS[c -1]] = follow[p.RHS[c - 1]] + f[:]
		else:
		    if p.RHS[c -1] in nts:
			follow[p.RHS[c - 1]] = f[:]
 	for p in self.productions:
 	    cutoff = range(len(p.RHS))
 	    cutoff.reverse()
	    if p.RHS[-1] in nts:
		if follow.has_key(p.LHS):
		    add = follow[p.LHS]
		else:
		    add = []

		if follow.has_key(p.RHS[-1]):
		    follow[p.RHS[-1]] = follow[p.RHS[-1]] + add
		else:
		    follow[p.RHS[-1]] = add
 	    for c in cutoff[:-1]:
 		f = self.first(p.RHS[c])
 		if Grammar.EPS in f:
 		    if follow.has_key(p.LHS):
 			add = follow[p.LHS]
 		    else:
 			add = []
 		    if follow.has_key(p.RHS[c-1]):
 			follow[p.RHS[c-1]] = follow[p.RHS[c-1]] + add
 		    elif add:
 			follow[p.RHS[c - 1]] = add
	for k in follow.keys():
	    d = {}
	    for i in follow[k]:
		d[i] = 1
	    follow[k] = d.keys()
	return follow


    def closure(self, items):
	res = items[:]
	while 1:
	    more = []
	    for (prodind, rhsind), term in res:
		if rhsind == len(self.productions[prodind].RHS):
		    continue
		for p in self.LHSproductions(self.productions[prodind].RHS[rhsind]):
		    try:
			newpart = self.productions[prodind].RHS[rhsind + 1]
		    except IndexError:
			newpart = Grammar.EPS
		    stringofsyms = [newpart, term]
		    for t in self.firstofstring(stringofsyms):
			if ((self.productions.index(p), 0), t) not in res:
			    more.append(((self.productions.index(p), 0), t))
		    if term == Grammar.EOF and newpart == Grammar.EPS:
			if ((self.productions.index(p), 0), Grammar.EOF) not in res:
			    more.append(((self.productions.index(p), 0), Grammar.EOF))
	    if more:
		res = res + more
	    else:
		break
	return res


    def goto(self, items, sym):
	itemset = []
	for (prodind, rhsind), term in items:
	    try:
		if self.productions[prodind].RHS[rhsind] == sym and ((prodind, rhsind+1), term) not in itemset:
		    itemset.append( ((prodind, rhsind +1), term))
	    except IndexError:  
		pass
	return self.closure(itemset)

    #
    # LR1items -- this operation is extremely expensive.  Try not to call it
    # more than once per production of tables
    #
    def LR1items(self, save=0):
	res = [self.closure([((0,0), Grammar.EOF)])]
	syms = self.nonterms + self.terms
	while 1:
	    more = []
	    for itemset in res:
		for s in syms:
		    tmp = self.goto(itemset, s)
		    if tmp and tmp not in res and tmp not in more:
			more.append(tmp)
	    if not more:
		if save:
		    self.LR1items = res
		return res
	    else:
		res = res + more
	
    def actiontable(self, items=None):
	if not items:
	    items = self.LR1items()
	res = []
	state_i = 0
	terms = self.terms[:]
	terms.append(Grammar.EOF)
	errentry = ("", -1)
	for state in items:
	    list = [errentry] * len(terms)
	    res.append(list)
	    for (prodind, rhsind), term in state:
		if (rhsind ) == len(self.productions[prodind].RHS):
		    if prodind != 0:
			new = ("r", prodind)
			old = res[state_i][terms.index(term)]
			if old != errentry and old != new:
			    print "Conflict[%d,%d]:" % (state_i, terms.index(term)), old, "->", new
			res[state_i][terms.index(term)] = new
		    else:
			new = ("a", -1)
			old = res[state_i][terms.index(term)]
			if old != errentry and old != new:
			    print "Conflict[%d,%d]:" % (state_i, terms.index(term)), old, "->", new
			res[state_i][terms.index(term)] = new
		elif self.productions[prodind].RHS[rhsind] in terms:
		    tmp = self.goto(state, self.productions[prodind].RHS[rhsind])
		    if tmp in items:
			inner_i = terms.index(self.productions[prodind].RHS[rhsind])
			new =  ("s", items.index(tmp))
			old = res[state_i][inner_i]
			if old != errentry and old != new:
			    print "Conflict[%d, %d]:" % (state_i, inner_i), old, "->", new
			res[state_i][inner_i] = new
	    state_i = state_i + 1
	return res


    def gototable(self, items=None):
	if not items:
	    items = self.LR1items()
	res = []
	state_i = 0
	nonterms = self.nonterms
	nonterms.remove(self.productions[0].LHS)
	err = None
	for state in items:
	    list = [err] * len(nonterms)
	    res.append(list)
	    nonterm_i = 0
	    for nt in nonterms:
		goto = self.goto(state, nt)
		if goto in items:
		    res[state_i][nonterm_i] = items.index(goto)
		nonterm_i = nonterm_i + 1
	    state_i = state_i + 1
	return res


    def setfunc(self, prodind, func):
	self.productions[prodind].setfunc(func)

    
    def setfuncname(self, prodinf, funcname):
	self.productions[prodind].setfuncname(funcname)

    def default_prodfunc(self):
	return lambda *args: args[0]

    def prodinfotable(self):
	res = []
	nonterms = self.nonterms
	for p in self.productions:
	    lhsind = nonterms.index(p.LHS)
	    func = p.func
	    if not func:
		func = self.default_prodfunc()
	    plen = len(p.RHS)
	    res.append((plen, func, lhsind))
	return res



    def mkengine(self, inbufchunksize=None, stackchunksize=None):
	#import PyLRengine
	self.augment()
	items = self.LR1items()
	at = self.actiontable(items)
	gt = self.gototable(items)
	self.productions = self.productions[1:]  # unaugment
	pi = self.prodinfotable()
	if not inbufchunksize:
	    inbufchunksize = 50
	if not stackchunksize:
	    stackchunksize = 100
        return (pi, at, gt)
	#e = PyLRengine.NewEngine(pi, at, gt, inbufchunksize, stackchunksize)
        #return e


    def writefile(self, filename):
	import PyLRtabletemplate
	template = PyLRtabletemplate.__doc__
	vals = {}
	import time
	vals["date"] = time.ctime(time.time())
	vals["filename"] = filename
	if not hasattr(self, "extrasource"):
	    vals["extrasource"] = ""
	else:
	    vals["extrasource"] = self.extrasource
	vals["grammar"] = `self`
	self.augment()
	lr1items = self.LR1items()
	actiontable = self.actiontable(lr1items)
	actiontable_s = "[\n\t"
	for l in actiontable:
	    actiontable_s = "%s%s,\n\t" % (actiontable_s, `l`)
	vals["actiontable"] = actiontable_s[:-3] + "\n]\n\n"
	gototable = self.gototable(lr1items)
	gototable_s = "[\n\t"
	for l in gototable:
	    gototable_s = "%s%s,\n\t" % (gototable_s, `l`)
	vals["gototable"] = gototable_s[:-3] + "\n]\n\n"
	self.productions = self.productions[1:]  # unaugment
	pi = self.prodinfotable()
	pi_s = "[\n\t"
	for l, f, e in pi:
	    pi_s = "%s(%d, _fdummy, %d),\n\t" % (pi_s, l,e )
	vals["prodinfo"] = pi_s + "]\n\n"
	fp = open(filename, "w")
	fp.write(template % vals)
	fp.close()


class LALRGrammar(Grammar):



    def __init__(self, prods, toks=None):
	Grammar.__init__(self, prods, toks)
	self.LALRitems = []


    def LALRclosure(self, itemset):
	res = itemset[:]
	while 1:
	    more = []
	    for (prodind, rhsind) in res:
		if rhsind == len(self.productions[prodind].RHS):
		    continue
		for p in self.LHSproductions(self.productions[prodind].RHS[rhsind]):
		    newprodind = self.productions.index(p)
		    if (newprodind, 0) not in res:
			more.append((newprodind, 0))
	    if more:
		res = res + more
	    else:
		break
	return res


    def goto(self, itemset, sym):
	items = []
	for (prodind, rhsind) in itemset:
	    try:
		if (self.productions[prodind].RHS[rhsind] == sym and 
		    (prodind, rhsind+1) not in items):
		    items.append((prodind, rhsind +1))
	    except IndexError:  
		pass
	return self.LALRclosure(items)
	

    def LR0items(self):
	res = [self.LALRclosure([(0,0)])]
	syms = self.nonterms + self.terms
	while 1:
	    more = []
	    for iset in res:
		for s in syms:
		    g = self.goto(iset, s)
		    if g and g not in res:
			more.append(g)
	    if not more:
		break
	    else:
		res = res + more
	return res


    def kernels(self, items):
	res = []
	for (prodind, rhsind) in items:
	    if (rhsind != 0) or (prodind == 0 and rhsind == 0):
		res.append((prodind, rhsind))
	return res


    def lookaheads(self, itemset, setsofitems=None):
	if not setsofitems:
	    setsofitems = self.LR0items()
	state_i = setsofitems.index(itemset)
	spontaneous = []
	propagates = {}
	kernels = self.kernels(itemset)
	for (kpi, kri) in kernels:
	    C = self.closure([((kpi, kri), Grammar.DummyLA)])
	    for (cpi, cri), t in C:
		if (cri) >= len(self.productions[cpi].RHS):
		    continue
		try:
		    newstate = setsofitems.index(self.goto(itemset, self.productions[cpi].RHS[cri]))
		except ValueError:
		    continue
		if t != Grammar.DummyLA:
		    spontaneous.append((newstate, (cpi, cri+1), t))
		else:
 		    if propagates.has_key((kpi, kri)):
			propagates[(kpi, kri)].append((newstate, (cpi, cri+1)))
		    else:
			propagates[(kpi, kri)]=[(newstate, (cpi, cri+1))]
	return spontaneous, propagates
		

    def initLALR1items(self, setsofitems): 
	props = {}
	la_table = []
	for x in range(len(setsofitems)):
	    la_table.append([])
	    for y in range(len(setsofitems[x])):
		la_table[x].append([])
	la_table[0][0] = [Grammar.EOF]
	state_i = 0
	for itemset in setsofitems:
            sp, pr = self.lookaheads(itemset, setsofitems)
	    for ns, (pi, ri), t in sp:
		inner = setsofitems[ns].index((pi, ri))
		la_table[ns][inner].append(t)
	    props[state_i] = pr
	    state_i = state_i + 1
	return la_table, props
    

    def LALR1items(self):
	self.soi = soi = self.LR0items()
	la_table, props = self.initLALR1items(soi)
	while 1:
	    added_la = 0
	    state_i = 0
	    for state in la_table:
		ii = 0
		for propterms in state:
		    if not propterms:
			ii = ii + 1
			continue
		    item = soi[state_i][ii]
		    ii = ii + 1
		    try:
			proplist = props[state_i][item]
		    except KeyError:
			continue
		    for pstate, pitem in proplist:
			inner = soi[pstate].index(pitem)
			for pt in propterms:
			    if pt not in la_table[pstate][inner]:
				added_la = 1
				la_table[pstate][inner].append(pt)
		state_i = state_i + 1
	    if not added_la:
		break
	#
	# this section just reorganizes the above data
	# to the state it's used in later...
	# 
	res = []
	state_i = 0
	for state in soi:
	    item_i = 0
	    inner = []
	    for item in state:
		for term in la_table[state_i][item_i]:
		    if (item, term) not in inner:
			inner.append((item, term))
		item_i = item_i + 1
	    inner.sort()
	    res.append(inner)
	    state_i = state_i + 1
	self.LALRitems = res
	return res


    def actiontable(self):
	items = self.LALRitems
	res = []
	state_i = 0
	terms = self.terms[:]
	terms.append(Grammar.EOF)
	errentry = ("", -1)
	for state in items:
	    list = [errentry] * len(terms)
	    res.append(list)
	    for (prodind, rhsind), term in state:
		if (rhsind ) == len(self.productions[prodind].RHS):
		    if prodind != 0:
			new = ("r", prodind)
			old = res[state_i][terms.index(term)]
			if old != errentry and old != new:
			    print "Conflict[%d,%d]:" % (state_i, terms.index(term)), old, "->", new
			res[state_i][terms.index(term)] = new
		    else:
			new = ("a", -1)
			old = res[state_i][terms.index(term)]
			if old != errentry and old != new:
			    print "Conflict[%d,%d]:" % (state_i, terms.index(term)), old, "->", new
			res[state_i][terms.index(term)] = new
	    state_i = state_i + 1
	res = self.addshifts(res)
	return res


    def addshifts(self, res):
	state_i = 0
	for state in self.soi:
	    for pi, ri in state:
		try:
		    if self.productions[pi].RHS[ri] in self.terms:
			gt = self.goto(state, self.productions[pi].RHS[ri])
			res[state_i][self.terms.index(self.productions[pi].RHS[ri])] = ("s", self.soi.index(gt))
		except ValueError:
		    pass
		except IndexError:
		    pass
	    state_i = state_i + 1
	return res
	    

    def gototable(self):
	items = self.soi
	res = []
	state_i = 0
	nonterms = self.nonterms
	err = None
	for state in items:
	    list = [err] * len(nonterms)
	    res.append(list)
	    nonterm_i = 0
	    for nt in nonterms:
		goto = self.goto(state, nt)
                # (SMR) this sort() breaks at least the sample grammar,
                # it makes this item impossible to locate below...
		#goto.sort()
		if goto in items:
		    res[state_i][nonterm_i] = items.index(goto)
		nonterm_i = nonterm_i + 1
	    state_i = state_i + 1
	return res


    def mkengine(self, inbufchunksize=None, stackchunksize=None):
	#import PyLRengine
	self.augment()
	self.LALR1items()
	at = self.actiontable()
	gt = self.gototable()
	self.productions = self.productions[1:]  # unaugment
	pi = self.prodinfotable()
	if not inbufchunksize:
	    inbufchunksize = 50
	if not stackchunksize:
	    stackchunksize = 100
        return (pi, at, gt)
        #e = PyLRengine.NewEngine(pi, at, gt, inbufchunksize, stackchunksize)
	#return e
	

import sys
W = sys.stderr.write

def get_state (stack):
    if stack:
        return stack[-1][1]
    else:
        return 0

def eval (info, action, goto, stream):
    stack = []
    while 1:
        tok, val = stream[0]
        state = get_state (stack)
        shift_reduce, n = action[state][tok]
        #W ('state:%2d tok:%s s/r:%s n:%d stack:%40r %r\n' % (
        #    state, tok, shift_reduce, n, stack, stream
        #    )
        #   )
        if shift_reduce is 's':
            # shift
            stream.pop(0)
            stack.append ((val, n))
        elif shift_reduce is 'r':
            plen, pfun, lhsi = info[n-1]
            # pop args off of stack
            args, stack = stack[-plen:], stack[:-plen]
            # call action function
            result = pfun (*[x[0] for x in args])
            next_state = goto[get_state(stack)][lhsi]
            #W ('reduce: len=%d lhsi=%d next_state:%r pfun()=>%r stack=%r\n' % (
            #    plen, lhsi, next_state, result, stack
            #    ))
            stack.append ((result, next_state))
        elif shift_reduce is 'a':
            return stack
        else:
            raise SyntaxError
            

# the token interface will be defined in 
# a wrapper module and in the lexer.
# this is the format the information will be stored in

# first, define the tokens
#
(id, plus, times, minus, lparen, rparen, eof) = range(7)
# then, define how you want them to appear
# in the documentation of the output file
#
toks = ["id", "+", "*", "-", "(", ")"]

class g_0:

    def __init__ (self):
        self.prods = [
            Production (*x) for x in [
                ("E", ["E", plus, "T"],        self.p_0),
                ("E", ["T"],                   self.p_1), 
                ("T", ["T", times, "F"],       self.p_2), 
                ("T", ["F"],                   self.p_3),
                ("F", [lparen, "E", rparen],   self.p_4),
                ("F", [id],                    self.p_5)
                ]
            ]

    def p_0 (self, v0, _i0, v1):
        return v0 + v1

    def p_1 (self, v0):
        return v0

    def p_2 (self, v0, _i0, v1):
        return v0 * v1

    def p_3 (self, v0):
        return v0

    def p_4 (self, _i0, v0, _i1):
        return v0

    def p_5 (self, v0):
        return v0

def t0():
    stream = ((lparen, 0), (id, 34), (times, 0), (id, 92), (rparen, 0), (plus, 0), (id, 17))
    g = g_0()
    lalr = LALRGrammar (g.prods, toks)
    info, action, goto = lalr.mkengine()
    # translate from tokens=>terms
    stream = [(lalr.terms.index(t),v) for (t,v) in stream] + [(len(lalr.terms), 0)]
    return eval (info, action, goto, stream)

class g_1:

    def __init__ (self):
        self.prods = [
            Production (*x) for x in [
                ("E", ["E", minus, "T"],        self.p_0),
                ("E", ["T"],                    self.p_1),
                ("T", [id],                     self.p_2),
                ("T", [lparen, "E", rparen],    self.p_3),
                ]
            ]

    def p_0 (self, v0, _i0, v1):
        return v0 - v1

    def p_1 (self, v0):
        return v0

    def p_2 (self, v0):
        return v0

    def p_3 (self, _i0, v0, _i1):
        return v0

def t1():
    # n-n-n
    stream = ((id, 19), (minus, 0), (id, 12), (minus, 0), (id, 3))
    g = g_1()
    lalr = LALRGrammar (g.prods, toks)
    info, action, goto = lalr.mkengine()
    # translate from tokens=>terms
    stream = [(lalr.terms.index(t),v) for (t,v) in stream] + [(len(lalr.terms), 0)]
    return eval (info, action, goto, stream)

from pprint import pprint as pp

def _test (prods):
    #
    # define the productions (LHS=left hand side, RHS=right hand side)
    #
    g = LALRGrammar(prods, toks)
    return g, g.mkengine()
