Implementing graph to tree conversion using Haskell

I began to become quite enamored with Haskell recently. At the very least, even if you can't use it on your current projects - because of the boss, because of legacy code, or because you just can't understand it well enough - you can always use it as an endless source of brain-teasers and puzzles.

Also, since I'm recently switched to XMonad, at least some knowledge of haskell is a must. By the way, I'm extremely happy with XMonad, but that is a theme for a separate blog post :)

Currently, I'm exploring various typeclasses (State, Reader, Arrow, etc) and sometimes try to code some small snippets using them. For example, to practice using State monad, I implemented a method to extract a tree from graph using DFS. Obviously, in such operation you need to maintain set of visited nodes somewhere, thus State seems to be a good fit. Here's the code in Haskell, and the same code in Scala (written in more "traditional" style):
import qualified Data.Map as Map
import qualified Data.Set as Set
import Data.List
import Text.Regex
import Control.Monad.State
import Control.Applicative
-- some setup for the types
type Graph a = Map.Map a [a]
data Tree a = Tree
              { root :: a
              , children :: [Tree a]
              } deriving (Eq, Ord)
instance (Show a) => (Show (Tree a)) where
  show tree =
    show (root tree)
    concatMap ("\n"++) 
      (map (intercalate "\n" . map ("  " ++) . splitRegex (mkRegex "\n") . show) (children tree))
toTree :: (Ord a) => Graph a -> a -> State (Set.Set a) (Tree a)
toTree graph node = do
    visited <- get
    (Tree node . reverse) <$>
        (\siblingTrees child ->
          siblingTrees >>= \strs -> (:strs) <$> toTree graph child)
        (return [] <* modify (Set.insert node))
        (filter (`Set.notMember` visited) $ Map.findWithDefault [] node graph)
mapFst :: (a -> b) -> (a, c) -> (b, c)
mapFst fn (a, c) = (fn a, c)
mapSnd :: (b -> c) -> (a, b) -> (a, c)
mapSnd fn (a, b) = (a, fn b)
exampleGraph :: Map.Map String [String]
exampleGraph = Map.fromList
  [ ("tree", ["branch"])
  , ("branch", ["apple", "banana"])
  , ("apple", ["tree"])
main :: IO ()
main = print $ evalState (toTree exampleGraph "tree") Set.empty
case class Tree[A](root: A, children: List[Tree[A]])
def toTree[A](root: A, graph: Map[A, List[A]]): Tree[A] = {
  def visit(node: A, graph: Map[A, List[A]], visited: Set[A]): (Tree[A], Set[A]) = {
    graph.get(node) match {
      case None => (Tree(node, Nil), visited + node)
      case Some(children) =>
        val (childTrees, newVisited) =
            .foldLeft((List.empty[Tree[A]], visited + node)){ case ((siblingTrees, visitedSoFar), child) =>
              val (childTree, newVisited) = visit(child, graph, visitedSoFar)
              (childTree :: siblingTrees, newVisited)
        (Tree(node, childTrees.reverse), newVisited)
  visit(root, graph, Set())._1
Note that scala version is almost twice as long (if we don't look at all the setup code from haskell version).

As a side note: if you need to quickly understand what Arrow typeclass means, I definitely recommend A Brutal Introduction to Arrows by Christopher Lane Hinson.


Popular posts from this blog

How to create your own simple 3D render engine in pure Java

Solving quadruple dependency injection problem in Angular

Configuration objects in Scallop