;; -*- Mode: Irken -*- (require "lib/basis.scm") ;; Typed Tail CPS Transform for Irken ;; Based on Figure 8 of "Compiling with Continuations, Continued" ;; by Andrew Kennedy (ICFP 2007). ;; ;; The source language is a small ML-like fragment (close to what the ;; reader already produces as sexp trees). The target is λ^T_CPS, the ;; double-barrelled typed CPS language of Figure 7. ;; ;; The transform has three mutually-recursive entry points mirroring ;; the paper exactly: ;; ;; (cps-exp e h kappa) -- ⟦e⟧ h κ non-tail position ;; (cps-tail e h k) -- ⟪e⟫ h k tail position ;; (cps-def d) -- ⟦d⟧ function definition ;; ;; h and k are always plain symbols (continuation variables). ;; kappa is a Irken procedure symbol -> cps-term (the meta-level context). ;; ────────────────────────────────────────────────────────────────────── ;; 1. Target-language datatypes (λ^T_CPS , Figure 7) ;; ────────────────────────────────────────────────────────────────────── ;; Value types ;; τ ::= unit | exn | τ×σ | τ+σ | τ→σ (datatype vtype (:unit) (:exn) (:prod vtype vtype) (:sum vtype vtype) (:fun vtype vtype) ;; τ → σ (handler type elided for clarity) ) ;; CPS values (CVal) ;; V ::= () | (x,y) | in_i x | λk x.K ;; ;; NOTE: λk x.K appears only in letval for *anonymous* non-recursive ;; closures (the ⟪fn x => e⟫ case). Named recursive functions are ;; introduced by letfun. (datatype cps-val (:unit) (:pair symbol symbol) ;; (x , y) (:inj int symbol) ;; in_i x i ∈ {0,1} (0-based here) (:lam symbol symbol cps-term) ;; λk x . K (non-recursive closure) ) ;; Function definitions FunDef ;; f k h x = K (datatype fundef (:def symbol ;; f – function name symbol ;; k – return continuation symbol ;; h – handler continuation symbol ;; x – argument cps-term ;; K – body )) ;; Continuation definitions ContDef (for letcont, possibly mutually recursive) ;; k x = K (datatype contdef (:def symbol ;; k symbol ;; x cps-term ;; K )) ;; CPS terms (CTm) ;; K ::= letval x = V in K ;; | let x = π_i y in K ;; | letcont C̄ in K -- mutually recursive group ;; | letfun F̄ in K -- mutually recursive group ;; | k x -- continuation application (jump / return) ;; | f k h x -- function application ;; | case x of k1 | k2 -- binary sum dispatch (datatype cps-term (:letval symbol cps-val cps-term) ;; letval x = V in K (:letproj symbol int symbol cps-term) ;; let x = π_i y in K (:letcont (list contdef) cps-term) ;; letcont {k x=K}* in L (:letfun (list fundef) cps-term) ;; letfun {f k h x=K}* in L (:appc symbol symbol) ;; k x (:app symbol symbol symbol symbol) ;; f k h x (:case symbol symbol symbol) ;; case x of k1 | k2 ) ;; ────────────────────────────────────────────────────────────────────── ;; 2. Source language (the ML fragment the reader hands us as sexps) ;; ;; We consume the sexp tree produced by read.scm. The grammar: ;; ;; e ::= x -- variable (sexp:symbol) ;; | (e e) -- application (sexp:list (e e)) ;; | (fn x e) -- lambda (sexp:list 'fn x e) ;; | (tuple e e) -- pair (sexp:list 'tuple e e) ;; | (proj i e) -- projection (sexp:list 'proj i e) ;; | (unit) -- unit (sexp:list 'unit) ;; | (inj i e) -- injection (sexp:list 'inj i e) ;; | (let x e e) -- let-val (sexp:list 'let x e e) ;; | (letrec (f x e) e) -- letrec (sexp:list 'letrec ...) ;; | (raise e) -- raise (sexp:list 'raise e) ;; | (handle e x e) -- handler (sexp:list 'handle e x e) ;; | (case e x1 e1 x2 e2) -- binary case (sexp:list 'case ...) ;; ;; ────────────────────────────────────────────────────────────────────── ;; ────────────────────────────────────────────────────────────────────── ;; 3. Fresh name generation ;; ────────────────────────────────────────────────────────────────────── ;;(define cps-counter (make-ref 0)) (define cps-counter 0) (define (fresh prefix) ;; Generate a fresh symbol like k.3 or x.17 ;; (let ((n (+ 1 (deref cps-counter)))) (let ((n (+ 1 cps-counter))) (set! cps-counter n) (string->symbol (format prefix "." (int n))))) ;; Conveniences for the common cases in the transform (define (fresh-k) (fresh "k")) (define (fresh-h) (fresh "h")) (define (fresh-x) (fresh "x")) (define (fresh-f) (fresh "f")) (define (fresh-j) (fresh "j")) ;; ────────────────────────────────────────────────────────────────────── ;; 4. Helper: build a singleton letcont / letfun ;; ────────────────────────────────────────────────────────────────────── (define (letcont1 k x body rest) ;; letcont k x = body in rest (cps-term:letcont (list (contdef:def k x body)) rest)) (define (letfun1 f k h x body rest) ;; letfun f k h x = body in rest (cps-term:letfun (list (fundef:def f k h x body)) rest)) ;; ────────────────────────────────────────────────────────────────────── ;; 5. The CPS transform (Figure 8) ;; ────────────────────────────────────────────────────────────────────── ;; ;; The paper uses two translation functions that are *mutually recursive*: ;; ;; ⟦e⟧ h κ non-tail: h is a cvar, κ : symbol -> cps-term ;; ⟪e⟫ h k tail: h and k are both cvars ;; ;; We implement them as (cps-exp e h kappa) and (cps-tail e h k). ;; They call each other freely, exactly as in the paper. ;; ── 5a. Non-tail transform ⟦e⟧ h κ ───────────────────────────────── (define (pp-cps-src e) (match e with (sexp:symbol x _) -> (symbol->string x) (sexp:list args) -> (format "(" (string-join (map pp-cps-src args) " ") ")") _ -> "?")) (define (cps-exp e h kappa) ;; (printf "cps-exp: " (pp-cps-src e) "\n") (match e with ;; ⟦x⟧ h κ = κ(x) (sexp:symbol x _) -> (kappa x) ;; ⟦()⟧ h κ = letval x = () in κ(x) (sexp:list ((sexp:symbol 'unit _))) -> (let ((x (fresh-x))) (cps-term:letval x (cps-val:unit) (kappa x))) ;; ⟦(e1 e2)⟧ h κ ;; = ⟦e1⟧ h (x1. ⟦e2⟧ h (x2. letcont k x = κ(x) in x1 k h x2)) (sexp:list (e1 e2)) -> (let ((k (fresh-k)) (x3 (fresh-x)) (kappa-x3 (kappa x3))) (cps-exp e1 h (lambda (x1) (cps-exp e2 h (lambda (x2) (letcont1 k x3 kappa-x3 (cps-term:app x1 k h x2))))))) ;; ⟦fn x => e⟧ h κ ;; = letfun f k h' x = ⟪e⟫ h' k in κ(f) (sexp:list ((sexp:symbol 'fn _) (sexp:symbol x _) body)) -> (let ((f (fresh-f)) (k (fresh-k)) (h2 (fresh-h))) (letfun1 f k h2 x (cps-tail body h2 k) (kappa f))) ;; ⟦(tuple e1 e2)⟧ h κ ;; = ⟦e1⟧ h (x1. ⟦e2⟧ h (x2. letval x = (x1,x2) in κ(x))) (sexp:list ((sexp:symbol 'tuple _) e1 e2)) -> (cps-exp e1 h (lambda (x1) (cps-exp e2 h (lambda (x2) (let ((x (fresh-x))) (cps-term:letval x (cps-val:pair x1 x2) (kappa x))))))) ;; ⟦(inj i e)⟧ h κ ;; = ⟦e⟧ h (z. letval x = in_i z in κ(x)) (sexp:list ((sexp:symbol 'inj _) (sexp:int i) ei)) -> (cps-exp ei h (lambda (z) (let ((x (fresh-x))) (cps-term:letval x (cps-val:inj i z) (kappa x))))) ;; ⟦(proj i e)⟧ h κ ;; = ⟦e⟧ h (z. let x = π_i z in κ(x)) (sexp:list ((sexp:symbol 'proj _) (sexp:int i) ei)) -> (cps-exp ei h (lambda (z) (let ((x (fresh-x))) (cps-term:letproj x i z (kappa x))))) ;; ⟦let x = e1 in e2⟧ h κ ;; = letcont j x = ⟦e2⟧ h κ in ⟪e1⟫ h j ;; ;; Note: the bound variable x from the source shadows in e2, so we ;; use the source symbol directly as the contdef parameter — ;; this keeps variable names legible. (All binders are distinct ;; in the output since every letcont/letfun intro is at a unique ;; syntactic site and we α-rename as needed via fresh names at ;; call sites in the tail transform.) (sexp:list ((sexp:symbol 'let _) (sexp:symbol x _) e1 e2)) -> (let ((j (fresh-j))) (letcont1 j x (cps-exp e2 h kappa) (cps-tail e1 h j))) ;; ⟦letrec (f x e_f) in e⟧ h κ ;; = letfun ⟦f x = e_f⟧ in ⟦e⟧ h κ (sexp:list ((sexp:symbol 'letrec _) (sexp:list ((sexp:symbol f _) (sexp:symbol xf _) ef)) e)) -> (cps-term:letfun (list (cps-def f xf ef)) (cps-exp e h kappa)) ;; ⟦raise e⟧ h κ = ⟦e⟧ h (z. h z) ;; ;; The result of κ is dead — raising never returns. We still need ;; a well-typed term so we construct letcont k x = κ(x) in h z , ;; making the dead continuation explicit. In practice the dead-cont ;; shrinking reduction will eliminate it immediately. (sexp:list ((sexp:symbol 'raise _) er)) -> (cps-exp er h (lambda (z) (let ((k (fresh-k)) (x (fresh-x))) (letcont1 k x (kappa x) ;; dead — raise never returns (cps-term:appc h z))))) ;; ⟦e1 handle x => e2⟧ h κ ;; = letcont j x = κ(x) in ;; letcont h' x = ⟪e2⟫ h j in ;; ⟪e1⟫ h' j (sexp:list ((sexp:symbol 'handle _) e1 (sexp:symbol x _) e2)) -> (let ((j (fresh-j)) (xj (fresh-x)) (h2 (fresh-h))) (letcont1 j xj (kappa xj) (letcont1 h2 x (cps-tail e2 h j) (cps-tail e1 h2 j)))) ;; ⟦case e of in1 x1 => e1 | in2 x2 => e2⟧ h κ ;; = ⟦e⟧ h (z. ;; letcont j x = κ(x) in ;; letcont k1 x1 = ⟪e1⟫ h j in ;; letcont k2 x2 = ⟪e2⟫ h j in ;; case z of k1 | k2) (sexp:list ((sexp:symbol 'case _) escr (sexp:symbol x1 _) e1 (sexp:symbol x2 _) e2)) -> (cps-exp escr h (lambda (z) (let ((j (fresh-j)) (xj (fresh-x)) (k1 (fresh-k)) (k2 (fresh-k))) (letcont1 j xj (kappa xj) (letcont1 k1 x1 (cps-tail e1 h j) (letcont1 k2 x2 (cps-tail e2 h j) (cps-term:case z k1 k2))))))) _ -> (error1 "cps-exp: unrecognised expression" e) )) ;; ── 5b. Tail transform ⟪e⟫ h k ───────────────────────────────────── (define (cps-tail e h k) ;;(print-string (format "cps-tail: " (pp-cps-src e) "\n")) (match e with ;; ⟪x⟫ h k = k x (sexp:symbol x _) -> (cps-term:appc k x) ;; ⟪()⟫ h k = letval x = () in k x (sexp:list ((sexp:symbol 'unit _))) -> (let ((x (fresh-x))) (cps-term:letval x (cps-val:unit) (cps-term:appc k x))) ;; ⟪e1 e2⟫ h k ;; = ⟦e1⟧ h (x1. ⟦e2⟧ h (x2. x1 k h x2)) ;; ;; The key tail-call case: we pass k *directly* instead of wrapping ;; it in a fresh letcont. This is what makes the transform "tail". (sexp:list (e1 e2)) -> (cps-exp e1 h (lambda (x1) (cps-exp e2 h (lambda (x2) (cps-term:app x1 k h x2))))) ;; ⟪fn x => e⟫ h k ;; = letval f = λj x. ⟪e⟫ h j in k f ;; ;; Note: letval not letfun — non-recursive anonymous closure. (sexp:list ((sexp:symbol 'fn _) (sexp:symbol x _) body)) -> (let ((f (fresh-f)) (j (fresh-j))) (cps-term:letval f (cps-val:lam j x (cps-tail body h j)) (cps-term:appc k f))) ;; ⟪(tuple e1 e2)⟫ h k ;; = ⟦e1⟧ h (x1. ⟦e2⟧ h (x2. letval x = (x1,x2) in k x)) (sexp:list ((sexp:symbol 'tuple _) e1 e2)) -> (cps-exp e1 h (lambda (x1) (cps-exp e2 h (lambda (x2) (let ((x (fresh-x))) (cps-term:letval x (cps-val:pair x1 x2) (cps-term:appc k x))))))) ;; ⟪inj i e⟫ h k ;; = ⟦e⟧ h (z. letval x = in_i z in k x) (sexp:list ((sexp:symbol 'inj _) (sexp:int i) ei)) -> (cps-exp ei h (lambda (z) (let ((x (fresh-x))) (cps-term:letval x (cps-val:inj i z) (cps-term:appc k x))))) ;; ⟪proj i e⟫ h k ;; = ⟦e⟧ h (z. let x = π_i z in k x) (sexp:list ((sexp:symbol 'proj _) (sexp:int i) ei)) -> (cps-exp ei h (lambda (z) (let ((x (fresh-x))) (cps-term:letproj x i z (cps-term:appc k x))))) ;; ⟪let x = e1 in e2⟫ h k ;; = letcont j x = ⟪e2⟫ h k in ⟪e1⟫ h j ;; ;; Both e2 and e1 are now in tail position w.r.t. k. (sexp:list ((sexp:symbol 'let _) (sexp:symbol x _) e1 e2)) -> (let ((j (fresh-j))) (letcont1 j x (cps-tail e2 h k) (cps-tail e1 h j))) ;; ⟪letrec (f x e_f) in e⟫ h k ;; = letfun ⟦f x = e_f⟧ in ⟪e⟫ h k (sexp:list ((sexp:symbol 'letrec _) (sexp:list ((sexp:symbol f _) (sexp:symbol xf _) ef)) e)) -> (cps-term:letfun (list (cps-def f xf ef)) (cps-tail e h k)) ;; ⟪raise e⟫ h k = ⟦e⟧ h (z. h z) ;; ;; Same as non-tail raise: invoke the handler. k is dead. (sexp:list ((sexp:symbol 'raise _) er)) -> (cps-exp er h (lambda (z) (cps-term:appc h z))) ;; ⟪e1 handle x => e2⟫ h k ;; = letcont h' x = ⟪e2⟫ h k in ⟪e1⟫ h' k ;; ;; Both branches share the same k; h' catches exceptions from e1. (sexp:list ((sexp:symbol 'handle _) e1 (sexp:symbol x _) e2)) -> (let ((h2 (fresh-h))) (letcont1 h2 x (cps-tail e2 h k) (cps-tail e1 h2 k))) ;; ⟪case e of in1 x1 => e1 | in2 x2 => e2⟫ h k ;; = ⟦e⟧ h (z. ;; letcont k1 x1 = ⟪e1⟫ h k in ;; letcont k2 x2 = ⟪e2⟫ h k in ;; case z of k1 | k2) ;; ;; No join-point j needed — both branches already target k directly. (sexp:list ((sexp:symbol 'case _) escr (sexp:symbol x1 _) e1 (sexp:symbol x2 _) e2)) -> (cps-exp escr h (lambda (z) (let ((k1 (fresh-k)) (k2 (fresh-k))) (letcont1 k1 x1 (cps-tail e1 h k) (letcont1 k2 x2 (cps-tail e2 h k) (cps-term:case z k1 k2)))))) _ -> (error1 "cps-tail: unrecognised expression" e) )) ;; ── 5c. Definition transform ⟦d⟧ ─────────────────────────────────── ;; ;; ⟦f x = e⟧ = fundef:def f k h x (⟪e⟫ h k) ;; ;; Returns a fundef, not a cps-term. (define (cps-def f xf ef) (let ((k (fresh-k)) (h (fresh-h))) (fundef:def f k h xf (cps-tail ef h k)))) ;; ── 5d. Top-level entry point ──────────────────────────────────────── ;; ;; Translate a complete source expression into a closed CPS program. ;; The top-level continuation is 'halt (accepting unit), matching the ;; (prog) rule in Figure 7. ;; ;; The top-level handler raises an uncaught-exception error. ;; Both are modelled as free cvars that the runtime will supply. (define (cps-program e) ;; ⟦e⟧ top-h (x. halt x) ;; where top-h and halt are the distinguished top-level cvars. (cps-exp e 'top-h (lambda (x) (cps-term:appc 'halt x)))) ;; ────────────────────────────────────────────────────────────────────── ;; 6. Pretty-printer (for debugging / inspection) ;; ────────────────────────────────────────────────────────────────────── (define (pp-cps-val v) (match v with (cps-val:unit) -> "()" (cps-val:pair x y) -> (format "(" (sym x) "," (sym y) ")") (cps-val:inj i x) -> (format "in" (int i) " " (sym x)) (cps-val:lam k x K) -> (format "λ" (sym k) " " (sym x) "." (pp-cps-term K)) )) (define (pp-contdef cd) (match cd with (contdef:def k x K) -> (format (sym k) " " (sym x) " = " (pp-cps-term K)) )) (define (pp-fundef fd) (match fd with (fundef:def f k h x K) -> (format (sym f) " " (sym k) " " (sym h) " " (sym x) " = " (pp-cps-term K)) )) (define (pp-cps-term t) (match t with (cps-term:letval x v rest) -> (format "letval " (sym x) " = " (pp-cps-val v) " in\n" (pp-cps-term rest)) (cps-term:letproj x i y rest) -> (format "let " (sym x) " = π" (int i) " " (sym y) " in\n" (pp-cps-term rest)) (cps-term:letcont cdefs rest) -> (format "letcont " (string-join (map pp-contdef cdefs) "\n and ") "\nin " (pp-cps-term rest)) (cps-term:letfun fdefs rest) -> (format "letfun " (string-join (map pp-fundef fdefs) "\n and ") "\nin " (pp-cps-term rest)) (cps-term:appc k x) -> (format (sym k) " " (sym x)) (cps-term:app f k h x) -> (format (sym f) " " (sym k) " " (sym h) " " (sym x)) (cps-term:case x k1 k2) -> (format "case " (sym x) " of " (sym k1) " | " (sym k2)) )) ;; ────────────────────────────────────────────────────────────────────── ;; 7. Example / smoke test ;; ────────────────────────────────────────────────────────────────────── ;; ;; Translate the classic identity function: ;; (fn x x) ;; ;; Expected (modulo fresh names): ⟪fn x => x⟫ top-h halt ;; Should produce something like: ;; letval f.1 = λk.2 x.3 . k.2 x.3 in halt f.1 ;; ;; Translate (let f (fn x x) (f unit)) ;; (let f (fn x x) (f (unit))) ;; Expected tail output shows f applied with return cont passed directly. (define (read-one-form-from-string s) (reader "<string>" (string-reader s))) (define (run-example) (let ((v-src (read-one-form-from-string "z")) (id-src (read-one-form-from-string "(fn x x)")) (unit-src (read-one-form-from-string "(unit)")) (app-src (read-one-form-from-string "((fn x x) (unit))")) (let-src (read-one-form-from-string "(let f (fn x x) (f (unit)))")) (raise-src (read-one-form-from-string "(raise (unit))")) (hdl-src (read-one-form-from-string "(handle (raise (unit)) e e)"))) (set! cps-counter 0) (print-string "=== variable ===\n") (print-string (pp-cps-term (cps-program (car v-src)))) (print-string "\n\n") (set! cps-counter 0) (print-string "=== identity function ===\n") (print-string (pp-cps-term (cps-program (car id-src)))) (print-string "\n\n") (set! cps-counter 0) (print-string "=== (unit) ===\n") (pp (car unit-src) 80) (print-string (pp-cps-term (cps-program (car unit-src)))) (print-string "\n\n") (set! cps-counter 0) (print-string "=== application (non-tail) ===\n") (pp (car app-src) 80) (print-string (pp-cps-term (cps-program (car app-src)))) (print-string "\n\n") (set! cps-counter 0) (print-string "=== let + application ===\n") (print-string (pp-cps-term (cps-program (car let-src)))) (print-string "\n\n") (set! cps-counter 0) (print-string "=== raise ===\n") (print-string (pp-cps-term (cps-program (car raise-src)))) (print-string "\n\n") (set! cps-counter 0) (print-string "=== handle ===\n") (print-string (pp-cps-term (cps-program (car hdl-src)))) (print-string "\n") )) (run-example)