元素代數表示式的基本表達模板

介紹和動機

表達模板 ( 在下面表示為 ET )是一種功能強大的模板超程式設計技術,用於加速有時非常昂貴的表示式的計算。它廣泛用於不同的領域,例如線上性代數庫的實現中。

對於此示例,請考慮線性代數計算的上下文。更具體地,僅涉及逐元素操作的計算。這種計算是 ET 的最基本的應用,它們可以很好地介紹 ET 如何在內部工作。

讓我們來看一個激勵人心的例子。考慮表示式的計算:

Vector vec_1, vec_2, vec_3;

// Initializing vec_1, vec_2 and vec_3.

Vector result = vec_1 + vec_2*vec_3;

這裡為了簡單起見,我假設類 Vector 和操作+(向量加:元素加操作)和操作*(這裡表示向量內積:也是元素操作)都正確實現,如它們應該如何,數學上。

在不使用 ET (或其他類似技術) 的傳統實現中,至少發生了五個 Vector 例項的構造以獲得最終的 result

  1. 對應於 vec_1vec_2vec_3 的三個例項。
  2. 一個臨時的 Vector 例項 _tmp,代表 _tmp = vec_2*vec_3; 的結果。
  3. 最後正確使用返回值優化,在 result = vec_1 + _tmp; 中構建最終的 result

使用 ET 的實現可以消除 2 中臨時 Vector _tmp 的建立,因此只留下 Vector 例項的四個構造。更有趣的是,請考慮以下更復雜的表示式:

Vector result = vec_1 + (vec_2*vec3 + vec_1)*(vec_2 + vec_3*vec_1);

總共還有四個 Vector 個例項:vec_1, vec_2, vec_3result。換句話說,在此示例中,僅涉及按元素操作,保證不會從中間計算建立臨時物件

ET 如何運作

基本上,任何代數計算的 ET 都包含兩個構建塊:

  1. 純代數表示式PAE ):它們是代數表示式的代理/抽象。純代數不進行實際計算,它們僅僅是計算工作流的抽象/建模。PAE 可以是任何代數表示式的輸入或輸出的模型。 PAE 的例項通常被認為是便宜的複製。
  2. 懶惰評估 :這是實際計算的實現。在下面的示例中,我們將看到對於僅涉及逐元素操作的表示式,延遲評估可以在最終結果的索引訪問操作內實現實際計算,從而建立按需評估方案:不執行計算只有在訪問/要求最終結果時。

那麼,具體如何在這個例子中實現 ET ?我們現在來看看吧。

請始終考慮以下程式碼段:

Vector vec_1, vec_2, vec_3;

// Initializing vec_1, vec_2 and vec_3.

Vector result = vec_1 + vec_2*vec_3;

計算結果的表示式可以進一步分解為兩個子表示式:

  1. 向量加表示式(表示為 plus_expr
  2. 向量內積表示式(表示為 innerprod_expr )。

什麼外星人做的是以下幾點:

  • ET 不是立即計算每個子表示式,而是首先使用圖形結構對整個表示式進行建模。圖中的每個節點代表 PAE 。節點的邊緣連線表示實際的計算流程。因此,對於上面的表示式,我們獲得以下圖表:

           result = plus_expr( vec_1, innerprod_expr(vec_2, vec_3) )
              /   \
             /     \
            /       \
           /   innerprod_expr( vec_2, vec_3 )
          /         /  \
         /         /    \
        /         /      \
     vec_1     vec_2    vec_3
    
  • 最後的計算是通過檢視圖層次結構來實現的 :因為這裡我們只處理逐元素操作,result 中每個索引值的計算可以獨立完成result 的最終評估可以被懶惰地推遲到元素 - 對這個元素的明智評價 19。換句話說,由於 resultelem_res 的元素的計算可以使用 vec_1elem_1),vec_2elem_2)和 vec_3elem_3)中的相應元素表示為:

    elem_res = elem_1 + elem_2*elem_3;
    

因此,不需要建立臨時 Vector 來儲存中間內積的結果: 一個元素的整個計算可以完全完成,並在索引訪問操作中編碼

以下是實際操作中的示例程式碼

檔案 vec.hh:std::vector 的包裝器,用於在呼叫構造時顯示日誌

#ifndef EXPR_VEC
# define EXPR_VEC

# include <vector>
# include <cassert>
# include <utility>
# include <iostream>
# include <algorithm>
# include <functional>

///
/// This is a wrapper for std::vector. It's only purpose is to print out a log when a
/// vector constructions in called.
/// It wraps the indexed access operator [] and the size() method, which are 
/// important for later ETs implementation.
///

// std::vector wrapper.
template<typename ScalarType> class Vector
{
public:
  explicit Vector() { std::cout << "ctor called.\n"; };
  explicit Vector(int size): _vec(size) { std::cout << "ctor called.\n"; };
  explicit Vector(const std::vector<ScalarType> &vec): _vec(vec)
  { std::cout << "ctor called.\n"; };
  
  Vector(const Vector<ScalarType> & vec): _vec{vec()}
  { std::cout << "copy ctor called.\n"; };
  Vector(Vector<ScalarType> && vec): _vec(std::move(vec()))
  { std::cout << "move ctor called.\n"; };

  Vector<ScalarType> & operator=(const Vector<ScalarType> &) = default;
  Vector<ScalarType> & operator=(Vector<ScalarType> &&) = default;

  decltype(auto) operator[](int indx) { return _vec[indx]; }
  decltype(auto) operator[](int indx) const { return _vec[indx]; }

  decltype(auto) operator()() & { return (_vec); };        
  decltype(auto) operator()() const & { return (_vec); };  
  Vector<ScalarType> && operator()() && { return std::move(*this); }

  int size() const { return _vec.size(); }
  
private:
  std::vector<ScalarType> _vec;
};

///
/// These are conventional overloads of operator + (the vector plus operation)
/// and operator * (the vector inner product operation) without using the expression
/// templates. They are later used for bench-marking purpose.
///

// + (vector plus) operator.
template<typename ScalarType>
auto operator+(const Vector<ScalarType> &lhs, const Vector<ScalarType> &rhs)
{
  assert(lhs().size() == rhs().size() &&
         "error: ops plus -> lhs and rhs size mismatch.");
  
  std::vector<ScalarType> _vec;
  _vec.resize(lhs().size());
  std::transform(std::cbegin(lhs()), std::cend(lhs()),
                 std::cbegin(rhs()), std::begin(_vec),
                 std::plus<>());
  return Vector<ScalarType>(std::move(_vec));
}

// * (vector inner product) operator.
template<typename ScalarType>
auto operator*(const Vector<ScalarType> &lhs, const Vector<ScalarType> &rhs)
{
  assert(lhs().size() == rhs().size() &&
         "error: ops multiplies -> lhs and rhs size mismatch.");
  
  std::vector<ScalarType> _vec;
  _vec.resize(lhs().size());
  std::transform(std::cbegin(lhs()), std::cend(lhs()),
                 std::cbegin(rhs()), std::begin(_vec),
                 std::multiplies<>());
  return Vector<ScalarType>(std::move(_vec));
}

#endif //!EXPR_VEC

File expr.hh:用於逐元素操作的表示式模板的實現(vector plus 和 vector inner product)

讓我們把它分解成各個部分。

  1. 第 1 節為所有表示式實現了一個基類。它採用了奇怪的重複模板模式CRTP )。
  2. 第 2 節實現了第一個 PAE :一個終端,它只是一個輸入資料結構的包裝器(const 引用),包含用於計算的實際輸入值。
  3. 第 3 節實現了第二個 PAEbinary_operation ,它是一個稍後用於 vector_plus 和 vector_innerprod 的類别範本。它由操作型別左側 PAE右側 PAE 引數化。實際計算在索引訪問運算子中編碼。
  4. 第 4 節將 vector_plus 和 vector_innerprod 操作定義為元素操作。它還會為 PAE s 過載 operator +和* :這樣這兩個操作也會返回 PAE
#ifndef EXPR_EXPR
# define EXPR_EXPR
      

/// Fwd declaration.
template<typename> class Vector;

namespace expr
{

/// -----------------------------------------
///
/// Section 1.
///
/// The first section is a base class template for all kinds of expression. It         
/// employs the Curiously Recurring Template Pattern, which enables its instantiation 
/// to any kind of expression structure inheriting from it.
///
/// -----------------------------------------

  /// Base class for all expressions.
  template<typename Expr> class expr_base
  {
  public:
    const Expr& self() const { return static_cast<const Expr&>(*this); }
    Expr& self() { return static_cast<Expr&>(*this); }

  protected:
    explicit expr_base() {};
    int size() const { return self().size_impl(); }
    auto operator[](int indx) const { return self().at_impl(indx); }
    auto operator()() const { return self()(); };
  };
  

/// -----------------------------------------
///
/// The following section 2 & 3 are abstractions of pure algebraic expressions (PAE).
/// Any PAE can be converted to a real object instance using operator(): it is in 
/// this conversion process, where the real computations are done.

///
/// Section 2. Terminal
///
/// A terminal is an abstraction wrapping a const reference to the Vector data 
/// structure. It inherits from expr_base, therefore providing a unified interface
/// wrapping a Vector into a PAE.
///
/// It provides the size() method, indexed access through at_impl() and a conversion
/// to referenced object through () operator.
/// 
/// It might no be necessary for user defined data structures to have a terminal 
/// wrapper, since user defined structure can inherit expr_base, therefore eliminates
/// the need to provide such terminal wrapper. 
///
/// -----------------------------------------

  /// Generic wrapper for underlying data structure.
  template<typename DataType> class terminal: expr_base<terminal<DataType>>
  {
  public:
    using base_type = expr_base<terminal<DataType>>;
    using base_type::size;
    using base_type::operator[];
    friend base_type;
    
    explicit terminal(const DataType &val): _val(val) {}
    int size_impl() const { return _val.size(); };
    auto at_impl(int indx) const { return _val[indx]; };
    decltype(auto) operator()() const { return (_val); }
    
  private:
    const DataType &_val;
  };

/// -----------------------------------------
///
/// Section 3. Binary operation expression.
///
/// This is a PAE abstraction of any binary expression. Similarly it inherits from 
/// expr_base.
///
/// It provides the size() method, indexed access through at_impl() and a conversion
/// to referenced object through () operator. Each call to the at_impl() method is
/// a element wise computation.
/// 
/// -----------------------------------------

  /// Generic wrapper for binary operations (that are element-wise).
  template<typename Ops, typename lExpr, typename rExpr>
  class binary_ops: public expr_base<binary_ops<Ops,lExpr,rExpr>>
  {
  public:
    using base_type = expr_base<binary_ops<Ops,lExpr,rExpr>>;
    using base_type::size;
    using base_type::operator[];
    friend base_type;
    
    explicit binary_ops(const Ops &ops, const lExpr &lxpr, const rExpr &rxpr)
      : _ops(ops), _lxpr(lxpr), _rxpr(rxpr) {};
    int size_impl() const { return _lxpr.size(); };

    /// This does the element-wise computation for index indx.
    auto at_impl(int indx) const { return _ops(_lxpr[indx], _rxpr[indx]); };

    /// Conversion from arbitrary expr to concrete data type. It evaluates
    /// element-wise computations for all indices.
    template<typename DataType> operator DataType()
    {
      DataType _vec(size());
      for(int _ind = 0; _ind < _vec.size(); ++_ind)
        _vec[_ind] = (*this)[_ind];
      return _vec;
    }
    
  private: /// Ops and expr are assumed cheap to copy.
    Ops   _ops;
    lExpr _lxpr;
    rExpr _rxpr;
  };

/// -----------------------------------------
/// Section 4.
///
/// The following two structs defines algebraic operations on PAEs: here only vector 
/// plus and vector inner product are implemented. 
///
/// First, some element-wise operations are defined : in other words, vec_plus and 
/// vec_prod acts on elements in Vectors, but not whole Vectors. 
///
/// Then, operator + & * are overloaded on PAEs, such that: + & * operations on PAEs         
/// also return PAEs.
///
/// -----------------------------------------

  /// Element-wise plus operation.
  struct vec_plus_t
  {
    constexpr explicit vec_plus_t() = default; 
    template<typename LType, typename RType>
    auto operator()(const LType &lhs, const RType &rhs) const
    { return lhs+rhs; }
  };
  
  /// Element-wise inner product operation.
  struct vec_prod_t
  {
    constexpr explicit vec_prod_t() = default; 
    template<typename LType, typename RType>
    auto operator()(const LType &lhs, const RType &rhs) const
    { return lhs*rhs; }
  };
  
  /// Constant plus and inner product operator objects.
  constexpr vec_plus_t vec_plus{};
  constexpr vec_prod_t vec_prod{};
  
  /// Plus operator overload on expressions: return binary expression.
  template<typename lExpr, typename rExpr>
  auto operator+(const lExpr &lhs, const rExpr &rhs)
  { return binary_ops<vec_plus_t,lExpr,rExpr>(vec_plus,lhs,rhs); }
  
  /// Inner prod operator overload on expressions: return binary expression.
  template<typename lExpr, typename rExpr>
  auto operator*(const lExpr &lhs, const rExpr &rhs)
  { return binary_ops<vec_prod_t,lExpr,rExpr>(vec_prod,lhs,rhs); }
  
} //!expr

#endif //!EXPR_EXPR

檔案 main.cc:測試 src 檔案

# include <chrono>
# include <iomanip>
# include <iostream>
# include "vec.hh"
# include "expr.hh"
# include "boost/core/demangle.hpp"

int main()
{
  using dtype = float;
  constexpr int size = 5e7;
  
  std::vector<dtype> _vec1(size);
  std::vector<dtype> _vec2(size);
  std::vector<dtype> _vec3(size);

  // ... Initialize vectors' contents.

  Vector<dtype> vec1(std::move(_vec1));
  Vector<dtype> vec2(std::move(_vec2));
  Vector<dtype> vec3(std::move(_vec3));

  unsigned long start_ms_no_ets =
    std::chrono::duration_cast<std::chrono::milliseconds>
    (std::chrono::system_clock::now().time_since_epoch()).count();
  std::cout << "\nNo-ETs evaluation starts.\n";
  
  Vector<dtype> result_no_ets = vec1 + (vec2*vec3);
  
  unsigned long stop_ms_no_ets =
    std::chrono::duration_cast<std::chrono::milliseconds>
    (std::chrono::system_clock::now().time_since_epoch()).count();
  std::cout << std::setprecision(6) << std::fixed
            << "No-ETs. Time eclapses: " << (stop_ms_no_ets-start_ms_no_ets)/1000.0
            << " s.\n" << std::endl;
  
  unsigned long start_ms_ets =
    std::chrono::duration_cast<std::chrono::milliseconds>
    (std::chrono::system_clock::now().time_since_epoch()).count();
  std::cout << "Evaluation using ETs starts.\n";
  
  expr::terminal<Vector<dtype>> vec4(vec1);
  expr::terminal<Vector<dtype>> vec5(vec2);
  expr::terminal<Vector<dtype>> vec6(vec3);
  
  Vector<dtype> result_ets = (vec4 + vec5*vec6);
  
  unsigned long stop_ms_ets =
    std::chrono::duration_cast<std::chrono::milliseconds>
    (std::chrono::system_clock::now().time_since_epoch()).count();
  std::cout << std::setprecision(6) << std::fixed
            << "With ETs. Time eclapses: " << (stop_ms_ets-start_ms_ets)/1000.0
            << " s.\n" << std::endl;
  
  auto ets_ret_type = (vec4 + vec5*vec6);
  std::cout << "\nETs result's type:\n";
  std::cout << boost::core::demangle( typeid(decltype(ets_ret_type)).name() ) << '\n'; 

  return 0;
}

使用 GCC 5.3 使用 -O3 -std=c++14 編譯時,這是一個可能的輸出:

ctor called.
ctor called.
ctor called.

No-ETs evaluation starts.
ctor called.
ctor called.
No-ETs. Time eclapses: 0.571000 s.

Evaluation using ETs starts.
ctor called.
With ETs. Time eclapses: 0.164000 s.

ETs result's type:
expr::binary_ops<expr::vec_plus_t, expr::terminal<Vector<float> >, expr::binary_ops<expr::vec_prod_t, expr::terminal<Vector<float> >, expr::terminal<Vector<float> > > >

觀察結果如下:

  • **在這種情況下,**使用 ET 可以實現相當顯著的效能提升 (> 3x)。 ****
  • 消除了臨時 Vector 物件的建立。與 ETs 一樣,ctor 只被呼叫一次。
  • Boost::demangle 用於視覺化轉換前 ET 返回的型別:它清楚地構建了與上面演示的完全相同的表達圖。

缺點和警告

  • ET 的一個明顯缺點是學習曲線,實施的複雜性和程式碼維護難度。在上面僅考慮元素操作的示例中,實現包含了大量的樣板,更不用說在現實世界中,每個計算中都會出現更復雜的代數表示式,並且元素方面的獨立性不再成立(例如矩陣乘法) ),難度將是指數級的。

  • 使用 ET 的另一個警告是它們與 auto 關鍵字配合得很好。如上所述, PAE 本質上是代理:並且代理基本上不能與 auto 一起使用。請考慮以下示例:

     auto result = ...;                // Some expensive expression: 
                                       // auto returns the expr graph, 
                                       // NOT the computed value.
     for(auto i = 0; i < 100; ++i)
         ScalrType value = result* ... // Some other expensive computations using result.
    

for 迴圈的每次迭代中,將重新計算結果,因為表示式圖形而不是計算值被傳遞給 for 迴圈。

實現 ET 的現有庫 ****

  • boost::proto 是一個功能強大的庫,允許你為自己的表示式定義自己的規則和語法,並使用 ET 執行。
  • Eigen 是一個線性代數庫,可以使用 ET 有效地實現各種代數計算。