基本掃描用法

scan 用於在值列表上多次呼叫函式,該函式可能包含狀態。

scan 語法(截至 theano 0.9):

scan(
    fn,
    sequences=None,
    outputs_info=None,
    non_sequences=None,
    n_steps=None,
    truncate_gradient=-1,
    go_backwards=False,
    mode=None,
    name=None,
    profile=False,
    allow_gc=None,
    strict=False)

乍一看,這可能非常令人困惑。我們將在多個程式碼示例中解釋幾個基本但重要的 scan 用法。

以下程式碼示例假定你已執行匯入:

import numpy as np
import theano
import theano.tensor as T

sequences - 在列表上對映函式

在最簡單的情況下,掃描只是將純函式(沒有狀態的函式)對映到列表。列表在 sequences 引數中指定

  s_x = T.ivector()
  s_y, _ = theano.scan(
      fn = lambda x:x*x,
      sequences = [s_x])
  fn = theano.function([s_x], s_y)
  fn([1,2,3,4,5]) #[1,4,9,16,25]

注意 scan 有兩個返回值,前者是結果列表,後者是狀態值的更新,稍後將對此進行說明。

sequences - 在列表中壓縮函式

與上面幾乎相同,只需給 sequences 引數列出兩個元素。兩個元素的順序應該與 fn 中的引數順序相匹配

  s_x1 = T.ivector()
  s_x2 = T.ivector()
  s_y, _ = theano.scan(
      fn = lambda x1,x2:x1**x2,
      sequences = [s_x1, s_x2])
  fn = theano.function([s_x], s_y)
  fn([1,2,3,4,5],[0,1,2,3,4]) #[1,2,9,64,625]

outputs_info - 累積列表

累積涉及狀態變數。狀態變數需要初始值,該值應在 outputs_info 引數中指定。

  s_x = T.ivector()
  v_sum = th.shared(np.int32(0))
  s_y, update_sum = theano.scan(
      lambda x,y:x+y,
      sequences = [s_x],
      outputs_info = [s_sum])
  fn = theano.function([s_x], s_y, updates=update_sum)
  
  v_sum.get_value() # 0
  fn([1,2,3,4,5]) # [1,3,6,10,15]
  v_sum.get_value() # 15
  fn([-1,-2,-3,-4,-5]) # [14,12,9,5,0]
  v_sum.get_value() # 0

我們將一個共享變數放入 outputs_info,這將導致 scan 返回更新到我們的共享變數,然後可以將其放入 theano.function

non_sequencesn_steps - 物流地圖軌道 x -> lambda*x*(1-x)

你可以在 non_sequences 引數中提供在 scan 期間不會發生變化的輸入。在這種情況下,s_lambda 是一個不變的變數(但不是常量,因為它必須在執行時提供)。

  s_x = T.fscalar()
  s_lambda = T.fscalar()
  s_t = T.iscalar()
  s_y, _ = theano.scan(
      fn = lambda x,l: l*x*(1-x),
      outputs_info = [s_x],
      non_sequences = [s_lambda],
      n_steps = s_t
  )
  fn = theano.function([s_x, s_lambda, s_t], s_y)

  fn(.75, 4., 10) #a stable orbit

  #[ 0.75,  0.75,  0.75,  0.75,  0.75,  0.75,  0.75,  0.75,  0.75,  0.75]

  fn(.65, 4., 10) #a chaotic orbit

  #[ 0.91000003,  0.32759991,  0.88111287,  0.41901192,  0.97376364,
  # 0.10219204,  0.3669953 ,  0.92923898,  0.2630156 ,  0.77535355]

水龍頭 - 斐波那契

狀態/輸入可能有多個時間步長。這是通過:

  • dict(input=<init_value>, taps=<list of int>) 放入 sequences 論證中。

  • dict(initial=<init_value>, taps=<list of int>) 放入 outputs_info 論證中。

在這個例子中,我們使用 outputs_info 中的兩個抽頭來計算遞迴關係 x_n = x_{n-1} + x_{n-2}

s_x0 = T.iscalar()
s_x1 = T.iscalar()
s_n = T.iscalar()
s_y, _ = theano.scan(
    fn = lambda x1,x2: x1+x2,
    outputs_info = [dict(initial=T.join(0,[s_x0, s_x1]), taps=[-2,-1])],
    n_steps = s_n
)
fn_fib = theano.function([s_x0, s_x1, s_n], s_y)
fn_fib(1,1,10)
# [2, 3, 5, 8, 13, 21, 34, 55, 89, 144]