127 lines
		
	
	
		
			2.7 KiB
		
	
	
	
		
			Go
		
	
	
			
		
		
	
	
			127 lines
		
	
	
		
			2.7 KiB
		
	
	
	
		
			Go
		
	
	
| // Copyright ©2014 The Gonum Authors. All rights reserved.
 | |
| // Use of this source code is governed by a BSD-style
 | |
| // license that can be found in the LICENSE file.
 | |
| 
 | |
| package mat
 | |
| 
 | |
| import (
 | |
| 	"gonum.org/v1/gonum/blas"
 | |
| 	"gonum.org/v1/gonum/blas/blas64"
 | |
| 	"gonum.org/v1/gonum/internal/asm/f64"
 | |
| )
 | |
| 
 | |
| // Inner computes the generalized inner product
 | |
| //
 | |
| //	xᵀ A y
 | |
| //
 | |
| // between the vectors x and y with matrix A, where x and y are treated as
 | |
| // column vectors.
 | |
| //
 | |
| // This is only a true inner product if A is symmetric positive definite, though
 | |
| // the operation works for any matrix A.
 | |
| //
 | |
| // Inner panics if x.Len != m or y.Len != n when A is an m x n matrix.
 | |
| func Inner(x Vector, a Matrix, y Vector) float64 {
 | |
| 	m, n := a.Dims()
 | |
| 	if x.Len() != m {
 | |
| 		panic(ErrShape)
 | |
| 	}
 | |
| 	if y.Len() != n {
 | |
| 		panic(ErrShape)
 | |
| 	}
 | |
| 	if m == 0 || n == 0 {
 | |
| 		return 0
 | |
| 	}
 | |
| 
 | |
| 	var sum float64
 | |
| 
 | |
| 	switch a := a.(type) {
 | |
| 	case RawSymmetricer:
 | |
| 		amat := a.RawSymmetric()
 | |
| 		if amat.Uplo != blas.Upper {
 | |
| 			// Panic as a string not a mat.Error.
 | |
| 			panic(badSymTriangle)
 | |
| 		}
 | |
| 		var xmat, ymat blas64.Vector
 | |
| 		if xrv, ok := x.(RawVectorer); ok {
 | |
| 			xmat = xrv.RawVector()
 | |
| 		} else {
 | |
| 			break
 | |
| 		}
 | |
| 		if yrv, ok := y.(RawVectorer); ok {
 | |
| 			ymat = yrv.RawVector()
 | |
| 		} else {
 | |
| 			break
 | |
| 		}
 | |
| 		for i := 0; i < x.Len(); i++ {
 | |
| 			xi := x.AtVec(i)
 | |
| 			if xi != 0 {
 | |
| 				if ymat.Inc == 1 {
 | |
| 					sum += xi * f64.DotUnitary(
 | |
| 						amat.Data[i*amat.Stride+i:i*amat.Stride+n],
 | |
| 						ymat.Data[i:],
 | |
| 					)
 | |
| 				} else {
 | |
| 					sum += xi * f64.DotInc(
 | |
| 						amat.Data[i*amat.Stride+i:i*amat.Stride+n],
 | |
| 						ymat.Data[i*ymat.Inc:], uintptr(n-i),
 | |
| 						1, uintptr(ymat.Inc),
 | |
| 						0, 0,
 | |
| 					)
 | |
| 				}
 | |
| 			}
 | |
| 			yi := y.AtVec(i)
 | |
| 			if i != n-1 && yi != 0 {
 | |
| 				if xmat.Inc == 1 {
 | |
| 					sum += yi * f64.DotUnitary(
 | |
| 						amat.Data[i*amat.Stride+i+1:i*amat.Stride+n],
 | |
| 						xmat.Data[i+1:],
 | |
| 					)
 | |
| 				} else {
 | |
| 					sum += yi * f64.DotInc(
 | |
| 						amat.Data[i*amat.Stride+i+1:i*amat.Stride+n],
 | |
| 						xmat.Data[(i+1)*xmat.Inc:], uintptr(n-i-1),
 | |
| 						1, uintptr(xmat.Inc),
 | |
| 						0, 0,
 | |
| 					)
 | |
| 				}
 | |
| 			}
 | |
| 		}
 | |
| 		return sum
 | |
| 	case RawMatrixer:
 | |
| 		amat := a.RawMatrix()
 | |
| 		var ymat blas64.Vector
 | |
| 		if yrv, ok := y.(RawVectorer); ok {
 | |
| 			ymat = yrv.RawVector()
 | |
| 		} else {
 | |
| 			break
 | |
| 		}
 | |
| 		for i := 0; i < x.Len(); i++ {
 | |
| 			xi := x.AtVec(i)
 | |
| 			if xi != 0 {
 | |
| 				if ymat.Inc == 1 {
 | |
| 					sum += xi * f64.DotUnitary(
 | |
| 						amat.Data[i*amat.Stride:i*amat.Stride+n],
 | |
| 						ymat.Data,
 | |
| 					)
 | |
| 				} else {
 | |
| 					sum += xi * f64.DotInc(
 | |
| 						amat.Data[i*amat.Stride:i*amat.Stride+n],
 | |
| 						ymat.Data, uintptr(n),
 | |
| 						1, uintptr(ymat.Inc),
 | |
| 						0, 0,
 | |
| 					)
 | |
| 				}
 | |
| 			}
 | |
| 		}
 | |
| 		return sum
 | |
| 	}
 | |
| 	for i := 0; i < x.Len(); i++ {
 | |
| 		xi := x.AtVec(i)
 | |
| 		for j := 0; j < y.Len(); j++ {
 | |
| 			sum += xi * a.At(i, j) * y.AtVec(j)
 | |
| 		}
 | |
| 	}
 | |
| 	return sum
 | |
| }
 |