203 lines
5.0 KiB
Go
203 lines
5.0 KiB
Go
// Copyright ©2021 The Gonum Authors. All rights reserved.
|
||
// Use of this source code is governed by a BSD-style
|
||
// license that can be found in the LICENSE file.
|
||
|
||
package gonum
|
||
|
||
import (
|
||
"math"
|
||
|
||
"gonum.org/v1/gonum/blas"
|
||
"gonum.org/v1/gonum/blas/blas64"
|
||
)
|
||
|
||
// Dpstf2 computes the Cholesky factorization with complete pivoting of an n×n
|
||
// symmetric positive semidefinite matrix A.
|
||
//
|
||
// The factorization has the form
|
||
//
|
||
// Pᵀ * A * P = Uᵀ * U , if uplo = blas.Upper,
|
||
// Pᵀ * A * P = L * Lᵀ, if uplo = blas.Lower,
|
||
//
|
||
// where U is an upper triangular matrix, L is lower triangular, and P is a
|
||
// permutation matrix.
|
||
//
|
||
// tol is a user-defined tolerance. The algorithm terminates if the pivot is
|
||
// less than or equal to tol. If tol is negative, then n*eps*max(A[k,k]) will be
|
||
// used instead.
|
||
//
|
||
// On return, A contains the factor U or L from the Cholesky factorization and
|
||
// piv contains P stored such that P[piv[k],k] = 1.
|
||
//
|
||
// Dpstf2 returns the computed rank of A and whether the factorization can be
|
||
// used to solve a system. Dpstf2 does not attempt to check that A is positive
|
||
// semi-definite, so if ok is false, the matrix A is either rank deficient or is
|
||
// not positive semidefinite.
|
||
//
|
||
// The length of piv must be n and the length of work must be at least 2*n,
|
||
// otherwise Dpstf2 will panic.
|
||
//
|
||
// Dpstf2 is an internal routine. It is exported for testing purposes.
|
||
func (Implementation) Dpstf2(uplo blas.Uplo, n int, a []float64, lda int, piv []int, tol float64, work []float64) (rank int, ok bool) {
|
||
switch {
|
||
case uplo != blas.Upper && uplo != blas.Lower:
|
||
panic(badUplo)
|
||
case n < 0:
|
||
panic(nLT0)
|
||
case lda < max(1, n):
|
||
panic(badLdA)
|
||
}
|
||
|
||
// Quick return if possible.
|
||
if n == 0 {
|
||
return 0, true
|
||
}
|
||
|
||
switch {
|
||
case len(a) < (n-1)*lda+n:
|
||
panic(shortA)
|
||
case len(piv) != n:
|
||
panic(badLenPiv)
|
||
case len(work) < 2*n:
|
||
panic(shortWork)
|
||
}
|
||
|
||
// Initialize piv.
|
||
for i := range piv[:n] {
|
||
piv[i] = i
|
||
}
|
||
|
||
// Compute the first pivot.
|
||
pvt := 0
|
||
ajj := a[0]
|
||
for i := 1; i < n; i++ {
|
||
aii := a[i*lda+i]
|
||
if aii > ajj {
|
||
pvt = i
|
||
ajj = aii
|
||
}
|
||
}
|
||
if ajj <= 0 || math.IsNaN(ajj) {
|
||
return 0, false
|
||
}
|
||
|
||
// Compute stopping value if not supplied.
|
||
dstop := tol
|
||
if dstop < 0 {
|
||
dstop = float64(n) * dlamchE * ajj
|
||
}
|
||
|
||
// Set first half of work to zero, holds dot products.
|
||
dots := work[:n]
|
||
for i := range dots {
|
||
dots[i] = 0
|
||
}
|
||
work2 := work[n : 2*n]
|
||
|
||
bi := blas64.Implementation()
|
||
if uplo == blas.Upper {
|
||
// Compute the Cholesky factorization Pᵀ * A * P = Uᵀ * U.
|
||
for j := 0; j < n; j++ {
|
||
// Update dot products and compute possible pivots which are stored
|
||
// in the second half of work.
|
||
for i := j; i < n; i++ {
|
||
if j > 0 {
|
||
tmp := a[(j-1)*lda+i]
|
||
dots[i] += tmp * tmp
|
||
}
|
||
work2[i] = a[i*lda+i] - dots[i]
|
||
}
|
||
if j > 0 {
|
||
// Find the pivot.
|
||
pvt = j
|
||
ajj = work2[pvt]
|
||
for k := j + 1; k < n; k++ {
|
||
wk := work2[k]
|
||
if wk > ajj {
|
||
pvt = k
|
||
ajj = wk
|
||
}
|
||
}
|
||
// Test for exit.
|
||
if ajj <= dstop || math.IsNaN(ajj) {
|
||
a[j*lda+j] = ajj
|
||
return j, false
|
||
}
|
||
}
|
||
if j != pvt {
|
||
// Swap pivot rows and columns.
|
||
a[pvt*lda+pvt] = a[j*lda+j]
|
||
bi.Dswap(j, a[j:], lda, a[pvt:], lda)
|
||
if pvt < n-1 {
|
||
bi.Dswap(n-pvt-1, a[j*lda+(pvt+1):], 1, a[pvt*lda+(pvt+1):], 1)
|
||
}
|
||
bi.Dswap(pvt-j-1, a[j*lda+(j+1):], 1, a[(j+1)*lda+pvt:], lda)
|
||
// Swap dot products and piv.
|
||
dots[j], dots[pvt] = dots[pvt], dots[j]
|
||
piv[j], piv[pvt] = piv[pvt], piv[j]
|
||
}
|
||
ajj = math.Sqrt(ajj)
|
||
a[j*lda+j] = ajj
|
||
// Compute elements j+1:n of row j.
|
||
if j < n-1 {
|
||
bi.Dgemv(blas.Trans, j, n-j-1,
|
||
-1, a[j+1:], lda, a[j:], lda,
|
||
1, a[j*lda+j+1:], 1)
|
||
bi.Dscal(n-j-1, 1/ajj, a[j*lda+j+1:], 1)
|
||
}
|
||
}
|
||
} else {
|
||
// Compute the Cholesky factorization Pᵀ * A * P = L * Lᵀ.
|
||
for j := 0; j < n; j++ {
|
||
// Update dot products and compute possible pivots which are stored
|
||
// in the second half of work.
|
||
for i := j; i < n; i++ {
|
||
if j > 0 {
|
||
tmp := a[i*lda+(j-1)]
|
||
dots[i] += tmp * tmp
|
||
}
|
||
work2[i] = a[i*lda+i] - dots[i]
|
||
}
|
||
if j > 0 {
|
||
// Find the pivot.
|
||
pvt = j
|
||
ajj = work2[pvt]
|
||
for k := j + 1; k < n; k++ {
|
||
wk := work2[k]
|
||
if wk > ajj {
|
||
pvt = k
|
||
ajj = wk
|
||
}
|
||
}
|
||
// Test for exit.
|
||
if ajj <= dstop || math.IsNaN(ajj) {
|
||
a[j*lda+j] = ajj
|
||
return j, false
|
||
}
|
||
}
|
||
if j != pvt {
|
||
// Swap pivot rows and columns.
|
||
a[pvt*lda+pvt] = a[j*lda+j]
|
||
bi.Dswap(j, a[j*lda:], 1, a[pvt*lda:], 1)
|
||
if pvt < n-1 {
|
||
bi.Dswap(n-pvt-1, a[(pvt+1)*lda+j:], lda, a[(pvt+1)*lda+pvt:], lda)
|
||
}
|
||
bi.Dswap(pvt-j-1, a[(j+1)*lda+j:], lda, a[pvt*lda+(j+1):], 1)
|
||
// Swap dot products and piv.
|
||
dots[j], dots[pvt] = dots[pvt], dots[j]
|
||
piv[j], piv[pvt] = piv[pvt], piv[j]
|
||
}
|
||
ajj = math.Sqrt(ajj)
|
||
a[j*lda+j] = ajj
|
||
// Compute elements j+1:n of column j.
|
||
if j < n-1 {
|
||
bi.Dgemv(blas.NoTrans, n-j-1, j,
|
||
-1, a[(j+1)*lda:], lda, a[j*lda:], 1,
|
||
1, a[(j+1)*lda+j:], lda)
|
||
bi.Dscal(n-j-1, 1/ajj, a[(j+1)*lda+j:], lda)
|
||
}
|
||
}
|
||
}
|
||
return n, true
|
||
}
|