使用计数器为树的节点编号

我们有这样的树数据类型:

data Tree a = Tree a [Tree a] deriving Show

我们希望编写一个函数,从递增计数器为树的每个节点分配一个数字:

tag::Tree a -> Tree (a, Int)

漫漫长路

首先我们要做很长一段时间,因为它很好地说明了 State monad 的低级机制。

import Control.Monad.State

-- Function that numbers the nodes of a `Tree`.
tag::Tree a -> Tree (a, Int)
tag tree = 
    -- tagStep is where the action happens.  This just gets the ball
    -- rolling, with `0` as the initial counter value.
    evalState (tagStep tree) 0

-- This is one monadic "step" of the calculation.  It assumes that
-- it has access to the current counter value implicitly.
tagStep::Tree a -> State Int (Tree (a, Int))
tagStep (Tree a subtrees) = do
    -- The `get::State s s` action accesses the implicit state
    -- parameter of the State monad.  Here we bind that value to
    -- the variable `counter`.
    counter <- get 

    -- The `put::s -> State s ()` sets the implicit state parameter
    -- of the `State` monad.  The next `get` that we execute will see
    -- the value of `counter + 1` (assuming no other puts in between).
    put (counter + 1)

    -- Recurse into the subtrees.  `mapM` is a utility function
    -- for executing a monadic actions (like `tagStep`) on a list of
    -- elements, and producing the list of results.  Each execution of 
    -- `tagStep` will be executed with the counter value that resulted
    -- from the previous list element's execution.
    subtrees' <- mapM tagStep subtrees  

    return $ Tree (a, counter) subtrees'

重构

将计数器拆分为 postIncrement 操作

我们正在使用当前计数器然后 putting counter + 1 的位可以分成一个 postIncrement 动作,类似于许多 C 风格的语言提供的:

postIncrement::Enum s => State s s
postIncrement = do
    result <- get
    modify succ
    return result

将树步行拆分为高阶函数

树行走逻辑可以拆分为自己的函数,如下所示:

mapTreeM::Monad m => (a -> m b) -> Tree a -> m (Tree b)
mapTreeM action (Tree a subtrees) = do
    a' <- action a
    subtrees' <- mapM (mapTreeM action) subtrees
    return $ Tree a' subtrees'

有了这个和 postIncrement 功能,我们可以重写 tagStep

tagStep::Tree a -> State Int (Tree (a, Int))
tagStep = mapTreeM step
    where step::a -> State Int (a, Int)
          step a = do 
              counter <- postIncrement
              return (a, counter)

使用 Traversable

上面的 mapTreeM 解决方案可以很容易地重写为 Traversable 类的一个实例 :

instance Traversable Tree where
    traverse action (Tree a subtrees) = 
        Tree <$> action a <*> traverse action subtrees

请注意,这要求我们使用 Applicative<*> 运算符)而不是 Monad

有了它,现在我们可以像专业人士一样写 tag

tag::Traversable t => t a -> t (a, Int)
tag init t = evalState (traverse step t) 0
    where step a = do tag <- postIncrement
                      return (a, tag)

请注意,这适用于任何 Traversable 类型,而不仅仅是我们的 Tree 类型!

摆脱 Traversable 样板

GHC 有一个 DeriveTraversable 扩展,无需编写上面的实例:

{-# LANGUAGE DeriveFunctor, DeriveFoldable, DeriveTraversable #-}

data Tree a = Tree a [Tree a]
            deriving (Show, Functor, Foldable, Traversable)