map and analytic funcions

This commit is contained in:
Andrei Kislitsyn 2021-03-21 19:05:11 +03:00
parent b7e1349ead
commit fa78ed1f45
7 changed files with 122 additions and 77 deletions

View File

@ -14,10 +14,10 @@ public interface LinearOpsTensorAlgebra<T, TensorType : TensorStructure<T>, Inde
public fun TensorType.qr(): TensorType
//https://pytorch.org/docs/stable/generated/torch.lu.html
public fun TensorType.lu(): Pair<TensorType, IndexTensorType>
public fun TensorType.lu(tol: T): Pair<TensorType, IndexTensorType>
//https://pytorch.org/docs/stable/generated/torch.lu_unpack.html
public fun luPivot(lu: TensorType, pivots: IndexTensorType): Triple<TensorType, TensorType, TensorType>
public fun luPivot(luTensor: TensorType, pivotsTensor: IndexTensorType): Triple<TensorType, TensorType, TensorType>
//https://pytorch.org/docs/stable/linalg.html#torch.linalg.svd
public fun TensorType.svd(): Triple<TensorType, TensorType, TensorType>

View File

@ -3,6 +3,11 @@ package space.kscience.kmath.tensors
// https://proofwiki.org/wiki/Definition:Algebra_over_Ring
public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
public fun TensorType.map(transform: (T) -> T): TensorType
public fun TensorType.eq(other: TensorType, eqFunction: (T, T) -> Boolean): Boolean
public fun TensorType.contentEquals(other: TensorType, eqFunction: (T, T) -> Boolean): Boolean
//https://pytorch.org/docs/stable/generated/torch.full.html
public fun full(value: T, shape: IntArray): TensorType

View File

@ -5,6 +5,7 @@ import space.kscience.kmath.nd.*
import space.kscience.kmath.structures.*
import space.kscience.kmath.tensors.TensorStrides
import space.kscience.kmath.tensors.TensorStructure
import kotlin.math.atanh
public open class BufferedTensor<T>(

View File

@ -7,77 +7,40 @@ public class DoubleAnalyticTensorAlgebra:
AnalyticTensorAlgebra<Double, DoubleTensor>,
DoubleOrderedTensorAlgebra()
{
override fun DoubleTensor.exp(): DoubleTensor {
TODO("Not yet implemented")
}
override fun DoubleTensor.exp(): DoubleTensor = this.map(::exp)
override fun DoubleTensor.log(): DoubleTensor {
TODO("Not yet implemented")
}
// todo log with other base????
override fun DoubleTensor.log(): DoubleTensor = this.map(::ln)
override fun DoubleTensor.sqrt(): DoubleTensor {
TODO("Not yet implemented")
}
override fun DoubleTensor.sqrt(): DoubleTensor = this.map(::sqrt)
override fun DoubleTensor.cos(): DoubleTensor {
TODO("Not yet implemented")
}
override fun DoubleTensor.cos(): DoubleTensor = this.map(::cos)
override fun DoubleTensor.acos(): DoubleTensor {
TODO("Not yet implemented")
}
override fun DoubleTensor.acos(): DoubleTensor = this.map(::acos)
override fun DoubleTensor.cosh(): DoubleTensor {
TODO("Not yet implemented")
}
override fun DoubleTensor.cosh(): DoubleTensor = this.map(::cosh)
override fun DoubleTensor.acosh(): DoubleTensor {
TODO("Not yet implemented")
}
override fun DoubleTensor.acosh(): DoubleTensor = this.map(::acosh)
override fun DoubleTensor.sin(): DoubleTensor {
TODO("Not yet implemented")
}
override fun DoubleTensor.sin(): DoubleTensor = this.map(::sin)
override fun DoubleTensor.asin(): DoubleTensor {
TODO("Not yet implemented")
}
override fun DoubleTensor.asin(): DoubleTensor = this.map(::asin)
override fun DoubleTensor.sinh(): DoubleTensor {
TODO("Not yet implemented")
}
override fun DoubleTensor.sinh(): DoubleTensor = this.map(::sinh)
override fun DoubleTensor.asinh(): DoubleTensor {
TODO("Not yet implemented")
}
override fun DoubleTensor.asinh(): DoubleTensor = this.map(::asinh)
override fun DoubleTensor.tan(): DoubleTensor {
TODO("Not yet implemented")
}
override fun DoubleTensor.tan(): DoubleTensor = this.map(::tan)
override fun DoubleTensor.atan(): DoubleTensor {
TODO("Not yet implemented")
}
override fun DoubleTensor.atan(): DoubleTensor = this.map(::atan)
override fun DoubleTensor.tanh(): DoubleTensor {
TODO("Not yet implemented")
}
override fun DoubleTensor.tanh(): DoubleTensor = this.map(::tanh)
override fun DoubleTensor.atanh(): DoubleTensor {
return DoubleTensor(
this.shape,
this.buffer.array().map(::atanh).toDoubleArray(),
this.bufferStart
)
}
override fun DoubleTensor.atanh(): DoubleTensor = this.map(::atanh)
override fun DoubleTensor.ceil(): DoubleTensor {
TODO("Not yet implemented")
}
override fun DoubleTensor.ceil(): DoubleTensor = this.map(::ceil)
override fun DoubleTensor.floor(): DoubleTensor {
TODO("Not yet implemented")
}
override fun DoubleTensor.floor(): DoubleTensor = this.map(::floor)
override fun DoubleTensor.clamp(min: Double, max: Double): DoubleTensor {
TODO("Not yet implemented")

View File

@ -1,6 +1,7 @@
package space.kscience.kmath.tensors.core
import space.kscience.kmath.tensors.LinearOpsTensorAlgebra
import kotlin.math.sqrt
public class DoubleLinearOpsTensorAlgebra :
LinearOpsTensorAlgebra<Double, DoubleTensor, IntTensor>,
@ -10,47 +11,47 @@ public class DoubleLinearOpsTensorAlgebra :
TODO("Not yet implemented")
}
override fun DoubleTensor.lu(): Pair<DoubleTensor, IntTensor> {
override fun DoubleTensor.lu(tol: Double): Pair<DoubleTensor, IntTensor> {
// todo checks
checkSquareMatrix(shape)
val luTensor = this.copy()
val luTensor = copy()
val n = this.shape.size
val m = this.shape.last()
val pivotsShape = IntArray(n - 1) { i -> this.shape[i] }
val n = shape.size
val m = shape.last()
val pivotsShape = IntArray(n - 1) { i -> shape[i] }
val pivotsTensor = IntTensor(
pivotsShape,
IntArray(pivotsShape.reduce(Int::times)) { 0 } //todo default???
IntArray(pivotsShape.reduce(Int::times)) { 0 }
)
for ((lu, pivots) in luTensor.matrixSequence().zip(pivotsTensor.vectorSequence())){
for (row in 0 until m) pivots[row] = row
for (i in 0 until m) {
var maxA = -1.0
var iMax = i
var maxVal = -1.0
var maxInd = i
for (k in i until m) {
val absA = kotlin.math.abs(lu[k, i])
if (absA > maxA) {
maxA = absA
iMax = k
if (absA > maxVal) {
maxVal = absA
maxInd = k
}
}
//todo check singularity
if (iMax != i) {
if (maxInd != i) {
val j = pivots[i]
pivots[i] = pivots[iMax]
pivots[iMax] = j
pivots[i] = pivots[maxInd]
pivots[maxInd] = j
for (k in 0 until m) {
val tmp = lu[i, k]
lu[i, k] = lu[iMax, k]
lu[iMax, k] = tmp
lu[i, k] = lu[maxInd, k]
lu[maxInd, k] = tmp
}
}
@ -71,6 +72,9 @@ public class DoubleLinearOpsTensorAlgebra :
override fun luPivot(luTensor: DoubleTensor, pivotsTensor: IntTensor): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
//todo checks
checkSquareMatrix(luTensor.shape)
check(luTensor.shape.dropLast(1).toIntArray() contentEquals pivotsTensor.shape) { "Bed shapes (("} //todo rewrite
val n = luTensor.shape.last()
val pTensor = luTensor.zeroesLike()
for ((p, pivot) in pTensor.matrixSequence().zip(pivotsTensor.vectorSequence())){
@ -104,7 +108,30 @@ public class DoubleLinearOpsTensorAlgebra :
}
override fun DoubleTensor.cholesky(): DoubleTensor {
TODO("Not yet implemented")
// todo checks
checkSquareMatrix(shape)
val n = shape.last()
val lTensor = zeroesLike()
for ((a, l) in this.matrixSequence().zip(lTensor.matrixSequence())) {
for (i in 0 until n) {
for (j in 0 until i) {
var h = a[i, j]
for (k in 0 until j) {
h -= l[i, k] * l[j, k]
}
l[i, j] = h / l[j, j]
}
var h = a[i, i]
for (j in 0 until i) {
h -= l[i, j] * l[i, j]
}
l[i, i] = sqrt(h)
}
}
return lTensor
}
override fun DoubleTensor.qr(): DoubleTensor {

View File

@ -1,7 +1,7 @@
package space.kscience.kmath.tensors.core
import space.kscience.kmath.tensors.TensorPartialDivisionAlgebra
import kotlin.math.abs
public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, DoubleTensor> {
@ -277,6 +277,43 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
TODO("Not yet implemented")
}
override fun DoubleTensor.map(transform: (Double) -> Double): DoubleTensor {
return DoubleTensor(
this.shape,
this.buffer.array().map { transform(it) }.toDoubleArray(),
this.bufferStart
)
}
public fun DoubleTensor.contentEquals(other: DoubleTensor, delta: Double = 1e-5): Boolean {
return this.contentEquals(other) { x, y -> abs(x - y) < delta }
}
public fun DoubleTensor.eq(other: DoubleTensor, delta: Double = 1e-5): Boolean {
return this.eq(other) { x, y -> abs(x - y) < delta }
}
override fun DoubleTensor.contentEquals(other: DoubleTensor, eqFunction: (Double, Double) -> Boolean): Boolean {
if (!(this.shape contentEquals other.shape)){
return false
}
return this.eq(other, eqFunction)
}
override fun DoubleTensor.eq(other: DoubleTensor, eqFunction: (Double, Double) -> Boolean): Boolean {
// todo broadcasting checking
val n = this.strides.linearSize
if (n != other.strides.linearSize){
return false
}
for (i in 0 until n){
if (!eqFunction(this.buffer[this.bufferStart + i], other.buffer[other.bufferStart + i])) {
return false
}
}
return true
}
}

View File

@ -64,4 +64,16 @@ internal inline fun <T, TensorType : TensorStructure<T>,
internal inline fun <T, TensorType : TensorStructure<T>,
TorchTensorAlgebraType : TensorAlgebra<T, TensorType>>
TorchTensorAlgebraType.checkView(a: TensorType, shape: IntArray): Unit =
check(a.shape.reduce(Int::times) == shape.reduce(Int::times))
check(a.shape.reduce(Int::times) == shape.reduce(Int::times))
internal inline fun <T, TensorType : TensorStructure<T>,
TorchTensorAlgebraType : TensorAlgebra<T, TensorType>>
TorchTensorAlgebraType.checkSquareMatrix(shape: IntArray): Unit {
val n = shape.size
check(n >= 2) {
"Expected tensor with 2 or more dimensions, got size $n instead"
}
check(shape[n - 1] == shape[n - 2]) {
"Tensor must be batches of square matrices, but they are ${shape[n - 1]} by ${shape[n - 1]} matrices"
}
}