Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
236 changes: 131 additions & 105 deletions src/operator/tensor/la_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
* Copyright (c) 2017 by Contributors
* \file la_op-inl.h
* \brief Operators for advanced linear algebra.
* \note See https://arxiv.org/pdf/1710.08717.pdf for details of gradient computations.
*/
#ifndef MXNET_OPERATOR_TENSOR_LA_OP_INL_H_
#define MXNET_OPERATOR_TENSOR_LA_OP_INL_H_
Expand All @@ -32,20 +33,29 @@ namespace op {

using namespace mshadow;

// Helper functions.
struct CopyLowerToUpper {
// Copies lower/upper triangular part to upper/lower, i.e. to the opposite side.
struct CopyTriangularToOppositeSide {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, int matrix_size, int stride, DType* data) {
MSHADOW_XINLINE static void Map(int i, int matrix_size, int stride, DType* data, bool to_lower) {
// Below computation works even when we are dealing with a batch of matrices.
const int row((i % matrix_size) / stride), col(i % stride);
if ( row > col ) data[i + (col - row) * (stride - 1)] = data[i];
if (row > col) {
if (to_lower) {
data[i] = data[i + (col - row) * (stride - 1)];
} else {
data[i + (col - row) * (stride - 1)] = data[i];
}
}
}
};
struct ZeroUpper {

// Zero's lower/upper triangular part of a matrix.
struct ZeroTriangular {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, int matrix_size, int stride, DType* data) {
MSHADOW_XINLINE static void Map(int i, int matrix_size, int stride, DType* data,
bool zero_lower) {
const int row((i % matrix_size) / stride), col(i % stride);
if ( row < col ) data[i] = 0;
if ((!zero_lower && (row < col)) || (zero_lower && (row > col))) data[i] = 0;
}
};
struct Scale {
Expand Down Expand Up @@ -103,87 +113,91 @@ struct gemm2 {
}
};

// L = potrf(A).
// B = potrf(A).
struct potrf {
template<typename xpu, typename DType>
static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& L,
static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& B,
Stream<xpu> *s, const nnvm::NodeAttrs& attrs) {
if ( A.dptr_ != L.dptr_ ) Copy(L, A, s);
linalg_batch_potrf(L, true, s);
const LaCholeskyParam& param = nnvm::get<LaCholeskyParam>(attrs.parsed);
if ( A.dptr_ != B.dptr_ ) Copy(B, A, s);
linalg_batch_potrf(B, param.lower, s);
using namespace mxnet_op;
Kernel<ZeroUpper, xpu>::Launch(s, L.MSize(), L.size(1)*L.stride_, L.stride_, L.dptr_);
Kernel<ZeroTriangular, xpu>::Launch(s, B.MSize(), B.size(1)*B.stride_, B.stride_,
B.dptr_, !param.lower);
}
template<typename xpu, typename DType>
static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& L,
static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& B,
const OpContext& ctx, const nnvm::NodeAttrs& attrs) {
Stream<xpu> *s = ctx.get_stream<xpu>();
op(A, L, s, attrs);
op(A, B, s, attrs);
}
};

// A = potri(L).
// A = potri(B).
struct potri {
template<typename xpu, typename DType>
static void op(const Tensor<xpu, 3, DType>& L, const Tensor<xpu, 3, DType>& A,
static void op(const Tensor<xpu, 3, DType>& B, const Tensor<xpu, 3, DType>& A,
Stream<xpu> *s, const nnvm::NodeAttrs& attrs) {
if ( A.dptr_ != L.dptr_ ) Copy(A, L, s);
linalg_batch_potri(A, true, s);
const LaCholeskyParam& param = nnvm::get<LaCholeskyParam>(attrs.parsed);
if ( A.dptr_ != B.dptr_ ) Copy(A, B, s);
linalg_batch_potri(A, param.lower, s);
using namespace mxnet_op;
Kernel<CopyLowerToUpper, xpu>::Launch(s, A.MSize(), A.size(1)*A.stride_, A.stride_, A.dptr_);
Kernel<CopyTriangularToOppositeSide, xpu>::Launch(s, A.MSize(), A.size(1)*A.stride_, A.stride_,
A.dptr_, !param.lower);
}
template<typename xpu, typename DType>
static void op(const Tensor<xpu, 3, DType>& L, const Tensor<xpu, 3, DType>& A,
static void op(const Tensor<xpu, 3, DType>& B, const Tensor<xpu, 3, DType>& A,
const OpContext& ctx, const nnvm::NodeAttrs& attrs) {
Stream<xpu> *s = ctx.get_stream<xpu>();
op(L, A, s, attrs);
op(B, A, s, attrs);
}
};

// B = trsm(L,A)
// C = trsm(A,B)
struct trsm {
template<typename xpu, typename DType>
static void op(const Tensor<xpu, 3, DType>& L, const Tensor<xpu, 3, DType>& B,
DType alpha, bool rightside, bool transpose, Stream<xpu> *s) {
linalg_batch_trsm(L, B, alpha, rightside, true, transpose, s);
static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& C,
DType alpha, bool rightside, bool lower, bool transpose, Stream<xpu> *s) {
linalg_batch_trsm(A, C, alpha, rightside, lower, transpose, s);
}
template<typename xpu, typename DType>
static void op(const Tensor<xpu, 3, DType>& L, const Tensor<xpu, 3, DType>& A,
const Tensor<xpu, 3, DType>& B,
static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& B,
const Tensor<xpu, 3, DType>& C,
Stream<xpu> *s, const nnvm::NodeAttrs& attrs) {
if ( A.dptr_ != B.dptr_ ) Copy(B, A, s);
if ( B.dptr_ != C.dptr_ ) Copy(C, B, s);
const LaTriangMatrixMultParam& param = nnvm::get<LaTriangMatrixMultParam>(attrs.parsed);
op(L, B, DType(param.alpha), param.rightside, param.transpose, s);
op(A, C, DType(param.alpha), param.rightside, param.lower, param.transpose, s);
}
template<typename xpu, typename DType>
static void op(const Tensor<xpu, 3, DType>& L, const Tensor<xpu, 3, DType>& A,
const Tensor<xpu, 3, DType>& B,
static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& B,
const Tensor<xpu, 3, DType>& C,
const OpContext& ctx, const nnvm::NodeAttrs& attrs) {
Stream<xpu> *s = ctx.get_stream<xpu>();
op(L, A, B, s, attrs);
op(A, B, C, s, attrs);
}
};

// B = trmm(L,A)
// C = trmm(A,B)
struct trmm {
template<typename xpu, typename DType>
static void op(const Tensor<xpu, 3, DType>& L, const Tensor<xpu, 3, DType>& B,
DType alpha, bool rightside, bool transpose, Stream<xpu> *s) {
linalg_batch_trmm(L, B, alpha, rightside, true, transpose, s);
static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& C,
DType alpha, bool rightside, bool lower, bool transpose, Stream<xpu> *s) {
linalg_batch_trmm(A, C, alpha, rightside, lower, transpose, s);
}
template<typename xpu, typename DType>
static void op(const Tensor<xpu, 3, DType>& L, const Tensor<xpu, 3, DType>& A,
const Tensor<xpu, 3, DType>& B, Stream<xpu> *s,
static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& B,
const Tensor<xpu, 3, DType>& C, Stream<xpu> *s,
const nnvm::NodeAttrs& attrs) {
if ( A.dptr_ != B.dptr_ ) Copy(B, A, s);
if ( B.dptr_ != C.dptr_ ) Copy(C, B, s);
const LaTriangMatrixMultParam& param = nnvm::get<LaTriangMatrixMultParam>(attrs.parsed);
op(L, B, DType(param.alpha), param.rightside, param.transpose, s);
op(A, C, DType(param.alpha), param.rightside, param.lower, param.transpose, s);
}
template<typename xpu, typename DType>
static void op(const Tensor<xpu, 3, DType>& L, const Tensor<xpu, 3, DType>& A,
const Tensor<xpu, 3, DType>& B, const OpContext& ctx,
static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& B,
const Tensor<xpu, 3, DType>& C, const OpContext& ctx,
const nnvm::NodeAttrs& attrs) {
Stream<xpu> *s = ctx.get_stream<xpu>();
op(L, A, B, s, attrs);
op(A, B, C, s, attrs);
}
};

Expand Down Expand Up @@ -223,8 +237,8 @@ struct syrk {
linalg_batch_syrk(A, B, alpha, beta, tA, s);
// Symmetric B is in lower triangle: Copy to upper
using namespace mxnet_op;
Kernel<CopyLowerToUpper, xpu>::Launch(s, B.MSize(), B.size(1)*B.stride_,
B.stride_, B.dptr_);
Kernel<CopyTriangularToOppositeSide, xpu>::Launch(s, B.MSize(), B.size(1)*B.stride_,
B.stride_, B.dptr_, false);
}
template<typename xpu, typename DType>
static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& B,
Expand Down Expand Up @@ -276,8 +290,8 @@ struct gelqf {
Tensor<xpu, 2, DType> QLeft(Qi.dptr_, Shape2(m, m), Qi.stride_, s);
Copy(Li, QLeft, s);
using namespace mxnet_op;
Kernel<ZeroUpper, xpu>::Launch(s, Li.MSize(), m*Li.stride_, Li.stride_,
Li.dptr_);
Kernel<ZeroTriangular, xpu>::Launch(s, Li.MSize(), m*Li.stride_, Li.stride_,
Li.dptr_, false);
// Call orglq: Input is Qi and part of work. Overwrites Qi by final Q
// matrix (conversion from internal representation)
linalg_orglq(Qi, work, s);
Expand Down Expand Up @@ -395,117 +409,129 @@ struct gemm2_backward {

struct potrf_backward {
template<typename xpu, typename DType>
static void op(const Tensor<xpu, 3, DType>& dL, const Tensor<xpu, 3, DType>& L,
static void op(const Tensor<xpu, 3, DType>& dB, const Tensor<xpu, 3, DType>& B,
const Tensor<xpu, 3, DType>& dA,
Stream<xpu>* s, const nnvm::NodeAttrs& attrs) {
// Backward of L = potrf(A).
// dA = 0.5 * L**T * copyLTU(L**T * dL) * L**(-1)
// Backward of B = potrf(A).
// dA = 0.5 * B**T * copyLTU(B**T * dB) * B**(-1)
// Here, copyLTU(M) creates a symmetric matrix from the square matrix M
// by setting the upper triangle to be equal to the lower triangle, leaving
// lower triangle and diagonal unchanged.
if ( dL.dptr_ != dA.dptr_ ) {
Copy(dA, dL, s);
// The function also handles the case when B is upper triangular by appropriate
// transpositions.
const LaCholeskyParam& param = nnvm::get<LaCholeskyParam>(attrs.parsed);
if ( dB.dptr_ != dA.dptr_ ) {
Copy(dA, dB, s);
}
trmm::op(L, dA, DType(1.0), false, true, s);
trmm::op(B, dA, DType(1.0), !param.lower, param.lower, true, s);
using namespace mxnet_op;
Kernel<CopyLowerToUpper, xpu>::Launch
(s, dA.MSize(), dA.size(1)*dA.stride_, dA.stride_, dA.dptr_);
trsm::op(L, dA, DType(1.0), false, true, s);
trsm::op(L, dA, DType(0.5), true, false, s);
Kernel<CopyTriangularToOppositeSide, xpu>::Launch
(s, dA.MSize(), dA.size(1)*dA.stride_, dA.stride_, dA.dptr_, !param.lower);
trsm::op(B, dA, DType(1.0), false, param.lower, param.lower, s);
trsm::op(B, dA, DType(0.5), true, param.lower, !param.lower, s);
}
template<typename xpu, typename DType>
static void op(const Tensor<xpu, 3, DType>& dL, const Tensor<xpu, 3, DType>& L,
static void op(const Tensor<xpu, 3, DType>& dB, const Tensor<xpu, 3, DType>& B,
const Tensor<xpu, 3, DType>& dA,
const OpContext& ctx, const nnvm::NodeAttrs& attrs) {
Stream<xpu> *s = ctx.get_stream<xpu>();
op(dL, L, dA, s, attrs);
op(dB, B, dA, s, attrs);
}
};

struct potri_backward {
template<typename xpu, typename DType>
static void op(const Tensor<xpu, 3, DType>& dA, const Tensor<xpu, 3, DType>& L,
const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& dL,
static void op(const Tensor<xpu, 3, DType>& dA, const Tensor<xpu, 3, DType>& B,
const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& dB,
Stream<xpu>* s, const nnvm::NodeAttrs& attrs) {
// Backward of A = potri(L).
// dL = -tril( A * (dA + dA**T) * L**(-T)), where tril() extracts lower triangle
// Backward of A = potri(B).
// dB = -tril( A * (dA + dA**T) * B**(-T)), where tril() extracts lower triangle
// and diagonal. We must not assume that dA is symmetric.
// The function also handles the case when B is upper triangular by appropriate
// transpositions.
// Note: Calling gemm twice here is a bit wasteful, but otherwise the symmetrization
// of dA would require temporary memory.
gemm::op(A, dA, dL, DType(1.), DType(0.), false, false, s);
gemm::op(A, dA, dL, DType(1.), DType(1.), false, true, s);
trsm::op(L, dL, DType(-1.), true, true, s);
const LaCholeskyParam& param = nnvm::get<LaCholeskyParam>(attrs.parsed);
if (param.lower) {
gemm::op(A, dA, dB, DType(1.), DType(0.), false, false, s);
gemm::op(A, dA, dB, DType(1.), DType(1.), false, true, s);
} else {
gemm::op(dA, A, dB, DType(1.), DType(0.), false, false, s);
gemm::op(dA, A, dB, DType(1.), DType(1.), true, false, s);
}
trsm::op(B, dB, DType(-1.), param.lower, param.lower, true, s);
using namespace mxnet_op;
Kernel<ZeroUpper, xpu>::Launch(s, dL.MSize(), dL.size(1)*dL.stride_, dL.stride_,
dL.dptr_);
Kernel<ZeroTriangular, xpu>::Launch(s, dB.MSize(), dB.size(1)*dB.stride_, dB.stride_,
dB.dptr_, !param.lower);
}
template<typename xpu, typename DType>
static void op(const Tensor<xpu, 3, DType>& dA, const Tensor<xpu, 3, DType>& L,
const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& dL,
static void op(const Tensor<xpu, 3, DType>& dA, const Tensor<xpu, 3, DType>& B,
const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& dB,
const OpContext& ctx, const nnvm::NodeAttrs& attrs) {
Stream<xpu> *s = ctx.get_stream<xpu>();
op(dA, L, A, dL, s, attrs);
op(dA, B, A, dB, s, attrs);
}
};

struct trsm_backward {
template<typename xpu, typename DType>
static void op(const Tensor<xpu, 3, DType>& dB, const Tensor<xpu, 3, DType>& L,
const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& B,
const Tensor<xpu, 3, DType>& dL, const Tensor<xpu, 3, DType>& dA,
static void op(const Tensor<xpu, 3, DType>& dC, const Tensor<xpu, 3, DType>& A,
const Tensor<xpu, 3, DType>& B, const Tensor<xpu, 3, DType>& C,
const Tensor<xpu, 3, DType>& dA, const Tensor<xpu, 3, DType>& dB,
Stream<xpu>* s, const nnvm::NodeAttrs& attrs) {
// Backward of B = trsm(L,A).
// Backward of C = trsm(A,B).
const LaTriangMatrixMultParam& param = nnvm::get<LaTriangMatrixMultParam>(attrs.parsed);
// Compute dB
if ( dB.dptr_ != dC.dptr_ ) Copy(dB, dC, s);
trsm::op(A, dB, DType(param.alpha), param.rightside, param.lower, !param.transpose, s);
// Compute dA
if ( dA.dptr_ != dB.dptr_ ) Copy(dA, dB, s);
trsm::op(L, dA, DType(param.alpha), param.rightside, !param.transpose, s);
// Compute dL
const bool da_left(param.rightside == param.transpose);
DType scale(-1.0/param.alpha);
(da_left ? gemm::op(dA, B, dL, scale, DType(0), param.transpose, !param.transpose, s)
: gemm::op(B, dA, dL, scale, DType(0), !param.transpose, param.transpose, s));
(da_left ? gemm::op(dB, C, dA, scale, DType(0), param.transpose, !param.transpose, s)
: gemm::op(C, dB, dA, scale, DType(0), !param.transpose, param.transpose, s));
using namespace mxnet_op;
Kernel<ZeroUpper, xpu>::Launch(s, dL.MSize(), dL.size(1)*dL.stride_, dL.stride_, dL.dptr_);
Kernel<ZeroTriangular, xpu>::Launch(s, dA.MSize(), dA.size(1)*dA.stride_, dA.stride_,
dA.dptr_, !param.lower);
}
template<typename xpu, typename DType>
static void op(const Tensor<xpu, 3, DType>& dB, const Tensor<xpu, 3, DType>& L,
const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& B,
const Tensor<xpu, 3, DType>& dL, const Tensor<xpu, 3, DType>& dA,
static void op(const Tensor<xpu, 3, DType>& dC, const Tensor<xpu, 3, DType>& A,
const Tensor<xpu, 3, DType>& B, const Tensor<xpu, 3, DType>& C,
const Tensor<xpu, 3, DType>& dA, const Tensor<xpu, 3, DType>& dB,
const OpContext& ctx, const nnvm::NodeAttrs& attrs) {
Stream<xpu> *s = ctx.get_stream<xpu>();
op(dB, L, A, B, dL, dA, s, attrs);
op(dC, A, B, C, dA, dB, s, attrs);
}
};

struct trmm_backward {
template<typename xpu, typename DType>
static void op(const Tensor<xpu, 3, DType>& dB, const Tensor<xpu, 3, DType>& L,
const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& dL,
const Tensor<xpu, 3, DType>& dA, Stream<xpu>* s,
static void op(const Tensor<xpu, 3, DType>& dC, const Tensor<xpu, 3, DType>& A,
const Tensor<xpu, 3, DType>& B, const Tensor<xpu, 3, DType>& dA,
const Tensor<xpu, 3, DType>& dB, Stream<xpu>* s,
const nnvm::NodeAttrs& attrs) {
// Backward of B = trmm(L,A).
// Backward of C = trmm(A,B).
const LaTriangMatrixMultParam& param = nnvm::get<LaTriangMatrixMultParam>(attrs.parsed);
// Compute dL
// Compute dA
DType scale(param.alpha);
if (param.rightside == param.transpose) {
gemm::op(dB, A, dL, scale, DType(0.), param.transpose, !param.transpose, s);
gemm::op(dC, B, dA, scale, DType(0.), param.transpose, !param.transpose, s);
} else {
gemm::op(A, dB, dL, scale, DType(0.), !param.transpose, param.transpose, s);
gemm::op(B, dC, dA, scale, DType(0.), !param.transpose, param.transpose, s);
}
using namespace mxnet_op;
Kernel<ZeroUpper, xpu>::Launch(s, dL.MSize(), dL.size(1)*dL.stride_, dL.stride_,
dL.dptr_);
// Compute dA
if (dA.dptr_ != dB.dptr_) Copy(dA, dB, s);
trmm::op(L, dA, scale, param.rightside, !param.transpose, s);
Kernel<ZeroTriangular, xpu>::Launch(s, dA.MSize(), dA.size(1)*dA.stride_, dA.stride_,
dA.dptr_, !param.lower);
// Compute dB
if (dB.dptr_ != dC.dptr_) Copy(dB, dC, s);
trmm::op(A, dB, scale, param.rightside, param.lower, !param.transpose, s);
}
template<typename xpu, typename DType>
static void op(const Tensor<xpu, 3, DType>& dB, const Tensor<xpu, 3, DType>& L,
const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& dL,
const Tensor<xpu, 3, DType>& dA, const OpContext& ctx,
static void op(const Tensor<xpu, 3, DType>& dC, const Tensor<xpu, 3, DType>& A,
const Tensor<xpu, 3, DType>& B, const Tensor<xpu, 3, DType>& dA,
const Tensor<xpu, 3, DType>& dB, const OpContext& ctx,
const nnvm::NodeAttrs& attrs) {
Stream<xpu> *s = ctx.get_stream<xpu>();
op(dB, L, A, dL, dA, s, attrs);
op(dC, A, B, dA, dB, s, attrs);
}
};

Expand Down Expand Up @@ -586,13 +612,13 @@ struct gelqf_backward {
Tensor<xpu, 3, DType> tempM = ctx.requested[0]
.get_space_typed<xpu, 3, DType>(dL.shape_, s);
Copy(tempM, dL, s);
trmm::op(L, tempM, DType(1.0), false, true, s);
trmm::op(L, tempM, DType(1.0), false, true, true, s);
gemm::op(dA, Q, tempM, DType(-1.0), DType(1.0), false, true, s);
Kernel<CopyLowerToUpper, xpu>::Launch
Kernel<CopyTriangularToOppositeSide, xpu>::Launch
(s, tempM.MSize(), tempM.size(1)*tempM.stride_, tempM.stride_,
tempM.dptr_);
tempM.dptr_, false);
gemm::op(tempM, Q, dA, DType(1.0), DType(1.0), false, false, s);
trsm::op(L, dA, DType(1.0), false, true, s);
trsm::op(L, dA, DType(1.0), false, true, true, s);
}
};

Expand Down
Loading