Monad for lazy evaluation Andre van Tonder 21 Aug 2003 06:28 UTC

I have found a nice abstraction for the CPS-style approach to lazy
computation, which make the definitions very simple and transparent.

The abstraction relies on a monad M representing delayed computations that
  - hides the continuations that appear in the CPS solution.
  - makes it impossible to define a well-typed function M a -> b.
    Since functions of this type would be exactly those that returned
    instead of calling their continuations, we now have a uniform
    way of ensuring that we only write safe functions.

All the examples below run in constant space in MzScheme.  It would be
good to test in other Schemes.

The definition of a leak-free stream-filter now becomes very simple (see
below for full definitions):

(define (stream-filter p? s)
  (stream-match s
    [()      stream-nil]
    [(h . t) (if (p? h)
                 (stream-cons h (stream-filter p? t))
                 (stream-filter p? t))]))

Notice that now we can use stream constructors, and we do not have to
break the definition into two separate funtions as in the previous
CPS-style attempt.

The definition of (a basic) unfold is equally simple:

; unfold : (b -> ((M #f) | (stream-cons (a | 'drop) b)) b -> stream a

(define (unfold f seed)
  (letm ([res (f seed)])
    (cond [res (match res
                 [()      (return 'error)]
                 [(h . t) (if (eq? h 'drop)
                              (unfold f t)
                              (stream-cons h (unfold f t)))])]
          [else stream-nil])))

which is pretty transparent as well.  One point to mention here is
that typing forces us to use a monadic let, called letm here.

Defining a stream of integers is easy:

(define (integers-from n)
  (stream-cons n (integers-from (+ n 1))))

Lazy computations are invoked by "run" as follows: e.g.

(run (stream-car (integers-from 5)))    ; ===> 5

(run (times3 5))  ; ===> 15

The function "run" is only meant to be used at the toplevel.  Elsewhere it
is unsafe for space consumption (since it causes a function to return
instead of calling its continuation).

Alternative definitions of stream-filter are, for example, not using
stream-match:

(define (stream-filter1 p? s)
  (ifm (stream-null? s)
       stream-nil
       (letm ([h (stream-car s)])
         (if (p? h)
             (stream-cons h (stream-filter1 p? (stream-cdr s)))
             (stream-filter1 p? (stream-cdr s))))))

or in  terms of unfold:

(define (stream-filter3 p? s)
  (unfold (lambda (s)
            (stream-match s
              [()      (return #f)]
              [(h . t) (if (p? h)
                           (stream-cons h t)
                           (stream-cons 'drop t))]))
          s))

Regards
Andre

; Monad for lazy evaluation in Scheme:

; Copyright 2003 Andre van Tonder

; Principle:

; Based on a CPS-style memoizing delay, which appears to be safe for space
; as long as the entire computation stays in CPS style.
;
; This is enforced by defining a monad M for lazy evaluation which
;   - hides the continuation
;   - makes it impossible to define a well-typed function
;       M a -> a

; Since functions of type M a -> a would be exactly the functions that
; returned instead of calling their continuation, in this way we ensure
; that the computation stays CPS.

; The only unsafe function is run : M a -> b,
; which is meant to be called at top-level only.

;=======================================================================
; CPS-style delay

(define-syntax codelay
  (syntax-rules ()
    [(codelay thunk-cps)
     (let ([memo-pair (cons #f #f)])
       (lambda (k*)
         (if (car memo-pair)
             (k* (cdr memo-pair))
             (thunk-cps (make-memoizer memo-pair k*)))))]))

(define (make-memoizer memo-pair k)
  (lambda (x)
    (set-car! memo-pair #t)
    (set-cdr! memo-pair x)
    (k x)))

;========================================================================
; Lazy monad - similar to continuation monad.
; see Wadler: How to declare an imperative for the latter.

; We are defining (Monad M) where M = make-lazy.

; Constructor for elements of the monad:

(define-syntax make-lazy
  (syntax-rules ()
    [(make-lazy exp) (codelay (lambda (k) (k exp)))]))

; >>= : (M a) (a -> (M b)) -> (M b)

(define (>>= ma f:a->mb)
  (lambda (j)
    (ma (lambda (x)
          ((f:a->mb x) j)))))

; return : a -> (M a)

(define-syntax return
  (syntax-rules ()
    [(return exp) (make-lazy exp)]))

; Syntax allowing infix >>=:
; in-monad : (M a) (a -> M b) -> (M b)
;    etc.

(define-syntax in-monad
  (syntax-rules (>>=)
    [(in-monad exp1 >>= exp2 ...)
     (>>= exp1 (in-monad exp2 ...))]
    [(in-monad exp) exp]))

; The only unsafe operation.  Using run at other than
; toplevel may lead to space leaks...

; run : (M a) -> a

(define (run k) (k (lambda (x) x)))

; Monadic let:
;   initializer : (M a)
;   expression  : (M b)
;   letm        : (M b)

(define-syntax letm
  (syntax-rules ()
    [(letm ((name initializer)) expression)
     (>>= initializer (lambda (name) expression))]))

; Nested monadic let:

(define-syntax letm*
  (syntax-rules ()
    [(letm* () expr) expr]
    [(letm* (binding1 binding2 ...) expr)
     (letm (binding1) (letm* (binding2 ...) expr))]))

; Monadic if:
;   ifm (M a) (M b) (M c)

(define-syntax ifm
  (syntax-rules ()
    [(ifm cond branch1 branch2)
     (in-monad cond >>=
               (lambda (cond*) (if cond* branch1 branch2)))]))

;=========================================================================
; Convenient list deconstructor.

(define-syntax match
  (syntax-rules ()
    [(match lst
       [()      exp1]
       [(h . t) exp2])
     (cond [(null? lst) exp1]
           [(pair? lst) (let ([h (car lst)]
                              [t (cdr lst)])
                          exp2)]
           [else 'match-error])]))

;=========================================================================
; Stream primitives:

; define-type: Stream = M List

; stream-nil : Stream

(define stream-nil
  (return '()))

; stream-cons : a Stream -> Stream

(define-syntax stream-cons
  (syntax-rules ()
    [(stream-cons h t)
     (return (cons h t))]))

; Stream deconstructor where:
;   s    : Stream
;   exp1 : M b
;   exp3 : M c
;      h : a
;      t : Stream
; stream-match : M b | M c

(define-syntax stream-match
  (syntax-rules ()
    [(stream-match s
       [()      exp1]
       [(h . t) exp2])
     (in-monad s >>=
               (lambda (s*)
                 (match s*
                   [()      exp1]
                   [(h . t) exp2])))]))

; stream-car : Stream -> M b

(define (stream-car s)
  (stream-match s
    [()      (return 'error-stream-car)]
    [(h . t) (return h)]))

; stream-cdr : Stream -> Stream

(define (stream-cdr s)
  (stream-match s
    [()      (return 'error-stream-cdr)]
    [(h . t) t]))

; stream-null : Stream -> M Bool

(define (stream-null? s)
  (stream-match s
    [()      (return #t)]
    [(h . t) (return #f)]))

;=========================================================================
; Stream functions - LEAK-FREE:

; stream-filter : (a -> bool) Stream -> Stream

(define (stream-filter p? s)
  (stream-match s
    [()      stream-nil]
    [(h . t) (if (p? h)
                 (stream-cons h (stream-filter p? t))
                 (stream-filter p? t))]))

; unfold : (b -> ((M #f) | (stream-cons (a | 'drop) b)) b -> stream a

(define (unfold f seed)
  (letm ([res (f seed)])
    (cond [res (match res
                 [()      (return 'error)]
                 [(h . t) (if (eq? h 'drop)
                              (unfold f t)
                              (stream-cons h (unfold f t)))])]
          [else stream-nil])))

; integers-from : int -> Stream

(define (integers-from n)
  (stream-cons n (integers-from (+ n 1))))

; stream-ref : int Stream -> M a

(define (stream-ref index s)
  (stream-match s
    [()      (return (error 'stream-ref))]
    [(h . t) (if (zero? index)
                 (return h)
                 (stream-ref (- index 1) t))]))

;========================================================================
; TESTS

;-----------------------------------------------------------------------
; Check memoization

(define s (return (begin (display 'evaluating) 1)))

(run s)
(run s)
; ==> Should display 'evaluating only once

(define (lazy-add x y)
  (letm* ([x* x]
          [y* y])
         (return (+ x* y*))))

(run (let ([s (return (begin (display 'evaluating1) 1))])
       (lazy-add s s)))

; ==> Should display 'evaluating1 only once

;-----------------------------------------------------------------------
; Check for off-by-one error:

(run (stream-ref 0 (stream-filter
                    zero?
                    (integers-from 0))))

;==> should terminate giving 0

;--------------------------------------------------------------------
; Check for leak in stream-filter:

(define (times3 n)
  (stream-ref 3 (stream-filter
                 (lambda (x) (zero? (modulo x n)))
                 (integers-from 0))))

(run (times3 5))

;(times3 100000000)    ; ==> runs in constant space

;===================================================================
; Alternative definition testing stream primitives:

; stream-filter1 : (a -> bool) Stream -> Stream

(define (stream-filter1 p? s)
  (ifm (stream-null? s)
       stream-nil
       (letm ([h (stream-car s)])
         (if (p? h)
             (stream-cons h (stream-filter1 p? (stream-cdr s)))
             (stream-filter1 p? (stream-cdr s))))))

(define (times3-1 n)
  (stream-ref 3 (stream-filter1
                 (lambda (x) (zero? (modulo x n)))
                 (integers-from 0))))

(run (times3-1 10))

;(run (times3-1 100000000))  ;==> Runs in constant space

;-----------------------------------------------------------------
; Yet another definition testing stream-fold:

; stream-filter3 : (a -> bool) Stream -> Stream

(define (stream-filter3 p? s)
  (unfold (lambda (s)
            (stream-match s
              [()      (return #f)]
              [(h . t) (if (p? h)
                           (stream-cons h t)
                           (stream-cons 'drop t))]))
          s))

(define (times3-3 n)
  (stream-ref 3 (stream-filter3
                 (lambda (x) (zero? (modulo x n)))
                 (integers-from 0))))

(run (times3-3 17))

;(run (times3-3 100000000))   ; stream-fold runs in constant space