Friday, April 25, 2014

Partial Lenses in F#

Partial Lenses in F#

Based on Mauricio's article about Lenses in F#, I recently tried to use lenses in our project, but very soon ran into some fundamental limitations. It turned out be an interesting problem and in the end, we use a modified version of lenses in our project and are happy with the result.

What are lenses

As a quick reminder, lenses are bi-directional transformations which allow you to make a well-behaved copy-and-update operation on (potentially deeply) nested data structures.

They offer two functions get : 'a -> 'b and set : 'b -> 'a -> 'a. The idea is, that you can "zoom in" to a property and perform local transformations that properly propagate through the object graph. Based on those two we can define a function update : '(b -> 'b) -> 'a -> 'a that first applies the get then maps a function over the value and in the end applies set to perform a map somewhere in the object graph.

In particular, lenses can be composed and still behave as one would expect for a copy-and-update operation. For a more detailed introduction, read Mauricio's article.

The Problem

While lenses formulated in the way, as Mauricio presents them are very good for product types like tuples or records, they do not work for sum types like discriminated unions. Consider the following example

1: 
2: 
3: 
4: 
type Shape =
    | Rectangle of width : float * length : float
    | Circle of radius : float
    | Prism of width : float * float * height : float

In this case, we cannot create a lens for say the prism, let alone for the height of prism.

As it turns out, the latter is an important feature, because in F# in general and in our project in particular, we use a lot of discriminated unions. So there was the question, if we could somehow extend the idea of a lens, such that it would work also for sum types as opposed to only product types

Towards partial lenses

The idea, how to make lenses work well with sum types, to consider partial functions instead of total ones: get : 'a -> 'b optionand set : 'b -> 'a -> 'a.

The first point of interest is, that only the get function has a different signature, this is due to the fact, that we may not get a value out of our lens - it is a partial function. When we try to set a value, though, the original object already exists, so if the lens does not trigger, then it just returns the original result instead of changing anything. Our derived update function also keeps its type.

The next question is how do we compose such lenses? It turns out, that it is essentially a monadic bind of the option monad.

 1: 
 2: 
 3: 
 4: 
 5: 
 6: 
 7: 
 8: 
 9: 
10: 
11: 
type Lens<'a, 'b> =
    { Get : 'a -> 'b option
      Set : 'b -> 'a -> 'a }
    member l.Update f a =
        match l.Get a with
        | Some x -> l.Set (f x) a
        | None -> a

let compose (l1 : Lens<_, _>) (l2 : Lens<_, _>) =
    { Get = fun a -> Option.bind l1.Get (l2.Get a)
      Set = l1.Set >> l2.Update }

Now we define our partial lenses for discriminated unions thus:

 1: 
 2: 
 3: 
 4: 
 5: 
 6: 
 7: 
 8: 
 9: 
10: 
11: 
type MyType =
    | MyCase of string
    | MyOther of int

let myCase =
    { Get = function
        | MyCase value -> Some value
        | _ -> None
      Set = fun newValue -> function
        | MyCase _ -> newValue
        | a -> a }

For normal or total lenses, we can define the get as simply wrapping the Value with Some.

Making sure the lenses are well-behaved

In order to be well-behaved, lenses should fulfil the following lens laws:

  • get-set-law: set (get a) a = a
  • set-get-law: get (set v a) = v
  • set-set-law: set v' (set v a) = set v' a

Now, we cannot immediately fulfil those laws, because our get function is partial, or rather yields an option value. However, we do can show, that the laws hold for the case, where the lense actually yields a value, so this is a conservative extension.

We redefine the get as get' = get >> Option.get, that is, we apply get and immediately unwrap the option. This is now no longer a total function, i.e. it throws for cases, where get returned None, but for the cases, where the original get yielded a Some value, it yields the same value.

By definition, for all total lenses, our partial lens yields the same value, wrapped with Some, so for those cases get' behaves exactly like the original total lens and therefore fulfills the same laws.

And for partial lenses, we only want to show, that the laws hold for the case, where the lens yields a value: Therefore, without loss of generality, assume a = MyCase x for some x and the case MyCase of the lens.

 1: 
 2: 
 3: 
 4: 
 5: 
 6: 
 7: 
 8: 
 9: 
10: 
11: 
12: 
13: 
14: 
15: 
16: 
17: 
18: 
19: 
20: 
21: 
set (get' a) a
    = set ((get >> Option.get) a) a
    = set (Option.get (get (MyCase x))) a
    = set (Option.get (Some x)) a
    = set x a
    = set x (MyCase x)
    = MyCase x
    = a

get' (set v a)
    = get' (set v (MyCase x))
    = get' (MyCase v)
    = Option.get (Some v)
    = v

set v' (set v a)
    = set v' (set v (MyCase x))
    = set v' (MyCase v)
    = MyCase v'
    = set v' (MyCase x)
    = set v' a

Possible improvement

One possible improvement would be, to distinguish between partial lenses (for sum types) and total lenses (for product types). Whilst I think, it would be possible to properly propagate which lenses are total and which ones are partial (in a static fashion), we did not go down that route, because of two reasons: To get a value, a single pattern match is enough on the caller site and our objects are rarely consisting only of product types. Therefore the more general partial lenses were good enough for our use case.

If one wants to make this distinction properly, one needs 4 overloads of the lens function each: One where both sides are total lenses, one with the first a partial lens, one with the second being partial and one for both sides partial.

type Shape =
  | Rectangle of width: float * length: float
  | Circle of radius: float
  | Prism of width: float * float * height: float

Full name: partiallenses.Shape
union case Shape.Rectangle: width: float * length: float -> Shape
Multiple items
val float : value:'T -> float (requires member op_Explicit)

Full name: Microsoft.FSharp.Core.Operators.float

--------------------
type float = System.Double

Full name: Microsoft.FSharp.Core.float

--------------------
type float<'Measure> = float

Full name: Microsoft.FSharp.Core.float<_>
union case Shape.Circle: radius: float -> Shape
union case Shape.Prism: width: float * float * height: float -> Shape
type Lens<'a,'b> =
  {Get: 'a -> 'b option;
   Set: 'b -> 'a -> 'a;}
  member Update : f:('b -> 'b) -> a:'a -> 'a

Full name: partiallenses.Lens<_,_>
Lens.Get: 'a -> 'b option
type 'T option = Option<'T>

Full name: Microsoft.FSharp.Core.option<_>
Multiple items
Lens.Set: 'b -> 'a -> 'a

--------------------
module Set

from Microsoft.FSharp.Collections

--------------------
type Set<'T (requires comparison)> =
  interface IComparable
  interface IEnumerable
  interface IEnumerable<'T>
  interface ICollection<'T>
  new : elements:seq<'T> -> Set<'T>
  member Add : value:'T -> Set<'T>
  member Contains : value:'T -> bool
  override Equals : obj -> bool
  member IsProperSubsetOf : otherSet:Set<'T> -> bool
  member IsProperSupersetOf : otherSet:Set<'T> -> bool
  ...

Full name: Microsoft.FSharp.Collections.Set<_>

--------------------
new : elements:seq<'T> -> Set<'T>
val l : Lens<'a,'b>
member Lens.Update : f:('b -> 'b) -> a:'a -> 'a

Full name: partiallenses.Lens`2.Update
val f : ('b -> 'b)
val a : 'a
union case Option.Some: Value: 'T -> Option<'T>
val x : 'b
Lens.Set: 'b -> 'a -> 'a
union case Option.None: Option<'T>
val compose : l1:Lens<'a,'b> -> l2:Lens<'c,'a> -> Lens<'c,'b>

Full name: partiallenses.compose
val l1 : Lens<'a,'b>
val l2 : Lens<'c,'a>
val a : 'c
module Option

from Microsoft.FSharp.Core
val bind : binder:('T -> 'U option) -> option:'T option -> 'U option

Full name: Microsoft.FSharp.Core.Option.bind
Lens.Get: 'c -> 'a option
Multiple items
module Set

from Microsoft.FSharp.Collections

--------------------
type Set<'T (requires comparison)> =
  interface IComparable
  interface IEnumerable
  interface IEnumerable<'T>
  interface ICollection<'T>
  new : elements:seq<'T> -> Set<'T>
  member Add : value:'T -> Set<'T>
  member Contains : value:'T -> bool
  override Equals : obj -> bool
  member IsProperSubsetOf : otherSet:Set<'T> -> bool
  member IsProperSupersetOf : otherSet:Set<'T> -> bool
  ...

Full name: Microsoft.FSharp.Collections.Set<_>

--------------------
new : elements:seq<'T> -> Set<'T>
member Lens.Update : f:('b -> 'b) -> a:'a -> 'a
type MyType =
  | MyCase of string
  | MyOther of int

Full name: partiallenses.MyType
union case MyType.MyCase: string -> MyType
Multiple items
val string : value:'T -> string

Full name: Microsoft.FSharp.Core.Operators.string

--------------------
type string = System.String

Full name: Microsoft.FSharp.Core.string
union case MyType.MyOther: int -> MyType
Multiple items
val int : value:'T -> int (requires member op_Explicit)

Full name: Microsoft.FSharp.Core.Operators.int

--------------------
type int = int32

Full name: Microsoft.FSharp.Core.int

--------------------
type int<'Measure> = int

Full name: Microsoft.FSharp.Core.int<_>
val myCase : Lens<MyType,string>

Full name: partiallenses.myCase
val value : string
val newValue : string
val a : MyType
val set : (string -> MyType -> MyType)

Full name: partiallenses.set
val get' : (MyType -> string)

Full name: partiallenses.get'
val get : (MyType -> string option)

Full name: partiallenses.get
val get : option:'T option -> 'T

Full name: Microsoft.FSharp.Core.Option.get

Friday, April 4, 2014

Memoization

Caching

When programming, we sometimes get to a point, where a program runs slow, because the computer performs the same computation over and over. To understand when caching can make sense to make your program faster, let us first be a little more precise about what it really means, and when it can be used.

First and foremost caching means, that before runnning a function, that may take a large amount of time, we check in a lookup table e.g. a dictionary, if the result has been computed previously and thus need not be computed again. Instead, the result is immediately returned from the lookup table.

In order for caching to make sense, we need the function in question to satisfy certain criteria:

  • It should be referentially transparent, i.e. it has to return the same result for the same input, for at least as long as the cache is valid. Good examples would be mathematical functions like prime number checking, hashing, etc. but also things like reading the contents of a text file into the memory, if we know the files is not going to change for the lifetime of the cache. Typical examples of functions, that do not satisfy this property would be a random number generator, getting the current balance from your account, etc.
  • It should take long, this is not a precise definition, but since we add additional complexity, like lookup from a table, etc. we make the individual call slower. So we need to gain enough by caching for this to be justified.

Memoization

Memoization extends the idea of a cache from single values to whole functions. While it may at first sound like just a detail for performance tuning, it can have, in fact, a deep impact on the time and space complexity of algorithms. Consider this simple function:

1: 
2: 
3: 
let rec fib = function
    | 0 | 1 -> 1
    | n -> fib (n-1) + fib (n-2)

the function as it is written here has a time complexity of \(O(2^n)\), because for each value it branches twice.

On the other hand, every value \(fib(n)\) only depends of the values of \(fib(i)\) for \(i < n\).

Now what would happen, if \(fib\) was memoized, i.e. each value would be computed at most once.

To calculate the value \(fib(n)\), we need to calculate the value \(fib(n-1)\) which in turn depends of \(fib(n-2)\). That means, that by the time we are done calculating \(fib(n-1)\), we already have \(fib(n-2)\) in the cache. And so, we do not branch in this case, but immediately return the result.

Well, it turns out, that the algorithm all of the sudden becomes \(O(n)\). We traded time complexity for some space complexity.

Implementing Memoization in F#

We can implement memoization using a mutable dictionary of some sort. A purely function implementation without a mutable dictionary, I might explore in a later blog post.

We need a cache and a way to get or add the a function value. And our memoized function can be written thusly

1: 
2: 
3: 
4: 
5: 
6: 
7: 
8: 
open System.Collections.Generic
let memo f =
    let dict = Dictionary()
    fun x -> match dict.TryGetValue x with
             | true, value -> value
             | _ -> let value = f x
                    dict.Add(x, value)
                    value

this implementation, however has several caveats:

  • it explodes, if you use null (or ()) as the function argument
  • it is not thread-safe; the operation Add explodes, if the key is already present.

We can do better:

1: 
2: 
3: 
4: 
5: 
6: 
7: 
8: 
open System.Collections.Generic
let memo f =
    let dict = Dictionary()
    fun x -> match dict.TryGetValue (Some x) with
             | true, value -> value
             | _ -> let value = f x
                    dict.[Some x] <- value
                    value

this implementation is better. Instead does not explode with null input and we might call it semi thread-safe. This means, that it won’t explode in case of a race condition, but it might calculate unnecessarily and then overwrite the cache with the value. On the other hand, it works without locks.

Supporting recursive functions

Unfortunately, this approach does not work for recursive functions, because the memoization is only added to the outermost function, but internally the function calls the non-memoized version. A very elegant solution, is to use a memoizing \(Y\)-combinator. I shall explore the \(Y\)-combinator in a later post. In short, it is a way of writing a recursive function without recursion. Instead, another argument is added, that corresponds to the recursive definition.

With this approach our \(fib\)-function now becomes

1: 
2: 
3: 
let fib fib = function
    | 0 | 1 -> 1
    | n -> fib (n-1) + fib (n-2)

Here, I used the same name as the function for the recursive argument, so that the code still looks like a recursive function.

To memoize all the versions of \(fib\), we need to be able to reuse the cache. So we really need two things: create a cache and then use this cache in a getOrCache function, that can either get the value or compute it and cache it for later use. Both things can be implemented with a simple function:

1: 
2: 
3: 
4: 
5: 
6: 
7: 
8: 
open System.Collections.Generic
let createCache () =
    let dict = Dictionary()
    fun f x -> match dict.TryGetValue (Some x) with
               | true, value -> value
               | _ -> let value = f x
                      dict.[Some x] <- value
                      value

our memo function then becomes

1: 
2: 
3: 
let memo cache f =
    let cache = cache()
    cache f

and the memoizing \(Y\)-combinator looks thus:

1: 
2: 
3: 
4: 
let memoFix cache f =
    let cache = cache()
    let rec fn x = cache (f fn) x
    fn

the point being, that we reuse the same cache for all the recursive versions of f.

Different kinds of caches

We can use any kind of cache, as long as we can fulfil the function signature for createCache. For instance we can use the truly thread-safe ConcurrentDictionary instead.

1: 
2: 
3: 
4: 
open System.Collections.Concurrent
let createCache () =
    let dict = ConcurrentDictionary()
    fun f x -> dict.GetOrAdd(Some x, lazy(f x)).Value

We can now create a small module for each kind of cache, so that we do not need to pass the createCache function each time.

Multiple curried arguments

And lastly, we might want to be able to memoize more than just one single argument. For tupled arguments, we already have the feature, and for curried arguments, we can just call memo repeatedly.

1: 
2: 
3: 
4: 
5: 
6: 
7: 
8: 
let memo f = memo createCache f
let memo2 f = memo (memo << f)
let memo3 f = memo (memo2 << f)
let memo4 f = memo (memo3 << f)
let memo5 f = memo (memo4 << f)
let memo6 f = memo (memo5 << f)
let memo7 f = memo (memo6 << f)
let memo8 f = memo (memo7 << f)

the whole snippet is available under on fssnip.net

val fib : _arg1:int -> int

Full name: memo.fib
val n : int
namespace System
namespace System.Collections
namespace System.Collections.Generic
val memo : f:('a -> 'b) -> ('a -> 'b) (requires equality)

Full name: memo.memo
val f : ('a -> 'b) (requires equality)
val dict : Dictionary<'a,'b> (requires equality)
Multiple items
type Dictionary<'TKey,'TValue> =
  new : unit -> Dictionary<'TKey, 'TValue> + 5 overloads
  member Add : key:'TKey * value:'TValue -> unit
  member Clear : unit -> unit
  member Comparer : IEqualityComparer<'TKey>
  member ContainsKey : key:'TKey -> bool
  member ContainsValue : value:'TValue -> bool
  member Count : int
  member GetEnumerator : unit -> Enumerator<'TKey, 'TValue>
  member GetObjectData : info:SerializationInfo * context:StreamingContext -> unit
  member Item : 'TKey -> 'TValue with get, set
  ...
  nested type Enumerator
  nested type KeyCollection
  nested type ValueCollection

Full name: System.Collections.Generic.Dictionary<_,_>

--------------------
Dictionary() : unit
Dictionary(capacity: int) : unit
Dictionary(comparer: IEqualityComparer<'TKey>) : unit
Dictionary(dictionary: IDictionary<'TKey,'TValue>) : unit
Dictionary(capacity: int, comparer: IEqualityComparer<'TKey>) : unit
Dictionary(dictionary: IDictionary<'TKey,'TValue>, comparer: IEqualityComparer<'TKey>) : unit
val x : 'a (requires equality)
Dictionary.TryGetValue(key: 'a, value: byref<'b>) : bool
val value : 'b
Dictionary.Add(key: 'a, value: 'b) : unit
val dict : Dictionary<'a option,'b> (requires equality)
Dictionary.TryGetValue(key: 'a option, value: byref<'b>) : bool
union case Option.Some: Value: 'T -> Option<'T>
val fib : fib:(int -> int) -> _arg1:int -> int

Full name: memo.fib
val fib : (int -> int)
val createCache : unit -> (('a -> 'b) -> 'a -> 'b) (requires equality)

Full name: memo.createCache
val memo : cache:(unit -> 'a -> 'b) -> f:'a -> 'b

Full name: memo.memo
val cache : (unit -> 'a -> 'b)
val f : 'a
val cache : ('a -> 'b)
val memoFix : cache:(unit -> 'a -> 'b -> 'c) -> f:(('b -> 'c) -> 'a) -> ('b -> 'c)

Full name: memo.memoFix
val cache : (unit -> 'a -> 'b -> 'c)
val f : (('b -> 'c) -> 'a)
val cache : ('a -> 'b -> 'c)
val fn : ('b -> 'c)
val x : 'b
namespace System.Collections.Concurrent
val createCache : unit -> (('a -> 'b) -> 'a -> 'b)

Full name: memo.createCache
val dict : ConcurrentDictionary<'a option,Lazy<'b>>
Multiple items
type ConcurrentDictionary<'TKey,'TValue> =
  new : unit -> ConcurrentDictionary<'TKey, 'TValue> + 6 overloads
  member AddOrUpdate : key:'TKey * addValueFactory:Func<'TKey, 'TValue> * updateValueFactory:Func<'TKey, 'TValue, 'TValue> -> 'TValue + 1 overload
  member Clear : unit -> unit
  member ContainsKey : key:'TKey -> bool
  member Count : int
  member GetEnumerator : unit -> IEnumerator<KeyValuePair<'TKey, 'TValue>>
  member GetOrAdd : key:'TKey * valueFactory:Func<'TKey, 'TValue> -> 'TValue + 1 overload
  member IsEmpty : bool
  member Item : 'TKey -> 'TValue with get, set
  member Keys : ICollection<'TKey>
  ...

Full name: System.Collections.Concurrent.ConcurrentDictionary<_,_>

--------------------
ConcurrentDictionary() : unit
ConcurrentDictionary(collection: IEnumerable<KeyValuePair<'TKey,'TValue>>) : unit
ConcurrentDictionary(comparer: IEqualityComparer<'TKey>) : unit
ConcurrentDictionary(concurrencyLevel: int, capacity: int) : unit
ConcurrentDictionary(collection: IEnumerable<KeyValuePair<'TKey,'TValue>>, comparer: IEqualityComparer<'TKey>) : unit
ConcurrentDictionary(concurrencyLevel: int, collection: IEnumerable<KeyValuePair<'TKey,'TValue>>, comparer: IEqualityComparer<'TKey>) : unit
ConcurrentDictionary(concurrencyLevel: int, capacity: int, comparer: IEqualityComparer<'TKey>) : unit
val f : ('a -> 'b)
val x : 'a
ConcurrentDictionary.GetOrAdd(key: 'a option, value: Lazy<'b>) : Lazy<'b>
ConcurrentDictionary.GetOrAdd(key: 'a option, valueFactory: System.Func<'a option,Lazy<'b>>) : Lazy<'b>
val memo : f:('a -> 'b) -> ('a -> 'b)

Full name: memo.memo
val memo2 : f:('a -> 'b -> 'c) -> ('a -> 'b -> 'c)

Full name: memo.memo2
val f : ('a -> 'b -> 'c)
val memo3 : f:('a -> 'b -> 'c -> 'd) -> ('a -> 'b -> 'c -> 'd)

Full name: memo.memo3
val f : ('a -> 'b -> 'c -> 'd)
val memo4 : f:('a -> 'b -> 'c -> 'd -> 'e) -> ('a -> 'b -> 'c -> 'd -> 'e)

Full name: memo.memo4
val f : ('a -> 'b -> 'c -> 'd -> 'e)
val memo5 : f:('a -> 'b -> 'c -> 'd -> 'e -> 'f) -> ('a -> 'b -> 'c -> 'd -> 'e -> 'f)

Full name: memo.memo5
val f : ('a -> 'b -> 'c -> 'd -> 'e -> 'f)
val memo6 : f:('a -> 'b -> 'c -> 'd -> 'e -> 'f -> 'g) -> ('a -> 'b -> 'c -> 'd -> 'e -> 'f -> 'g)

Full name: memo.memo6
val f : ('a -> 'b -> 'c -> 'd -> 'e -> 'f -> 'g)
val memo7 : f:('a -> 'b -> 'c -> 'd -> 'e -> 'f -> 'g -> 'h) -> ('a -> 'b -> 'c -> 'd -> 'e -> 'f -> 'g -> 'h)

Full name: memo.memo7
val f : ('a -> 'b -> 'c -> 'd -> 'e -> 'f -> 'g -> 'h)
val memo8 : f:('a -> 'b -> 'c -> 'd -> 'e -> 'f -> 'g -> 'h -> 'i) -> ('a -> 'b -> 'c -> 'd -> 'e -> 'f -> 'g -> 'h -> 'i)

Full name: memo.memo8
val f : ('a -> 'b -> 'c -> 'd -> 'e -> 'f -> 'g -> 'h -> 'i)