元素代数表达式的基本表达模板

介绍和动机

表达模板 ( 在下面表示为 ET )是一种功能强大的模板元编程技术,用于加速有时非常昂贵的表达式的计算。它广泛用于不同的领域,例如在线性代数库的实现中。

对于此示例,请考虑线性代数计算的上下文。更具体地,仅涉及逐元素操作的计算。这种计算是 ET 的最基本的应用,它们可以很好地介绍 ET 如何在内部工作。

让我们来看一个激励人心的例子。考虑表达式的计算:

Vector vec_1, vec_2, vec_3;

// Initializing vec_1, vec_2 and vec_3.

Vector result = vec_1 + vec_2*vec_3;

这里为了简单起见,我假设类 Vector 和操作+(向量加:元素加操作)和操作*(这里表示向量内积:也是元素操作)都正确实现,如它们应该如何,数学上。

在不使用 ET (或其他类似技术) 的传统实现中,至少发生了五个 Vector 实例的构造以获得最终的 result

  1. 对应于 vec_1vec_2vec_3 的三个实例。
  2. 一个临时的 Vector 实例 _tmp,代表 _tmp = vec_2*vec_3; 的结果。
  3. 最后正确使用返回值优化,在 result = vec_1 + _tmp; 中构建最终的 result

使用 ET 的实现可以消除 2 中临时 Vector _tmp 的创建,因此只留下 Vector 实例的四个构造。更有趣的是,请考虑以下更复杂的表达式:

Vector result = vec_1 + (vec_2*vec3 + vec_1)*(vec_2 + vec_3*vec_1);

总共还有四个 Vector 个实例:vec_1, vec_2, vec_3result。换句话说,在此示例中,仅涉及按元素操作,保证不会从中间计算创建临时对象

ET 如何运作

基本上,任何代数计算的 ET 都包含两个构建块:

  1. 纯代数表达式PAE ):它们是代数表达式的代理/抽象。纯代数不进行实际计算,它们仅仅是计算工作流的抽象/建模。PAE 可以是任何代数表达式的输入或输出的模型。 PAE 的实例通常被认为是便宜的复制。
  2. 懒惰评估 :这是实际计算的实现。在下面的示例中,我们将看到对于仅涉及逐元素操作的表达式,延迟评估可以在最终结果的索引访问操作内实现实际计算,从而创建按需评估方案:不执行计算只有在访问/要求最终结果时。

那么,具体如何在这个例子中实现 ET ?我们现在来看看吧。

请始终考虑以下代码段:

Vector vec_1, vec_2, vec_3;

// Initializing vec_1, vec_2 and vec_3.

Vector result = vec_1 + vec_2*vec_3;

计算结果的表达式可以进一步分解为两个子表达式:

  1. 向量加表达式(表示为 plus_expr
  2. 向量内积表达式(表示为 innerprod_expr )。

什么外星人做的是以下几点:

  • ET 不是立即计算每个子表达式,而是首先使用图形结构对整个表达式进行建模。图中的每个节点代表 PAE 。节点的边缘连接表示实际的计算流程。因此,对于上面的表达式,我们获得以下图表:

           result = plus_expr( vec_1, innerprod_expr(vec_2, vec_3) )
              /   \
             /     \
            /       \
           /   innerprod_expr( vec_2, vec_3 )
          /         /  \
         /         /    \
        /         /      \
     vec_1     vec_2    vec_3
    
  • 最后的计算是通过查看图层次结构来实现的 :因为这里我们只处理逐元素操作,result 中每个索引值的计算可以独立完成result 的最终评估可以被懒惰地推迟到元素 - 对这个元素的明智评价 19。换句话说,由于 resultelem_res 的元素的计算可以使用 vec_1elem_1),vec_2elem_2)和 vec_3elem_3)中的相应元素表示为:

    elem_res = elem_1 + elem_2*elem_3;
    

因此,不需要创建临时 Vector 来存储中间内积的结果: 一个元素的整个计算可以完全完成,并在索引访问操作中编码

以下是实际操作中的示例代码

文件 vec.hh:std::vector 的包装器,用于在调用构造时显示日志

#ifndef EXPR_VEC
# define EXPR_VEC

# include <vector>
# include <cassert>
# include <utility>
# include <iostream>
# include <algorithm>
# include <functional>

///
/// This is a wrapper for std::vector. It's only purpose is to print out a log when a
/// vector constructions in called.
/// It wraps the indexed access operator [] and the size() method, which are 
/// important for later ETs implementation.
///

// std::vector wrapper.
template<typename ScalarType> class Vector
{
public:
  explicit Vector() { std::cout << "ctor called.\n"; };
  explicit Vector(int size): _vec(size) { std::cout << "ctor called.\n"; };
  explicit Vector(const std::vector<ScalarType> &vec): _vec(vec)
  { std::cout << "ctor called.\n"; };
  
  Vector(const Vector<ScalarType> & vec): _vec{vec()}
  { std::cout << "copy ctor called.\n"; };
  Vector(Vector<ScalarType> && vec): _vec(std::move(vec()))
  { std::cout << "move ctor called.\n"; };

  Vector<ScalarType> & operator=(const Vector<ScalarType> &) = default;
  Vector<ScalarType> & operator=(Vector<ScalarType> &&) = default;

  decltype(auto) operator[](int indx) { return _vec[indx]; }
  decltype(auto) operator[](int indx) const { return _vec[indx]; }

  decltype(auto) operator()() & { return (_vec); };        
  decltype(auto) operator()() const & { return (_vec); };  
  Vector<ScalarType> && operator()() && { return std::move(*this); }

  int size() const { return _vec.size(); }
  
private:
  std::vector<ScalarType> _vec;
};

///
/// These are conventional overloads of operator + (the vector plus operation)
/// and operator * (the vector inner product operation) without using the expression
/// templates. They are later used for bench-marking purpose.
///

// + (vector plus) operator.
template<typename ScalarType>
auto operator+(const Vector<ScalarType> &lhs, const Vector<ScalarType> &rhs)
{
  assert(lhs().size() == rhs().size() &&
         "error: ops plus -> lhs and rhs size mismatch.");
  
  std::vector<ScalarType> _vec;
  _vec.resize(lhs().size());
  std::transform(std::cbegin(lhs()), std::cend(lhs()),
                 std::cbegin(rhs()), std::begin(_vec),
                 std::plus<>());
  return Vector<ScalarType>(std::move(_vec));
}

// * (vector inner product) operator.
template<typename ScalarType>
auto operator*(const Vector<ScalarType> &lhs, const Vector<ScalarType> &rhs)
{
  assert(lhs().size() == rhs().size() &&
         "error: ops multiplies -> lhs and rhs size mismatch.");
  
  std::vector<ScalarType> _vec;
  _vec.resize(lhs().size());
  std::transform(std::cbegin(lhs()), std::cend(lhs()),
                 std::cbegin(rhs()), std::begin(_vec),
                 std::multiplies<>());
  return Vector<ScalarType>(std::move(_vec));
}

#endif //!EXPR_VEC

File expr.hh:用于逐元素操作的表达式模板的实现(vector plus 和 vector inner product)

让我们把它分解成各个部分。

  1. 第 1 节为所有表达式实现了一个基类。它采用了奇怪的重复模板模式CRTP )。
  2. 第 2 节实现了第一个 PAE :一个终端,它只是一个输入数据结构的包装器(const 引用),包含用于计算的实际输入值。
  3. 第 3 节实现了第二个 PAEbinary_operation ,它是一个稍后用于 vector_plus 和 vector_innerprod 的类模板。它由操作类型左侧 PAE右侧 PAE 参数化。实际计算在索引访问运算符中编码。
  4. 第 4 节将 vector_plus 和 vector_innerprod 操作定义为元素操作。它还会为 PAE s 重载 operator +和* :这样这两个操作也会返回 PAE
#ifndef EXPR_EXPR
# define EXPR_EXPR
      

/// Fwd declaration.
template<typename> class Vector;

namespace expr
{

/// -----------------------------------------
///
/// Section 1.
///
/// The first section is a base class template for all kinds of expression. It         
/// employs the Curiously Recurring Template Pattern, which enables its instantiation 
/// to any kind of expression structure inheriting from it.
///
/// -----------------------------------------

  /// Base class for all expressions.
  template<typename Expr> class expr_base
  {
  public:
    const Expr& self() const { return static_cast<const Expr&>(*this); }
    Expr& self() { return static_cast<Expr&>(*this); }

  protected:
    explicit expr_base() {};
    int size() const { return self().size_impl(); }
    auto operator[](int indx) const { return self().at_impl(indx); }
    auto operator()() const { return self()(); };
  };
  

/// -----------------------------------------
///
/// The following section 2 & 3 are abstractions of pure algebraic expressions (PAE).
/// Any PAE can be converted to a real object instance using operator(): it is in 
/// this conversion process, where the real computations are done.

///
/// Section 2. Terminal
///
/// A terminal is an abstraction wrapping a const reference to the Vector data 
/// structure. It inherits from expr_base, therefore providing a unified interface
/// wrapping a Vector into a PAE.
///
/// It provides the size() method, indexed access through at_impl() and a conversion
/// to referenced object through () operator.
/// 
/// It might no be necessary for user defined data structures to have a terminal 
/// wrapper, since user defined structure can inherit expr_base, therefore eliminates
/// the need to provide such terminal wrapper. 
///
/// -----------------------------------------

  /// Generic wrapper for underlying data structure.
  template<typename DataType> class terminal: expr_base<terminal<DataType>>
  {
  public:
    using base_type = expr_base<terminal<DataType>>;
    using base_type::size;
    using base_type::operator[];
    friend base_type;
    
    explicit terminal(const DataType &val): _val(val) {}
    int size_impl() const { return _val.size(); };
    auto at_impl(int indx) const { return _val[indx]; };
    decltype(auto) operator()() const { return (_val); }
    
  private:
    const DataType &_val;
  };

/// -----------------------------------------
///
/// Section 3. Binary operation expression.
///
/// This is a PAE abstraction of any binary expression. Similarly it inherits from 
/// expr_base.
///
/// It provides the size() method, indexed access through at_impl() and a conversion
/// to referenced object through () operator. Each call to the at_impl() method is
/// a element wise computation.
/// 
/// -----------------------------------------

  /// Generic wrapper for binary operations (that are element-wise).
  template<typename Ops, typename lExpr, typename rExpr>
  class binary_ops: public expr_base<binary_ops<Ops,lExpr,rExpr>>
  {
  public:
    using base_type = expr_base<binary_ops<Ops,lExpr,rExpr>>;
    using base_type::size;
    using base_type::operator[];
    friend base_type;
    
    explicit binary_ops(const Ops &ops, const lExpr &lxpr, const rExpr &rxpr)
      : _ops(ops), _lxpr(lxpr), _rxpr(rxpr) {};
    int size_impl() const { return _lxpr.size(); };

    /// This does the element-wise computation for index indx.
    auto at_impl(int indx) const { return _ops(_lxpr[indx], _rxpr[indx]); };

    /// Conversion from arbitrary expr to concrete data type. It evaluates
    /// element-wise computations for all indices.
    template<typename DataType> operator DataType()
    {
      DataType _vec(size());
      for(int _ind = 0; _ind < _vec.size(); ++_ind)
        _vec[_ind] = (*this)[_ind];
      return _vec;
    }
    
  private: /// Ops and expr are assumed cheap to copy.
    Ops   _ops;
    lExpr _lxpr;
    rExpr _rxpr;
  };

/// -----------------------------------------
/// Section 4.
///
/// The following two structs defines algebraic operations on PAEs: here only vector 
/// plus and vector inner product are implemented. 
///
/// First, some element-wise operations are defined : in other words, vec_plus and 
/// vec_prod acts on elements in Vectors, but not whole Vectors. 
///
/// Then, operator + & * are overloaded on PAEs, such that: + & * operations on PAEs         
/// also return PAEs.
///
/// -----------------------------------------

  /// Element-wise plus operation.
  struct vec_plus_t
  {
    constexpr explicit vec_plus_t() = default; 
    template<typename LType, typename RType>
    auto operator()(const LType &lhs, const RType &rhs) const
    { return lhs+rhs; }
  };
  
  /// Element-wise inner product operation.
  struct vec_prod_t
  {
    constexpr explicit vec_prod_t() = default; 
    template<typename LType, typename RType>
    auto operator()(const LType &lhs, const RType &rhs) const
    { return lhs*rhs; }
  };
  
  /// Constant plus and inner product operator objects.
  constexpr vec_plus_t vec_plus{};
  constexpr vec_prod_t vec_prod{};
  
  /// Plus operator overload on expressions: return binary expression.
  template<typename lExpr, typename rExpr>
  auto operator+(const lExpr &lhs, const rExpr &rhs)
  { return binary_ops<vec_plus_t,lExpr,rExpr>(vec_plus,lhs,rhs); }
  
  /// Inner prod operator overload on expressions: return binary expression.
  template<typename lExpr, typename rExpr>
  auto operator*(const lExpr &lhs, const rExpr &rhs)
  { return binary_ops<vec_prod_t,lExpr,rExpr>(vec_prod,lhs,rhs); }
  
} //!expr

#endif //!EXPR_EXPR

文件 main.cc:测试 src 文件

# include <chrono>
# include <iomanip>
# include <iostream>
# include "vec.hh"
# include "expr.hh"
# include "boost/core/demangle.hpp"

int main()
{
  using dtype = float;
  constexpr int size = 5e7;
  
  std::vector<dtype> _vec1(size);
  std::vector<dtype> _vec2(size);
  std::vector<dtype> _vec3(size);

  // ... Initialize vectors' contents.

  Vector<dtype> vec1(std::move(_vec1));
  Vector<dtype> vec2(std::move(_vec2));
  Vector<dtype> vec3(std::move(_vec3));

  unsigned long start_ms_no_ets =
    std::chrono::duration_cast<std::chrono::milliseconds>
    (std::chrono::system_clock::now().time_since_epoch()).count();
  std::cout << "\nNo-ETs evaluation starts.\n";
  
  Vector<dtype> result_no_ets = vec1 + (vec2*vec3);
  
  unsigned long stop_ms_no_ets =
    std::chrono::duration_cast<std::chrono::milliseconds>
    (std::chrono::system_clock::now().time_since_epoch()).count();
  std::cout << std::setprecision(6) << std::fixed
            << "No-ETs. Time eclapses: " << (stop_ms_no_ets-start_ms_no_ets)/1000.0
            << " s.\n" << std::endl;
  
  unsigned long start_ms_ets =
    std::chrono::duration_cast<std::chrono::milliseconds>
    (std::chrono::system_clock::now().time_since_epoch()).count();
  std::cout << "Evaluation using ETs starts.\n";
  
  expr::terminal<Vector<dtype>> vec4(vec1);
  expr::terminal<Vector<dtype>> vec5(vec2);
  expr::terminal<Vector<dtype>> vec6(vec3);
  
  Vector<dtype> result_ets = (vec4 + vec5*vec6);
  
  unsigned long stop_ms_ets =
    std::chrono::duration_cast<std::chrono::milliseconds>
    (std::chrono::system_clock::now().time_since_epoch()).count();
  std::cout << std::setprecision(6) << std::fixed
            << "With ETs. Time eclapses: " << (stop_ms_ets-start_ms_ets)/1000.0
            << " s.\n" << std::endl;
  
  auto ets_ret_type = (vec4 + vec5*vec6);
  std::cout << "\nETs result's type:\n";
  std::cout << boost::core::demangle( typeid(decltype(ets_ret_type)).name() ) << '\n'; 

  return 0;
}

使用 GCC 5.3 使用 -O3 -std=c++14 编译时,这是一个可能的输出:

ctor called.
ctor called.
ctor called.

No-ETs evaluation starts.
ctor called.
ctor called.
No-ETs. Time eclapses: 0.571000 s.

Evaluation using ETs starts.
ctor called.
With ETs. Time eclapses: 0.164000 s.

ETs result's type:
expr::binary_ops<expr::vec_plus_t, expr::terminal<Vector<float> >, expr::binary_ops<expr::vec_prod_t, expr::terminal<Vector<float> >, expr::terminal<Vector<float> > > >

观察结果如下:

  • **在这种情况下,**使用 ET 可以实现相当显着的性能提升 (> 3x)。 ****
  • 消除了临时 Vector 对象的创建。与 ETs 一样,ctor 只被调用一次。
  • Boost::demangle 用于可视化转换前 ET 返回的类型:它清楚地构建了与上面演示的完全相同的表达图。

缺点和警告

  • ET 的一个明显缺点是学习曲线,实施的复杂性和代码维护难度。在上面仅考虑元素操作的示例中,实现包含了大量的样板,更不用说在现实世界中,每个计算中都会出现更复杂的代数表达式,并且元素方面的独立性不再成立(例如矩阵乘法) ),难度将是指数级的。

  • 使用 ET 的另一个警告是它们与 auto 关键字配合得很好。如上所述, PAE 本质上是代理:并且代理基本上不能与 auto 一起使用。请考虑以下示例:

     auto result = ...;                // Some expensive expression: 
                                       // auto returns the expr graph, 
                                       // NOT the computed value.
     for(auto i = 0; i < 100; ++i)
         ScalrType value = result* ... // Some other expensive computations using result.
    

for 循环的每次迭代中,将重新计算结果,因为表达式图形而不是计算值被传递给 for 循环。

实现 ET 的现有库 ****

  • boost::proto 是一个功能强大的库,允许你为自己的表达式定义自己的规则和语法,并使用 ET 执行。
  • Eigen 是一个线性代数库,可以使用 ET 有效地实现各种代数计算。