Caching
When programming, we sometimes get to a point, where a program runs slow, because the computer performs the same computation over and over. To understand when caching can make sense to make your program faster, let us first be a little more precise about what it really means, and when it can be used.
First and foremost caching means, that before runnning a function, that may take a large amount of time, we check in a lookup table e.g. a dictionary, if the result has been computed previously and thus need not be computed again. Instead, the result is immediately returned from the lookup table.
In order for caching to make sense, we need the function in question to satisfy certain criteria:
- It should be referentially transparent, i.e. it has to return the same result for the same input, for at least as long as the cache is valid. Good examples would be mathematical functions like prime number checking, hashing, etc. but also things like reading the contents of a text file into the memory, if we know the files is not going to change for the lifetime of the cache. Typical examples of functions, that do not satisfy this property would be a random number generator, getting the current balance from your account, etc.
- It should take long, this is not a precise definition, but since we add additional complexity, like lookup from a table, etc. we make the individual call slower. So we need to gain enough by caching for this to be justified.
Memoization
Memoization extends the idea of a cache from single values to whole functions. While it may at first sound like just a detail for performance tuning, it can have, in fact, a deep impact on the time and space complexity of algorithms. Consider this simple function:
1: 2: 3: |
let rec fib = function | 0 | 1 -> 1 | n -> fib (n-1) + fib (n-2) |
the function as it is written here has a time complexity of \(O(2^n)\), because for each value it branches twice.
On the other hand, every value \(fib(n)\) only depends of the values of \(fib(i)\) for \(i < n\).
Now what would happen, if \(fib\) was memoized, i.e. each value would be computed at most once.
To calculate the value \(fib(n)\), we need to calculate the value \(fib(n-1)\) which in turn depends of \(fib(n-2)\). That means, that by the time we are done calculating \(fib(n-1)\), we already have \(fib(n-2)\) in the cache. And so, we do not branch in this case, but immediately return the result.
Well, it turns out, that the algorithm all of the sudden becomes \(O(n)\). We traded time complexity for some space complexity.
Implementing Memoization in F#
We can implement memoization using a mutable dictionary of some sort. A purely function implementation without a mutable dictionary, I might explore in a later blog post.
We need a cache and a way to get or add the a function value. And our memoized function can be written thusly
1: 2: 3: 4: 5: 6: 7: 8: |
open System.Collections.Generic let memo f = let dict = Dictionary() fun x -> match dict.TryGetValue x with | true, value -> value | _ -> let value = f x dict.Add(x, value) value |
this implementation, however has several caveats:
- it explodes, if you use
null
(or()
) as the function argument - it is not thread-safe; the operation
Add
explodes, if the key is already present.
We can do better:
1: 2: 3: 4: 5: 6: 7: 8: |
open System.Collections.Generic let memo f = let dict = Dictionary() fun x -> match dict.TryGetValue (Some x) with | true, value -> value | _ -> let value = f x dict.[Some x] <- value value |
this implementation is better. Instead does not explode with null
input and we might call it semi thread-safe. This means, that it won’t explode in case of a race condition, but it might calculate unnecessarily and then overwrite the cache with the value. On the other hand, it works without locks.
Supporting recursive functions
Unfortunately, this approach does not work for recursive functions, because the memoization is only added to the outermost function, but internally the function calls the non-memoized version. A very elegant solution, is to use a memoizing \(Y\)-combinator. I shall explore the \(Y\)-combinator in a later post. In short, it is a way of writing a recursive function without recursion. Instead, another argument is added, that corresponds to the recursive definition.
With this approach our \(fib\)-function now becomes
1: 2: 3: |
let fib fib = function | 0 | 1 -> 1 | n -> fib (n-1) + fib (n-2) |
Here, I used the same name as the function for the recursive argument, so that the code still looks like a recursive function.
To memoize all the versions of \(fib\), we need to be able to reuse the cache. So we really need two things: create a cache and then use this cache in a getOrCache
function, that can either get the value or compute it and cache it for later use. Both things can be implemented with a simple function:
1: 2: 3: 4: 5: 6: 7: 8: |
open System.Collections.Generic let createCache () = let dict = Dictionary() fun f x -> match dict.TryGetValue (Some x) with | true, value -> value | _ -> let value = f x dict.[Some x] <- value value |
our memo
function then becomes
1: 2: 3: |
let memo cache f = let cache = cache() cache f |
and the memoizing \(Y\)-combinator looks thus:
1: 2: 3: 4: |
let memoFix cache f = let cache = cache() let rec fn x = cache (f fn) x fn |
the point being, that we reuse the same cache
for all the recursive versions of f
.
Different kinds of caches
We can use any kind of cache, as long as we can fulfil the function signature for createCache
. For instance we can use the truly thread-safe ConcurrentDictionary
instead.
1: 2: 3: 4: |
open System.Collections.Concurrent let createCache () = let dict = ConcurrentDictionary() fun f x -> dict.GetOrAdd(Some x, lazy(f x)).Value |
We can now create a small module for each kind of cache, so that we do not need to pass the createCache
function each time.
Multiple curried arguments
And lastly, we might want to be able to memoize more than just one single argument. For tupled arguments, we already have the feature, and for curried arguments, we can just call memo
repeatedly.
1: 2: 3: 4: 5: 6: 7: 8: |
let memo f = memo createCache f let memo2 f = memo (memo << f) let memo3 f = memo (memo2 << f) let memo4 f = memo (memo3 << f) let memo5 f = memo (memo4 << f) let memo6 f = memo (memo5 << f) let memo7 f = memo (memo6 << f) let memo8 f = memo (memo7 << f) |
the whole snippet is available under on fssnip.net
Full name: memo.fib
Full name: memo.memo
type Dictionary<'TKey,'TValue> =
new : unit -> Dictionary<'TKey, 'TValue> + 5 overloads
member Add : key:'TKey * value:'TValue -> unit
member Clear : unit -> unit
member Comparer : IEqualityComparer<'TKey>
member ContainsKey : key:'TKey -> bool
member ContainsValue : value:'TValue -> bool
member Count : int
member GetEnumerator : unit -> Enumerator<'TKey, 'TValue>
member GetObjectData : info:SerializationInfo * context:StreamingContext -> unit
member Item : 'TKey -> 'TValue with get, set
...
nested type Enumerator
nested type KeyCollection
nested type ValueCollection
Full name: System.Collections.Generic.Dictionary<_,_>
--------------------
Dictionary() : unit
Dictionary(capacity: int) : unit
Dictionary(comparer: IEqualityComparer<'TKey>) : unit
Dictionary(dictionary: IDictionary<'TKey,'TValue>) : unit
Dictionary(capacity: int, comparer: IEqualityComparer<'TKey>) : unit
Dictionary(dictionary: IDictionary<'TKey,'TValue>, comparer: IEqualityComparer<'TKey>) : unit
Full name: memo.fib
Full name: memo.createCache
Full name: memo.memo
Full name: memo.memoFix
Full name: memo.createCache
type ConcurrentDictionary<'TKey,'TValue> =
new : unit -> ConcurrentDictionary<'TKey, 'TValue> + 6 overloads
member AddOrUpdate : key:'TKey * addValueFactory:Func<'TKey, 'TValue> * updateValueFactory:Func<'TKey, 'TValue, 'TValue> -> 'TValue + 1 overload
member Clear : unit -> unit
member ContainsKey : key:'TKey -> bool
member Count : int
member GetEnumerator : unit -> IEnumerator<KeyValuePair<'TKey, 'TValue>>
member GetOrAdd : key:'TKey * valueFactory:Func<'TKey, 'TValue> -> 'TValue + 1 overload
member IsEmpty : bool
member Item : 'TKey -> 'TValue with get, set
member Keys : ICollection<'TKey>
...
Full name: System.Collections.Concurrent.ConcurrentDictionary<_,_>
--------------------
ConcurrentDictionary() : unit
ConcurrentDictionary(collection: IEnumerable<KeyValuePair<'TKey,'TValue>>) : unit
ConcurrentDictionary(comparer: IEqualityComparer<'TKey>) : unit
ConcurrentDictionary(concurrencyLevel: int, capacity: int) : unit
ConcurrentDictionary(collection: IEnumerable<KeyValuePair<'TKey,'TValue>>, comparer: IEqualityComparer<'TKey>) : unit
ConcurrentDictionary(concurrencyLevel: int, collection: IEnumerable<KeyValuePair<'TKey,'TValue>>, comparer: IEqualityComparer<'TKey>) : unit
ConcurrentDictionary(concurrencyLevel: int, capacity: int, comparer: IEqualityComparer<'TKey>) : unit
ConcurrentDictionary.GetOrAdd(key: 'a option, valueFactory: System.Func<'a option,Lazy<'b>>) : Lazy<'b>
Full name: memo.memo
Full name: memo.memo2
Full name: memo.memo3
Full name: memo.memo4
Full name: memo.memo5
Full name: memo.memo6
Full name: memo.memo7
Full name: memo.memo8
No comments:
Post a Comment