;; -*- Mode: Irken -*-

(include "self/transform.scm")

(datatype literal
  (:string string)
  (:int int)
  (:char char)
  (:bool bool)
  (:undef)
  ;; vectors, records, etc...
  )

;; node type holds metadata related to the node,
;;  but sub-nodes are held with the record.
(datatype node
  (:varref symbol)
  (:varset symbol)
  (:literal literal)
  (:cexp (list type) type string) ;; generic-tvars type template
  (:nvcase symbol (list symbol))  ;; datatype alts
  (:sequence)
  (:if)
  (:function symbol (list symbol)) ;; name formals
  (:call)
  (:let (list symbol))
  (:fix (list symbol))
  (:subst symbol symbol)
  (:primapp symbol sexp) ;; name params
  )

(define node-counter (make-counter 0))

;; given a list of nodes, add up their sizes (+1)
(define (sum-size l)
  (fold (lambda (n acc) (+ n.size acc)) 1 l))

(define no-type (pred '? '()))

(define no-type?
  (type:pred '? () _) -> #t
  _ -> #f
  )

;; a cleaner way to do this might be with an alist? (makes sense if
;;   most flags are clear most of the time?)
;; flags
(define (node-get-flag node i) 
  (bit-get node.flags i))
(define (node-set-flag! node i)
  (set! node.flags (bit-set node.flags i)))

;; defined node flags
(define NFLAG-RECURSIVE 0)
(define NFLAG-ESCAPES   1)
(define NFLAG-LEAF      2)

(define (make-node t subs)
  {t=t subs=subs size=(sum-size subs) id=(node-counter.inc) type=no-type flags=0}
  )

(define (node/varref name)
  (make-node (node:varref name) '()))

(define (node/varset name val)
  (make-node (node:varset name) (LIST val)))

(define varref->name
  (node:varref name) -> name
  _ -> (error "varref->name"))

(define (node/literal lit)
  (make-node (node:literal lit) '()))

(define (node/cexp gens type template args)
  (make-node (node:cexp gens type template) args))

(define (node/sequence subs)
  (make-node (node:sequence) subs))

(define (node/if test then else)
  (let ((nodes (LIST test then else)))
    (make-node (node:if) nodes)))

(define (node/function name formals body)
  (make-node (node:function name formals) (LIST body)))

(define (node/call rator rands)
  (let ((subs (list:cons rator rands)))
    (make-node (node:call) subs)))

(define (node/fix names inits body)
  (let ((subs (append inits (LIST body))))
    (make-node (node:fix names) subs)))

(define (node/let names inits body)
  (let ((subs (append inits (LIST body))))
    (make-node (node:let names) subs)))

(define (node/nvcase dt tags value alts else)
  (let ((subs (list:cons value (list:cons else alts))))
    (make-node (node:nvcase dt tags) subs)))

(define (node/subst from to body)
  (make-node (node:subst from to) (LIST body)))

(define (node/primapp name params args)
  (make-node (node:primapp name params) args))

(define (node-copy node0)
  (let ((node1 (make-node node0.t node0.subs)))
    (set! node1.flags node0.flags)
    (set! node1.type node0.type)
    node1))

(define (unpack-fix subs)
  ;; unpack (init0 init1 ... body) for fix and let.
  (let ((rsubs (reverse subs)))
    (match rsubs with
      (body . rinits) -> (:fixsubs body (reverse rinits))
      _ -> (error "unpack-fix: no body?"))))

(define literal->string
  (literal:string s) -> (format (char #\") s (char #\"))
  (literal:int n)    -> (format (int n))
  (literal:char ch)  -> (format (char #\#) (char #\\) (char ch)) ;; printable?
  (literal:bool b)   -> (format (bool b))
  (literal:undef)    -> (format "#u")
  )

(define (flags-repr n)
  (let loop ((bits '())
	     (n n))
    (cond ((= n 0) (list->string bits))
	  ((= (logand n 1) 1)
	   (loop (list:cons #\1 bits) (>> n 1)))
	  (else
	   (loop (list:cons #\0 bits) (>> n 1))))))

(define (indent-off n offset)
  (let loop ((n (- (* 2 n) offset)))
    (cond ((> n 0)
	   (print-string " ")
	   (loop (- n 1))))))

(define (pp-node n d)
  (define PS print-string)
  (let ((tr (type-repr n.type))
	(head (format (int n.id) " " (flags-repr n.flags))))
    (newline)
    (PS head)
    (indent-off d (string-length head))
    (match n.t with
      (node:varref name)	     -> (PS (format "varref " (sym name) " : " tr))
      (node:varset name)	     -> (PS (format "varset " (sym name) " : " tr))
      (node:literal lit)	     -> (PS (format "literal " (p literal->string lit) " : " tr))
      (node:cexp gens type template) -> (PS (format "cexp " template " : " tr))
      (node:sequence)		     -> (PS (format "sequence : " tr))
      (node:if)			     -> (PS (format "conditional : " tr))
      (node:call)		     -> (PS (format "call : " tr))
      (node:function name formals)   -> (PS (format "function " (sym name) " (" (join symbol->string " " formals) ") : " tr))
      (node:fix formals)	     -> (PS (format "fix (" (join symbol->string " " formals) ") : " tr))
      (node:nvcase dt tags)	     -> (PS (format "nvcase " (sym dt) "(" (join symbol->string " " tags) ") : " tr))
      (node:subst from to)	     -> (PS (format "subst " (sym from) "->" (sym to)))
      (node:primapp name params)     -> (PS (format "primapp " (sym name) " " (p repr params) " : " tr))
      (node:let formals)	     -> (PS (format "let (" (join symbol->string " " formals) ") : " tr))
      )
    (pp-nodes n.subs (+ 1 d))
    ))

(define (pp-nodes l d)
  (for-each (lambda (n) (pp-node n d)) l))

(define (get-formals l)
  (define p
    (sexp:symbol formal) acc -> (list:cons formal acc)
    _			 acc -> (error1 "malformed formal" l))
  (reverse (fold p '() l)))

(define (unpack-bindings bindings)
  (let loop ((l bindings)
	     (names '())
	     (inits '()))
    (match l with
      () -> (:pair (reverse names) (reverse inits))
      ((sexp:list ((sexp:symbol name) init)) . l)
      -> (loop l (list:cons name names) (list:cons init inits))
      _ -> (error1 "unpack-bindings" l)
      )))

(define (parse-cexp-sig sig)
  (let ((generic-tvars (alist-maker))
	(result (parse-type* sig generic-tvars)))
    (:scheme (generic-tvars::values) result)))

(define walk
  (sexp:symbol s) -> (node/varref s)
  (sexp:string s) -> (node/literal (literal:string s))
  (sexp:int n)    -> (node/literal (literal:int n))
  (sexp:char c)   -> (node/literal (literal:char c))
  (sexp:bool b)   -> (node/literal (literal:bool b))
  (sexp:undef)    -> (node/literal (literal:undef))
  ;; (:vector (list sexp))
  ;; (:record (list field))
  ;; (:cons symbol symbol)
  ;; (:attr sexp symbol)
  (sexp:list l)
  -> (match l with
       ((sexp:symbol '%%cexp) sig (sexp:string template) . args)
       -> (let ((scheme (parse-cexp-sig sig)))
	    (match scheme with
	      (:scheme gens type)
	      -> (node/cexp gens type template (map walk args))))
       ((sexp:symbol '%nvcase) (sexp:symbol dt) val-exp (sexp:list tags) (sexp:list alts) ealt)
       -> (node/nvcase dt (map sexp->symbol tags) (walk val-exp) (map walk alts) (walk ealt))
       ((sexp:symbol 'begin) . exps)
       -> (node/sequence (map walk exps))
       ((sexp:symbol 'set!) (sexp:symbol name) arg)
       -> (node/varset name (walk arg))
       ((sexp:symbol 'quote) arg)
       -> (let ((lit (walk arg)))
	    (match lit.t with
	      (node:literal _) -> lit
	      _ -> (error1 "expected literal type" l)))
       ((sexp:symbol 'if) test then else)
       -> (node/if (walk test) (walk then) (walk else))
       ((sexp:symbol 'function) (sexp:symbol name) (sexp:list formals) . body)
       -> (node/function name (get-formals formals) (node/sequence (map walk body)))
       ((sexp:symbol 'fix) (sexp:list names) (sexp:list inits) . body)
       -> (node/fix (get-formals names) (map walk inits) (node/sequence (map walk body)))
       ((sexp:symbol 'let-splat) (sexp:list bindings) . body)
       -> (match (unpack-bindings bindings) with
	    (:pair names inits)
	    -> (node/let names (map walk inits) (node/sequence (map walk body))))
       ((sexp:symbol 'letrec) (sexp:list bindings) . body)
       -> (match (unpack-bindings bindings) with
	    (:pair names inits)
	    -> (node/fix names (map walk inits) (node/sequence (map walk body))))
       ((sexp:symbol 'let_subst) (sexp:list ((sexp:symbol from) (sexp:symbol to))) body)
       -> (node/subst from to (walk body))
       (rator . rands)
       -> (match rator with
	    (sexp:symbol name)
	    -> (if (eq? (string-ref (symbol->string name) 0) #\%)
		   (match rands with
		     (params . rands)
		     -> (node/primapp name params (map walk rands))
		     _ -> (error1 "null primapp missing params?" l))
		   (node/call (walk rator) (map walk rands)))
	    (sexp:cons dt alt)
	    -> (node/primapp '%dtcon rator (map walk rands))
	    _ -> (node/call (walk rator) (map walk rands)))
       _ -> (error1 "syntax error: " l)
       )
  x -> (error1 "syntax error 2: " x)
  )

;(define (walk exp)
;  (print-string "walking... ") (unread exp) (newline)
;  (walk* exp))

(define (frob name num)
  (string->symbol (format (sym name) "_" (int num))))

(define (make-vardef name serial)
  (let ((frobbed (frob name serial)))
    {name=name name2=frobbed assigns='() refs='() serial=serial }
    ))

(define (make-var-map)
  (let ((map (tree:empty))
	(counter (make-counter 0)))
    (define (add sym)
      (let ((vd (make-vardef sym (counter.inc))))
	(set! map (tree/insert map symbol<? sym vd))
	vd))
    (define (lookup sym)
      (tree/member map symbol<? sym))
    (define (get) map)
    {add=add lookup=lookup get=get}
    ))

(define (rename-variables n)

  (let ((varmap (make-var-map)))

    (define (rename-all exps lenv)
      (for-each (lambda (exp) (rename exp lenv)) exps))
    
    (define (rename exp lenv)

      (define (lookup name)
	(let loop0 ((lenv lenv))
	  (match lenv with
	    ()		 -> (maybe:no)
	    (rib . next) -> (let loop1 ((l rib))
			      (match l with
				()	  -> (loop0 next)
				(vd . tl) -> (if (eq? name vd.name)
						 (maybe:yes vd)
						 (loop1 tl)))))))
      
      (match exp.t with
	(node:function name formals)
	-> (let ((rib (map varmap.add formals))
		 (name2 (match (lookup name) with
			  (maybe:no) -> name
			  (maybe:yes vd) -> vd.name2)))
	     (set! exp.t (node:function name2 (map (lambda (x) x.name2) rib)))
	     (rename-all exp.subs (list:cons rib lenv)))
	(node:fix names)
	-> (let ((rib (map varmap.add names)))
	     (set! exp.t (node:fix (map (lambda (x) x.name2) rib)))
	     (rename-all exp.subs (list:cons rib lenv)))
	(node:let names)
	-> (let ((rib (map varmap.add names)))
	     (set! exp.t (node:let (map (lambda (x) x.name2) rib)))
	     (rename-all exp.subs (list:cons rib lenv)))
	(node:varref name)
	-> (match (lookup name) with
	     (maybe:no) -> #u ;; can't rename it if we don't know what it is
	     (maybe:yes vd) -> (set! exp.t (node:varref vd.name2)))
	(node:varset name)
	-> (match (lookup name) with
	     (maybe:no) -> #u
	     (maybe:yes vd) -> (set! exp.t (node:varset vd.name2)))
	_ -> (rename-all exp.subs lenv)
	))

    (rename n '())
    ))

;; walk the node tree, applying subst nodes
(define (apply-substs exp)
  
  ;; could we do this more easily by flattening the environment and just using member?
  (define shadow
    names ()		-> (list:nil)
    names (pair . tail) -> (if (not (member-eq? pair.from names))
			       (list:cons pair (shadow names tail))
			       (shadow names tail)))

  (define lookup
    name ()	       -> name
    name (pair . tail) -> (if (eq? name pair.from)
			      pair.to
			      (lookup name tail)))

  (define (walk exp lenv)
    (let/cc return
	(match exp.t with
	  (node:fix formals)   -> (set! lenv (shadow formals lenv))
	  (node:let formals)   -> (set! lenv (shadow formals lenv))
	  ;; other binding expressions!
	  (node:subst from to) -> (begin
				    (print-string "      added entry\n")
				    (set! lenv (list:cons {from=from to=(lookup to lenv)} lenv))
				    (return (walk (car exp.subs) lenv)))
	  (node:varref name)   -> (set! exp.t (node:varref (lookup name lenv)))
	  (node:varset name)   -> (set! exp.t (node:varset (lookup name lenv)))
	  _ -> #u
	  )
      (set! exp.subs (map (lambda (x) (walk x lenv)) exp.subs))
      exp))

  (walk exp '())
  )

(define (test-nodes)
  (let ((context (make-context))
	(transform (transformer context))
	(exp0 (sexp:list (read-string "((lambda (x y) (+ x y)) 3 4)")))
	(exp1 (transform exp0))
	(node0 (walk exp1))
	(node1 (apply-substs node0))
	)
    (unread exp0) (newline)
    (unread exp1) (newline)
    (printn (walk exp1))
    (pp-node node1 0) (newline)
    (rename-variables node0)
    (pp-node node1 0) (newline)
    )
  )

;;(test-nodes)