Loop fusion in Haskell Roman Leshchinskiy Programming Languages and Systems University of New South Wales

What is this about? What I do Data Parallel Haskell compiles nested data-parallel programs to flat data-parallel ones lots of arrays and collective operations involved

zipWith (-) (zipWith (*) (zipWith (-) (zipWith (-) (zipWith (*) (zipWith (-) (zipWith (-)

(replicate_s segd as1) xs) (replicate_s segd bs1) ys)) (replicate_s segd bs2) ys) (replicate_s segd as2) xs))

return . foldl’ hash 5381 . map toLower . filter isAlpha =<< readFile f

What is this about? What I do Data Parallel Haskell compiles nested data-parallel programs to flat data-parallel ones lots of arrays and collective operations involved What other people do array programs with lots of collective operations What everybody wants no temporary arrays fused loops C-like speed

RULES "map/map"

map f (map g xs) = map (f . g) xs

RULES "map/map"

map f (map g xs) = map (f . g) xs

RULES "map/map" "filter/filter"

map f (map g xs) = map (f . g) xs filter f (filter g xs) = filter (λ x → f x && g x) xs

RULES "map/map" "filter/filter"

map f (map g xs) = map (f . g) xs filter f (filter g xs) = filter (λ x → f x && g x) xs

RULES "map/map" "filter/filter" "map/filter"

map f (map g xs) = map (f . g) xs filter f (filter g xs) = filter (λ x → f x && g x) xs map f (filter g xs) = mapFilter f g xs

RULES "map/map" "filter/filter"

map f (map g xs) = map (f . g) xs filter f (filter g xs) = filter (λ x → f x && g x) xs "map/filter" map f (filter g xs) = mapFilter f g xs "map/mapFilter" map f (mapFilter g h xs) = mapFilter (f . g) h xs "mapFilter/filter" mapFilter f g (filter h xs) = mapFilter (f λ x → g x && h x) xs ...

RULES "map/map" "filter/filter"

map f (map g xs) = map (f . g) xs filter f (filter g xs) = filter (λ x → f x && g x) xs "map/filter" map f (filter g xs) = mapFilter f g xs "map/mapFilter" map f (mapFilter g h xs) = mapFilter (f . g) h xs "mapFilter/filter" mapFilter f g (filter h xs) = mapFilter (f λ x → g x && h x) xs ...



The challenge

use a constant number of rewrite rules don’t require new rules for new combinators make adding new combinators easy fuse everything! don’t require specialised compiler support handle both sequential and parallel loops

Sequential loops

Streams data Step s a = Yield a s | Done data Stream a = ∃s. Stream (s → Step s a) s

stepper produces next element and state from current state similar to an iterator actually encodes an anamorphism (unfold)

sumS :: Num a ⇒ Stream a → a sumS (Stream step s) = go 0 s where go z s = case step s of Yield x s’ → go (z+x) s’ Done → z

stream :: Array a → Stream a stream arr = Stream step 0 where step i | i < length arr = Yield (arr ! i) (i+1) | otherwise = Done

mapS :: (a → b) → Stream a → Stream b mapS f (Stream step s) = Stream step’ s where step’ s = case step s of Yield x s’ → Yield (f x) s’ Done → Done

unstream :: Stream a → Array a unstream (Stream step s) =

Stream fusion in three easy steps

Stream fusion in three easy steps Step 1: implement array operations in terms of streams sum :: Num a ⇒ Array a → a sum = sumS . stream map :: (a → b) → Array a → Array b map f = unstream . mapS f . stream

Stream fusion in three easy steps Step 1: implement array operations in terms of streams sum :: Num a ⇒ Array a → a sum = sumS . stream map :: (a → b) → Array a → Array b map f = unstream . mapS f . stream Step 2: inline them sumsq :: Num a ⇒ Array a → a sumsq = sum . map (λx -> x*x)

Stream fusion in three easy steps Step 1: implement array operations in terms of streams sum :: Num a ⇒ Array a → a sum = sumS . stream map :: (a → b) → Array a → Array b map f = unstream . mapS f . stream Step 2: inline them sumsq :: Num a ⇒ Array a → a sumsq = sum . map (λx -> x*x) = sumS . stream . unstream . mapS f . stream

Stream fusion in three easy steps Step 1: implement array operations in terms of streams sum :: Num a ⇒ Array a → a sum = sumS . stream map :: (a → b) → Array a → Array b map f = unstream . mapS f . stream Step 2: inline them sumsq :: Num a ⇒ Array a → a sumsq = sum . map (λx -> x*x) = sumS . stream . unstream . mapS f . stream

Stream fusion in three easy steps Step 1: implement array operations in terms of streams sum :: Num a ⇒ Array a → a sum = sumS . stream map :: (a → b) → Array a → Array b map f = unstream . mapS f . stream Step 2: inline them sumsq :: Num a ⇒ Array a → a sumsq = sum . map (λx -> x*x) = sumS . stream . unstream . mapS f . stream Step 3: eliminate stream/unstream pairs "stream/unstream"

stream (unstream s) = s

Stream fusion in three easy steps Step 1: implement array operations in terms of streams sum :: Num a ⇒ Array a → a sum = sumS . stream map :: (a → b) → Array a → Array b map f = unstream . mapS f . stream Step 2: inline them sumsq :: Num a ⇒ Array a → a sumsq = sum . map (λx -> x*x) = sumS . mapS f . stream Step 3: eliminate stream/unstream pairs "stream/unstream"

stream (unstream s) = s

sum :: Num a ⇒ Array a → a sum = sumS . stream

re st

Step 1: implement array operations in terms of streams


Le tG

Step 2: inline them





map :: (a → b) → Array a → Array b map f = unstream . mapS f . stream

sumsq :: Num a ⇒ Array a → a sumsq = sum . map (λx -> x*x) = sumS . mapS f . stream Step 3: eliminate stream/unstream pairs "stream/unstream"

stream (unstream s) = s

Optimising stream operations sumsq xs = sumS (mapS square ( stream xs))

Optimising stream operations sumsq xs = sumS (mapS square ( stream xs)) inline

stream :: Array a → Stream a stream arr = Stream step 0 where step i | i < length arr = Yield (arr ! i) (i+1) | otherwise = Done

Optimising stream operations sumsq xs = sumS ( mapS square (Stream step1 0)) where step1 i = case i < length xs of True → Yield (xs ! i) (i+1) False → Done

Optimising stream operations sumsq xs = sumS ( mapS square (Stream step1 0)) where step1 i = case i < inline length xs of True → Yield (xs ! i) (i+1) False → Done

mapS :: (a → b) → Stream a → b mapS f (Stream step s) = Stream step’ s where step’ s = case step s of Yield x s’ → Yield (f x) s’ Done → Done

Optimising stream operations sumsq xs = sumS (Stream step2 0) where step1 i = case i < length xs of True → Yield (xs ! i) (i+1) False → Done step2 i = case step1 i of Yield x i’ → Yield (square x) i’ Done → Done

Optimising stream operations sumsq xs = sumS (Stream step2 0) where step1 i = case inlinei < length xs of True → Yield (xs ! i) (i+1) False → Done step2 i = case step1 i of Yield x i’ → Yield (square x) i’ Done → Done sumS :: Num a ⇒ Stream a → a sumS (Stream step s) = go 0 s where go z s = case step s of Yield x s’ → go (z+x) s’ Done → z

Optimising stream operations sumsq xs = go 0 0 where step1 i = case i < length xs of True → Yield (xs ! i) (i+1) False → Done step2 i = case step1 i of Yield x i’ → Yield (square x) i’ Done → Done go z i = case step2 i of Yield x i’ → go (z+x) i’ Done → z

Optimising stream operations sumsq xs = go 0 0 where step1 i = case i < length xs of True → Yield (xs ! i) (i+1) False → Done step2 i = case step1 i of Yield x i’ → Yield (square x) i’ Done → Done go z i = case step2 i of Yield x i’ → go (z+x) i’ Done → z inline

Optimising stream operations sumsq xs = go 0 0 where step1 i = case i < length xs of True → Yield (xs ! i) (i+1) False → Done go z i = case (case step1 i Yield x i’ Done Yield x i’ → go Done → z

of → Yield (square x) i’ → Done) of (z+x) i’

Optimising stream operations sumsq xs = go 0 0 where step1 i = case i < length xs of True → Yield (xs ! i) (i+1) case of case False → Done go z i = case (case step1 i of Yield x i’ → Yield (square x) i’ Done → Done) of Yield x i’ → go (z+x) i’ Done → z

Optimising stream operations sumsq xs = go 0 0 where step1 i = case i < length xs of True → Yield (xs ! i) (i+1) False → Done go z i = case step1 i of Yield x i’ → go (z + square x) i’ Done → z

Optimising stream operations sumsq xs = go 0 0 where step1 i = case i < length xs of True → Yield (xs ! i) (i+1) False → Done go z i = case step1 i of Yield x i’ → go (z + square x) i’ Done → z inline

Optimising stream operations sumsq xs = go 0 0 where go z i = case (case i < length xs of True → Yield (xs ! i) (i+1) False → Done) of Yield x i’ → go (z + square x) i’ Done → z

Optimising stream operations case of case sumsq xs = go 0 0 where go z i = case (case i < length xs of True → Yield (xs ! i) (i+1) False → Done) of Yield x i’ → go (z + square x) i’ Done → z

Optimising stream operations sumsq xs = go 0 0 where go z i = case i < length xs of True → go (z + square (xs ! i)) (i+1) False → z

Optimising stream operations sumsq xs = go 0 0 where go z i = case i < length xs of True → go (z + square (xs ! i)) (i+1) False → z

optimal loop no Stream or Step values ever created only general-purpose optimisations will be optimised further (unboxing etc.) requires a great compiler (thanks GHC team!)

Why does it work? sumsq xs = go 0 0 where step1 i = case i < length xs of ive s r u True -r→ Yield (xs ! i) (i+1) c n e o n False → Done step2 i = case step1 i of Yield x i’ → Yield (square x) i’ Done → Done go z i = case step2 i of Yield x i’ → go (z+x) i’ Done → z

Why does it work? sumsq xs = go 0 0 where step1 i = case i < length xs of ive s r u True -r→ Yield (xs ! i) (i+1) c n e o n False → Done step2 i = case step1 i of ve ursi→ Yield (square x) i’ Yield -xreci’ non Done → Done go z i = case step2 i of Yield x i’ → go (z+x) i’ Done → z

Why does it work? sumsq xs = go 0 0 where step1 i = case i < length xs of ive s r u True -r→ Yield (xs ! i) (i+1) c n e o n False → Done step2 i = case step1 i of ve ursi→ Yield (square x) i’ Yield -xreci’ non Done → Done go z i = case step2 i of sive→ go (z+x) i’ Yield xcuri’ re Done → z

filterS :: (a → Bool) → Stream a → Stream a filterS f (Stream step s) = Stream step’ s where step’ s = case step s of Yield x s’ | f x → Yield x s’ | otherwise → step s’ Done → Done

filterS :: (a → Bool) → Stream a → Stream a filterS f (Stream step s) = Stream step’ s where step’ s = case step s of Yield x s’ | f x cursive → Yield x s’ re | otherwise → step s’ Done → Done

Extending streams Idea: allow a loop iteration not to produce an element

Extending streams Idea: allow a loop iteration not to produce an element data Step s a = Yield a s | Skip s | Done

Extending streams Idea: allow a loop iteration not to produce an element data Step s a = Yield a s | Skip s | Done filterS :: (a → Bool) → Stream a → Stream a filterS f (Stream step s) = Stream step’ s where step’ s = case step s of Yield x s’ | f x → Yield x s’ | otherwise → Skip s’ Skip s’ → Skip s’ Done → Done

Extending streams Idea: allow a loop iteration not to produce an element data Step s a = Yield a s | Skip s | Done filterS :: (a → Bool) → Stream a → Stream a filterS f (Stream step s) = Stream step’ s where step’ s = case step s of Yield x s’ | f x → Yield x s’ ive s r u | otherwise c → Skip s’ n-re Skip s’ no → Skip s’ Done → Done

encode loops by streams implement array operations in terms of streams eliminate stream/unstream pairs (temporaries) stream producers are non-recursive standard optimisations remove overhead (loop fusion)

encode loops by streams implement array operations in terms of streams eliminate stream/unstream pairs (temporaries) stream producers are non-recursive standard optimisations remove overhead (loop fusion)

Standard optimisations: inlining, case-of-case, worker/wrapper transformation, SpecConstr, LiberateCase, specialisation ...

Parallel loops

mapP :: (a → b) → Array a → Array b mapP f xs = <split xs across workers> <map f over each chunk>

DPH on multicores Evaluation strategy after vectorisation operations are data parallel and flat executed by a gang of worker threads essentially fork-join parallelism f is sequential mapP :: (a → b) → Array a → Array b mapP f xs = <split xs across workers> <map f over each chunk>

DPH on multicores Evaluation strategy after vectorisation operations are data parallel and flat executed by a gang of worker threads essentially fork-join parallelism

mapP :: (a → b) → Array a → Array b mapP f xs = <split xs across workers> <map f over each chunk> sumP :: Num a ⇒ Array a → a sumP xs = <split xs across workers> <sum each chunk>

sumsqP = sumP . mapP square

sumsqP xs = <split xs across workers> <map square over each chunk> <split results across workers> <sum each chunk>

sumsqP xs = <split xs across workers> <map square over each chunk> <split results across workers> <sum each chunk>

sumsqP xs = <split xs across workers> <map square over each chunk> <split results across workers> <sum each chunk>

Distributed types Idea: let’s make the evaluation strategy explicit! (Keller 1999) data Dist a Dist (Array a) Dist Double

a is distributed across threads each thread has a local array (chunk) each thread has a local Double

a is distributed across threads each thread has a local array (chunk) each thread has a local Double

splitD joinD

distribute an array across threads collect thread-local chunks

splitD :: Array a → Dist (Array a) joinD :: Dist (Array a) → Array a

a is distributed across threads each thread has a local array (chunk) each thread has a local Double

splitD joinD

distribute an array across threads collect thread-local chunks

mapD sumD

execute a sequential operation in each thread compute sum of local values

splitD joinD mapD sumD

:: :: :: ::

Array a → Dist (Array a) Dist (Array a) → Array a (a → b) → Dist a → Dist b Num a ⇒ Dist a → a

mapP f xs = <split xs across workers> <map f over each chunk>

mapP f = joinD . mapD (map f) . splitD

-- collect -- map f over chunks -- split

mapP f = joinD . mapD (map f) . splitD

-- collect -- map f over chunks -- split

sumP xs = <split xs across workers> <sum each chunk>

mapP f = joinD . mapD (map f) . splitD

-- collect -- map f over chunks -- split


-- reduce -- sum each chunk -- split

= sumD . mapD sum . splitD

sumsqP = sumP . mapP square

sumsqP = . . . . .

sumD mapD sum splitD joinD mapD (map square) splitD


reduce sum each chunk split collect map square over chunks split

sumsqP = . . . . .

sumD mapD sum splitD joinD mapD (map square) splitD

RULES splitD (joinD xs) = xs


reduce sum each chunk split collect map square over chunks split

sumsqP = . . .

sumD mapD sum mapD (map square) splitD

RULES splitD (joinD xs) = xs


reduce sum each chunk map square over chunks split

sumsqP = . . .

sumD mapD sum mapD (map square) splitD


reduce sum each chunk map square over chunks split

RULES splitD (joinD xs) = xs mapD f (mapD g xs) = mapD (f . g) xs

sumsqP = sumD -- reduce . mapD (sum . map square) -- work . splitD -- split

RULES splitD (joinD xs) = xs mapD f (mapD g xs) = mapD (f . g) xs

sumsqP = sumD -- reduce . mapD (sum . map square) -- work . splitD -- split stream fusion

RULES splitD (joinD xs) = xs mapD f (mapD g xs) = mapD (f . g) xs

Distributed types on multicores

data Dist a

a is distributed across threads

splitD joinD mapD

distribute xs across threads collect thread-local chunks execute a sequential operation in each thread

splitD/joinD mapD/mapD

eliminate communication eliminate synchronisation

Distributed types on clusters

data Dist a

a is distributed across nodes

splitD joinD mapD

scatter gather execute operation on each node

splitD/joinD mapD/mapD

eliminate communication eliminate synchronisation

Distributed types on GPUs

data Dist a

a is in GPU memory

splitD joinD mapD

CPU −→ GPU transfer GPU −→ CPU transfer execute kernel on the GPU

splitD/joinD mapD/mapD

eliminate memory transfers (communication) fuse kernels (synchronisation)

Distribured types – summary

encode parallel loops as split/work/join eliminate unnecessary split/join pairs fuse sequential work (stream fusion) very general mechanism for fusing parallel computations applicable to a wide range of architectures again, no specialised compiler support




Obligatory benchmark


sumsq, Haskell dotp, C

sumsq, C smvm, Haskell


dotp, Haskell smvm, C

Runtime @ greyarea 10000








sumsq, Haskell dotp, C


sumsq, C smvm, Haskell


dotp, Haskell smvm, C



Parting thoughts it’s nice, it’s easy to use, it works high-level functional programs compiled to highly efficient code even parallel ones! rewrite rules + great optimiser = win DPH doesn’t require any special-purpose optimisations try this in an imperative language...

Stream fusion: dph, bytestring, vector, uvector Distributed types: dph

