Levin Tree Search with Context Models
1 Context Models
2 Collectors
make-list-collector
make-fxvector-collector
make-fxvector-collector/  auto
3 Fixnum Encoding
naturals->fixnum
fixnum->naturals
naturals->fixnum*
4 Byte Board Utilities
board
list->board
board-find
board->string
board->list
board-in-bounds?
board-set!
board-index
board-copy
board->bytes
board-relative-tiling/  collect
5 Line search for convex minimization
convex-line-search
quasi-exact-line-search
pt
ptg
8.17.0.4

Levin Tree Search with Context Models🔗ℹ

Laurent Orseau

 (require lts-cm) package: levintreesearch_cm

See the README for introductory examples for the LTS+CM algorithm on several domains such as Rubik’s cube and Sokooban.

1 Context Models🔗ℹ

In Levin Tree Search with Context Models (LTS+CM), context models are crucial for guiding the search. Contexts are pieces of information extracted from the environment or state, which are then used to predict which actions to take (the policy). Contexts are grouped into mutually exclusive sets called mutex sets. The number of mutex sets is assumed fixed during the search, but the number of contexts per mutex sets does not need to be known.

For efficiency though, each context is encoded into a Racket fixnum — let’s call this the context’s fixnum. The learnable parameters β of the context models are stored in a Context DataBase (CDB), as a matrix where each row is a context fixnum and each column is an action. Each mutex set is associated with a hash table where the key is the context fixnum, and the value is the row in the βmatrix.

Isn’t encoding into a fixnum restrictive? Not really. Having too many contexts per mutex sets can be detrimental to learning. For example, suppose that we have N datapoints and that each mutex sets has C contexts. Due to the mutual exclusion nature of mutex sets, each datapoint is associated with exactly one context per mutex set. Hence, each context may receive about N/C datapoints. As a crude rule of thumb, each context may need a few dozens of datapoints for their associated β parameters to learn good values. In practice, context occurrences within a mutex set are more likely to follow a Zipf distribution, which changes the argument a little, but the main idea still stands. Thus, if C is large, N must be large too.

Example: Relative Tiling with board-relative-tiling/collect

Tiling is a common strategy for generating contexts, especially in grid-based environments. The lts-cm/byte-board module provides dedicated functions for this. The following example demonstrates how to use board-relative-tiling/collect to extract contexts from a small board.

Consider a very simple game where the player ’P’ must get the gold ’G’ on a 2D grid. The grid has 4 rows and 5 columns and is surrounded by walls (’X’). Each cell can thus take 4 different values (including empty ’ ’).

Examples:
> (define board-rows
    '("XXXXX"
      "X P X"
      "X  GX"
      "XXXXX"))
> (define n-cols (string-length (first board-rows))) ; = 4
; Let's convert the string into a board of integers:
> (define board-as-list
    (for*/list ([row (in-list board-rows)] [chr (in-string row)])
       (case chr
         [(#\X)     0]
         [(#\space) 1]
         [(#\P)     2]
         [(#\G)     3]
         [else (error "unknown character" chr)])))
> (define a-board (list->board board-as-list n-cols))

Let’s use board-relative-tiling/collect to extract contexts for a 2x2 tiling scheme centered around the player at (p-row, p-col), with a look-around distance of 1. This call uses #:collect! writeln, causing it to print the 4 context fixnums — one per mutex set, that is, one per tile of size 2x2:
> (define max-val 3) ; 0: wall, 1: empty, 2: player, 3: gold
> (define pad-val 0) ; = wall
; Let's find the player P=2:
> (define-values (p-row p-col) (board-find a-board 2)) ;  = (1, 2)
> (board-relative-tiling/collect a-board
                                 #:collect! displayln
                                 #:row p-row #:col p-col
                                 #:max-value max-val #:pad-value pad-val
                                 #:row-dist 1 #:col-dist 1
                                 #:row-span 2 #:col-span 2)

6

9

101

151

This prints the 4 fixnums of the active contexts of each tile (mutex set):

Tile 1  Tile 2  Tile 3  Tile 4

┌──┐    ┌──┐    ┌──┐    ┌──┐

│XX│    │XX│    │ P│    │P │

│ P│    │P │          │ G│

└──┘    └──┘    └──┘    └──┘

Let’s see how to calculate these fixnums by hand. If size = max-val + 1, the fixnum of the context of the first tile is

fixnum of tile 1 = (((wall × size) + wall) × size + empty) × size + player

                 = (((0 × 4) + 0) × 4 + 1) × 4 + 2

                 = 6

The other context fixnums are calculated similarly. The same calculation can be done more simply using naturals->fixnum*:
> (define wall 0)
> (define empty-cell 1)
> (define player 2)
> (define gold 3)
> (define size 4)
; Tile 1:
> (naturals->fixnum* [wall size] [wall size] [empty-cell size] [player size])

6

; Tile 2:
> (naturals->fixnum* [wall size] [wall size] [player size] [empty-cell size])

9

Note that encodings are local to each mutex set, that is, even if two active contexts of different tiles have the same fixnums, they will map to different rows in the βmatrix.

2 Collectors🔗ℹ

 (require lts-cm/collector) package: levintreesearch_cm

Instead of #:collect! displayln, LTS+CM uses a collector that is typically passed as argument to the domain’s custom collect-context, for example:
> (define collect! (make-list-collector))
> (board-relative-tiling/collect a-board
                                 #:collect! collect!
                                 #:row p-row #:col p-col
                                 #:max-value max-val #:pad-value pad-val
                                 #:row-dist 1 #:col-dist 1
                                 #:row-span 2 #:col-span 2)
; Obtain the collected fixnums:
> (collect!)

'(6 9 101 151)

procedure

(make-list-collector)  (-> procedure?)

Creates a collector that gathers items into a list.

Examples:
> (define collect! (make-list-collector))
> (collect! 'a)
> (collect! 1)
> (collect! '(x y z))
; Obtain the collected elements:
> (collect!)

'(a 1 (x y z))

procedure

(make-fxvector-collector vec)  procedure?

  vec : fxvector?
Creates a collector designed to populate a pre-allocated fxvector to collect fixnums.

Examples:
> (define vec (make-fxvector 3))
> (define collect! (make-fxvector-collector vec))
> (collect! 1)
> (collect! 2)
> (collect! 3)
; Obtain the collected elements:
> vec

(fxvector 1 2 3)

Similar to make-fxvector-collector but for a fxvector of initially unknown size.

Examples:
> (define collect! (make-fxvector-collector/auto))
> (collect! 1)
> (collect! 2)
> (collect! 3)
; Obtain the collected elements:
> (collect!)

(fxvector 1 2 3)

3 Fixnum Encoding🔗ℹ

 (require lts-cm/encode) package: levintreesearch_cm

This module provides utilities for encoding lists of natural numbers into a single fixnum, and vice-versa. This is essential for creating compact representations that can be used as keys in hash tables or for other efficient processing. Technically, the encoding scheme is akin to representing a number in a mixed radix system, where each position can have a different base (size).

procedure

(naturals->fixnum ints sizes [n])  fixnum?

  ints : (listof natural?)
  sizes : (listof exact-positive-integer?)
  n : fixnum? = 0
Encodes a list of natural numbers, ints, into a single fixnum. The sizes list specifies the maximum value (plus one) for the corresponding integer in ints. The encoding is performed sequentially, and an optional initial fixnum n can be provided to chain encodings.

Each integer i from ints must be less than its corresponding s in sizes (i.e., 0 <= i < s). The function folds from left to right, effectively computing (((n * size_0 + int_0) * size_1 + int_1) * ...).

Examples:
> (naturals->fixnum '(0 2 1 2 2 0 1 0) '(2 3 4 5 6 7 8 9))

143145

> (define base-code (naturals->fixnum '(0 2 1 2 2) '(2 3 4 5 6)))
> (naturals->fixnum '(0 1 0) '(7 8 9) base-code)

143145

procedure

(fixnum->naturals n-orig 
  sizes 
  [#:remainder remainder]) 
  
(if (eq? remainder #true)
  (values (listof natural?) fixnum?)
  (listof natural?))
  n-orig : fixnum?
  sizes : (listof exact-positive-integer?)
  remainder : (or/c #t #f 'check-0 'cons) = 'check-0
Decodes a fixnum, n-orig, back into a list of natural numbers, given the list of sizes used for encoding. The sizes list is processed in reverse order for decoding, corresponding to how naturals->fixnum performs the encoding.

The remainder argument controls how any remaining value of n-orig after decoding with the given sizes is handled:
  • 'check-0 (default): Raises an error if the remainder is not zero, ensuring the fixnum is fully decoded by the given sizes.

  • #f: The remainder is discarded.

  • #t: Returns two values: the list of decoded naturals and the remainder.

  • 'cons: The remainder is cons’ed onto the beginning of the resulting list of naturals.

Examples:
> (fixnum->naturals 143145 '(2 3 4 5 6 7 8 9))

'(0 2 1 2 2 0 1 0)

> (fixnum->naturals 134145 '(2 3 4) #:remainder #t)

'(0 2 1)

5589

> (fixnum->naturals 38 '(7) #:remainder #f)

'(3)

> (fixnum->naturals 38 '(7) #:remainder 'cons)

'(5 3)

> (fixnum->naturals 23 '(2 2 2 2 2))

'(1 0 1 1 1)

syntax

(naturals->fixnum* [n fixnum? 0] [[val natint?] [size posint?]] ...+)

A convenience syntax (macro) for encoding sequences of [value size] pairs, equivalent to (naturals->fixnum (val ...) (size ...) n).

Examples:
> (naturals->fixnum* [0 2] [2 3] [7 12] [6 10])

316

; Starting with a base value:
> (define n0 (naturals->fixnum* [0 2] [2 3]))
> (naturals->fixnum* n0 [7 12] [6 10])

316

4 Byte Board Utilities🔗ℹ

 (require lts-cm/byte-board) package: levintreesearch_cm

This module provides utilities for working with 2D boards represented by byte strings A board is a structure holding a flat byte string along with its dimensions.

struct

(struct board (vec n-rows n-cols))

  vec : bytes?
  n-rows : exact-positive-integer?
  n-cols : exact-positive-integer?
Represents a 2D board. The vec field stores the board’s cell values in a flat byte string (row-major order). Cell values are thus restricted to 0-255.

procedure

(list->board lst n-cols)  board?

  lst : (listof byte?)
  n-cols : exact-positive-integer?
Creates a board from a flat list of byte values, lst. The board will have n-cols columns. If the length of lst is not a multiple of n-cols, the list is effectively truncated to the largest multiple of n-cols that fits, and the remaining elements are ignored.

procedure

(board-find aboard x)  
(or/c false/c exact-integer?)
(or/c false/c exact-integer?)
  aboard : board?
  x : byte?
Finds the first occurrence of the byte value x in aboard, searching in row-major order (left-to-right, then top-to-bottom). Returns two values: the row and column of the first occurrence of x. If x is not found in the board, it returns (values #f #f).

procedure

(board->string aboard)  string?

  aboard : board?
Converts the aboard to a multi-line string representation, suitable for printing to the console.

Examples:
> (define brd (list->board (range 9) 3))
> (displayln (board->string brd))

┌─┬─┬─┐

│0│1│2│

├─┼─┼─┤

│3│4│5│

├─┼─┼─┤

│6│7│8│

└─┴─┴─┘

procedure

(board->list aboard)  (listof byte?)

  aboard : board?
Converse of list->board.

procedure

(board-in-bounds? brd row col)  boolean?

  brd : board?
  row : exact-integer?
  col : exact-integer?
Checks if the given row and col coordinates are within the valid bounds of the board brd. Returns #t if (0 <= row < (board-n-rows brd)) and (0 <= col < (board-n-cols brd)), and #f otherwise.

procedure

(board-set! aboard row col val)  void?

  aboard : board?
  row : exact-integer?
  col : exact-integer?
  val : byte?
Sets the value of the cell at (row, col) in aboard to val.

syntax

(board-index aboard row col)

A macro for calculating the 1D index into the board’s internal flat byte vector that corresponds to the 2D coordinates (row, col). aboard must be an instance of board?, and row and col must be exact integers. This is primarily an internal utility but can be useful for optimized board manipulations.

procedure

(board-copy brd)  board?

  brd : board?
Creates and returns a new board that is a (deep) copy of the input brd.

procedure

(board->bytes aboard)  bytes?

  aboard : board?
Returns the internal byte vector (a bytes? object) that stores the cell data for aboard. Important: This function returns the actual internal byte string, not a copy. Therefore, modifications to the returned byte string will directly affect the aboard from which it was obtained. For a safe copy, use (bytes-copy (board->bytes aboard)) or create a new board via board-copy.

procedure

(board-relative-tiling/collect brd    
  #:collect! collect!    
  #:row row0    
  #:col col0    
  [#:max-value max-value    
  #:pad-value pad-value    
  #:row-dist row-dist    
  #:col-dist col-dist    
  #:row-span row-span    
  #:col-span col-span])  void?
  brd : board?
  collect! : (-> fixnum? any/c)
  row0 : exact-integer?
  col0 : exact-integer?
  max-value : byte? = 255
  pad-value : byte? = max-value
  row-dist : exact-positive-integer? = 1
  col-dist : exact-positive-integer? = row-dist
  row-span : exact-positive-integer? = 2
  col-span : exact-positive-integer? = row-span
Collects context fixnums generated by applying a relative tiling scheme on the given brd around a central point (row0, col0). The context fixnums are collected through repeated calls to the collect! procedure. The arguments are best described with a picture:

The number of tiles (mutex sets) generated by such a tiling is (row-dist × 2 + 1 - row-span) × (col-dist × 2 + 1 - col-span).

For each tile, the cells of the tile are encoded into a single fixnum using naturals->fixnum with size = (+ max-value 1). If a cell of a tile is outside the boundaries of the board, the pad-value is used in place of the cell’s value. The resulting code is passed to collect!.

5 Line search for convex minimization🔗ℹ

 (require lts-cm/delta-secant) package: levintreesearch_cm

This module implements the Δ-Secant line search algorithm for the paper “Line Search for Convex Minimization”.

The function convex-line-search returns the lowest point found of a given convex function between two initial points when a stopping criterion is satisfied.

The function quasi-exact-line-search build upon convex-line-search to ensure sufficient progress is made, and is intended to be used within an optimization algorithm such as gradient descent or Frank-Wolfe.

procedure

(convex-line-search f    
  xleft    
  xright    
  [#:yleft yleft    
  #:xq xq    
  #:yq yq    
  #:y-tolerance real?    
  #:stop-when stop-when    
  #:callback callback])  dict?
  f : (-> real? real?)
  xleft : real?
  xright : real?
  yleft : real? = (f xleft)
  xq : real? = (* 0.5 (+ xleft xright))
  yq : real? = (f xq)
  real? : y-tolerance = 1e-10
  stop-when : (-> dict? any/c)
   = (λ (dic) (<= (dict-ref dic 'ygap) y-tolerance))
  callback : (-> dict? any/c) = (λ (dic) (void))
The function f is assumed convex between xleft and xright, and the behaviour is undefined otherwise. Assume that y* = f(x*) = min_x f(x) is the minimum of f on the interval [xleft, xright].

This function returns a dictionary of values with the following keys:
  • 'iter: Number of iterations performed.

  • 'xlow and 'ylow: lowest point found — usually these are the quantities of interest.

  • 'xgap and 'ygap: upper bounds on |xlow - x*| and |ylow - x*|.

  • 'x- and 'x+: x-interval containing x*.

  • 'ya and 'yb: The minimum of these two values is a lower bound on y*.

  • 'pts: The 5 points around x*. See paper.

The arguments yleft and yq MUST be equal to (f xleft) and (f xq).

The argument xq is the first point within [xleft, xright] to be queried.

The argument stop-when controls when the algorithm should terminate. By default, it terminates when the y-distance to the minimum ('ygap) is provably less than y-tolerance.

The argument callback can be used to monitor the progress of the line search.

Examples:
> (convex-line-search (λ (x) (sqr (- x 1))) -2 5)

(list

 '(iter . 23)

 (list

  'pts

  (pt 0.9999892887337545 1.1473122458246348e-10)

  (pt 0.9999966814872121 1.1012527123440112e-11)

  (pt 1.0000016825306512 2.8309093923890705e-12)

  (pt 1.0000135688568499 1.8411387621275733e-10)

  (pt 1.000055351641258 3.063804189957421e-9))

 '(x- . 0.9999972646480525)

 '(x+ . 1.0000109385368174)

 '(xgap . 1.3673888764942355e-5)

 '(ya . -2.9452989812498634e-11)

 '(yb . -1.1960640832144743e-11)

 '(ygap . 3.2283899204887706e-11)

 '(xlow . 1.0000016825306512)

 '(ylow . 2.8309093923890705e-12))

> (define (keep-keys dic keys) (filter (λ (l) (memq (car l) keys)) dic))
> (keep-keys (convex-line-search (λ (x) (sqr (- x 1))) -2 5 #:y-tolerance 0.01)
             '(iter xlow ylow y-gap))

'((iter . 7) (xlow . 1.0095288532116586) (ylow . 9.079904352933715e-5))

> (keep-keys
   (convex-line-search (λ (x) (max (sqr (- x 1)) (sqr (+ x 1)))) -2 5)
   '(iter xlow ylow xgap ygap))

'((iter . 15)

  (xgap . 1.5371377058688084e-11)

  (ygap . 9.722889160457271e-12)

  (xlow . -2.796517824676535e-12)

  (ylow . 1.0000000000055929))

procedure

(quasi-exact-line-search f    
  [xleft    
  xright    
  #:yleft yleft    
  #:xq xq    
  #:yq yq    
  #:jac^2 jac^2    
  #:c c    
  #:callback callback])  dict?
  f : (-> real? real?)
  xleft : real? = 0.0
  xright : real? = 1.0
  yleft : real? = (f xleft)
  xq : real? = (* 0.5 (+ xleft xright))
  yq : real? = (f xq)
  jac^2 : (or/c #f positive-real?) = #f
  c : positive-real? = 1.0
  callback : (-> dict? any/c) = (λ (dic) (void))
Like convex-line-search but the argument c controls how close to the minimum the returned value ylow (within the returned dictionary) should be compared to the initial value yleft; more precisely, we have ylow - y* ≤ c(yleft - ylow).

Moreover, by contrast to convex-line-search, if the minimum is found to be at xright, the range [xleft, xright] is quadrupled to the right and the line search continues, and so on. This means that for example the call (quasi-exact-line-search / 1 2) loops forever. To prevent this quadrupling behaviour, one can force the function f to be increasing at xright, for eaxmple with (λ (x) (if (< x 2) (/ x) +inf.0)) instead of /.

The argument jac^2, if provided, should be the squared 2-norm of the jacobian (aka the gradient or derivative) at xleft. This information may be used to speed up the search.

See convex-line-search for the description of the returned dictionary, and of the other arguments.

Example:
> (for/list ([c '(1 10 100)])
    (keep-keys
     (quasi-exact-line-search (λ (x) (sqr (- x 1))) -2 5 #:c c)
     '(iter xlow ylow)))

'(((iter . 2) (xlow . 1.47265625) (ylow . 0.2234039306640625))

  ((iter . 4) (xlow . 0.7788681457319648) (ylow . 0.04889929697201954))

  ((iter . 6) (xlow . 1.0095288532116586) (ylow . 9.079904352933715e-5)))

struct

(struct pt (x y))

  x : real?
  y : real?

struct

(struct ptg pt (g))

  g : real?
Points without and with gradient. May be used in the 'pts entry of the return dictionaries of convex-line-search and quasi-exact-line-search.