Levin Tree Search with Context Models
(require lts-cm) | package: levintreesearch_cm |
See the README for introductory examples for the LTS+CM algorithm on several domains such as Rubik’s cube and Sokooban.
1 Context Models
In Levin Tree Search with Context Models (LTS+CM), context models are crucial for guiding the search. Contexts are pieces of information extracted from the environment or state, which are then used to predict which actions to take (the policy). Contexts are grouped into mutually exclusive sets called mutex sets. The number of mutex sets is assumed fixed during the search, but the number of contexts per mutex sets does not need to be known.
For efficiency though, each context is encoded into a Racket fixnum —
Isn’t encoding into a fixnum restrictive? Not really. Having too many contexts per mutex sets can be detrimental to learning. For example, suppose that we have N datapoints and that each mutex sets has C contexts. Due to the mutual exclusion nature of mutex sets, each datapoint is associated with exactly one context per mutex set. Hence, each context may receive about N/C datapoints. As a crude rule of thumb, each context may need a few dozens of datapoints for their associated β parameters to learn good values. In practice, context occurrences within a mutex set are more likely to follow a Zipf distribution, which changes the argument a little, but the main idea still stands. Thus, if C is large, N must be large too.
Example: Relative Tiling with board-relative-tiling/collect
Tiling is a common strategy for generating contexts, especially in grid-based environments. The lts-cm/byte-board module provides dedicated functions for this. The following example demonstrates how to use board-relative-tiling/collect to extract contexts from a small board.
> (define board-rows '("XXXXX" "X P X" "X GX" "XXXXX")) > (define n-cols (string-length (first board-rows))) ; = 4 ; Let's convert the string into a board of integers:
> (define board-as-list (for*/list ([row (in-list board-rows)] [chr (in-string row)]) (case chr [(#\X) 0] [(#\space) 1] [(#\P) 2] [(#\G) 3] [else (error "unknown character" chr)]))) > (define a-board (list->board board-as-list n-cols))
> (define max-val 3) ; 0: wall, 1: empty, 2: player, 3: gold > (define pad-val 0) ; = wall ; Let's find the player P=2: > (define-values (p-row p-col) (board-find a-board 2)) ; = (1, 2)
> (board-relative-tiling/collect a-board #:collect! displayln #:row p-row #:col p-col #:max-value max-val #:pad-value pad-val #:row-dist 1 #:col-dist 1 #:row-span 2 #:col-span 2)
6
9
101
151
Tile 1 Tile 2 Tile 3 Tile 4 |
┌──┐ ┌──┐ ┌──┐ ┌──┐ |
│XX│ │XX│ │ P│ │P │ |
│ P│ │P │ │ │ │ G│ |
└──┘ └──┘ └──┘ └──┘ |
fixnum of tile 1 = (((wall × size) + wall) × size + empty) × size + player |
= (((0 × 4) + 0) × 4 + 1) × 4 + 2 |
= 6 |
> (define wall 0) > (define empty-cell 1) > (define player 2) > (define gold 3) > (define size 4) ; Tile 1: > (naturals->fixnum* [wall size] [wall size] [empty-cell size] [player size]) 6
; Tile 2: > (naturals->fixnum* [wall size] [wall size] [player size] [empty-cell size]) 9
Note that encodings are local to each mutex set, that is, even if two active contexts of different tiles have the same fixnums, they will map to different rows in the βmatrix.
2 Collectors
(require lts-cm/collector) | package: levintreesearch_cm |
> (define collect! (make-list-collector))
> (board-relative-tiling/collect a-board #:collect! collect! #:row p-row #:col p-col #:max-value max-val #:pad-value pad-val #:row-dist 1 #:col-dist 1 #:row-span 2 #:col-span 2) ; Obtain the collected fixnums: > (collect!) '(6 9 101 151)
procedure
(make-list-collector) → (-> procedure?)
> (define collect! (make-list-collector)) > (collect! 'a) > (collect! 1) > (collect! '(x y z)) ; Obtain the collected elements: > (collect!) '(a 1 (x y z))
procedure
(make-fxvector-collector vec) → procedure?
vec : fxvector?
> (define vec (make-fxvector 3)) > (define collect! (make-fxvector-collector vec)) > (collect! 1) > (collect! 2) > (collect! 3) ; Obtain the collected elements: > vec (fxvector 1 2 3)
procedure
> (define collect! (make-fxvector-collector/auto)) > (collect! 1) > (collect! 2) > (collect! 3) ; Obtain the collected elements: > (collect!) (fxvector 1 2 3)
3 Fixnum Encoding
(require lts-cm/encode) | package: levintreesearch_cm |
This module provides utilities for encoding lists of natural numbers into a single fixnum, and vice-versa. This is essential for creating compact representations that can be used as keys in hash tables or for other efficient processing. Technically, the encoding scheme is akin to representing a number in a mixed radix system, where each position can have a different base (size).
procedure
(naturals->fixnum ints sizes [n]) → fixnum?
ints : (listof natural?) sizes : (listof exact-positive-integer?) n : fixnum? = 0
Each integer i from ints must be less than its corresponding s in sizes (i.e., 0 <= i < s). The function folds from left to right, effectively computing (((n * size_0 + int_0) * size_1 + int_1) * ...).
> (naturals->fixnum '(0 2 1 2 2 0 1 0) '(2 3 4 5 6 7 8 9)) 143145
> (define base-code (naturals->fixnum '(0 2 1 2 2) '(2 3 4 5 6))) > (naturals->fixnum '(0 1 0) '(7 8 9) base-code) 143145
procedure
(fixnum->naturals n-orig sizes [ #:remainder remainder])
→
(if (eq? remainder #true) (values (listof natural?) fixnum?) (listof natural?)) n-orig : fixnum? sizes : (listof exact-positive-integer?) remainder : (or/c #t #f 'check-0 'cons) = 'check-0
'check-0 (default): Raises an error if the remainder is not zero, ensuring the fixnum is fully decoded by the given sizes.
#f: The remainder is discarded.
#t: Returns two values: the list of decoded naturals and the remainder.
'cons: The remainder is cons’ed onto the beginning of the resulting list of naturals.
> (fixnum->naturals 143145 '(2 3 4 5 6 7 8 9)) '(0 2 1 2 2 0 1 0)
> (fixnum->naturals 134145 '(2 3 4) #:remainder #t)
'(0 2 1)
5589
> (fixnum->naturals 38 '(7) #:remainder #f) '(3)
> (fixnum->naturals 38 '(7) #:remainder 'cons) '(5 3)
> (fixnum->naturals 23 '(2 2 2 2 2)) '(1 0 1 1 1)
syntax
(naturals->fixnum* [n fixnum? 0] [[val natint?] [size posint?]] ...+)
> (naturals->fixnum* [0 2] [2 3] [7 12] [6 10]) 316
; Starting with a base value: > (define n0 (naturals->fixnum* [0 2] [2 3])) > (naturals->fixnum* n0 [7 12] [6 10]) 316
4 Byte Board Utilities
(require lts-cm/byte-board) | package: levintreesearch_cm |
This module provides utilities for working with 2D boards represented by byte strings A board is a structure holding a flat byte string along with its dimensions.
struct
vec : bytes? n-rows : exact-positive-integer? n-cols : exact-positive-integer?
procedure
(list->board lst n-cols) → board?
lst : (listof byte?) n-cols : exact-positive-integer?
procedure
(board-find aboard x) →
(or/c false/c exact-integer?) (or/c false/c exact-integer?) aboard : board? x : byte?
procedure
(board->string aboard) → string?
aboard : board?
> (define brd (list->board (range 9) 3)) > (displayln (board->string brd))
┌─┬─┬─┐
│0│1│2│
├─┼─┼─┤
│3│4│5│
├─┼─┼─┤
│6│7│8│
└─┴─┴─┘
procedure
(board->list aboard) → (listof byte?)
aboard : board?
procedure
(board-in-bounds? brd row col) → boolean?
brd : board? row : exact-integer? col : exact-integer?
procedure
(board-set! aboard row col val) → void?
aboard : board? row : exact-integer? col : exact-integer? val : byte?
syntax
(board-index aboard row col)
procedure
(board-copy brd) → board?
brd : board?
procedure
(board->bytes aboard) → bytes?
aboard : board?
procedure
(board-relative-tiling/collect brd #:collect! collect! #:row row0 #:col col0 [ #:max-value max-value #:pad-value pad-value #:row-dist row-dist #:col-dist col-dist #:row-span row-span #:col-span col-span]) → void? brd : board? collect! : (-> fixnum? any/c) row0 : exact-integer? col0 : exact-integer? max-value : byte? = 255 pad-value : byte? = max-value row-dist : exact-positive-integer? = 1 col-dist : exact-positive-integer? = row-dist row-span : exact-positive-integer? = 2 col-span : exact-positive-integer? = row-span
The number of tiles (mutex sets) generated by such a tiling is (row-dist × 2 + 1 - row-span) × (col-dist × 2 + 1 - col-span).
For each tile, the cells of the tile are encoded into a single fixnum using naturals->fixnum with size = (+ max-value 1). If a cell of a tile is outside the boundaries of the board, the pad-value is used in place of the cell’s value. The resulting code is passed to collect!.
5 Line search for convex minimization
(require lts-cm/delta-secant) | package: levintreesearch_cm |
This module implements the Δ-Secant line search algorithm for the paper “Line Search for Convex Minimization”.
The function convex-line-search returns the lowest point found of a given convex function between two initial points when a stopping criterion is satisfied.
The function quasi-exact-line-search build upon convex-line-search to ensure sufficient progress is made, and is intended to be used within an optimization algorithm such as gradient descent or Frank-Wolfe.
procedure
(convex-line-search f xleft xright [ #:yleft yleft #:xq xq #:yq yq #:y-tolerance real? #:stop-when stop-when #:callback callback]) → dict? f : (-> real? real?) xleft : real? xright : real? yleft : real? = (f xleft) xq : real? = (* 0.5 (+ xleft xright)) yq : real? = (f xq) real? : y-tolerance = 1e-10
stop-when : (-> dict? any/c) = (λ (dic) (<= (dict-ref dic 'ygap) y-tolerance)) callback : (-> dict? any/c) = (λ (dic) (void))
'iter: Number of iterations performed.
'xlow and 'ylow: lowest point found — usually these are the quantities of interest.
'xgap and 'ygap: upper bounds on |xlow - x*| and |ylow - x*|.
'x- and 'x+: x-interval containing x*.
'ya and 'yb: The minimum of these two values is a lower bound on y*.
'pts: The 5 points around x*. See paper.
The arguments yleft and yq MUST be equal to (f xleft) and (f xq).
The argument xq is the first point within [xleft, xright] to be queried.
The argument stop-when controls when the algorithm should terminate. By default, it terminates when the y-distance to the minimum ('ygap) is provably less than y-tolerance.
The argument callback can be used to monitor the progress of the line search.
> (convex-line-search (λ (x) (sqr (- x 1))) -2 5)
(list
'(iter . 23)
(list
'pts
(pt 0.9999892887337545 1.1473122458246348e-10)
(pt 0.9999966814872121 1.1012527123440112e-11)
(pt 1.0000016825306512 2.8309093923890705e-12)
(pt 1.0000135688568499 1.8411387621275733e-10)
(pt 1.000055351641258 3.063804189957421e-9))
'(x- . 0.9999972646480525)
'(x+ . 1.0000109385368174)
'(xgap . 1.3673888764942355e-5)
'(ya . -2.9452989812498634e-11)
'(yb . -1.1960640832144743e-11)
'(ygap . 3.2283899204887706e-11)
'(xlow . 1.0000016825306512)
'(ylow . 2.8309093923890705e-12))
> (define (keep-keys dic keys) (filter (λ (l) (memq (car l) keys)) dic))
> (keep-keys (convex-line-search (λ (x) (sqr (- x 1))) -2 5 #:y-tolerance 0.01) '(iter xlow ylow y-gap)) '((iter . 7) (xlow . 1.0095288532116586) (ylow . 9.079904352933715e-5))
> (keep-keys (convex-line-search (λ (x) (max (sqr (- x 1)) (sqr (+ x 1)))) -2 5) '(iter xlow ylow xgap ygap))
'((iter . 15)
(xgap . 1.5371377058688084e-11)
(ygap . 9.722889160457271e-12)
(xlow . -2.796517824676535e-12)
(ylow . 1.0000000000055929))
procedure
(quasi-exact-line-search f [ xleft xright #:yleft yleft #:xq xq #:yq yq #:jac^2 jac^2 #:c c #:callback callback]) → dict? f : (-> real? real?) xleft : real? = 0.0 xright : real? = 1.0 yleft : real? = (f xleft) xq : real? = (* 0.5 (+ xleft xright)) yq : real? = (f xq) jac^2 : (or/c #f positive-real?) = #f c : positive-real? = 1.0 callback : (-> dict? any/c) = (λ (dic) (void))
Moreover, by contrast to convex-line-search, if the minimum is found to be at xright, the range [xleft, xright] is quadrupled to the right and the line search continues, and so on. This means that for example the call (quasi-exact-line-search / 1 2) loops forever. To prevent this quadrupling behaviour, one can force the function f to be increasing at xright, for eaxmple with (λ (x) (if (< x 2) (/ x) +inf.0)) instead of /.
The argument jac^2, if provided, should be the squared 2-norm of the jacobian (aka the gradient or derivative) at xleft. This information may be used to speed up the search.
See convex-line-search for the description of the returned dictionary, and of the other arguments.
> (for/list ([c '(1 10 100)]) (keep-keys (quasi-exact-line-search (λ (x) (sqr (- x 1))) -2 5 #:c c) '(iter xlow ylow)))
'(((iter . 2) (xlow . 1.47265625) (ylow . 0.2234039306640625))
((iter . 4) (xlow . 0.7788681457319648) (ylow . 0.04889929697201954))
((iter . 6) (xlow . 1.0095288532116586) (ylow . 9.079904352933715e-5)))