;; -*- Mode: Irken -*-

(include "self/nodes.scm")

(datatype tag
  (:bare int)
  (:uobj int)
  )

;; RTL instructions
(datatype insn
  (:return int)						;; return register
  (:literal literal cont)				;; <value> <k>
  (:cexp type string (list int) cont)			;; <sig> <template> <args> <k>
  (:test int insn insn cont)				;; <reg> <then> <else> <k>
  (:testcexp (list int) type string insn insn cont)	;; <regs> <sig> <template> <then> <else> <k>
  (:jump int int)					;; <reg> <target>
  (:close symbol insn cont)				;; <name> <body> <k>
  (:varref int int cont)				;; <depth> <index> <k>
  (:varset int int int cont)				;; <depth> <index> <reg> <k>
  (:new-env int cont)					;; <size> <k>
  (:alloc tag int cont)					;; <tag> <size> <k>
  (:store int int int int cont)				;; <offset> <arg> <tuple> <i> <k>
  (:invoke (maybe symbol) int int cont)			;; <name> <closure> <args> <k>
  (:tail (maybe symbol) int int)			;; <name> <closure> <args>
  (:trcall int symbol (list int))			;; <depth> <name> <args>
  (:push int cont)					;; <env>
  (:pop int cont)					;; <result>
  (:primop symbol sexp (list int) cont)			;; <name> <params> <args> <k>
  (:move int int cont)					;; <var> <src> <k>
  (:fatbar int insn insn cont)				;; <label> <alt0> <alt1> <k>
  (:fail int int)					;; <label> <npop>
  (:nvcase int symbol (list symbol) (list insn) insn cont)
  )

;; continuation
;; XXX wonder if this would make more sense as a record?
(datatype cont
  (:k int (list int) insn) ;; <target-register> <free-registers> <code>
  (:nil)
  )

(datatype cpsenv
  (:nil)
  (:rib (list symbol) cpsenv)
  (:reg int cpsenv)
  (:fat int cpsenv)
  )

(define (make-register-allocator)
  (let ((max-reg -1))
    (define (allocate free)
      (let loop ((i 0))
	(if (member? i free =)
	    (loop (+ i 1))
	    (begin (set! max-reg (max i max-reg)) i))))
    (define (get-max) max-reg)
    {alloc=allocate get-max=get-max}
    ))

(define k/free
  (cont:k _ free _)   -> free
  (cont:nil) -> (error "k/free"))

(define k/target
  (cont:k target _ _) -> target
  (cont:nil) -> (error "k/target")
  )

(define (compile exp context)

  (let ((current-funs '(top)))

    (define (set-flag! flag)
      (vars-set-flag! context (car current-funs) flag))

    (define (cont free generator)
      (let ((reg (context.regalloc.alloc free)))
	(cont:k reg free (generator reg))))

    (define (dead free k)
      (cont:k -1 free k))

    (define (compile tail? exp lenv k)

      ;; override continuation when in tail position
      (if tail?
	  (set! k (cont (k/free k) gen-return)))
  
      (match exp.t with
	(node:literal lit)	      -> (c-literal lit k)
	(node:sequence)		      -> (c-sequence tail? exp.subs lenv k)
	(node:if)		      -> (c-conditional tail? exp lenv k)
	(node:function name formals)  -> (c-function name formals (car exp.subs) lenv k)
	(node:varref name)	      -> (c-varref name lenv k)
	(node:varset name)	      -> (c-varset name (car exp.subs) lenv k)
	(node:cexp gens sig template) -> (c-cexp sig template exp.subs lenv k)
	(node:call)		      -> (c-call tail? exp lenv k)
	(node:fix formals)	      -> (c-let-splat tail? formals exp.subs lenv k)
	(node:let formals)            -> (c-let-splat tail? formals exp.subs lenv k)
	(node:primapp name params)    -> (c-primapp tail? name params exp lenv k)
	(node:nvcase dt alts)         -> (c-nvcase tail? dt alts exp.subs lenv k)
	_ -> (begin (pp-node exp 0) (error1 "NYI" exp))
	)
      )

    (define (c-literal val k) (insn:literal val k))

    (define (c-sequence tail? nodes lenv k)
      (match nodes with
	()	     -> (error "empty sequence?")
	(exp)	     -> (compile tail? exp lenv k)
	(exp . exps) -> (compile #f exp lenv (dead (k/free k) (c-sequence tail? exps lenv k)))
	))

    ;; XXX consider redoing with fatbar?
    (define (c-conditional tail? exp lenv k)
      (let ((target (k/target k))
	    (free (k/free k)))
	(match exp.subs with
	  (test then else)
	  -> (match test.t with
	       (node:cexp _ sig template)
	       -> (c-simple-conditional tail? test then else sig template lenv k)
	       _ -> (compile
		     #f test lenv
		     (cont free
			   (lambda (reg)
			     (insn:test
			      reg
			      (compile tail? then lenv (cont free (lambda (reg) (insn:jump reg target))))
			      (compile tail? else lenv (cont free (lambda (reg) (insn:jump reg target))))
			      k))
			   )))
	  _ -> (error1 "c-conditional" exp)
	  )))

    (define (c-simple-conditional tail? test then else sig template lenv k)
      (let ((free (k/free k))
	    (target (k/target k)))
	(define (finish regs)
	  ;; <regs> <sig> <template> <then> <else> <k>
	  (insn:testcexp
	   regs sig template
	   (compile tail? then lenv (cont free (lambda (reg) (insn:jump reg target))))
	   (compile tail? else lenv (cont free (lambda (reg) (insn:jump reg target))))
	   k))
	(collect-primargs test.subs '() lenv k finish)))

    (define extend-lenv
      () lenv -> lenv ;; don't extend with an empty rib
      fs lenv -> (cpsenv:rib fs lenv)
      )

    (define (c-function name formals body lenv k)
      (set-flag! VFLAG-ALLOCATES)
      (PUSH current-funs name)
      (let ((r
	     (insn:close
	      name
	      (compile #t
		       body
		       (extend-lenv formals lenv)
		       (cont '() gen-return))
	      k)))
	(pop current-funs)
	r))

    (define search-rib
      name0 _ ()		  -> (maybe:no)
      name0 i (name1 . names) -> (if (eq? name0 name1)
				     (maybe:yes i)
				     (search-rib name0 (+ i 1) names)))

    (define lexical-address
      name _ (cpsenv:nil) -> (error1 "unbound variable" name)
      name d (cpsenv:rib names lenv) -> (match (search-rib name 0 names) with
					  (maybe:yes i) -> (:pair d i)
					  (maybe:no)    -> (lexical-address name (+ d 1) lenv))
      name d (cpsenv:fat _ lenv) -> (lexical-address name d lenv)
      name d (cpsenv:reg r lenv) -> (:reg r)
      )

    (define (c-varref name lenv k)
      (match (lexical-address name 0 lenv) with
	(:reg r) -> (insn:move r -1 k)
	(:pair depth index) -> (insn:varref depth index k)
	))

    (define (c-varset name exp lenv k)
      (let ((kfun
	     (match (lexical-address name 0 lenv) with
	       (:pair depth index)
	       -> (lambda (reg) (insn:varset depth index reg k))
	       (:reg index)
	       -> (lambda (reg) (insn:move reg index k)))))
	(compile #f exp lenv (cont (k/free k) kfun))))

    (define (c-primapp tail? name params exp lenv k)
      (let ((args exp.subs))
	(match name with
	  '%fail   -> (c-fail tail? lenv k)
	  '%fatbar -> (c-fatbar tail? args lenv k)
	  '%dtcon  -> (begin (if (> (length args) 0)
				 (set-flag! VFLAG-ALLOCATES))
			     (c-primargs args name params lenv k))
	  '%rextend -> (c-record-literal exp lenv k)
	  '%raccess -> (let ((arg0 (nth args 0))
			     (sig (get-record-sig-sexp arg0.type)))
			 (c-primargs args '%record-get
				     (sexp params sig) ;; (field sig)
				     lenv k))
	  _ -> (c-primargs args name params lenv k))))
   
    (define (c-cexp sig template args lenv k)
      (collect-primargs args '() lenv k
			(lambda (regs)
			  (insn:cexp sig template regs k))))

    (define (collect-primargs args regs lenv k ck)
      (match args with
	()        -> (ck (reverse regs))
	(hd . tl) -> (compile #f hd lenv
			      (cont (append (k/free k) regs)
				    (lambda (reg) (collect-primargs tl (cons reg regs) lenv k ck))))
	))

    (define (c-primargs args op parm lenv k)
      (collect-primargs args '() lenv k
			(lambda (regs) (insn:primop op parm regs k))))

    (define (safe-for-tr-call exp fun)
      (match fun with
	(node:varref name)
	-> (and (node-get-flag exp NFLAG-RECURSIVE)
		(not (vars-get-flag context (car current-funs) VFLAG-ESCAPES)))
	_ -> #f))

    (define (c-trcall depth name args lenv k)
      (collect-primargs args '() lenv k
			(lambda (regs) (insn:trcall depth name regs))))

    (define (c-call tail? exp lenv k)
      (match exp.subs with
	(fun . args)
	-> (if (and tail? (safe-for-tr-call exp fun.t))
	       (let ((name (varref->name fun.t)))
		 (match (lexical-address name 0 lenv) with
		   (:reg _) -> (error "c-call function in register?")
		   (:pair depth index) -> (c-trcall depth name args lenv k)
		   ))
	       (let ((gen-invoke (if tail? gen-tail gen-invoke))
		     (name (match fun.t with
			     (node:varref name) -> (maybe:yes name)
			     _ -> (maybe:no))))
		 (define (make-call args-reg)
		   (compile #f fun lenv (cont (cons args-reg (k/free k))
					      (lambda (closure-reg) (gen-invoke name closure-reg args-reg k)))))
		 (if (> (length args) 0)
		     (compile-args args lenv (cont (k/free k) make-call))
		     (make-call -1))))
	() -> (error "c-call: no function?")
	))

    (define (compile-args args lenv k)
      (set-flag! VFLAG-ALLOCATES)
      (match args with
	() -> (insn:new-env 0 k)
	_  -> (let ((nargs (length args)))
		(insn:new-env
		 nargs
		 (cont (k/free k)
		       (lambda (tuple-reg)
			 (compile-store-args 0 1 args tuple-reg
					     (cons tuple-reg (k/free k)) lenv k)))))
	))

    (define (compile-store-args i offset args tuple-reg free-regs lenv k)
      (compile
       #f (car args) lenv
       (cont free-regs
	     (lambda (arg-reg)
	       (insn:store
		offset arg-reg tuple-reg i
		(if (null? (cdr args)) ;; was this the last argument?
		    k
		    (dead
		     free-regs
		     (compile-store-args (+ i 1) offset (cdr args) tuple-reg free-regs lenv k))))))))

    (define (c-let-splat tail? formals subs lenv k)
      (let ((rsubs (reverse subs)) ;; subs = (init0 init1 ... body)
	    (body (car rsubs))
	    (inits (reverse (cdr rsubs)))
	    (nargs (length formals))
	    (free (k/free k))
	    (k-body (dead free
			  (compile tail? body (extend-lenv formals lenv)
				   (cont (k/free k) (lambda (reg) (insn:pop reg k)))))))
	(insn:new-env
	 nargs
	 (cont free
	       (lambda (tuple-reg)
		 (insn:push
		  tuple-reg
		  (dead free
			(compile-store-args 0 1 inits tuple-reg
					    (list:cons tuple-reg free)
					    (extend-lenv formals lenv)
					    k-body))))))))

    (define (c-nvcase tail? dtname alt-formals subs lenv k)
      (let ((free (k/free k)))
	;; nvcase subs = <value>, <else-clause>, <alt0>, ...
	(match (alist/lookup context.datatypes dtname) with
	  (maybe:no) -> (error1 "no such datatype" dtname)
	  (maybe:yes dt)
	  -> (let ((value (nth subs 0))
		   (eclause (nth subs 1))
		   (alts (cdr (cdr subs))))
	       (define (finish test-reg)
		 (let ((jump-k (cont free (lambda (reg) (insn:jump reg (k/target k)))))
		       (alts (map (lambda (alt) (compile tail? alt lenv jump-k)) alts))
		       (ealt (compile tail? eclause lenv jump-k)))
		   (if (not (= (dt.get-nalts) (length alts)))
		       (match eclause.t with
			 (node:primapp '%match-error _)
			 -> (error1 "incomplete match" alt-formals)
			 _ -> #u))
		   (insn:nvcase test-reg dtname alt-formals alts ealt k)))
	       (compile #f value lenv (cont free finish))))))

    (define fatbar-counter (make-counter 0))

    (define (c-fatbar tail? subs lenv k)
      (let ((label (fatbar-counter.inc))
	    (lenv0 (cpsenv:fat label lenv))
	    (free (k/free k))
	    (target (k/target k)))
	(insn:fatbar label
		     (compile tail? (nth subs 0) lenv0 (cont free (lambda (reg) (insn:jump reg target))))
		     (compile tail? (nth subs 1) lenv  (cont free (lambda (reg) (insn:jump reg target))))
		     k)))

    (define (c-fail tail? lenv k)
      ;; lookup the closest surrounding fatbar label
      (let loop ((depth 0)
		 (lenv lenv))
	(match lenv with
	  (cpsenv:nil)	   -> (error "%fail without fatbar?")
	  (cpsenv:rib _ lenv)  -> (loop (+ depth 1) lenv)
	  (cpsenv:reg _ lenv)  -> (loop depth lenv)
	  (cpsenv:fat label _) -> (insn:fail depth label))))

    (define (c-record-literal exp lenv k)
      (let loop ((exp exp)
		 (fields '()))
        ;; (%rextend field0 (%rextend field1 (%rmake) ...)) => {field0=x field1=y}
	(match exp.t with
	  (node:primapp '%rextend (sexp:symbol field)) ;; add another field
	  -> (match exp.subs with
	       (exp0 val)
	       -> (loop exp0 (list:cons (:pair field val) fields))
	       _ -> (error1 "malformed %rextend" exp))
	  (node:primapp '%rmake _) ;; done - put the names in canonical order
	  -> (let ((field0 (sort 
			    (lambda (a b)
			      (match a b with
				(:pair f0 _) (:pair f1 _)
				-> (symbol<? f0 f1)))
			    fields))
		   (sig (map pair-first fields))
		   (args (map pair-second fields))
		   (tag (get-record-tag sig))
		   (free (k/free k)))
	       (insn:alloc (tag:uobj tag)
			   (length args)
			   (cont free
				 (lambda (reg)
				   (compile-store-args
				    0 0 args reg
				    (list:cons reg free)
				    lenv k)))))
	  _ -> (c-record-extension fields exp lenv k))))
		       
    (define (c-record-extension fields exp lenv k)
      (error "c-record-extension: NYI"))

    (define (record-label-tag label)
      (let loop ((l context.labels))
	(match l with
	  ((:pair key val) . tl)
	  -> (if (eq? key label)
		 #u
		 (loop tl))
	  () -> (let ((index (length context.labels)))
		  (PUSH context.labels (:pair label index)))
	  )))

    (define (sig=? sig0 sig1)
      (and (= (length sig0) (length sig1))
	   (every2? eq? sig0 sig1)))

    (define (get-record-tag sig)
      (let loop ((l context.records))
	(match l with
	  ((:pair key val) . tl)
	  -> (if (sig=? key sig)
		 val
		 (loop tl))
	  ;; create a new entry
	  () -> (let ((index (length context.records)))
		  (for-each record-label-tag sig)
		  (PUSH context.records (:pair sig index))
		  index)
	  )))

    (define (gen-return reg)
      (insn:return reg))
    (define (gen-invoke name closure-reg args-reg k)
      (set-flag! VFLAG-ALLOCATES)
      (insn:invoke name closure-reg args-reg k))
    (define (gen-tail name closure-reg args-reg k)
      (insn:tail name closure-reg args-reg))

    (compile #t exp (cpsenv:nil) (cont '() gen-return))
    ))

;;; XXX redo this with the new format macro - this function is horrible.
(define (print-insn insn d)
  (define (print-line print-info k)
    (match k with
      (cont:k target free k0)
      -> (begin
	   (newline)
	   (indent d)
	   (if (= target -1) (print-string "-") (print target))
	   ;;(print-string " ") (print free)
	   (print-string " ")
	   (print-info)
	   (print-insn k0 d)
	   )
      _ -> #u
      ))
  (define (ps x) (print x) (print-string " "))
  (define (ps2 x) (print-string x) (print-string " "))
  (match insn with
    (insn:return target)	    -> (begin (newline) (indent d) (ps2 "- ret") (print target))
    (insn:tail n c a)		    -> (print-line (lambda () (ps2 "tail") (ps n) (ps c) (ps a)) (cont:nil))
    (insn:trcall d n args)	    -> (print-line (lambda () (ps2 "trcall") (ps d) (ps n) (ps args)) (cont:nil))
    (insn:literal lit k)	    -> (print-line (lambda () (ps2 "lit") (ps2 (literal->string lit))) k)
    (insn:cexp sig template args k) -> (print-line (lambda () (ps2 "cexp") (ps2 (type-repr sig)) (ps template) (ps args)) k)
    (insn:test reg then else k)	    -> (print-line (lambda () (ps2 "test") (print reg) (print-insn then (+ d 1)) (print-insn else (+ d 1))) k)
    (insn:testcexp r s t k0 k1 k)   -> (print-line (lambda () (ps2 "testcexp") (ps r) (ps2 (type-repr s)) (ps t) (print-insn k0 (+ d 1)) (print-insn k1 (+ d 1))) k)
    (insn:jump reg trg)		    -> (print-line (lambda () (ps2 "jmp") (print trg)) (cont:nil))
    (insn:close name body k)	    -> (print-line (lambda () (ps2 "close") (print name) (print-insn body (+ d 1))) k)
    (insn:varref d i k)		    -> (print-line (lambda () (ps2 "ref") (ps d) (ps i)) k)
    (insn:varset d i v k)	    -> (print-line (lambda () (ps2 "set") (ps d) (ps i) (ps v)) k)
    (insn:store o a t i k)	    -> (print-line (lambda () (ps2 "stor") (ps o) (ps a) (ps t) (ps i)) k)
    (insn:invoke n c a k)	    -> (print-line (lambda () (ps2 "invoke") (ps n) (ps c) (ps a)) k)
    (insn:new-env n k)		    -> (print-line (lambda () (ps2 "env") (ps n)) k)
    (insn:alloc tag size k)         -> (print-line (lambda () (ps2 "alloc") (ps tag) (ps size)) k)
    (insn:push r k)                 -> (print-line (lambda () (ps2 "push") (ps r)) k)
    (insn:pop r k)                  -> (print-line (lambda () (ps2 "pop") (ps r)) k)
    (insn:primop name p args k)     -> (print-line (lambda () (ps2 "primop") (ps name) (ps2 (repr p)) (ps args)) k)
    (insn:move var src k)           -> (print-line (lambda () (ps2 "move") (ps var) (ps src)) k)
    (insn:fatbar lab k0 k1 k)       -> (print-line (lambda () (ps2 "fatbar") (ps lab) (print-insn k0 (+ d 1)) (print-insn k1 (+ d 1))) k)
    (insn:fail lab npop)            -> (print-line (lambda () (ps2 "fail") (ps lab) (ps npop)) (cont:nil))
    (insn:nvcase tr dt formals alts ealt k)
    -> (print-line (lambda () (ps2 "nvcase")
			   (ps tr) (ps dt) (ps formals)
			   (for-each (lambda (insn) (print-insn insn (+ d 1))) alts)
			   (print-insn ealt (+ d 1)))
		   k)
    ))

(define (walk-insns p insn)
  (define (walk insn d)
    (p insn d)
    (let ((k
	   (match insn with
	     ;; no continuation
	     (insn:return target) -> (cont:nil)
	     (insn:tail _ _ _)	  -> (cont:nil)
	     (insn:trcall _ _ _)  -> (cont:nil)
	     (insn:jump _ _)	  -> (cont:nil)
	     (insn:fail _ _)      -> (cont:nil)
	     ;; these insns contain sub-bodies...
	     (insn:fatbar _ k0 k1 k)	     -> (begin (walk k0 (+ d 1)) (walk k1 (+ d 1)) k)
	     (insn:close _ body k)	     -> (begin (walk body (+ d 1)) k)
	     (insn:test _ then else k)	     -> (begin (walk then (+ d 1)) (walk else (+ d 1)) k)
	     (insn:testcexp _ _ _ k0 k1 k)   -> (begin (walk k0 (+ d 1)) (walk k1 (+ d 1)) k)
	     (insn:nvcase _ _ _ alts ealt k) -> (begin (for-each (lambda (x) (walk x (+ d 1))) alts)
						       (walk ealt (+ d 1)) k)
	     ;; ... the rest just have one continuation
	     (insn:literal _ k)	    -> k
	     (insn:cexp _ _ _ k)    -> k
	     (insn:varref _ _ k)    -> k
	     (insn:varset _ _ _ k)  -> k
	     (insn:store _ _ _ _ k) -> k
	     (insn:invoke _ _ _ k)  -> k
	     (insn:new-env _ k)	    -> k
	     (insn:alloc _ _ k)	    -> k
	     (insn:push _ k)        -> k
	     (insn:pop _ k)         -> k
	     (insn:primop _ _ _ k)  -> k
	     (insn:move _ _ k)      -> k
	     )))
      (match k with
	(cont:k target free insn) -> (walk insn d)
	(cont:nil) -> #u)))
  (walk insn 0))

(define done-insn (:pair (insn:return -1) -1))

(define (make-insn-generator insn)
  (make-generator
   (lambda (consumer)
     (walk-insns (lambda (insn depth) (consumer (:pair insn depth))) insn)
     (let loop ()
       (consumer done-insn)
       (loop)))))
   
(define (iterate-insns insn)
  (let ((g (make-insn-generator insn)))
    (for-range
	i 10
	(match (g) with
	  (:pair insn depth)
	  -> (begin (indent depth)
		    (printn (%%cexp ('a -> int) "get_case(%0)" insn)))))
    ))