;; -*- Mode: Irken -*-

(require "lib/basis.scm")

;; Typed Tail CPS Transform for Irken
;; Based on Figure 8 of "Compiling with Continuations, Continued"
;; by Andrew Kennedy (ICFP 2007).
;;
;; The source language is a small ML-like fragment (close to what the
;; reader already produces as sexp trees).  The target is λ^T_CPS, the
;; double-barrelled typed CPS language of Figure 7.
;;
;; The transform has three mutually-recursive entry points mirroring
;; the paper exactly:
;;
;;   (cps-exp  e h kappa)   -- ⟦e⟧ h κ   non-tail position
;;   (cps-tail e h k)       -- ⟪e⟫ h k   tail position
;;   (cps-def  d)           -- ⟦d⟧        function definition
;;
;; h and k are always plain symbols (continuation variables).
;; kappa is a Irken procedure  symbol -> cps-term  (the meta-level context).

;; ──────────────────────────────────────────────────────────────────────
;; 1.  Target-language datatypes  (λ^T_CPS , Figure 7)
;; ──────────────────────────────────────────────────────────────────────

;; Value types
;;   τ ::= unit | exn | τ×σ | τ+σ | τ→σ
(datatype vtype
  (:unit)
  (:exn)
  (:prod vtype vtype)
  (:sum  vtype vtype)
  (:fun  vtype vtype)       ;; τ → σ  (handler type elided for clarity)
  )

;; CPS values  (CVal)
;;   V ::= () | (x,y) | in_i x | λk x.K
;;
;; NOTE: λk x.K appears only in letval for *anonymous* non-recursive
;; closures (the ⟪fn x => e⟫ case).  Named recursive functions are
;; introduced by letfun.
(datatype cps-val
  (:unit)
  (:pair symbol symbol)          ;; (x , y)
  (:inj  int    symbol)          ;; in_i x       i ∈ {0,1}  (0-based here)
  (:lam  symbol symbol cps-term) ;; λk x . K     (non-recursive closure)
  )

;; Function definitions  FunDef
;;   f k h x = K
(datatype fundef
  (:def symbol   ;; f  – function name
        symbol   ;; k  – return continuation
        symbol   ;; h  – handler continuation
        symbol   ;; x  – argument
        cps-term ;; K  – body
        ))

;; Continuation definitions  ContDef  (for letcont, possibly mutually recursive)
;;   k x = K
(datatype contdef
  (:def symbol   ;; k
        symbol   ;; x
        cps-term ;; K
        ))

;; CPS terms  (CTm)
;;   K ::= letval x = V in K
;;       | let x = π_i y in K
;;       | letcont C̄ in K          -- mutually recursive group
;;       | letfun  F̄ in K          -- mutually recursive group
;;       | k x                     -- continuation application (jump / return)
;;       | f k h x                 -- function application
;;       | case x of k1 | k2       -- binary sum dispatch
(datatype cps-term
  (:letval  symbol cps-val cps-term)               ;; letval x = V in K
  (:letproj symbol int symbol cps-term)            ;; let x = π_i y in K
  (:letcont (list contdef) cps-term)               ;; letcont {k x=K}* in L
  (:letfun  (list fundef)  cps-term)               ;; letfun  {f k h x=K}* in L
  (:appc    symbol symbol)                         ;; k x
  (:app     symbol symbol symbol symbol)           ;; f k h x
  (:case    symbol symbol symbol)                  ;; case x of k1 | k2
  )

;; ──────────────────────────────────────────────────────────────────────
;; 2.  Source language  (the ML fragment the reader hands us as sexps)
;;
;; We consume the sexp tree produced by read.scm.  The grammar:
;;
;;   e ::= x                              -- variable       (sexp:symbol)
;;       | (e e)                          -- application    (sexp:list (e e))
;;       | (fn x e)                       -- lambda         (sexp:list 'fn x e)
;;       | (tuple e e)                    -- pair           (sexp:list 'tuple e e)
;;       | (proj i e)                     -- projection     (sexp:list 'proj i e)
;;       | (unit)                         -- unit           (sexp:list 'unit)
;;       | (inj i e)                      -- injection      (sexp:list 'inj i e)
;;       | (let x e e)                    -- let-val        (sexp:list 'let x e e)
;;       | (letrec (f x e) e)             -- letrec         (sexp:list 'letrec ...)
;;       | (raise e)                      -- raise          (sexp:list 'raise e)
;;       | (handle e x e)                 -- handler        (sexp:list 'handle e x e)
;;       | (case e x1 e1 x2 e2)          -- binary case    (sexp:list 'case ...)
;;
;; ──────────────────────────────────────────────────────────────────────

;; ──────────────────────────────────────────────────────────────────────
;; 3.  Fresh name generation
;; ──────────────────────────────────────────────────────────────────────

;;(define cps-counter (make-ref 0))
(define cps-counter 0)

(define (fresh prefix)
  ;; Generate a fresh symbol like  k.3  or  x.17
;;  (let ((n (+ 1 (deref cps-counter))))
  (let ((n (+ 1 cps-counter)))
    (set! cps-counter n)
    (string->symbol (format prefix "." (int n)))))

;; Conveniences for the common cases in the transform
(define (fresh-k)  (fresh "k"))
(define (fresh-h)  (fresh "h"))
(define (fresh-x)  (fresh "x"))
(define (fresh-f)  (fresh "f"))
(define (fresh-j)  (fresh "j"))

;; ──────────────────────────────────────────────────────────────────────
;; 4.  Helper: build a singleton letcont / letfun
;; ──────────────────────────────────────────────────────────────────────

(define (letcont1 k x body rest)
  ;; letcont k x = body in rest
  (cps-term:letcont (list (contdef:def k x body)) rest))

(define (letfun1 f k h x body rest)
  ;; letfun f k h x = body in rest
  (cps-term:letfun (list (fundef:def f k h x body)) rest))

;; ──────────────────────────────────────────────────────────────────────
;; 5.  The CPS transform  (Figure 8)
;; ──────────────────────────────────────────────────────────────────────
;;
;; The paper uses two translation functions that are *mutually recursive*:
;;
;;   ⟦e⟧ h κ    non-tail: h is a cvar, κ : symbol -> cps-term
;;   ⟪e⟫ h k    tail:     h and k are both cvars
;;
;; We implement them as  (cps-exp e h kappa)  and  (cps-tail e h k).
;; They call each other freely, exactly as in the paper.

;; ── 5a.  Non-tail transform  ⟦e⟧ h κ ─────────────────────────────────

(define (pp-cps-src e)
  (match e with
    (sexp:symbol x _) -> (symbol->string x)
    (sexp:list args)  -> (format "(" (string-join (map pp-cps-src args) " ") ")")
    _                 -> "?"))


(define (cps-exp e h kappa)
;;  (printf "cps-exp:  " (pp-cps-src e) "\n")
  (match e with

    ;; ⟦x⟧ h κ  =  κ(x)
    (sexp:symbol x _)
    -> (kappa x)

    ;; ⟦()⟧ h κ  =  letval x = () in κ(x)
    (sexp:list ((sexp:symbol 'unit _)))
    -> (let ((x (fresh-x)))
         (cps-term:letval x (cps-val:unit) (kappa x)))

    ;; ⟦(e1 e2)⟧ h κ
    ;;   =  ⟦e1⟧ h (x1. ⟦e2⟧ h (x2. letcont k x = κ(x) in x1 k h x2))
    (sexp:list (e1 e2))
    -> (let ((k (fresh-k))
             (x3 (fresh-x))
             (kappa-x3 (kappa x3)))
         (cps-exp e1 h
           (lambda (x1)
             (cps-exp e2 h
               (lambda (x2)
                 (letcont1 k x3 kappa-x3
                           (cps-term:app x1 k h x2)))))))

    ;; ⟦fn x => e⟧ h κ
    ;;   =  letfun f k h' x = ⟪e⟫ h' k in κ(f)
    (sexp:list ((sexp:symbol 'fn _) (sexp:symbol x _) body))
    -> (let ((f  (fresh-f))
             (k  (fresh-k))
             (h2 (fresh-h)))
         (letfun1 f k h2 x (cps-tail body h2 k)
           (kappa f)))

    ;; ⟦(tuple e1 e2)⟧ h κ
    ;;   =  ⟦e1⟧ h (x1. ⟦e2⟧ h (x2. letval x = (x1,x2) in κ(x)))
    (sexp:list ((sexp:symbol 'tuple _) e1 e2))
    -> (cps-exp e1 h
         (lambda (x1)
           (cps-exp e2 h
             (lambda (x2)
               (let ((x (fresh-x)))
                 (cps-term:letval x (cps-val:pair x1 x2) (kappa x)))))))

    ;; ⟦(inj i e)⟧ h κ
    ;;   =  ⟦e⟧ h (z. letval x = in_i z in κ(x))
    (sexp:list ((sexp:symbol 'inj _) (sexp:int i) ei))
    -> (cps-exp ei h
         (lambda (z)
           (let ((x (fresh-x)))
             (cps-term:letval x (cps-val:inj i z) (kappa x)))))

    ;; ⟦(proj i e)⟧ h κ
    ;;   =  ⟦e⟧ h (z. let x = π_i z in κ(x))
    (sexp:list ((sexp:symbol 'proj _) (sexp:int i) ei))
    -> (cps-exp ei h
         (lambda (z)
           (let ((x (fresh-x)))
             (cps-term:letproj x i z (kappa x)))))

    ;; ⟦let x = e1 in e2⟧ h κ
    ;;   =  letcont j x = ⟦e2⟧ h κ in ⟪e1⟫ h j
    ;;
    ;; Note: the bound variable x from the source shadows in e2, so we
    ;; use the source symbol directly as the contdef parameter —
    ;; this keeps variable names legible.  (All binders are distinct
    ;; in the output since every letcont/letfun intro is at a unique
    ;; syntactic site and we α-rename as needed via fresh names at
    ;; call sites in the tail transform.)
    (sexp:list ((sexp:symbol 'let _) (sexp:symbol x _) e1 e2))
    -> (let ((j (fresh-j)))
         (letcont1 j x (cps-exp e2 h kappa)
           (cps-tail e1 h j)))

    ;; ⟦letrec (f x e_f) in e⟧ h κ
    ;;   =  letfun ⟦f x = e_f⟧ in ⟦e⟧ h κ
    (sexp:list ((sexp:symbol 'letrec _)
                (sexp:list ((sexp:symbol f _) (sexp:symbol xf _) ef))
                e))
    -> (cps-term:letfun
         (list (cps-def f xf ef))
         (cps-exp e h kappa))

    ;; ⟦raise e⟧ h κ  =  ⟦e⟧ h (z. h z)
    ;;
    ;; The result of κ is dead — raising never returns.  We still need
    ;; a well-typed term so we construct  letcont k x = κ(x) in h z ,
    ;; making the dead continuation explicit.  In practice the dead-cont
    ;; shrinking reduction will eliminate it immediately.
    (sexp:list ((sexp:symbol 'raise _) er))
    -> (cps-exp er h
         (lambda (z)
           (let ((k (fresh-k))
                 (x (fresh-x)))
             (letcont1 k x (kappa x)          ;; dead — raise never returns
               (cps-term:appc h z)))))

    ;; ⟦e1 handle x => e2⟧ h κ
    ;;   =  letcont j  x = κ(x) in
    ;;      letcont h' x = ⟪e2⟫ h j in
    ;;      ⟪e1⟫ h' j
    (sexp:list ((sexp:symbol 'handle _) e1 (sexp:symbol x _) e2))
    -> (let ((j  (fresh-j))
             (xj (fresh-x))
             (h2 (fresh-h)))
         (letcont1 j xj (kappa xj)
           (letcont1 h2 x (cps-tail e2 h j)
             (cps-tail e1 h2 j))))

    ;; ⟦case e of in1 x1 => e1 | in2 x2 => e2⟧ h κ
    ;;   =  ⟦e⟧ h (z.
    ;;        letcont j  x  = κ(x) in
    ;;        letcont k1 x1 = ⟪e1⟫ h j in
    ;;        letcont k2 x2 = ⟪e2⟫ h j in
    ;;        case z of k1 | k2)
    (sexp:list ((sexp:symbol 'case _) escr
                (sexp:symbol x1 _) e1
                (sexp:symbol x2 _) e2))
    -> (cps-exp escr h
         (lambda (z)
           (let ((j  (fresh-j))
                 (xj (fresh-x))
                 (k1 (fresh-k))
                 (k2 (fresh-k)))
             (letcont1 j xj (kappa xj)
               (letcont1 k1 x1 (cps-tail e1 h j)
                 (letcont1 k2 x2 (cps-tail e2 h j)
                   (cps-term:case z k1 k2)))))))

    _ -> (error1 "cps-exp: unrecognised expression" e)
    ))

;; ── 5b.  Tail transform  ⟪e⟫ h k ─────────────────────────────────────

(define (cps-tail e h k)
  ;;(print-string (format "cps-tail: " (pp-cps-src e) "\n"))
  (match e with

    ;; ⟪x⟫ h k  =  k x
    (sexp:symbol x _)
    -> (cps-term:appc k x)

    ;; ⟪()⟫ h k  =  letval x = () in k x
    (sexp:list ((sexp:symbol 'unit _)))
    -> (let ((x (fresh-x)))
         (cps-term:letval x (cps-val:unit) (cps-term:appc k x)))

    ;; ⟪e1 e2⟫ h k
    ;;   =  ⟦e1⟧ h (x1. ⟦e2⟧ h (x2. x1 k h x2))
    ;;
    ;; The key tail-call case: we pass k *directly* instead of wrapping
    ;; it in a fresh letcont.  This is what makes the transform "tail".
    (sexp:list (e1 e2))
    -> (cps-exp e1 h
         (lambda (x1)
           (cps-exp e2 h
             (lambda (x2)
               (cps-term:app x1 k h x2)))))

    ;; ⟪fn x => e⟫ h k
    ;;   =  letval f = λj x. ⟪e⟫ h j in k f
    ;;
    ;; Note: letval not letfun — non-recursive anonymous closure.
    (sexp:list ((sexp:symbol 'fn _) (sexp:symbol x _) body))
    -> (let ((f (fresh-f))
             (j (fresh-j)))
         (cps-term:letval f
           (cps-val:lam j x (cps-tail body h j))
           (cps-term:appc k f)))

    ;; ⟪(tuple e1 e2)⟫ h k
    ;;   =  ⟦e1⟧ h (x1. ⟦e2⟧ h (x2. letval x = (x1,x2) in k x))
    (sexp:list ((sexp:symbol 'tuple _) e1 e2))
    -> (cps-exp e1 h
         (lambda (x1)
           (cps-exp e2 h
             (lambda (x2)
               (let ((x (fresh-x)))
                 (cps-term:letval x (cps-val:pair x1 x2)
                   (cps-term:appc k x)))))))

    ;; ⟪inj i e⟫ h k
    ;;   =  ⟦e⟧ h (z. letval x = in_i z in k x)
    (sexp:list ((sexp:symbol 'inj _) (sexp:int i) ei))
    -> (cps-exp ei h
         (lambda (z)
           (let ((x (fresh-x)))
             (cps-term:letval x (cps-val:inj i z)
               (cps-term:appc k x)))))

    ;; ⟪proj i e⟫ h k
    ;;   =  ⟦e⟧ h (z. let x = π_i z in k x)
    (sexp:list ((sexp:symbol 'proj _) (sexp:int i) ei))
    -> (cps-exp ei h
         (lambda (z)
           (let ((x (fresh-x)))
             (cps-term:letproj x i z (cps-term:appc k x)))))

    ;; ⟪let x = e1 in e2⟫ h k
    ;;   =  letcont j x = ⟪e2⟫ h k in ⟪e1⟫ h j
    ;;
    ;; Both e2 and e1 are now in tail position w.r.t. k.
    (sexp:list ((sexp:symbol 'let _) (sexp:symbol x _) e1 e2))
    -> (let ((j (fresh-j)))
         (letcont1 j x (cps-tail e2 h k)
           (cps-tail e1 h j)))

    ;; ⟪letrec (f x e_f) in e⟫ h k
    ;;   =  letfun ⟦f x = e_f⟧ in ⟪e⟫ h k
    (sexp:list ((sexp:symbol 'letrec _)
                (sexp:list ((sexp:symbol f _) (sexp:symbol xf _) ef))
                e))
    -> (cps-term:letfun
         (list (cps-def f xf ef))
         (cps-tail e h k))

    ;; ⟪raise e⟫ h k  =  ⟦e⟧ h (z. h z)
    ;;
    ;; Same as non-tail raise: invoke the handler.  k is dead.
    (sexp:list ((sexp:symbol 'raise _) er))
    -> (cps-exp er h (lambda (z) (cps-term:appc h z)))

    ;; ⟪e1 handle x => e2⟫ h k
    ;;   =  letcont h' x = ⟪e2⟫ h k in ⟪e1⟫ h' k
    ;;
    ;; Both branches share the same k; h' catches exceptions from e1.
    (sexp:list ((sexp:symbol 'handle _) e1 (sexp:symbol x _) e2))
    -> (let ((h2 (fresh-h)))
         (letcont1 h2 x (cps-tail e2 h k)
           (cps-tail e1 h2 k)))

    ;; ⟪case e of in1 x1 => e1 | in2 x2 => e2⟫ h k
    ;;   =  ⟦e⟧ h (z.
    ;;        letcont k1 x1 = ⟪e1⟫ h k in
    ;;        letcont k2 x2 = ⟪e2⟫ h k in
    ;;        case z of k1 | k2)
    ;;
    ;; No join-point j needed — both branches already target k directly.
    (sexp:list ((sexp:symbol 'case _) escr
                (sexp:symbol x1 _) e1
                (sexp:symbol x2 _) e2))
    -> (cps-exp escr h
         (lambda (z)
           (let ((k1 (fresh-k))
                 (k2 (fresh-k)))
             (letcont1 k1 x1 (cps-tail e1 h k)
               (letcont1 k2 x2 (cps-tail e2 h k)
                 (cps-term:case z k1 k2))))))

    _ -> (error1 "cps-tail: unrecognised expression" e)
    ))

;; ── 5c.  Definition transform  ⟦d⟧ ───────────────────────────────────
;;
;; ⟦f x = e⟧  =  fundef:def f k h x (⟪e⟫ h k)
;;
;; Returns a fundef, not a cps-term.

(define (cps-def f xf ef)
  (let ((k (fresh-k))
        (h (fresh-h)))
    (fundef:def f k h xf (cps-tail ef h k))))

;; ── 5d.  Top-level entry point ────────────────────────────────────────
;;
;; Translate a complete source expression into a closed CPS program.
;; The top-level continuation is 'halt (accepting unit), matching the
;; (prog) rule in Figure 7.
;;
;; The top-level handler raises an uncaught-exception error.
;; Both are modelled as free cvars that the runtime will supply.

(define (cps-program e)
  ;; ⟦e⟧ top-h (x. halt x)
  ;; where top-h and halt are the distinguished top-level cvars.
  (cps-exp e 'top-h (lambda (x) (cps-term:appc 'halt x))))

;; ──────────────────────────────────────────────────────────────────────
;; 6.  Pretty-printer  (for debugging / inspection)
;; ──────────────────────────────────────────────────────────────────────

(define (pp-cps-val v)
  (match v with
    (cps-val:unit)        -> "()"
    (cps-val:pair x y)   -> (format "(" (sym x) "," (sym y) ")")
    (cps-val:inj i x)    -> (format "in" (int i) " " (sym x))
    (cps-val:lam k x K)  -> (format "λ" (sym k) " " (sym x) "." (pp-cps-term K))
    ))

(define (pp-contdef cd)
  (match cd with
    (contdef:def k x K)
    -> (format (sym k) " " (sym x) " = " (pp-cps-term K))
    ))

(define (pp-fundef fd)
  (match fd with
    (fundef:def f k h x K)
    -> (format (sym f) " " (sym k) " " (sym h) " " (sym x) " = " (pp-cps-term K))
    ))

(define (pp-cps-term t)
  (match t with
    (cps-term:letval x v rest)
    -> (format "letval " (sym x) " = " (pp-cps-val v) " in\n" (pp-cps-term rest))

    (cps-term:letproj x i y rest)
    -> (format "let " (sym x) " = π" (int i) " " (sym y) " in\n" (pp-cps-term rest))

    (cps-term:letcont cdefs rest)
    -> (format "letcont "
         (string-join (map pp-contdef cdefs) "\n    and ")
         "\nin " (pp-cps-term rest))

    (cps-term:letfun fdefs rest)
    -> (format "letfun "
         (string-join (map pp-fundef fdefs) "\n    and ")
         "\nin " (pp-cps-term rest))

    (cps-term:appc k x)
    -> (format (sym k) " " (sym x))

    (cps-term:app f k h x)
    -> (format (sym f) " " (sym k) " " (sym h) " " (sym x))

    (cps-term:case x k1 k2)
    -> (format "case " (sym x) " of " (sym k1) " | " (sym k2))
    ))

;; ──────────────────────────────────────────────────────────────────────
;; 7.  Example / smoke test
;; ──────────────────────────────────────────────────────────────────────
;;
;; Translate the classic identity function:
;;   (fn x x)
;;
;; Expected (modulo fresh names): ⟪fn x => x⟫ top-h halt
;;   Should produce something like:
;;     letval f.1 = λk.2 x.3 . k.2 x.3 in halt f.1
;;
;; Translate  (let f (fn x x) (f unit))
;;   (let f (fn x x) (f (unit)))
;; Expected tail output shows f applied with return cont passed directly.

(define (read-one-form-from-string s)
  (reader "<string>" (string-reader s)))

(define (run-example)
  (let ((v-src     (read-one-form-from-string "z"))
        (id-src    (read-one-form-from-string "(fn x x)"))
        (unit-src  (read-one-form-from-string "(unit)"))
        (app-src   (read-one-form-from-string "((fn x x) (unit))"))
        (let-src   (read-one-form-from-string "(let f (fn x x) (f (unit)))"))
        (raise-src (read-one-form-from-string "(raise (unit))"))
        (hdl-src   (read-one-form-from-string "(handle (raise (unit)) e e)")))
    (set! cps-counter 0)
    (print-string "=== variable ===\n")
    (print-string (pp-cps-term (cps-program (car v-src))))
    (print-string "\n\n")
    (set! cps-counter 0)
    (print-string "=== identity function ===\n")
    (print-string (pp-cps-term (cps-program (car id-src))))
    (print-string "\n\n")
    (set! cps-counter 0)
    (print-string "=== (unit) ===\n")
    (pp (car unit-src) 80)
    (print-string (pp-cps-term (cps-program (car unit-src))))
    (print-string "\n\n")
    (set! cps-counter 0)
    (print-string "=== application (non-tail) ===\n")
    (pp (car app-src) 80)
    (print-string (pp-cps-term (cps-program (car app-src))))
    (print-string "\n\n")
    (set! cps-counter 0)
    (print-string "=== let + application ===\n")
    (print-string (pp-cps-term (cps-program (car let-src))))
    (print-string "\n\n")
    (set! cps-counter 0)
    (print-string "=== raise ===\n")
    (print-string (pp-cps-term (cps-program (car raise-src))))
    (print-string "\n\n")
    (set! cps-counter 0)
    (print-string "=== handle ===\n")
    (print-string (pp-cps-term (cps-program (car hdl-src))))
    (print-string "\n")
    ))

(run-example)