KMP library for tensors #300

Merged
grinisrit merged 215 commits from feature/tensor-algebra into dev 2021-05-08 09:48:04 +03:00
40 changed files with 4062 additions and 17 deletions

View File

@ -236,6 +236,18 @@ One can still use generic algebras though.
> **Maturity**: EXPERIMENTAL
<hr/>
* ### [kmath-tensors](kmath-tensors)
>
>
> **Maturity**: PROTOTYPE
>
> **Features:**
> - [tensor algebra](kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt) : Basic linear algebra operations on tensors (plus, dot, etc.)
> - [tensor algebra with broadcasting](kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/algebras/BroadcastDoubleTensorAlgebra.kt) : Basic linear algebra operations implemented with broadcasting.
> - [linear algebra operations](kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt) : Advanced linear algebra operations like LU decomposition, SVD, etc.
<hr/>
* ### [kmath-viktor](kmath-viktor)
>
>

View File

@ -25,6 +25,7 @@ dependencies {
implementation(project(":kmath-dimensions"))
implementation(project(":kmath-ejml"))
implementation(project(":kmath-nd4j"))
implementation(project(":kmath-tensors"))
implementation(project(":kmath-for-real"))

View File

@ -0,0 +1,46 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.tensors
import space.kscience.kmath.operations.invoke
import space.kscience.kmath.tensors.core.BroadcastDoubleTensorAlgebra
// Dataset normalization
fun main() {
// work in context with broadcast methods
BroadcastDoubleTensorAlgebra {
// take dataset of 5-element vectors from normal distribution
val dataset = randomNormal(intArrayOf(100, 5)) * 1.5 // all elements from N(0, 1.5)
dataset += fromArray(
intArrayOf(5),
doubleArrayOf(0.0, 1.0, 1.5, 3.0, 5.0) // rows means
)
// find out mean and standard deviation of each column
val mean = dataset.mean(0, false)
val std = dataset.std(0, false)
println("Mean:\n$mean")
println("Standard deviation:\n$std")
// also we can calculate other statistic as minimum and maximum of rows
println("Minimum:\n${dataset.min(0, false)}")
println("Maximum:\n${dataset.max(0, false)}")
// now we can scale dataset with mean normalization
val datasetScaled = (dataset - mean) / std
// find out mean and std of scaled dataset
println("Mean of scaled:\n${datasetScaled.mean(0, false)}")
println("Mean of scaled:\n${datasetScaled.std(0, false)}")
}
}

View File

@ -0,0 +1,97 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.tensors
import space.kscience.kmath.operations.invoke
import space.kscience.kmath.tensors.core.DoubleTensor
import space.kscience.kmath.tensors.core.BroadcastDoubleTensorAlgebra
// solving linear system with LUP decomposition
fun main () {
// work in context with linear operations
BroadcastDoubleTensorAlgebra {
// set true value of x
val trueX = fromArray(
intArrayOf(4),
doubleArrayOf(-2.0, 1.5, 6.8, -2.4)
)
// and A matrix
val a = fromArray(
intArrayOf(4, 4),
doubleArrayOf(
0.5, 10.5, 4.5, 1.0,
8.5, 0.9, 12.8, 0.1,
5.56, 9.19, 7.62, 5.45,
1.0, 2.0, -3.0, -2.5
)
)
// calculate y value
val b = a dot trueX
// check out A and b
println("A:\n$a")
println("b:\n$b")
// solve `Ax = b` system using LUP decomposition
// get P, L, U such that PA = LU
val (p, l, u) = a.lu()
// check that P is permutation matrix
println("P:\n$p")
// L is lower triangular matrix and U is upper triangular matrix
println("L:\n$l")
println("U:\n$u")
// and PA = LU
println("PA:\n${p dot a}")
println("LU:\n${l dot u}")
/* Ax = b;
PAx = Pb;
LUx = Pb;
let y = Ux, then
Ly = Pb -- this system can be easily solved, since the matrix L is lower triangular;
Ux = y can be solved the same way, since the matrix L is upper triangular
*/
// this function returns solution x of a system lx = b, l should be lower triangular
fun solveLT(l: DoubleTensor, b: DoubleTensor): DoubleTensor {
val n = l.shape[0]
val x = zeros(intArrayOf(n))
for (i in 0 until n){
x[intArrayOf(i)] = (b[intArrayOf(i)] - l[i].dot(x).value()) / l[intArrayOf(i, i)]
}
return x
}
val y = solveLT(l, p dot b)
// solveLT(l, b) function can be easily adapted for upper triangular matrix by the permutation matrix revMat
// create it by placing ones on side diagonal
val revMat = u.zeroesLike()
val n = revMat.shape[0]
for (i in 0 until n) {
revMat[intArrayOf(i, n - 1 - i)] = 1.0
}
// solution of system ux = b, u should be upper triangular
fun solveUT(u: DoubleTensor, b: DoubleTensor): DoubleTensor = revMat dot solveLT(
revMat dot u dot revMat, revMat dot b
)
val x = solveUT(u, y)
println("True x:\n$trueX")
println("x founded with LU method:\n$x")
}
}

View File

@ -0,0 +1,241 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.tensors
import space.kscience.kmath.operations.invoke
import space.kscience.kmath.tensors.core.DoubleTensor
import space.kscience.kmath.tensors.core.BroadcastDoubleTensorAlgebra
import space.kscience.kmath.tensors.core.DoubleTensorAlgebra
import space.kscience.kmath.tensors.core.toDoubleArray
import kotlin.math.sqrt
const val seed = 100500L
// Simple feedforward neural network with backpropagation training
// interface of network layer
interface Layer {
fun forward(input: DoubleTensor): DoubleTensor
fun backward(input: DoubleTensor, outputError: DoubleTensor): DoubleTensor
}
// activation layer
open class Activation(
val activation: (DoubleTensor) -> DoubleTensor,
val activationDer: (DoubleTensor) -> DoubleTensor
) : Layer {
override fun forward(input: DoubleTensor): DoubleTensor {
return activation(input)
}
override fun backward(input: DoubleTensor, outputError: DoubleTensor): DoubleTensor {
return DoubleTensorAlgebra { outputError * activationDer(input) }
}
}
fun relu(x: DoubleTensor): DoubleTensor = DoubleTensorAlgebra {
x.map { if (it > 0) it else 0.0 }
}
fun reluDer(x: DoubleTensor): DoubleTensor = DoubleTensorAlgebra {
x.map { if (it > 0) 1.0 else 0.0 }
}
// activation layer with relu activator
class ReLU : Activation(::relu, ::reluDer)
fun sigmoid(x: DoubleTensor): DoubleTensor = DoubleTensorAlgebra {
1.0 / (1.0 + (-x).exp())
}
fun sigmoidDer(x: DoubleTensor): DoubleTensor = DoubleTensorAlgebra {
sigmoid(x) * (1.0 - sigmoid(x))
}
// activation layer with sigmoid activator
class Sigmoid : Activation(::sigmoid, ::sigmoidDer)
// dense layer
class Dense(
private val inputUnits: Int,
private val outputUnits: Int,
private val learningRate: Double = 0.1
) : Layer {
private val weights: DoubleTensor = DoubleTensorAlgebra {
randomNormal(
intArrayOf(inputUnits, outputUnits),
seed
) * sqrt(2.0 / (inputUnits + outputUnits))
}
private val bias: DoubleTensor = DoubleTensorAlgebra { zeros(intArrayOf(outputUnits)) }
override fun forward(input: DoubleTensor): DoubleTensor {
return BroadcastDoubleTensorAlgebra { (input dot weights) + bias }
}
override fun backward(input: DoubleTensor, outputError: DoubleTensor): DoubleTensor = DoubleTensorAlgebra {
val gradInput = outputError dot weights.transpose()
val gradW = input.transpose() dot outputError
val gradBias = outputError.mean(dim = 0, keepDim = false) * input.shape[0].toDouble()
weights -= learningRate * gradW
bias -= learningRate * gradBias
gradInput
}
}
// simple accuracy equal to the proportion of correct answers
fun accuracy(yPred: DoubleTensor, yTrue: DoubleTensor): Double {
check(yPred.shape contentEquals yTrue.shape)
val n = yPred.shape[0]
var correctCnt = 0
for (i in 0 until n) {
if (yPred[intArrayOf(i, 0)] == yTrue[intArrayOf(i, 0)]) {
correctCnt += 1
}
}
return correctCnt.toDouble() / n.toDouble()
}
// neural network class
class NeuralNetwork(private val layers: List<Layer>) {
private fun softMaxLoss(yPred: DoubleTensor, yTrue: DoubleTensor): DoubleTensor = BroadcastDoubleTensorAlgebra {
val onesForAnswers = yPred.zeroesLike()
yTrue.toDoubleArray().forEachIndexed { index, labelDouble ->
val label = labelDouble.toInt()
onesForAnswers[intArrayOf(index, label)] = 1.0
}
val softmaxValue = yPred.exp() / yPred.exp().sum(dim = 1, keepDim = true)
(-onesForAnswers + softmaxValue) / (yPred.shape[0].toDouble())
}
@OptIn(ExperimentalStdlibApi::class)
private fun forward(x: DoubleTensor): List<DoubleTensor> {
var input = x
return buildList {
layers.forEach { layer ->
val output = layer.forward(input)
add(output)
input = output
}
}
}
@OptIn(ExperimentalStdlibApi::class)
private fun train(xTrain: DoubleTensor, yTrain: DoubleTensor) {
val layerInputs = buildList {
add(xTrain)
addAll(forward(xTrain))
}
var lossGrad = softMaxLoss(layerInputs.last(), yTrain)
layers.zip(layerInputs).reversed().forEach { (layer, input) ->
lossGrad = layer.backward(input, lossGrad)
}
}
fun fit(xTrain: DoubleTensor, yTrain: DoubleTensor, batchSize: Int, epochs: Int) = DoubleTensorAlgebra {
fun iterBatch(x: DoubleTensor, y: DoubleTensor): Sequence<Pair<DoubleTensor, DoubleTensor>> = sequence {
val n = x.shape[0]
val shuffledIndices = (0 until n).shuffled()
for (i in 0 until n step batchSize) {
val excerptIndices = shuffledIndices.drop(i).take(batchSize).toIntArray()
val batch = x.rowsByIndices(excerptIndices) to y.rowsByIndices(excerptIndices)
yield(batch)
}
}
for (epoch in 0 until epochs) {
println("Epoch ${epoch + 1}/$epochs")
for ((xBatch, yBatch) in iterBatch(xTrain, yTrain)) {
train(xBatch, yBatch)
}
println("Accuracy:${accuracy(yTrain, predict(xTrain).argMax(1, true))}")
}
}
fun predict(x: DoubleTensor): DoubleTensor {
return forward(x).last()
}
}
@OptIn(ExperimentalStdlibApi::class)
fun main() {
BroadcastDoubleTensorAlgebra {
val features = 5
val sampleSize = 250
val trainSize = 180
val testSize = sampleSize - trainSize
// take sample of features from normal distribution
val x = randomNormal(intArrayOf(sampleSize, features), seed) * 2.5
x += fromArray(
intArrayOf(5),
doubleArrayOf(0.0, -1.0, -2.5, -3.0, 5.5) // rows means
)
// define class like '1' if the sum of features > 0 and '0' otherwise
val y = fromArray(
intArrayOf(sampleSize, 1),
DoubleArray(sampleSize) { i ->
if (x[i].sum() > 0.0) {
1.0
} else {
0.0
}
}
)
// split train ans test
val trainIndices = (0 until trainSize).toList().toIntArray()
val testIndices = (trainSize until sampleSize).toList().toIntArray()
val xTrain = x.rowsByIndices(trainIndices)
val yTrain = y.rowsByIndices(trainIndices)
val xTest = x.rowsByIndices(testIndices)
val yTest = y.rowsByIndices(testIndices)
// build model
val layers = buildList {
add(Dense(features, 64))
add(ReLU())
add(Dense(64, 16))
add(ReLU())
add(Dense(16, 2))
add(Sigmoid())
}
val model = NeuralNetwork(layers)
// fit it with train data
model.fit(xTrain, yTrain, batchSize = 20, epochs = 10)
// make prediction
val prediction = model.predict(xTest)
// process raw prediction via argMax
val predictionLabels = prediction.argMax(1, true)
// find out accuracy
val acc = accuracy(yTest, predictionLabels)
println("Test accuracy:$acc")
}
}

View File

@ -0,0 +1,68 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.tensors
import space.kscience.kmath.operations.invoke
import space.kscience.kmath.tensors.core.DoubleTensor
import space.kscience.kmath.tensors.core.DoubleTensorAlgebra
import kotlin.math.abs
// OLS estimator using SVD
fun main() {
//seed for random
val randSeed = 100500L
// work in context with linear operations
DoubleTensorAlgebra {
// take coefficient vector from normal distribution
val alpha = randomNormal(
intArrayOf(5),
randSeed
) + fromArray(
intArrayOf(5),
doubleArrayOf(1.0, 2.5, 3.4, 5.0, 10.1)
)
println("Real alpha:\n$alpha")
// also take sample of size 20 from normal distribution for x
val x = randomNormal(
intArrayOf(20, 5),
randSeed
)
// calculate y and add gaussian noise (N(0, 0.05))
val y = x dot alpha
y += y.randomNormalLike(randSeed) * 0.05
// now restore the coefficient vector with OSL estimator with SVD
val (u, singValues, v) = x.svd()
// we have to make sure the singular values of the matrix are not close to zero
println("Singular values:\n$singValues")
// inverse Sigma matrix can be restored from singular values with diagonalEmbedding function
val sigma = diagonalEmbedding(singValues.map{ x -> if (abs(x) < 1e-3) 0.0 else 1.0/x })
val alphaOLS = v dot sigma dot u.transpose() dot y
println("Estimated alpha:\n" +
"$alphaOLS")
// figure out MSE of approximation
fun mse(yTrue: DoubleTensor, yPred: DoubleTensor): Double {
require(yTrue.shape.size == 1)
require(yTrue.shape contentEquals yPred.shape)
val diff = yTrue - yPred
return diff.dot(diff).sqrt().value()
}
println("MSE: ${mse(alpha, alphaOLS)}")
}
}

View File

@ -0,0 +1,78 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.tensors
import space.kscience.kmath.operations.invoke
import space.kscience.kmath.tensors.core.BroadcastDoubleTensorAlgebra
// simple PCA
fun main(){
val seed = 100500L
// work in context with broadcast methods
BroadcastDoubleTensorAlgebra {
// assume x is range from 0 until 10
val x = fromArray(
intArrayOf(10),
(0 until 10).toList().map { it.toDouble() }.toDoubleArray()
)
// take y dependent on x with noise
val y = 2.0 * x + (3.0 + x.randomNormalLike(seed) * 1.5)
println("x:\n$x")
println("y:\n$y")
// stack them into single dataset
val dataset = stack(listOf(x, y)).transpose()
// normalize both x and y
val xMean = x.mean()
val yMean = y.mean()
val xStd = x.std()
val yStd = y.std()
val xScaled = (x - xMean) / xStd
val yScaled = (y - yMean) / yStd
// save means ans standard deviations for further recovery
val mean = fromArray(
intArrayOf(2),
doubleArrayOf(xMean, yMean)
)
println("Means:\n$mean")
val std = fromArray(
intArrayOf(2),
doubleArrayOf(xStd, yStd)
)
println("Standard deviations:\n$std")
// calculate the covariance matrix of scaled x and y
val covMatrix = cov(listOf(xScaled, yScaled))
println("Covariance matrix:\n$covMatrix")
// and find out eigenvector of it
val (_, evecs) = covMatrix.symEig()
val v = evecs[0]
println("Eigenvector:\n$v")
// reduce dimension of dataset
val datasetReduced = v dot stack(listOf(xScaled, yScaled))
println("Reduced data:\n$datasetReduced")
// we can restore original data from reduced data.
// for example, find 7th element of dataset
val n = 7
val restored = (datasetReduced[n] dot v.view(intArrayOf(1, 2))) * std + mean
println("Original value:\n${dataset[n]}")
println("Restored value:\n$restored")
}
}

View File

@ -750,7 +750,7 @@ public final class space/kscience/kmath/nd/BufferAlgebraNDKt {
public static final fun ring (Lspace/kscience/kmath/nd/AlgebraND$Companion;Lspace/kscience/kmath/operations/Ring;Lkotlin/jvm/functions/Function2;[I)Lspace/kscience/kmath/nd/BufferedRingND;
}
public final class space/kscience/kmath/nd/BufferND : space/kscience/kmath/nd/StructureND {
public class space/kscience/kmath/nd/BufferND : space/kscience/kmath/nd/StructureND {
public fun <init> (Lspace/kscience/kmath/nd/Strides;Lspace/kscience/kmath/structures/Buffer;)V
public fun elements ()Lkotlin/sequences/Sequence;
public fun get ([I)Ljava/lang/Object;
@ -791,10 +791,9 @@ public final class space/kscience/kmath/nd/DefaultStrides : space/kscience/kmath
public fun equals (Ljava/lang/Object;)Z
public fun getLinearSize ()I
public fun getShape ()[I
public fun getStrides ()Ljava/util/List;
public fun getStrides ()[I
public fun hashCode ()I
public fun index (I)[I
public fun offset ([I)I
}
public final class space/kscience/kmath/nd/DefaultStrides$Companion {
@ -878,6 +877,22 @@ public abstract interface class space/kscience/kmath/nd/GroupND : space/kscience
public final class space/kscience/kmath/nd/GroupND$Companion {
}
public final class space/kscience/kmath/nd/MutableBufferND : space/kscience/kmath/nd/BufferND, space/kscience/kmath/nd/MutableStructureND {
public fun <init> (Lspace/kscience/kmath/nd/Strides;Lspace/kscience/kmath/structures/MutableBuffer;)V
public final fun getMutableBuffer ()Lspace/kscience/kmath/structures/MutableBuffer;
public fun set ([ILjava/lang/Object;)V
}
public abstract interface class space/kscience/kmath/nd/MutableStructure1D : space/kscience/kmath/nd/MutableStructureND, space/kscience/kmath/nd/Structure1D, space/kscience/kmath/structures/MutableBuffer {
public fun set ([ILjava/lang/Object;)V
}
public abstract interface class space/kscience/kmath/nd/MutableStructure2D : space/kscience/kmath/nd/MutableStructureND, space/kscience/kmath/nd/Structure2D {
public fun getColumns ()Ljava/util/List;
public fun getRows ()Ljava/util/List;
public abstract fun set (IILjava/lang/Object;)V
}
public abstract interface class space/kscience/kmath/nd/MutableStructureND : space/kscience/kmath/nd/StructureND {
public abstract fun set ([ILjava/lang/Object;)V
}
@ -917,10 +932,10 @@ public final class space/kscience/kmath/nd/ShortRingNDKt {
public abstract interface class space/kscience/kmath/nd/Strides {
public abstract fun getLinearSize ()I
public abstract fun getShape ()[I
public abstract fun getStrides ()Ljava/util/List;
public abstract fun getStrides ()[I
public abstract fun index (I)[I
public fun indices ()Lkotlin/sequences/Sequence;
public abstract fun offset ([I)I
public fun offset ([I)I
}
public abstract interface class space/kscience/kmath/nd/Structure1D : space/kscience/kmath/nd/StructureND, space/kscience/kmath/structures/Buffer {
@ -934,6 +949,7 @@ public final class space/kscience/kmath/nd/Structure1D$Companion {
}
public final class space/kscience/kmath/nd/Structure1DKt {
public static final fun as1D (Lspace/kscience/kmath/nd/MutableStructureND;)Lspace/kscience/kmath/nd/MutableStructure1D;
public static final fun as1D (Lspace/kscience/kmath/nd/StructureND;)Lspace/kscience/kmath/nd/Structure1D;
public static final fun asND (Lspace/kscience/kmath/structures/Buffer;)Lspace/kscience/kmath/nd/Structure1D;
}
@ -954,6 +970,7 @@ public final class space/kscience/kmath/nd/Structure2D$Companion {
}
public final class space/kscience/kmath/nd/Structure2DKt {
public static final fun as2D (Lspace/kscience/kmath/nd/MutableStructureND;)Lspace/kscience/kmath/nd/MutableStructure2D;
public static final fun as2D (Lspace/kscience/kmath/nd/StructureND;)Lspace/kscience/kmath/nd/Structure2D;
}

View File

@ -19,6 +19,7 @@ import kotlin.reflect.KClass
* @param T the type of items.
*/
public typealias Matrix<T> = Structure2D<T>
public typealias MutableMatrix<T> = MutableStructure2D<T>
/**
* Alias or using [Buffer] as a point/vector in a many-dimensional space.

View File

@ -7,6 +7,8 @@ package space.kscience.kmath.nd
import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.structures.BufferFactory
import space.kscience.kmath.structures.MutableBuffer
import space.kscience.kmath.structures.MutableBufferFactory
/**
* Represents [StructureND] over [Buffer].
@ -15,7 +17,7 @@ import space.kscience.kmath.structures.BufferFactory
* @param strides The strides to access elements of [Buffer] by linear indices.
* @param buffer The underlying buffer.
*/
public class BufferND<T>(
public open class BufferND<T>(
public val strides: Strides,
public val buffer: Buffer<T>,
) : StructureND<T> {
@ -50,4 +52,35 @@ public inline fun <T, reified R : Any> StructureND<T>.mapToBuffer(
val strides = DefaultStrides(shape)
BufferND(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) })
}
}
/**
* Represents [MutableStructureND] over [MutableBuffer].
*
* @param T the type of items.
* @param strides The strides to access elements of [MutableBuffer] by linear indices.
* @param mutableBuffer The underlying buffer.
*/
public class MutableBufferND<T>(
strides: Strides,
public val mutableBuffer: MutableBuffer<T>,
) : MutableStructureND<T>, BufferND<T>(strides, mutableBuffer) {
override fun set(index: IntArray, value: T) {
mutableBuffer[strides.offset(index)] = value
}
}
/**
* Transform structure to a new structure using provided [MutableBufferFactory] and optimizing if argument is [MutableBufferND]
*/
public inline fun <T, reified R : Any> MutableStructureND<T>.mapToMutableBuffer(
factory: MutableBufferFactory<R> = MutableBuffer.Companion::auto,
crossinline transform: (T) -> R,
): MutableBufferND<R> {
return if (this is MutableBufferND<T>)
MutableBufferND(this.strides, factory.invoke(strides.linearSize) { transform(mutableBuffer[it]) })
else {
val strides = DefaultStrides(shape)
MutableBufferND(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) })
}
}

View File

@ -6,6 +6,8 @@
package space.kscience.kmath.nd
import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.structures.MutableBuffer
import space.kscience.kmath.structures.asMutableBuffer
import space.kscience.kmath.structures.asSequence
import kotlin.jvm.JvmInline
@ -25,6 +27,16 @@ public interface Structure1D<T> : StructureND<T>, Buffer<T> {
public companion object
}
/**
* A mutable structure that is guaranteed to be one-dimensional
*/
public interface MutableStructure1D<T> : Structure1D<T>, MutableStructureND<T>, MutableBuffer<T> {
public override operator fun set(index: IntArray, value: T) {
require(index.size == 1) { "Index dimension mismatch. Expected 1 but found ${index.size}" }
set(index[0], value)
}
}
/**
* A 1D wrapper for nd-structure
*/
@ -37,6 +49,23 @@ private value class Structure1DWrapper<T>(val structure: StructureND<T>) : Struc
override fun elements(): Sequence<Pair<IntArray, T>> = structure.elements()
}
/**
* A 1D wrapper for a mutable nd-structure
*/
private class MutableStructure1DWrapper<T>(val structure: MutableStructureND<T>) : MutableStructure1D<T> {
override val shape: IntArray get() = structure.shape
override val size: Int get() = structure.shape[0]
override fun elements(): Sequence<Pair<IntArray, T>> = structure.elements()
override fun get(index: Int): T = structure[index]
override fun set(index: Int, value: T) {
structure[intArrayOf(index)] = value
}
override fun copy(): MutableBuffer<T> =
structure.elements().map { it.second }.toMutableList().asMutableBuffer()
}
/**
* A structure wrapper for buffer
@ -52,6 +81,21 @@ private value class Buffer1DWrapper<T>(val buffer: Buffer<T>) : Structure1D<T> {
override operator fun get(index: Int): T = buffer[index]
}
internal class MutableBuffer1DWrapper<T>(val buffer: MutableBuffer<T>) : MutableStructure1D<T> {
override val shape: IntArray get() = intArrayOf(buffer.size)
override val size: Int get() = buffer.size
override fun elements(): Sequence<Pair<IntArray, T>> =
buffer.asSequence().mapIndexed { index, value -> intArrayOf(index) to value }
override operator fun get(index: Int): T = buffer[index]
override fun set(index: Int, value: T) {
buffer[index] = value
}
override fun copy(): MutableBuffer<T> = buffer.copy()
}
/**
* Represent a [StructureND] as [Structure1D]. Throw error in case of dimension mismatch
*/
@ -62,6 +106,11 @@ public fun <T> StructureND<T>.as1D(): Structure1D<T> = this as? Structure1D<T> ?
}
} else error("Can't create 1d-structure from ${shape.size}d-structure")
public fun <T> MutableStructureND<T>.as1D(): MutableStructure1D<T> =
this as? MutableStructure1D<T> ?: if (shape.size == 1) {
MutableStructure1DWrapper(this)
} else error("Can't create 1d-structure from ${shape.size}d-structure")
/**
* Represent this buffer as 1D structure
*/
@ -75,3 +124,4 @@ internal fun <T : Any> Structure1D<T>.unwrap(): Buffer<T> = when {
this is Structure1DWrapper && structure is BufferND<T> -> structure.buffer
else -> this
}

View File

@ -8,6 +8,7 @@ package space.kscience.kmath.nd
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.structures.VirtualBuffer
import space.kscience.kmath.structures.MutableListBuffer
import kotlin.jvm.JvmInline
import kotlin.reflect.KClass
@ -63,6 +64,32 @@ public interface Structure2D<T> : StructureND<T> {
public companion object
}
/**
* Represents mutable [Structure2D].
*/
public interface MutableStructure2D<T> : Structure2D<T>, MutableStructureND<T> {
/**
* Inserts an item at the specified indices.
*
* @param i the first index.
* @param j the second index.
* @param value the value.
*/
public operator fun set(i: Int, j: Int, value: T)
/**
* The buffer of rows of this structure. It gets elements from the structure dynamically.
*/
override val rows: List<MutableStructure1D<T>>
get() = List(rowNum) { i -> MutableBuffer1DWrapper(MutableListBuffer(colNum) { j -> get(i, j) })}
/**
* The buffer of columns of this structure. It gets elements from the structure dynamically.
*/
override val columns: List<MutableStructure1D<T>>
get() = List(colNum) { j -> MutableBuffer1DWrapper(MutableListBuffer(rowNum) { i -> get(i, j) }) }
}
/**
* A 2D wrapper for nd-structure
*/
@ -81,6 +108,33 @@ private value class Structure2DWrapper<T>(val structure: StructureND<T>) : Struc
override fun elements(): Sequence<Pair<IntArray, T>> = structure.elements()
}
/**
* A 2D wrapper for a mutable nd-structure
*/
private class MutableStructure2DWrapper<T>(val structure: MutableStructureND<T>): MutableStructure2D<T>
{
override val shape: IntArray get() = structure.shape
override val rowNum: Int get() = shape[0]
override val colNum: Int get() = shape[1]
override operator fun get(i: Int, j: Int): T = structure[i, j]
override fun set(index: IntArray, value: T) {
structure[index] = value
}
override operator fun set(i: Int, j: Int, value: T){
structure[intArrayOf(i, j)] = value
}
override fun elements(): Sequence<Pair<IntArray, T>> = structure.elements()
override fun equals(other: Any?): Boolean = false
override fun hashCode(): Int = 0
}
/**
* Represent a [StructureND] as [Structure1D]. Throw error in case of dimension mismatch
*/
@ -89,9 +143,18 @@ public fun <T> StructureND<T>.as2D(): Structure2D<T> = this as? Structure2D<T> ?
else -> error("Can't create 2d-structure from ${shape.size}d-structure")
}
public fun <T> MutableStructureND<T>.as2D(): MutableStructure2D<T> = this as? MutableStructure2D<T> ?: when (shape.size) {
2 -> MutableStructure2DWrapper(this)
else -> error("Can't create 2d-structure from ${shape.size}d-structure")
}
/**
* Expose inner [StructureND] if possible
*/
internal fun <T> Structure2D<T>.unwrap(): StructureND<T> =
if (this is Structure2DWrapper) structure
else this
else this
internal fun <T> MutableStructure2D<T>.unwrap(): MutableStructureND<T> =
if (this is MutableStructure2DWrapper) structure else this

View File

@ -184,12 +184,15 @@ public interface Strides {
/**
* Array strides
*/
public val strides: List<Int>
public val strides: IntArray
/**
* Get linear index from multidimensional index
*/
public fun offset(index: IntArray): Int
public fun offset(index: IntArray): Int = index.mapIndexed { i, value ->
if (value < 0 || value >= shape[i]) throw IndexOutOfBoundsException("Index $value out of shape bounds: (0,${this.shape[i]})")
value * strides[i]
}.sum()
/**
* Get multidimensional from linear
@ -221,7 +224,7 @@ public class DefaultStrides private constructor(override val shape: IntArray) :
/**
* Strides for memory access
*/
override val strides: List<Int> by lazy {
override val strides: IntArray by lazy {
sequence {
var current = 1
yield(1)
@ -230,14 +233,9 @@ public class DefaultStrides private constructor(override val shape: IntArray) :
current *= it
yield(current)
}
}.toList()
}.toList().toIntArray()
}
override fun offset(index: IntArray): Int = index.mapIndexed { i, value ->
if (value < 0 || value >= shape[i]) throw IndexOutOfBoundsException("Index $value out of shape bounds: (0,${this.shape[i]})")
value * strides[i]
}.sum()
override fun index(offset: Int): IntArray {
val res = IntArray(shape.size)
var current = offset

View File

@ -232,7 +232,7 @@ public value class MutableListBuffer<T>(public val list: MutableList<T>) : Mutab
}
/**
* Returns an [ListBuffer] that wraps the original list.
* Returns an [MutableListBuffer] that wraps the original list.
*/
public fun <T> MutableList<T>.asMutableBuffer(): MutableListBuffer<T> = MutableListBuffer(this)

40
kmath-tensors/README.md Normal file
View File

@ -0,0 +1,40 @@
# Module kmath-tensors
Common operations on tensors, the API consists of:
- [TensorAlgebra](src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt) : Basic algebra operations on tensors (plus, dot, etc.)
- [TensorPartialDivisionAlgebra](src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorPartialDivisionAlgebra.kt) : Emulates an algebra over a field
- [LinearOpsTensorAlgebra](src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt) : Linear algebra operations including LU, QR, Cholesky LL and SVD decompositions
- [AnalyticTensorAlgebra](src/commonMain/kotlin/space/kscience/kmath/tensors/api/AnalyticTensorAlgebra.kt) : Element-wise analytic operations
The library offers a multiplatform implementation for this interface over the `Double`'s. As a highlight, the user can find:
- [BroadcastDoubleTensorAlgebra](src/commonMain/kotlin/space/kscience/kmath/tensors/core/algebras/BroadcastDoubleTensorAlgebra.kt) : Basic algebra operations implemented with broadcasting.
- [DoubleLinearOpsTensorAlgebra](src/commonMain/kotlin/space/kscience/kmath/tensors/core/algebras/DoubleLinearOpsTensorAlgebra.kt) : Contains the power method for SVD and the spectrum of symmetric matrices.
## Artifact:
The Maven coordinates of this project are `space.kscience:kmath-tensors:0.3.0-dev-7`.
**Gradle:**
```gradle
repositories {
maven { url 'https://repo.kotlin.link' }
maven { url 'https://dl.bintray.com/hotkeytlt/maven' }
maven { url "https://dl.bintray.com/kotlin/kotlin-eap" } // include for builds based on kotlin-eap
}
dependencies {
implementation 'space.kscience:kmath-tensors:0.3.0-dev-7'
}
```
**Gradle Kotlin DSL:**
```kotlin
repositories {
maven("https://repo.kotlin.link")
maven("https://dl.bintray.com/kotlin/kotlin-eap") // include for builds based on kotlin-eap
maven("https://dl.bintray.com/hotkeytlt/maven") // required for a
}
dependencies {
implementation("space.kscience:kmath-tensors:0.3.0-dev-7")
}
```

View File

@ -0,0 +1,43 @@
plugins {
id("ru.mipt.npm.gradle.mpp")
}
kotlin.sourceSets {
all {
languageSettings.useExperimentalAnnotation("space.kscience.kmath.misc.UnstableKMathAPI")
}
commonMain {
dependencies {
api(project(":kmath-core"))
api(project(":kmath-stat"))
}
}
}
tasks.dokkaHtml {
dependsOn(tasks.build)
}
readme {
maturity = ru.mipt.npm.gradle.Maturity.PROTOTYPE
propertyByTemplate("artifact", rootProject.file("docs/templates/ARTIFACT-TEMPLATE.md"))
feature(
id = "tensor algebra",
description = "Basic linear algebra operations on tensors (plus, dot, etc.)",
ref = "src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt"
)
feature(
id = "tensor algebra with broadcasting",
description = "Basic linear algebra operations implemented with broadcasting.",
ref = "src/commonMain/kotlin/space/kscience/kmath/tensors/core/algebras/BroadcastDoubleTensorAlgebra.kt"
)
feature(
id = "linear algebra operations",
description = "Advanced linear algebra operations like LU decomposition, SVD, etc.",
ref = "src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt"
)
}

View File

@ -0,0 +1,7 @@
# Module kmath-tensors
Common linear algebra operations on tensors.
${features}
${artifact}

View File

@ -0,0 +1,131 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.tensors.api
/**
* Analytic operations on [Tensor].
*
* @param T the type of items closed under analytic functions in the tensors.
*/
public interface AnalyticTensorAlgebra<T> : TensorPartialDivisionAlgebra<T> {
/**
* @return the mean of all elements in the input tensor.
*/
public fun Tensor<T>.mean(): T
/**
* Returns the mean of each row of the input tensor in the given dimension [dim].
*
* If [keepDim] is true, the output tensor is of the same size as
* input except in the dimension [dim] where it is of size 1.
* Otherwise, [dim] is squeezed, resulting in the output tensor having 1 fewer dimension.
*
* @param dim the dimension to reduce.
* @param keepDim whether the output tensor has [dim] retained or not.
* @return the mean of each row of the input tensor in the given dimension [dim].
*/
public fun Tensor<T>.mean(dim: Int, keepDim: Boolean): Tensor<T>
/**
* @return the standard deviation of all elements in the input tensor.
*/
public fun Tensor<T>.std(): T
/**
* Returns the standard deviation of each row of the input tensor in the given dimension [dim].
*
* If [keepDim] is true, the output tensor is of the same size as
* input except in the dimension [dim] where it is of size 1.
* Otherwise, [dim] is squeezed, resulting in the output tensor having 1 fewer dimension.
*
* @param dim the dimension to reduce.
* @param keepDim whether the output tensor has [dim] retained or not.
* @return the standard deviation of each row of the input tensor in the given dimension [dim].
*/
public fun Tensor<T>.std(dim: Int, keepDim: Boolean): Tensor<T>
/**
* @return the variance of all elements in the input tensor.
*/
public fun Tensor<T>.variance(): T
/**
* Returns the variance of each row of the input tensor in the given dimension [dim].
*
* If [keepDim] is true, the output tensor is of the same size as
* input except in the dimension [dim] where it is of size 1.
* Otherwise, [dim] is squeezed, resulting in the output tensor having 1 fewer dimension.
*
* @param dim the dimension to reduce.
* @param keepDim whether the output tensor has [dim] retained or not.
* @return the variance of each row of the input tensor in the given dimension [dim].
*/
public fun Tensor<T>.variance(dim: Int, keepDim: Boolean): Tensor<T>
/**
* Returns the covariance matrix M of given vectors.
*
* M[i, j] contains covariance of i-th and j-th given vectors
*
* @param tensors the [List] of 1-dimensional tensors with same shape
* @return the covariance matrix
*/
public fun cov(tensors: List<Tensor<T>>): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.exp.html
public fun Tensor<T>.exp(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.log.html
public fun Tensor<T>.ln(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.sqrt.html
public fun Tensor<T>.sqrt(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.acos.html#torch.cos
public fun Tensor<T>.cos(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.acos.html#torch.acos
public fun Tensor<T>.acos(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.acosh.html#torch.cosh
public fun Tensor<T>.cosh(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.acosh.html#torch.acosh
public fun Tensor<T>.acosh(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.sin
public fun Tensor<T>.sin(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.asin
public fun Tensor<T>.asin(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.sinh
public fun Tensor<T>.sinh(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.asinh
public fun Tensor<T>.asinh(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.atan.html#torch.tan
public fun Tensor<T>.tan(): Tensor<T>
//https://pytorch.org/docs/stable/generated/torch.atan.html#torch.atan
public fun Tensor<T>.atan(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.atanh.html#torch.tanh
public fun Tensor<T>.tanh(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.atanh.html#torch.atanh
public fun Tensor<T>.atanh(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.ceil.html#torch.ceil
public fun Tensor<T>.ceil(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.floor.html#torch.floor
public fun Tensor<T>.floor(): Tensor<T>
}

View File

@ -0,0 +1,97 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.tensors.api
/**
* Common linear algebra operations. Operates on [Tensor].
*
* @param T the type of items closed under division in the tensors.
*/
public interface LinearOpsTensorAlgebra<T> : TensorPartialDivisionAlgebra<T> {
/**
* Computes the determinant of a square matrix input, or of each square matrix in a batched input.
* For more information: https://pytorch.org/docs/stable/linalg.html#torch.linalg.det
*
* @return the determinant.
*/
public fun Tensor<T>.det(): Tensor<T>
/**
* Computes the multiplicative inverse matrix of a square matrix input, or of each square matrix in a batched input.
* Given a square matrix `A`, return the matrix `AInv` satisfying
* `A dot AInv = AInv dot A = eye(a.shape[0])`.
* For more information: https://pytorch.org/docs/stable/linalg.html#torch.linalg.inv
*
* @return the multiplicative inverse of a matrix.
*/
public fun Tensor<T>.inv(): Tensor<T>
/**
* Cholesky decomposition.
*
* Computes the Cholesky decomposition of a Hermitian (or symmetric for real-valued matrices)
* positive-definite matrix or the Cholesky decompositions for a batch of such matrices.
* Each decomposition has the form:
* Given a tensor `input`, return the tensor `L` satisfying `input = L dot L.H`,
* where L is a lower-triangular matrix and L.H is the conjugate transpose of L,
* which is just a transpose for the case of real-valued input matrices.
* For more information: https://pytorch.org/docs/stable/linalg.html#torch.linalg.cholesky
*
* @return the batch of L matrices.
*/
public fun Tensor<T>.cholesky(): Tensor<T>
/**
* QR decomposition.
*
* Computes the QR decomposition of a matrix or a batch of matrices, and returns a pair `(Q, R)` of tensors.
* Given a tensor `input`, return tensors (Q, R) satisfying ``input = Q dot R``,
* with `Q` being an orthogonal matrix or batch of orthogonal matrices
* and `R` being an upper triangular matrix or batch of upper triangular matrices.
* For more information: https://pytorch.org/docs/stable/linalg.html#torch.linalg.qr
*
* @return pair of Q and R tensors.
*/
public fun Tensor<T>.qr(): Pair<Tensor<T>, Tensor<T>>
/**
* LUP decomposition
*
* Computes the LUP decomposition of a matrix or a batch of matrices.
* Given a tensor `input`, return tensors (P, L, U) satisfying `P dot input = L dot U`,
* with `P` being a permutation matrix or batch of matrices,
* `L` being a lower triangular matrix or batch of matrices,
* `U` being an upper triangular matrix or batch of matrices.
*
* * @return triple of P, L and U tensors
*/
public fun Tensor<T>.lu(): Triple<Tensor<T>, Tensor<T>, Tensor<T>>
/**
* Singular Value Decomposition.
*
* Computes the singular value decomposition of either a matrix or batch of matrices `input`.
* The singular value decomposition is represented as a triple `(U, S, V)`,
* such that `input = U dot diagonalEmbedding(S) dot V.H`,
* where V.H is the conjugate transpose of V.
* If input is a batch of tensors, then U, S, and Vh are also batched with the same batch dimensions as input.
* For more information: https://pytorch.org/docs/stable/linalg.html#torch.linalg.svd
*
* @return triple `(U, S, V)`.
*/
public fun Tensor<T>.svd(): Triple<Tensor<T>, Tensor<T>, Tensor<T>>
/**
* Returns eigenvalues and eigenvectors of a real symmetric matrix input or a batch of real symmetric matrices,
* represented by a pair (eigenvalues, eigenvectors).
* For more information: https://pytorch.org/docs/stable/generated/torch.symeig.html
*
* @return a pair (eigenvalues, eigenvectors)
*/
public fun Tensor<T>.symEig(): Pair<Tensor<T>, Tensor<T>>
}

View File

@ -0,0 +1,5 @@
package space.kscience.kmath.tensors.api
import space.kscience.kmath.nd.MutableStructureND
public typealias Tensor<T> = MutableStructureND<T>

View File

@ -0,0 +1,327 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.tensors.api
import space.kscience.kmath.operations.Algebra
/**
* Algebra over a ring on [Tensor].
* For more information: https://proofwiki.org/wiki/Definition:Algebra_over_Ring
*
* @param T the type of items in the tensors.
*/
public interface TensorAlgebra<T>: Algebra<Tensor<T>> {
/**
* Returns a single tensor value of unit dimension if tensor shape equals to [1].
*
* @return a nullable value of a potentially scalar tensor.
*/
public fun Tensor<T>.valueOrNull(): T?
/**
* Returns a single tensor value of unit dimension. The tensor shape must be equal to [1].
*
* @return the value of a scalar tensor.
*/
public fun Tensor<T>.value(): T
/**
* Each element of the tensor [other] is added to this value.
* The resulting tensor is returned.
*
* @param other tensor to be added.
* @return the sum of this value and tensor [other].
*/
public operator fun T.plus(other: Tensor<T>): Tensor<T>
/**
* Adds the scalar [value] to each element of this tensor and returns a new resulting tensor.
*
* @param value the number to be added to each element of this tensor.
* @return the sum of this tensor and [value].
*/
public operator fun Tensor<T>.plus(value: T): Tensor<T>
/**
* Each element of the tensor [other] is added to each element of this tensor.
* The resulting tensor is returned.
*
* @param other tensor to be added.
* @return the sum of this tensor and [other].
*/
public operator fun Tensor<T>.plus(other: Tensor<T>): Tensor<T>
/**
* Adds the scalar [value] to each element of this tensor.
*
* @param value the number to be added to each element of this tensor.
*/
public operator fun Tensor<T>.plusAssign(value: T): Unit
/**
* Each element of the tensor [other] is added to each element of this tensor.
*
* @param other tensor to be added.
*/
public operator fun Tensor<T>.plusAssign(other: Tensor<T>): Unit
/**
* Each element of the tensor [other] is subtracted from this value.
* The resulting tensor is returned.
*
* @param other tensor to be subtracted.
* @return the difference between this value and tensor [other].
*/
public operator fun T.minus(other: Tensor<T>): Tensor<T>
/**
* Subtracts the scalar [value] from each element of this tensor and returns a new resulting tensor.
*
* @param value the number to be subtracted from each element of this tensor.
* @return the difference between this tensor and [value].
*/
public operator fun Tensor<T>.minus(value: T): Tensor<T>
/**
* Each element of the tensor [other] is subtracted from each element of this tensor.
* The resulting tensor is returned.
*
* @param other tensor to be subtracted.
* @return the difference between this tensor and [other].
*/
public operator fun Tensor<T>.minus(other: Tensor<T>): Tensor<T>
/**
* Subtracts the scalar [value] from each element of this tensor.
*
* @param value the number to be subtracted from each element of this tensor.
*/
public operator fun Tensor<T>.minusAssign(value: T): Unit
/**
* Each element of the tensor [other] is subtracted from each element of this tensor.
*
* @param other tensor to be subtracted.
*/
public operator fun Tensor<T>.minusAssign(other: Tensor<T>): Unit
/**
* Each element of the tensor [other] is multiplied by this value.
* The resulting tensor is returned.
*
* @param other tensor to be multiplied.
* @return the product of this value and tensor [other].
*/
public operator fun T.times(other: Tensor<T>): Tensor<T>
/**
* Multiplies the scalar [value] by each element of this tensor and returns a new resulting tensor.
*
* @param value the number to be multiplied by each element of this tensor.
* @return the product of this tensor and [value].
*/
public operator fun Tensor<T>.times(value: T): Tensor<T>
/**
* Each element of the tensor [other] is multiplied by each element of this tensor.
* The resulting tensor is returned.
*
* @param other tensor to be multiplied.
* @return the product of this tensor and [other].
*/
public operator fun Tensor<T>.times(other: Tensor<T>): Tensor<T>
/**
* Multiplies the scalar [value] by each element of this tensor.
*
* @param value the number to be multiplied by each element of this tensor.
*/
public operator fun Tensor<T>.timesAssign(value: T): Unit
/**
* Each element of the tensor [other] is multiplied by each element of this tensor.
*
* @param other tensor to be multiplied.
*/
public operator fun Tensor<T>.timesAssign(other: Tensor<T>): Unit
/**
* Numerical negative, element-wise.
*
* @return tensor negation of the original tensor.
*/
public operator fun Tensor<T>.unaryMinus(): Tensor<T>
/**
* Returns the tensor at index i
* For more information: https://pytorch.org/cppdocs/notes/tensor_indexing.html
*
* @param i index of the extractable tensor
* @return subtensor of the original tensor with index [i]
*/
public operator fun Tensor<T>.get(i: Int): Tensor<T>
/**
* Returns a tensor that is a transposed version of this tensor. The given dimensions [i] and [j] are swapped.
* For more information: https://pytorch.org/docs/stable/generated/torch.transpose.html
*
* @param i the first dimension to be transposed
* @param j the second dimension to be transposed
* @return transposed tensor
*/
public fun Tensor<T>.transpose(i: Int = -2, j: Int = -1): Tensor<T>
/**
* Returns a new tensor with the same data as the self tensor but of a different shape.
* The returned tensor shares the same data and must have the same number of elements, but may have a different size
* For more information: https://pytorch.org/docs/stable/tensor_view.html
*
* @param shape the desired size
* @return tensor with new shape
*/
public fun Tensor<T>.view(shape: IntArray): Tensor<T>
/**
* View this tensor as the same size as [other].
* ``this.viewAs(other) is equivalent to this.view(other.shape)``.
* For more information: https://pytorch.org/cppdocs/notes/tensor_indexing.html
*
* @param other the result tensor has the same size as other.
* @return the result tensor with the same size as other.
*/
public fun Tensor<T>.viewAs(other: Tensor<T>): Tensor<T>
/**
* Matrix product of two tensors.
*
* The behavior depends on the dimensionality of the tensors as follows:
* 1. If both tensors are 1-dimensional, the dot product (scalar) is returned.
*
* 2. If both arguments are 2-dimensional, the matrix-matrix product is returned.
*
* 3. If the first argument is 1-dimensional and the second argument is 2-dimensional,
* a 1 is prepended to its dimension for the purpose of the matrix multiply.
* After the matrix multiply, the prepended dimension is removed.
*
* 4. If the first argument is 2-dimensional and the second argument is 1-dimensional,
* the matrix-vector product is returned.
*
* 5. If both arguments are at least 1-dimensional and at least one argument is N-dimensional (where N > 2),
* then a batched matrix multiply is returned. If the first argument is 1-dimensional,
* a 1 is prepended to its dimension for the purpose of the batched matrix multiply and removed after.
* If the second argument is 1-dimensional, a 1 is appended to its dimension for the purpose of the batched matrix
* multiple and removed after.
* The non-matrix (i.e. batch) dimensions are broadcasted (and thus must be broadcastable).
* For example, if `input` is a (j &times; 1 &times; n &times; n) tensor and `other` is a
* (k &times; n &times; n) tensor, out will be a (j &times; k &times; n &times; n) tensor.
*
* For more information: https://pytorch.org/docs/stable/generated/torch.matmul.html
*
* @param other tensor to be multiplied
* @return mathematical product of two tensors
*/
public infix fun Tensor<T>.dot(other: Tensor<T>): Tensor<T>
/**
* Creates a tensor whose diagonals of certain 2D planes (specified by [dim1] and [dim2])
* are filled by [diagonalEntries].
* To facilitate creating batched diagonal matrices,
* the 2D planes formed by the last two dimensions of the returned tensor are chosen by default.
*
* The argument [offset] controls which diagonal to consider:
* 1. If [offset] = 0, it is the main diagonal.
* 1. If [offset] > 0, it is above the main diagonal.
* 1. If [offset] < 0, it is below the main diagonal.
*
* The size of the new matrix will be calculated
* to make the specified diagonal of the size of the last input dimension.
* For more information: https://pytorch.org/docs/stable/generated/torch.diag_embed.html
*
* @param diagonalEntries the input tensor. Must be at least 1-dimensional.
* @param offset which diagonal to consider. Default: 0 (main diagonal).
* @param dim1 first dimension with respect to which to take diagonal. Default: -2.
* @param dim2 second dimension with respect to which to take diagonal. Default: -1.
*
* @return tensor whose diagonals of certain 2D planes (specified by [dim1] and [dim2])
* are filled by [diagonalEntries]
*/
public fun diagonalEmbedding(
diagonalEntries: Tensor<T>,
offset: Int = 0,
dim1: Int = -2,
dim2: Int = -1
): Tensor<T>
/**
* @return the sum of all elements in the input tensor.
*/
public fun Tensor<T>.sum(): T
/**
* Returns the sum of each row of the input tensor in the given dimension [dim].
*
* If [keepDim] is true, the output tensor is of the same size as
* input except in the dimension [dim] where it is of size 1.
* Otherwise, [dim] is squeezed, resulting in the output tensor having 1 fewer dimension.
*
* @param dim the dimension to reduce.
* @param keepDim whether the output tensor has [dim] retained or not.
* @return the sum of each row of the input tensor in the given dimension [dim].
*/
public fun Tensor<T>.sum(dim: Int, keepDim: Boolean): Tensor<T>
/**
* @return the minimum value of all elements in the input tensor.
*/
public fun Tensor<T>.min(): T
/**
* Returns the minimum value of each row of the input tensor in the given dimension [dim].
*
* If [keepDim] is true, the output tensor is of the same size as
* input except in the dimension [dim] where it is of size 1.
* Otherwise, [dim] is squeezed, resulting in the output tensor having 1 fewer dimension.
*
* @param dim the dimension to reduce.
* @param keepDim whether the output tensor has [dim] retained or not.
* @return the minimum value of each row of the input tensor in the given dimension [dim].
*/
public fun Tensor<T>.min(dim: Int, keepDim: Boolean): Tensor<T>
/**
* Returns the maximum value of all elements in the input tensor.
*/
public fun Tensor<T>.max(): T
/**
* Returns the maximum value of each row of the input tensor in the given dimension [dim].
*
* If [keepDim] is true, the output tensor is of the same size as
* input except in the dimension [dim] where it is of size 1.
* Otherwise, [dim] is squeezed, resulting in the output tensor having 1 fewer dimension.
*
* @param dim the dimension to reduce.
* @param keepDim whether the output tensor has [dim] retained or not.
* @return the maximum value of each row of the input tensor in the given dimension [dim].
*/
public fun Tensor<T>.max(dim: Int, keepDim: Boolean): Tensor<T>
/**
* Returns the index of maximum value of each row of the input tensor in the given dimension [dim].
*
* If [keepDim] is true, the output tensor is of the same size as
* input except in the dimension [dim] where it is of size 1.
* Otherwise, [dim] is squeezed, resulting in the output tensor having 1 fewer dimension.
*
* @param dim the dimension to reduce.
* @param keepDim whether the output tensor has [dim] retained or not.
* @return the the index of maximum value of each row of the input tensor in the given dimension [dim].
*/
public fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): Tensor<T>
}

View File

@ -0,0 +1,55 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.tensors.api
/**
* Algebra over a field with partial division on [Tensor].
* For more information: https://proofwiki.org/wiki/Definition:Division_Algebra
*
* @param T the type of items closed under division in the tensors.
*/
public interface TensorPartialDivisionAlgebra<T> : TensorAlgebra<T> {
/**
* Each element of the tensor [other] is divided by this value.
* The resulting tensor is returned.
*
* @param other tensor to divide by.
* @return the division of this value by the tensor [other].
*/
public operator fun T.div(other: Tensor<T>): Tensor<T>
/**
* Divide by the scalar [value] each element of this tensor returns a new resulting tensor.
*
* @param value the number to divide by each element of this tensor.
* @return the division of this tensor by the [value].
*/
public operator fun Tensor<T>.div(value: T): Tensor<T>
/**
* Each element of the tensor [other] is divided by each element of this tensor.
* The resulting tensor is returned.
*
* @param other tensor to be divided by.
* @return the division of this tensor by [other].
*/
public operator fun Tensor<T>.div(other: Tensor<T>): Tensor<T>
/**
* Divides by the scalar [value] each element of this tensor.
*
* @param value the number to divide by each element of this tensor.
*/
public operator fun Tensor<T>.divAssign(value: T)
/**
* Each element of this tensor is divided by each element of the [other] tensor.
*
* @param other tensor to be divide by.
*/
public operator fun Tensor<T>.divAssign(other: Tensor<T>)
}

View File

@ -0,0 +1,93 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.tensors.core
import space.kscience.kmath.tensors.api.Tensor
import space.kscience.kmath.tensors.core.internal.array
import space.kscience.kmath.tensors.core.internal.broadcastTensors
import space.kscience.kmath.tensors.core.internal.broadcastTo
import space.kscience.kmath.tensors.core.internal.tensor
/**
* Basic linear algebra operations implemented with broadcasting.
* For more information: https://pytorch.org/docs/stable/notes/broadcasting.html
*/
public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
override fun Tensor<Double>.plus(other: Tensor<Double>): DoubleTensor {
val broadcast = broadcastTensors(tensor, other.tensor)
val newThis = broadcast[0]
val newOther = broadcast[1]
val resBuffer = DoubleArray(newThis.linearStructure.linearSize) { i ->
newThis.mutableBuffer.array()[i] + newOther.mutableBuffer.array()[i]
}
return DoubleTensor(newThis.shape, resBuffer)
}
override fun Tensor<Double>.plusAssign(other: Tensor<Double>) {
val newOther = broadcastTo(other.tensor, tensor.shape)
for (i in 0 until tensor.linearStructure.linearSize) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] +=
newOther.mutableBuffer.array()[tensor.bufferStart + i]
}
}
override fun Tensor<Double>.minus(other: Tensor<Double>): DoubleTensor {
val broadcast = broadcastTensors(tensor, other.tensor)
val newThis = broadcast[0]
val newOther = broadcast[1]
val resBuffer = DoubleArray(newThis.linearStructure.linearSize) { i ->
newThis.mutableBuffer.array()[i] - newOther.mutableBuffer.array()[i]
}
return DoubleTensor(newThis.shape, resBuffer)
}
override fun Tensor<Double>.minusAssign(other: Tensor<Double>) {
val newOther = broadcastTo(other.tensor, tensor.shape)
for (i in 0 until tensor.linearStructure.linearSize) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] -=
newOther.mutableBuffer.array()[tensor.bufferStart + i]
}
}
override fun Tensor<Double>.times(other: Tensor<Double>): DoubleTensor {
val broadcast = broadcastTensors(tensor, other.tensor)
val newThis = broadcast[0]
val newOther = broadcast[1]
val resBuffer = DoubleArray(newThis.linearStructure.linearSize) { i ->
newThis.mutableBuffer.array()[newThis.bufferStart + i] *
newOther.mutableBuffer.array()[newOther.bufferStart + i]
}
return DoubleTensor(newThis.shape, resBuffer)
}
override fun Tensor<Double>.timesAssign(other: Tensor<Double>) {
val newOther = broadcastTo(other.tensor, tensor.shape)
for (i in 0 until tensor.linearStructure.linearSize) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] *=
newOther.mutableBuffer.array()[tensor.bufferStart + i]
}
}
override fun Tensor<Double>.div(other: Tensor<Double>): DoubleTensor {
val broadcast = broadcastTensors(tensor, other.tensor)
val newThis = broadcast[0]
val newOther = broadcast[1]
val resBuffer = DoubleArray(newThis.linearStructure.linearSize) { i ->
newThis.mutableBuffer.array()[newOther.bufferStart + i] /
newOther.mutableBuffer.array()[newOther.bufferStart + i]
}
return DoubleTensor(newThis.shape, resBuffer)
}
override fun Tensor<Double>.divAssign(other: Tensor<Double>) {
val newOther = broadcastTo(other.tensor, tensor.shape)
for (i in 0 until tensor.linearStructure.linearSize) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] /=
newOther.mutableBuffer.array()[tensor.bufferStart + i]
}
}
}

View File

@ -0,0 +1,38 @@
package space.kscience.kmath.tensors.core
import space.kscience.kmath.nd.Strides
import space.kscience.kmath.structures.*
import space.kscience.kmath.tensors.api.Tensor
import space.kscience.kmath.tensors.core.internal.TensorLinearStructure
/**
* Represents [Tensor] over a [MutableBuffer] intended to be used through [DoubleTensor] and [IntTensor]
*/
public open class BufferedTensor<T> internal constructor(
override val shape: IntArray,
internal val mutableBuffer: MutableBuffer<T>,
internal val bufferStart: Int
) : Tensor<T> {
/**
* Buffer strides based on [TensorLinearStructure] implementation
*/
public val linearStructure: Strides
get() = TensorLinearStructure(shape)
/**
* Number of elements in tensor
*/
public val numElements: Int
get() = linearStructure.linearSize
override fun get(index: IntArray): T = mutableBuffer[bufferStart + linearStructure.offset(index)]
override fun set(index: IntArray, value: T) {
mutableBuffer[bufferStart + linearStructure.offset(index)] = value
}
override fun elements(): Sequence<Pair<IntArray, T>> = linearStructure.indices().map {
it to this[it]
}
}

View File

@ -0,0 +1,20 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.tensors.core
import space.kscience.kmath.structures.DoubleBuffer
import space.kscience.kmath.tensors.core.internal.toPrettyString
/**
* Default [BufferedTensor] implementation for [Double] values
*/
public class DoubleTensor internal constructor(
shape: IntArray,
buffer: DoubleArray,
offset: Int = 0
) : BufferedTensor<Double>(shape, DoubleBuffer(buffer), offset) {
override fun toString(): String = toPrettyString()
}

View File

@ -0,0 +1,937 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.tensors.core
import space.kscience.kmath.nd.as1D
import space.kscience.kmath.nd.as2D
import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra
import space.kscience.kmath.tensors.api.LinearOpsTensorAlgebra
import space.kscience.kmath.tensors.api.TensorPartialDivisionAlgebra
import space.kscience.kmath.tensors.api.Tensor
import space.kscience.kmath.tensors.core.internal.dotHelper
import space.kscience.kmath.tensors.core.internal.getRandomNormals
import space.kscience.kmath.tensors.core.internal.*
import space.kscience.kmath.tensors.core.internal.broadcastOuterTensors
import space.kscience.kmath.tensors.core.internal.checkBufferShapeConsistency
import space.kscience.kmath.tensors.core.internal.checkEmptyDoubleBuffer
import space.kscience.kmath.tensors.core.internal.checkEmptyShape
import space.kscience.kmath.tensors.core.internal.checkShapesCompatible
import space.kscience.kmath.tensors.core.internal.checkSquareMatrix
import space.kscience.kmath.tensors.core.internal.checkTranspose
import space.kscience.kmath.tensors.core.internal.checkView
import space.kscience.kmath.tensors.core.internal.minusIndexFrom
import kotlin.math.*
/**
* Implementation of basic operations over double tensors and basic algebra operations on them.
*/
public open class DoubleTensorAlgebra :
TensorPartialDivisionAlgebra<Double>,
AnalyticTensorAlgebra<Double>,
LinearOpsTensorAlgebra<Double> {
public companion object : DoubleTensorAlgebra()
override fun Tensor<Double>.valueOrNull(): Double? = if (tensor.shape contentEquals intArrayOf(1))
tensor.mutableBuffer.array()[tensor.bufferStart] else null
override fun Tensor<Double>.value(): Double =
valueOrNull() ?: throw IllegalArgumentException("Inconsistent value for tensor of with $shape shape")
/**
* Constructs a tensor with the specified shape and data.
*
* @param shape the desired shape for the tensor.
* @param buffer one-dimensional data array.
* @return tensor with the [shape] shape and [buffer] data.
*/
public fun fromArray(shape: IntArray, buffer: DoubleArray): DoubleTensor {
checkEmptyShape(shape)
checkEmptyDoubleBuffer(buffer)
checkBufferShapeConsistency(shape, buffer)
return DoubleTensor(shape, buffer, 0)
}
/**
* Constructs a tensor with the specified shape and initializer.
*
* @param shape the desired shape for the tensor.
* @param initializer mapping tensor indices to values.
* @return tensor with the [shape] shape and data generated by the [initializer].
*/
public fun produce(shape: IntArray, initializer: (IntArray) -> Double): DoubleTensor =
fromArray(
shape,
TensorLinearStructure(shape).indices().map(initializer).toMutableList().toDoubleArray()
)
override operator fun Tensor<Double>.get(i: Int): DoubleTensor {
val lastShape = tensor.shape.drop(1).toIntArray()
val newShape = if (lastShape.isNotEmpty()) lastShape else intArrayOf(1)
val newStart = newShape.reduce(Int::times) * i + tensor.bufferStart
return DoubleTensor(newShape, tensor.mutableBuffer.array(), newStart)
}
/**
* Creates a tensor of a given shape and fills all elements with a given value.
*
* @param value the value to fill the output tensor with.
* @param shape array of integers defining the shape of the output tensor.
* @return tensor with the [shape] shape and filled with [value].
*/
public fun full(value: Double, shape: IntArray): DoubleTensor {
checkEmptyShape(shape)
val buffer = DoubleArray(shape.reduce(Int::times)) { value }
return DoubleTensor(shape, buffer)
}
/**
* Returns a tensor with the same shape as `input` filled with [value].
*
* @param value the value to fill the output tensor with.
* @return tensor with the `input` tensor shape and filled with [value].
*/
public fun Tensor<Double>.fullLike(value: Double): DoubleTensor {
val shape = tensor.shape
val buffer = DoubleArray(tensor.numElements) { value }
return DoubleTensor(shape, buffer)
}
/**
* Returns a tensor filled with the scalar value 0.0, with the shape defined by the variable argument [shape].
*
* @param shape array of integers defining the shape of the output tensor.
* @return tensor filled with the scalar value 0.0, with the [shape] shape.
*/
public fun zeros(shape: IntArray): DoubleTensor = full(0.0, shape)
/**
* Returns a tensor filled with the scalar value 0.0, with the same shape as a given array.
*
* @return tensor filled with the scalar value 0.0, with the same shape as `input` tensor.
*/
public fun Tensor<Double>.zeroesLike(): DoubleTensor = tensor.fullLike(0.0)
/**
* Returns a tensor filled with the scalar value 1.0, with the shape defined by the variable argument [shape].
*
* @param shape array of integers defining the shape of the output tensor.
* @return tensor filled with the scalar value 1.0, with the [shape] shape.
*/
public fun ones(shape: IntArray): DoubleTensor = full(1.0, shape)
/**
* Returns a tensor filled with the scalar value 1.0, with the same shape as a given array.
*
* @return tensor filled with the scalar value 1.0, with the same shape as `input` tensor.
*/
public fun Tensor<Double>.onesLike(): DoubleTensor = tensor.fullLike(1.0)
/**
* Returns a 2-D tensor with shape ([n], [n]), with ones on the diagonal and zeros elsewhere.
*
* @param n the number of rows and columns
* @return a 2-D tensor with ones on the diagonal and zeros elsewhere.
*/
public fun eye(n: Int): DoubleTensor {
val shape = intArrayOf(n, n)
val buffer = DoubleArray(n * n) { 0.0 }
val res = DoubleTensor(shape, buffer)
for (i in 0 until n) {
res[intArrayOf(i, i)] = 1.0
}
return res
}
/**
* Return a copy of the tensor.
*
* @return a copy of the `input` tensor with a copied buffer.
*/
public fun Tensor<Double>.copy(): DoubleTensor {
return DoubleTensor(tensor.shape, tensor.mutableBuffer.array().copyOf(), tensor.bufferStart)
}
override fun Double.plus(other: Tensor<Double>): DoubleTensor {
val resBuffer = DoubleArray(other.tensor.numElements) { i ->
other.tensor.mutableBuffer.array()[other.tensor.bufferStart + i] + this
}
return DoubleTensor(other.shape, resBuffer)
}
override fun Tensor<Double>.plus(value: Double): DoubleTensor = value + tensor
override fun Tensor<Double>.plus(other: Tensor<Double>): DoubleTensor {
checkShapesCompatible(tensor, other.tensor)
val resBuffer = DoubleArray(tensor.numElements) { i ->
tensor.mutableBuffer.array()[i] + other.tensor.mutableBuffer.array()[i]
}
return DoubleTensor(tensor.shape, resBuffer)
}
override fun Tensor<Double>.plusAssign(value: Double) {
for (i in 0 until tensor.numElements) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] += value
}
}
override fun Tensor<Double>.plusAssign(other: Tensor<Double>) {
checkShapesCompatible(tensor, other.tensor)
for (i in 0 until tensor.numElements) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] +=
other.tensor.mutableBuffer.array()[tensor.bufferStart + i]
}
}
override fun Double.minus(other: Tensor<Double>): DoubleTensor {
val resBuffer = DoubleArray(other.tensor.numElements) { i ->
this - other.tensor.mutableBuffer.array()[other.tensor.bufferStart + i]
}
return DoubleTensor(other.shape, resBuffer)
}
override fun Tensor<Double>.minus(value: Double): DoubleTensor {
val resBuffer = DoubleArray(tensor.numElements) { i ->
tensor.mutableBuffer.array()[tensor.bufferStart + i] - value
}
return DoubleTensor(tensor.shape, resBuffer)
}
override fun Tensor<Double>.minus(other: Tensor<Double>): DoubleTensor {
checkShapesCompatible(tensor, other)
val resBuffer = DoubleArray(tensor.numElements) { i ->
tensor.mutableBuffer.array()[i] - other.tensor.mutableBuffer.array()[i]
}
return DoubleTensor(tensor.shape, resBuffer)
}
override fun Tensor<Double>.minusAssign(value: Double) {
for (i in 0 until tensor.numElements) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] -= value
}
}
override fun Tensor<Double>.minusAssign(other: Tensor<Double>) {
checkShapesCompatible(tensor, other)
for (i in 0 until tensor.numElements) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] -=
other.tensor.mutableBuffer.array()[tensor.bufferStart + i]
}
}
override fun Double.times(other: Tensor<Double>): DoubleTensor {
val resBuffer = DoubleArray(other.tensor.numElements) { i ->
other.tensor.mutableBuffer.array()[other.tensor.bufferStart + i] * this
}
return DoubleTensor(other.shape, resBuffer)
}
override fun Tensor<Double>.times(value: Double): DoubleTensor = value * tensor
override fun Tensor<Double>.times(other: Tensor<Double>): DoubleTensor {
checkShapesCompatible(tensor, other)
val resBuffer = DoubleArray(tensor.numElements) { i ->
tensor.mutableBuffer.array()[tensor.bufferStart + i] *
other.tensor.mutableBuffer.array()[other.tensor.bufferStart + i]
}
return DoubleTensor(tensor.shape, resBuffer)
}
override fun Tensor<Double>.timesAssign(value: Double) {
for (i in 0 until tensor.numElements) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] *= value
}
}
override fun Tensor<Double>.timesAssign(other: Tensor<Double>) {
checkShapesCompatible(tensor, other)
for (i in 0 until tensor.numElements) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] *=
other.tensor.mutableBuffer.array()[tensor.bufferStart + i]
}
}
override fun Double.div(other: Tensor<Double>): DoubleTensor {
val resBuffer = DoubleArray(other.tensor.numElements) { i ->
this / other.tensor.mutableBuffer.array()[other.tensor.bufferStart + i]
}
return DoubleTensor(other.shape, resBuffer)
}
override fun Tensor<Double>.div(value: Double): DoubleTensor {
val resBuffer = DoubleArray(tensor.numElements) { i ->
tensor.mutableBuffer.array()[tensor.bufferStart + i] / value
}
return DoubleTensor(shape, resBuffer)
}
override fun Tensor<Double>.div(other: Tensor<Double>): DoubleTensor {
checkShapesCompatible(tensor, other)
val resBuffer = DoubleArray(tensor.numElements) { i ->
tensor.mutableBuffer.array()[other.tensor.bufferStart + i] /
other.tensor.mutableBuffer.array()[other.tensor.bufferStart + i]
}
return DoubleTensor(tensor.shape, resBuffer)
}
override fun Tensor<Double>.divAssign(value: Double) {
for (i in 0 until tensor.numElements) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] /= value
}
}
override fun Tensor<Double>.divAssign(other: Tensor<Double>) {
checkShapesCompatible(tensor, other)
for (i in 0 until tensor.numElements) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] /=
other.tensor.mutableBuffer.array()[tensor.bufferStart + i]
}
}
override fun Tensor<Double>.unaryMinus(): DoubleTensor {
val resBuffer = DoubleArray(tensor.numElements) { i ->
tensor.mutableBuffer.array()[tensor.bufferStart + i].unaryMinus()
}
return DoubleTensor(tensor.shape, resBuffer)
}
override fun Tensor<Double>.transpose(i: Int, j: Int): DoubleTensor {
val ii = tensor.minusIndex(i)
val jj = tensor.minusIndex(j)
checkTranspose(tensor.dimension, ii, jj)
val n = tensor.numElements
val resBuffer = DoubleArray(n)
val resShape = tensor.shape.copyOf()
resShape[ii] = resShape[jj].also { resShape[jj] = resShape[ii] }
val resTensor = DoubleTensor(resShape, resBuffer)
for (offset in 0 until n) {
val oldMultiIndex = tensor.linearStructure.index(offset)
val newMultiIndex = oldMultiIndex.copyOf()
newMultiIndex[ii] = newMultiIndex[jj].also { newMultiIndex[jj] = newMultiIndex[ii] }
val linearIndex = resTensor.linearStructure.offset(newMultiIndex)
resTensor.mutableBuffer.array()[linearIndex] =
tensor.mutableBuffer.array()[tensor.bufferStart + offset]
}
return resTensor
}
override fun Tensor<Double>.view(shape: IntArray): DoubleTensor {
checkView(tensor, shape)
return DoubleTensor(shape, tensor.mutableBuffer.array(), tensor.bufferStart)
}
override fun Tensor<Double>.viewAs(other: Tensor<Double>): DoubleTensor =
tensor.view(other.shape)
override infix fun Tensor<Double>.dot(other: Tensor<Double>): DoubleTensor {
if (tensor.shape.size == 1 && other.shape.size == 1) {
return DoubleTensor(intArrayOf(1), doubleArrayOf(tensor.times(other).tensor.mutableBuffer.array().sum()))
}
var newThis = tensor.copy()
var newOther = other.copy()
var penultimateDim = false
var lastDim = false
if (tensor.shape.size == 1) {
penultimateDim = true
newThis = tensor.view(intArrayOf(1) + tensor.shape)
}
if (other.shape.size == 1) {
lastDim = true
newOther = other.tensor.view(other.shape + intArrayOf(1))
}
val broadcastTensors = broadcastOuterTensors(newThis.tensor, newOther.tensor)
newThis = broadcastTensors[0]
newOther = broadcastTensors[1]
val l = newThis.shape[newThis.shape.size - 2]
val m1 = newThis.shape[newThis.shape.size - 1]
val m2 = newOther.shape[newOther.shape.size - 2]
val n = newOther.shape[newOther.shape.size - 1]
check(m1 == m2) {
"Tensors dot operation dimension mismatch: ($l, $m1) x ($m2, $n)"
}
val resShape = newThis.shape.sliceArray(0..(newThis.shape.size - 2)) + intArrayOf(newOther.shape.last())
val resSize = resShape.reduce { acc, i -> acc * i }
val resTensor = DoubleTensor(resShape, DoubleArray(resSize))
for ((res, ab) in resTensor.matrixSequence().zip(newThis.matrixSequence().zip(newOther.matrixSequence()))) {
val (a, b) = ab
dotHelper(a.as2D(), b.as2D(), res.as2D(), l, m1, n)
}
if (penultimateDim) {
return resTensor.view(
resTensor.shape.dropLast(2).toIntArray() +
intArrayOf(resTensor.shape.last())
)
}
if (lastDim) {
return resTensor.view(resTensor.shape.dropLast(1).toIntArray())
}
return resTensor
}
override fun diagonalEmbedding(diagonalEntries: Tensor<Double>, offset: Int, dim1: Int, dim2: Int):
DoubleTensor {
val n = diagonalEntries.shape.size
val d1 = minusIndexFrom(n + 1, dim1)
val d2 = minusIndexFrom(n + 1, dim2)
check(d1 != d2) {
"Diagonal dimensions cannot be identical $d1, $d2"
}
check(d1 <= n && d2 <= n) {
"Dimension out of range"
}
var lessDim = d1
var greaterDim = d2
var realOffset = offset
if (lessDim > greaterDim) {
realOffset *= -1
lessDim = greaterDim.also { greaterDim = lessDim }
}
val resShape = diagonalEntries.shape.slice(0 until lessDim).toIntArray() +
intArrayOf(diagonalEntries.shape[n - 1] + abs(realOffset)) +
diagonalEntries.shape.slice(lessDim until greaterDim - 1).toIntArray() +
intArrayOf(diagonalEntries.shape[n - 1] + abs(realOffset)) +
diagonalEntries.shape.slice(greaterDim - 1 until n - 1).toIntArray()
val resTensor = zeros(resShape)
for (i in 0 until diagonalEntries.tensor.numElements) {
val multiIndex = diagonalEntries.tensor.linearStructure.index(i)
var offset1 = 0
var offset2 = abs(realOffset)
if (realOffset < 0) {
offset1 = offset2.also { offset2 = offset1 }
}
val diagonalMultiIndex = multiIndex.slice(0 until lessDim).toIntArray() +
intArrayOf(multiIndex[n - 1] + offset1) +
multiIndex.slice(lessDim until greaterDim - 1).toIntArray() +
intArrayOf(multiIndex[n - 1] + offset2) +
multiIndex.slice(greaterDim - 1 until n - 1).toIntArray()
resTensor[diagonalMultiIndex] = diagonalEntries[multiIndex]
}
return resTensor.tensor
}
/**
* Applies the [transform] function to each element of the tensor and returns the resulting modified tensor.
*
* @param transform the function to be applied to each element of the tensor.
* @return the resulting tensor after applying the function.
*/
public fun Tensor<Double>.map(transform: (Double) -> Double): DoubleTensor {
return DoubleTensor(
tensor.shape,
tensor.mutableBuffer.array().map { transform(it) }.toDoubleArray(),
tensor.bufferStart
)
}
/**
* Compares element-wise two tensors with a specified precision.
*
* @param other the tensor to compare with `input` tensor.
* @param epsilon permissible error when comparing two Double values.
* @return true if two tensors have the same shape and elements, false otherwise.
*/
public fun Tensor<Double>.eq(other: Tensor<Double>, epsilon: Double): Boolean =
tensor.eq(other) { x, y -> abs(x - y) < epsilon }
/**
* Compares element-wise two tensors.
* Comparison of two Double values occurs with 1e-5 precision.
*
* @param other the tensor to compare with `input` tensor.
* @return true if two tensors have the same shape and elements, false otherwise.
*/
public infix fun Tensor<Double>.eq(other: Tensor<Double>): Boolean = tensor.eq(other, 1e-5)
private fun Tensor<Double>.eq(
other: Tensor<Double>,
eqFunction: (Double, Double) -> Boolean
): Boolean {
checkShapesCompatible(tensor, other)
val n = tensor.numElements
if (n != other.tensor.numElements) {
return false
}
for (i in 0 until n) {
if (!eqFunction(
tensor.mutableBuffer[tensor.bufferStart + i],
other.tensor.mutableBuffer[other.tensor.bufferStart + i]
)
) {
return false
}
}
return true
}
/**
* Returns a tensor of random numbers drawn from normal distributions with 0.0 mean and 1.0 standard deviation.
*
* @param shape the desired shape for the output tensor.
* @param seed the random seed of the pseudo-random number generator.
* @return tensor of a given shape filled with numbers from the normal distribution
* with 0.0 mean and 1.0 standard deviation.
*/
public fun randomNormal(shape: IntArray, seed: Long = 0): DoubleTensor =
DoubleTensor(shape, getRandomNormals(shape.reduce(Int::times), seed))
/**
* Returns a tensor with the same shape as `input` of random numbers drawn from normal distributions
* with 0.0 mean and 1.0 standard deviation.
*
* @param seed the random seed of the pseudo-random number generator.
* @return tensor with the same shape as `input` filled with numbers from the normal distribution
* with 0.0 mean and 1.0 standard deviation.
*/
public fun Tensor<Double>.randomNormalLike(seed: Long = 0): DoubleTensor =
DoubleTensor(tensor.shape, getRandomNormals(tensor.shape.reduce(Int::times), seed))
/**
* Concatenates a sequence of tensors with equal shapes along the first dimension.
*
* @param tensors the [List] of tensors with same shapes to concatenate
* @return tensor with concatenation result
*/
public fun stack(tensors: List<Tensor<Double>>): DoubleTensor {
check(tensors.isNotEmpty()) { "List must have at least 1 element" }
val shape = tensors[0].shape
check(tensors.all { it.shape contentEquals shape }) { "Tensors must have same shapes" }
val resShape = intArrayOf(tensors.size) + shape
val resBuffer = tensors.flatMap {
it.tensor.mutableBuffer.array().drop(it.tensor.bufferStart).take(it.tensor.numElements)
}.toDoubleArray()
return DoubleTensor(resShape, resBuffer, 0)
}
/**
* Builds tensor from rows of input tensor
*
* @param indices the [IntArray] of 1-dimensional indices
* @return tensor with rows corresponding to rows by [indices]
*/
public fun Tensor<Double>.rowsByIndices(indices: IntArray): DoubleTensor {
return stack(indices.map { this[it] })
}
internal fun Tensor<Double>.fold(foldFunction: (DoubleArray) -> Double): Double =
foldFunction(tensor.toDoubleArray())
internal fun Tensor<Double>.foldDim(
foldFunction: (DoubleArray) -> Double,
dim: Int,
keepDim: Boolean
): DoubleTensor {
check(dim < dimension) { "Dimension $dim out of range $dimension" }
val resShape = if (keepDim) {
shape.take(dim).toIntArray() + intArrayOf(1) + shape.takeLast(dimension - dim - 1).toIntArray()
} else {
shape.take(dim).toIntArray() + shape.takeLast(dimension - dim - 1).toIntArray()
}
val resNumElements = resShape.reduce(Int::times)
val resTensor = DoubleTensor(resShape, DoubleArray(resNumElements) { 0.0 }, 0)
for (index in resTensor.linearStructure.indices()) {
val prefix = index.take(dim).toIntArray()
val suffix = index.takeLast(dimension - dim - 1).toIntArray()
resTensor[index] = foldFunction(DoubleArray(shape[dim]) { i ->
tensor[prefix + intArrayOf(i) + suffix]
})
}
return resTensor
}
override fun Tensor<Double>.sum(): Double = tensor.fold { it.sum() }
override fun Tensor<Double>.sum(dim: Int, keepDim: Boolean): DoubleTensor =
foldDim({ x -> x.sum() }, dim, keepDim)
override fun Tensor<Double>.min(): Double = this.fold { it.minOrNull()!! }
override fun Tensor<Double>.min(dim: Int, keepDim: Boolean): DoubleTensor =
foldDim({ x -> x.minOrNull()!! }, dim, keepDim)
override fun Tensor<Double>.max(): Double = this.fold { it.maxOrNull()!! }
override fun Tensor<Double>.max(dim: Int, keepDim: Boolean): DoubleTensor =
foldDim({ x -> x.maxOrNull()!! }, dim, keepDim)
override fun Tensor<Double>.argMax(dim: Int, keepDim: Boolean): DoubleTensor =
foldDim({ x ->
x.withIndex().maxByOrNull { it.value }?.index!!.toDouble()
}, dim, keepDim)
override fun Tensor<Double>.mean(): Double = this.fold { it.sum() / tensor.numElements }
override fun Tensor<Double>.mean(dim: Int, keepDim: Boolean): DoubleTensor =
foldDim(
{ arr ->
check(dim < dimension) { "Dimension $dim out of range $dimension" }
arr.sum() / shape[dim]
},
dim,
keepDim
)
override fun Tensor<Double>.std(): Double = this.fold { arr ->
val mean = arr.sum() / tensor.numElements
sqrt(arr.sumOf { (it - mean) * (it - mean) } / (tensor.numElements - 1))
}
override fun Tensor<Double>.std(dim: Int, keepDim: Boolean): DoubleTensor = foldDim(
{ arr ->
check(dim < dimension) { "Dimension $dim out of range $dimension" }
val mean = arr.sum() / shape[dim]
sqrt(arr.sumOf { (it - mean) * (it - mean) } / (shape[dim] - 1))
},
dim,
keepDim
)
override fun Tensor<Double>.variance(): Double = this.fold { arr ->
val mean = arr.sum() / tensor.numElements
arr.sumOf { (it - mean) * (it - mean) } / (tensor.numElements - 1)
}
override fun Tensor<Double>.variance(dim: Int, keepDim: Boolean): DoubleTensor = foldDim(
{ arr ->
check(dim < dimension) { "Dimension $dim out of range $dimension" }
val mean = arr.sum() / shape[dim]
arr.sumOf { (it - mean) * (it - mean) } / (shape[dim] - 1)
},
dim,
keepDim
)
private fun cov(x: DoubleTensor, y: DoubleTensor): Double {
val n = x.shape[0]
return ((x - x.mean()) * (y - y.mean())).mean() * n / (n - 1)
}
override fun cov(tensors: List<Tensor<Double>>): DoubleTensor {
check(tensors.isNotEmpty()) { "List must have at least 1 element" }
val n = tensors.size
val m = tensors[0].shape[0]
check(tensors.all { it.shape contentEquals intArrayOf(m) }) { "Tensors must have same shapes" }
val resTensor = DoubleTensor(
intArrayOf(n, n),
DoubleArray(n * n) { 0.0 }
)
for (i in 0 until n) {
for (j in 0 until n) {
resTensor[intArrayOf(i, j)] = cov(tensors[i].tensor, tensors[j].tensor)
}
}
return resTensor
}
override fun Tensor<Double>.exp(): DoubleTensor = tensor.map(::exp)
override fun Tensor<Double>.ln(): DoubleTensor = tensor.map(::ln)
override fun Tensor<Double>.sqrt(): DoubleTensor = tensor.map(::sqrt)
override fun Tensor<Double>.cos(): DoubleTensor = tensor.map(::cos)
override fun Tensor<Double>.acos(): DoubleTensor = tensor.map(::acos)
override fun Tensor<Double>.cosh(): DoubleTensor = tensor.map(::cosh)
override fun Tensor<Double>.acosh(): DoubleTensor = tensor.map(::acosh)
override fun Tensor<Double>.sin(): DoubleTensor = tensor.map(::sin)
override fun Tensor<Double>.asin(): DoubleTensor = tensor.map(::asin)
override fun Tensor<Double>.sinh(): DoubleTensor = tensor.map(::sinh)
override fun Tensor<Double>.asinh(): DoubleTensor = tensor.map(::asinh)
override fun Tensor<Double>.tan(): DoubleTensor = tensor.map(::tan)
override fun Tensor<Double>.atan(): DoubleTensor = tensor.map(::atan)
override fun Tensor<Double>.tanh(): DoubleTensor = tensor.map(::tanh)
override fun Tensor<Double>.atanh(): DoubleTensor = tensor.map(::atanh)
override fun Tensor<Double>.ceil(): DoubleTensor = tensor.map(::ceil)
override fun Tensor<Double>.floor(): DoubleTensor = tensor.map(::floor)
override fun Tensor<Double>.inv(): DoubleTensor = invLU(1e-9)
override fun Tensor<Double>.det(): DoubleTensor = detLU(1e-9)
/**
* Computes the LU factorization of a matrix or batches of matrices `input`.
* Returns a tuple containing the LU factorization and pivots of `input`.
*
* @param epsilon permissible error when comparing the determinant of a matrix with zero
* @return pair of `factorization` and `pivots`.
* The `factorization` has the shape ``(*, m, n)``, where``(*, m, n)`` is the shape of the `input` tensor.
* The `pivots` has the shape ``(, min(m, n))``. `pivots` stores all the intermediate transpositions of rows.
*/
public fun Tensor<Double>.luFactor(epsilon: Double): Pair<DoubleTensor, IntTensor> =
computeLU(tensor, epsilon)
?: throw IllegalArgumentException("Tensor contains matrices which are singular at precision $epsilon")
/**
* Computes the LU factorization of a matrix or batches of matrices `input`.
* Returns a tuple containing the LU factorization and pivots of `input`.
* Uses an error of ``1e-9`` when calculating whether a matrix is degenerate.
*
* @return pair of `factorization` and `pivots`.
* The `factorization` has the shape ``(*, m, n)``, where``(*, m, n)`` is the shape of the `input` tensor.
* The `pivots` has the shape ``(, min(m, n))``. `pivots` stores all the intermediate transpositions of rows.
*/
public fun Tensor<Double>.luFactor(): Pair<DoubleTensor, IntTensor> = luFactor(1e-9)
/**
* Unpacks the data and pivots from a LU factorization of a tensor.
* Given a tensor [luTensor], return tensors (P, L, U) satisfying ``P * luTensor = L * U``,
* with `P` being a permutation matrix or batch of matrices,
* `L` being a lower triangular matrix or batch of matrices,
* `U` being an upper triangular matrix or batch of matrices.
*
* @param luTensor the packed LU factorization data
* @param pivotsTensor the packed LU factorization pivots
* @return triple of P, L and U tensors
*/
public fun luPivot(
luTensor: Tensor<Double>,
pivotsTensor: Tensor<Int>
): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
checkSquareMatrix(luTensor.shape)
check(
luTensor.shape.dropLast(2).toIntArray() contentEquals pivotsTensor.shape.dropLast(1).toIntArray() ||
luTensor.shape.last() == pivotsTensor.shape.last() - 1
) { "Inappropriate shapes of input tensors" }
val n = luTensor.shape.last()
val pTensor = luTensor.zeroesLike()
pTensor
.matrixSequence()
.zip(pivotsTensor.tensor.vectorSequence())
.forEach { (p, pivot) -> pivInit(p.as2D(), pivot.as1D(), n) }
val lTensor = luTensor.zeroesLike()
val uTensor = luTensor.zeroesLike()
lTensor.matrixSequence()
.zip(uTensor.matrixSequence())
.zip(luTensor.tensor.matrixSequence())
.forEach { (pairLU, lu) ->
val (l, u) = pairLU
luPivotHelper(l.as2D(), u.as2D(), lu.as2D(), n)
}
return Triple(pTensor, lTensor, uTensor)
}
/**
* QR decomposition.
*
* Computes the QR decomposition of a matrix or a batch of matrices, and returns a pair `(Q, R)` of tensors.
* Given a tensor `input`, return tensors (Q, R) satisfying ``input = Q * R``,
* with `Q` being an orthogonal matrix or batch of orthogonal matrices
* and `R` being an upper triangular matrix or batch of upper triangular matrices.
*
* @param epsilon permissible error when comparing tensors for equality.
* Used when checking the positive definiteness of the input matrix or matrices.
* @return pair of Q and R tensors.
*/
public fun Tensor<Double>.cholesky(epsilon: Double): DoubleTensor {
checkSquareMatrix(shape)
checkPositiveDefinite(tensor, epsilon)
val n = shape.last()
val lTensor = zeroesLike()
for ((a, l) in tensor.matrixSequence().zip(lTensor.matrixSequence()))
for (i in 0 until n) choleskyHelper(a.as2D(), l.as2D(), n)
return lTensor
}
override fun Tensor<Double>.cholesky(): DoubleTensor = cholesky(1e-6)
override fun Tensor<Double>.qr(): Pair<DoubleTensor, DoubleTensor> {
checkSquareMatrix(shape)
val qTensor = zeroesLike()
val rTensor = zeroesLike()
tensor.matrixSequence()
.zip(
(qTensor.matrixSequence()
.zip(rTensor.matrixSequence()))
).forEach { (matrix, qr) ->
val (q, r) = qr
qrHelper(matrix.asTensor(), q.asTensor(), r.as2D())
}
return qTensor to rTensor
}
override fun Tensor<Double>.svd(): Triple<DoubleTensor, DoubleTensor, DoubleTensor> =
svd(epsilon = 1e-10)
/**
* Singular Value Decomposition.
*
* Computes the singular value decomposition of either a matrix or batch of matrices `input`.
* The singular value decomposition is represented as a triple `(U, S, V)`,
* such that ``input = U.dot(diagonalEmbedding(S).dot(V.T))``.
* If input is a batch of tensors, then U, S, and Vh are also batched with the same batch dimensions as input.
*
* @param epsilon permissible error when calculating the dot product of vectors,
* i.e. the precision with which the cosine approaches 1 in an iterative algorithm.
* @return triple `(U, S, V)`.
*/
public fun Tensor<Double>.svd(epsilon: Double): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
val size = tensor.dimension
val commonShape = tensor.shape.sliceArray(0 until size - 2)
val (n, m) = tensor.shape.sliceArray(size - 2 until size)
val uTensor = zeros(commonShape + intArrayOf(min(n, m), n))
val sTensor = zeros(commonShape + intArrayOf(min(n, m)))
val vTensor = zeros(commonShape + intArrayOf(min(n, m), m))
tensor.matrixSequence()
.zip(
uTensor.matrixSequence()
.zip(
sTensor.vectorSequence()
.zip(vTensor.matrixSequence())
)
).forEach { (matrix, USV) ->
val matrixSize = matrix.shape.reduce { acc, i -> acc * i }
val curMatrix = DoubleTensor(
matrix.shape,
matrix.mutableBuffer.array().slice(matrix.bufferStart until matrix.bufferStart + matrixSize)
.toDoubleArray()
)
svdHelper(curMatrix, USV, m, n, epsilon)
}
return Triple(uTensor.transpose(), sTensor, vTensor.transpose())
}
override fun Tensor<Double>.symEig(): Pair<DoubleTensor, DoubleTensor> =
symEig(epsilon = 1e-15)
/**
* Returns eigenvalues and eigenvectors of a real symmetric matrix input or a batch of real symmetric matrices,
* represented by a pair (eigenvalues, eigenvectors).
*
* @param epsilon permissible error when comparing tensors for equality
* and when the cosine approaches 1 in the SVD algorithm.
* @return a pair (eigenvalues, eigenvectors)
*/
public fun Tensor<Double>.symEig(epsilon: Double): Pair<DoubleTensor, DoubleTensor> {
checkSymmetric(tensor, epsilon)
val (u, s, v) = tensor.svd(epsilon)
val shp = s.shape + intArrayOf(1)
val utv = u.transpose() dot v
val n = s.shape.last()
for (matrix in utv.matrixSequence())
cleanSymHelper(matrix.as2D(), n)
val eig = (utv dot s.view(shp)).view(s.shape)
return eig to v
}
/**
* Computes the determinant of a square matrix input, or of each square matrix in a batched input
* using LU factorization algorithm.
*
* @param epsilon error in the LU algorithm - permissible error when comparing the determinant of a matrix with zero
* @return the determinant.
*/
public fun Tensor<Double>.detLU(epsilon: Double = 1e-9): DoubleTensor {
checkSquareMatrix(tensor.shape)
val luTensor = tensor.copy()
val pivotsTensor = tensor.setUpPivots()
val n = shape.size
val detTensorShape = IntArray(n - 1) { i -> shape[i] }
detTensorShape[n - 2] = 1
val resBuffer = DoubleArray(detTensorShape.reduce(Int::times)) { 0.0 }
val detTensor = DoubleTensor(
detTensorShape,
resBuffer
)
luTensor.matrixSequence().zip(pivotsTensor.vectorSequence()).forEachIndexed { index, (lu, pivots) ->
resBuffer[index] = if (luHelper(lu.as2D(), pivots.as1D(), epsilon))
0.0 else luMatrixDet(lu.as2D(), pivots.as1D())
}
return detTensor
}
/**
* Computes the multiplicative inverse matrix of a square matrix input, or of each square matrix in a batched input
* using LU factorization algorithm.
* Given a square matrix `a`, return the matrix `aInv` satisfying
* ``a.dot(aInv) = aInv.dot(a) = eye(a.shape[0])``.
*
* @param epsilon error in the LU algorithm - permissible error when comparing the determinant of a matrix with zero
* @return the multiplicative inverse of a matrix.
*/
public fun Tensor<Double>.invLU(epsilon: Double = 1e-9): DoubleTensor {
val (luTensor, pivotsTensor) = luFactor(epsilon)
val invTensor = luTensor.zeroesLike()
val seq = luTensor.matrixSequence().zip(pivotsTensor.vectorSequence()).zip(invTensor.matrixSequence())
for ((luP, invMatrix) in seq) {
val (lu, pivots) = luP
luMatrixInv(lu.as2D(), pivots.as1D(), invMatrix.as2D())
}
return invTensor
}
/**
* LUP decomposition
*
* Computes the LUP decomposition of a matrix or a batch of matrices.
* Given a tensor `input`, return tensors (P, L, U) satisfying ``P * input = L * U``,
* with `P` being a permutation matrix or batch of matrices,
* `L` being a lower triangular matrix or batch of matrices,
* `U` being an upper triangular matrix or batch of matrices.
*
* @param epsilon permissible error when comparing the determinant of a matrix with zero
* @return triple of P, L and U tensors
*/
public fun Tensor<Double>.lu(epsilon: Double = 1e-9): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
val (lu, pivots) = tensor.luFactor(epsilon)
return luPivot(lu, pivots)
}
override fun Tensor<Double>.lu(): Triple<DoubleTensor, DoubleTensor, DoubleTensor> = lu(1e-9)
}

View File

@ -0,0 +1,17 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.tensors.core
import space.kscience.kmath.structures.IntBuffer
/**
* Default [BufferedTensor] implementation for [Int] values
*/
public class IntTensor internal constructor(
shape: IntArray,
buffer: IntArray,
offset: Int = 0
) : BufferedTensor<Int>(shape, IntBuffer(buffer), offset)

View File

@ -0,0 +1,57 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.tensors.core.internal
import space.kscience.kmath.nd.Strides
import kotlin.math.max
internal fun stridesFromShape(shape: IntArray): IntArray {
val nDim = shape.size
val res = IntArray(nDim)
if (nDim == 0)
return res
var current = nDim - 1
res[current] = 1
while (current > 0) {
res[current - 1] = max(1, shape[current]) * res[current]
current--
}
return res
}
internal fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): IntArray {
val res = IntArray(nDim)
var current = offset
var strideIndex = 0
while (strideIndex < nDim) {
res[strideIndex] = (current / strides[strideIndex])
current %= strides[strideIndex]
strideIndex++
}
return res
}
/**
* This [Strides] implementation follows the last dimension first convention
* For more information: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.strides.html
*
* @param shape the shape of the tensor.
*/
internal class TensorLinearStructure(override val shape: IntArray) : Strides {
override val strides: IntArray
get() = stridesFromShape(shape)
override fun index(offset: Int): IntArray =
indexFromOffset(offset, strides, shape.size)
override val linearSize: Int
get() = shape.reduce(Int::times)
}

View File

@ -0,0 +1,146 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.tensors.core.internal
import space.kscience.kmath.tensors.core.DoubleTensor
import kotlin.math.max
internal fun multiIndexBroadCasting(tensor: DoubleTensor, resTensor: DoubleTensor, linearSize: Int) {
for (linearIndex in 0 until linearSize) {
val totalMultiIndex = resTensor.linearStructure.index(linearIndex)
val curMultiIndex = tensor.shape.copyOf()
val offset = totalMultiIndex.size - curMultiIndex.size
for (i in curMultiIndex.indices) {
if (curMultiIndex[i] != 1) {
curMultiIndex[i] = totalMultiIndex[i + offset]
} else {
curMultiIndex[i] = 0
}
}
val curLinearIndex = tensor.linearStructure.offset(curMultiIndex)
resTensor.mutableBuffer.array()[linearIndex] =
tensor.mutableBuffer.array()[tensor.bufferStart + curLinearIndex]
}
}
internal fun broadcastShapes(vararg shapes: IntArray): IntArray {
var totalDim = 0
for (shape in shapes) {
totalDim = max(totalDim, shape.size)
}
val totalShape = IntArray(totalDim) { 0 }
for (shape in shapes) {
for (i in shape.indices) {
val curDim = shape[i]
val offset = totalDim - shape.size
totalShape[i + offset] = max(totalShape[i + offset], curDim)
}
}
for (shape in shapes) {
for (i in shape.indices) {
val curDim = shape[i]
val offset = totalDim - shape.size
check(curDim == 1 || totalShape[i + offset] == curDim) {
"Shapes are not compatible and cannot be broadcast"
}
}
}
return totalShape
}
internal fun broadcastTo(tensor: DoubleTensor, newShape: IntArray): DoubleTensor {
require(tensor.shape.size <= newShape.size) {
"Tensor is not compatible with the new shape"
}
val n = newShape.reduce { acc, i -> acc * i }
val resTensor = DoubleTensor(newShape, DoubleArray(n))
for (i in tensor.shape.indices) {
val curDim = tensor.shape[i]
val offset = newShape.size - tensor.shape.size
check(curDim == 1 || newShape[i + offset] == curDim) {
"Tensor is not compatible with the new shape and cannot be broadcast"
}
}
multiIndexBroadCasting(tensor, resTensor, n)
return resTensor
}
internal fun broadcastTensors(vararg tensors: DoubleTensor): List<DoubleTensor> {
val totalShape = broadcastShapes(*(tensors.map { it.shape }).toTypedArray())
val n = totalShape.reduce { acc, i -> acc * i }
return tensors.map { tensor ->
val resTensor = DoubleTensor(totalShape, DoubleArray(n))
multiIndexBroadCasting(tensor, resTensor, n)
resTensor
}
}
internal fun broadcastOuterTensors(vararg tensors: DoubleTensor): List<DoubleTensor> {
val onlyTwoDims = tensors.asSequence().onEach {
require(it.shape.size >= 2) {
"Tensors must have at least 2 dimensions"
}
}.any { it.shape.size != 2 }
if (!onlyTwoDims) {
return tensors.asList()
}
val totalShape = broadcastShapes(*(tensors.map { it.shape.sliceArray(0..it.shape.size - 3) }).toTypedArray())
val n = totalShape.reduce { acc, i -> acc * i }
return buildList {
for (tensor in tensors) {
val matrixShape = tensor.shape.sliceArray(tensor.shape.size - 2 until tensor.shape.size).copyOf()
val matrixSize = matrixShape[0] * matrixShape[1]
val matrix = DoubleTensor(matrixShape, DoubleArray(matrixSize))
val outerTensor = DoubleTensor(totalShape, DoubleArray(n))
val resTensor = DoubleTensor(totalShape + matrixShape, DoubleArray(n * matrixSize))
for (linearIndex in 0 until n) {
val totalMultiIndex = outerTensor.linearStructure.index(linearIndex)
var curMultiIndex = tensor.shape.sliceArray(0..tensor.shape.size - 3).copyOf()
curMultiIndex = IntArray(totalMultiIndex.size - curMultiIndex.size) { 1 } + curMultiIndex
val newTensor = DoubleTensor(curMultiIndex + matrixShape, tensor.mutableBuffer.array())
for (i in curMultiIndex.indices) {
if (curMultiIndex[i] != 1) {
curMultiIndex[i] = totalMultiIndex[i]
} else {
curMultiIndex[i] = 0
}
}
for (i in 0 until matrixSize) {
val curLinearIndex = newTensor.linearStructure.offset(
curMultiIndex +
matrix.linearStructure.index(i)
)
val newLinearIndex = resTensor.linearStructure.offset(
totalMultiIndex +
matrix.linearStructure.index(i)
)
resTensor.mutableBuffer.array()[resTensor.bufferStart + newLinearIndex] =
newTensor.mutableBuffer.array()[newTensor.bufferStart + curLinearIndex]
}
}
add(resTensor)
}
}
}

View File

@ -0,0 +1,64 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.tensors.core.internal
import space.kscience.kmath.tensors.api.Tensor
import space.kscience.kmath.tensors.core.DoubleTensor
import space.kscience.kmath.tensors.core.DoubleTensorAlgebra
internal fun checkEmptyShape(shape: IntArray) =
check(shape.isNotEmpty()) {
"Illegal empty shape provided"
}
internal fun checkEmptyDoubleBuffer(buffer: DoubleArray) =
check(buffer.isNotEmpty()) {
"Illegal empty buffer provided"
}
internal fun checkBufferShapeConsistency(shape: IntArray, buffer: DoubleArray) =
check(buffer.size == shape.reduce(Int::times)) {
"Inconsistent shape ${shape.toList()} for buffer of size ${buffer.size} provided"
}
internal fun <T> checkShapesCompatible(a: Tensor<T>, b: Tensor<T>) =
check(a.shape contentEquals b.shape) {
"Incompatible shapes ${a.shape.toList()} and ${b.shape.toList()} "
}
internal fun checkTranspose(dim: Int, i: Int, j: Int) =
check((i < dim) and (j < dim)) {
"Cannot transpose $i to $j for a tensor of dim $dim"
}
internal fun <T> checkView(a: Tensor<T>, shape: IntArray) =
check(a.shape.reduce(Int::times) == shape.reduce(Int::times))
internal fun checkSquareMatrix(shape: IntArray) {
val n = shape.size
check(n >= 2) {
"Expected tensor with 2 or more dimensions, got size $n instead"
}
check(shape[n - 1] == shape[n - 2]) {
"Tensor must be batches of square matrices, but they are ${shape[n - 1]} by ${shape[n - 1]} matrices"
}
}
internal fun DoubleTensorAlgebra.checkSymmetric(
tensor: Tensor<Double>, epsilon: Double = 1e-6
) =
check(tensor.eq(tensor.transpose(), epsilon)) {
"Tensor is not symmetric about the last 2 dimensions at precision $epsilon"
}
internal fun DoubleTensorAlgebra.checkPositiveDefinite(tensor: DoubleTensor, epsilon: Double = 1e-6) {
checkSymmetric(tensor, epsilon)
for (mat in tensor.matrixSequence())
check(mat.asTensor().detLU().value() > 0.0) {
"Tensor contains matrices which are not positive definite ${mat.asTensor().detLU().value()}"
}
}

View File

@ -0,0 +1,342 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.tensors.core.internal
import space.kscience.kmath.nd.MutableStructure1D
import space.kscience.kmath.nd.MutableStructure2D
import space.kscience.kmath.nd.as1D
import space.kscience.kmath.nd.as2D
import space.kscience.kmath.operations.invoke
import space.kscience.kmath.tensors.core.*
import space.kscience.kmath.tensors.core.DoubleTensorAlgebra
import space.kscience.kmath.tensors.core.DoubleTensorAlgebra.Companion.valueOrNull
import kotlin.math.abs
import kotlin.math.min
import kotlin.math.sign
import kotlin.math.sqrt
internal fun <T> BufferedTensor<T>.vectorSequence(): Sequence<BufferedTensor<T>> = sequence {
val n = shape.size
val vectorOffset = shape[n - 1]
val vectorShape = intArrayOf(shape.last())
for (offset in 0 until numElements step vectorOffset) {
val vector = BufferedTensor(vectorShape, mutableBuffer, bufferStart + offset)
yield(vector)
}
}
internal fun <T> BufferedTensor<T>.matrixSequence(): Sequence<BufferedTensor<T>> = sequence {
val n = shape.size
check(n >= 2) { "Expected tensor with 2 or more dimensions, got size $n" }
val matrixOffset = shape[n - 1] * shape[n - 2]
val matrixShape = intArrayOf(shape[n - 2], shape[n - 1])
for (offset in 0 until numElements step matrixOffset) {
val matrix = BufferedTensor(matrixShape, mutableBuffer, bufferStart + offset)
yield(matrix)
}
}
internal fun dotHelper(
a: MutableStructure2D<Double>,
b: MutableStructure2D<Double>,
res: MutableStructure2D<Double>,
l: Int, m: Int, n: Int
) {
for (i in 0 until l) {
for (j in 0 until n) {
var curr = 0.0
for (k in 0 until m) {
curr += a[i, k] * b[k, j]
}
res[i, j] = curr
}
}
}
internal fun luHelper(
lu: MutableStructure2D<Double>,
pivots: MutableStructure1D<Int>,
epsilon: Double
): Boolean {
val m = lu.rowNum
for (row in 0..m) pivots[row] = row
for (i in 0 until m) {
var maxVal = 0.0
var maxInd = i
for (k in i until m) {
val absA = abs(lu[k, i])
if (absA > maxVal) {
maxVal = absA
maxInd = k
}
}
if (abs(maxVal) < epsilon)
return true // matrix is singular
if (maxInd != i) {
val j = pivots[i]
pivots[i] = pivots[maxInd]
pivots[maxInd] = j
for (k in 0 until m) {
val tmp = lu[i, k]
lu[i, k] = lu[maxInd, k]
lu[maxInd, k] = tmp
}
pivots[m] += 1
}
for (j in i + 1 until m) {
lu[j, i] /= lu[i, i]
for (k in i + 1 until m) {
lu[j, k] -= lu[j, i] * lu[i, k]
}
}
}
return false
}
internal fun <T> BufferedTensor<T>.setUpPivots(): IntTensor {
val n = this.shape.size
val m = this.shape.last()
val pivotsShape = IntArray(n - 1) { i -> this.shape[i] }
pivotsShape[n - 2] = m + 1
return IntTensor(
pivotsShape,
IntArray(pivotsShape.reduce(Int::times)) { 0 }
)
}
internal fun DoubleTensorAlgebra.computeLU(
tensor: DoubleTensor,
epsilon: Double
): Pair<DoubleTensor, IntTensor>? {
checkSquareMatrix(tensor.shape)
val luTensor = tensor.copy()
val pivotsTensor = tensor.setUpPivots()
for ((lu, pivots) in luTensor.matrixSequence().zip(pivotsTensor.vectorSequence()))
if (luHelper(lu.as2D(), pivots.as1D(), epsilon))
return null
return Pair(luTensor, pivotsTensor)
}
internal fun pivInit(
p: MutableStructure2D<Double>,
pivot: MutableStructure1D<Int>,
n: Int
) {
for (i in 0 until n) {
p[i, pivot[i]] = 1.0
}
}
internal fun luPivotHelper(
l: MutableStructure2D<Double>,
u: MutableStructure2D<Double>,
lu: MutableStructure2D<Double>,
n: Int
) {
for (i in 0 until n) {
for (j in 0 until n) {
if (i == j) {
l[i, j] = 1.0
}
if (j < i) {
l[i, j] = lu[i, j]
}
if (j >= i) {
u[i, j] = lu[i, j]
}
}
}
}
internal fun choleskyHelper(
a: MutableStructure2D<Double>,
l: MutableStructure2D<Double>,
n: Int
) {
for (i in 0 until n) {
for (j in 0 until i) {
var h = a[i, j]
for (k in 0 until j) {
h -= l[i, k] * l[j, k]
}
l[i, j] = h / l[j, j]
}
var h = a[i, i]
for (j in 0 until i) {
h -= l[i, j] * l[i, j]
}
l[i, i] = sqrt(h)
}
}
internal fun luMatrixDet(lu: MutableStructure2D<Double>, pivots: MutableStructure1D<Int>): Double {
if (lu[0, 0] == 0.0) {
return 0.0
}
val m = lu.shape[0]
val sign = if ((pivots[m] - m) % 2 == 0) 1.0 else -1.0
return (0 until m).asSequence().map { lu[it, it] }.fold(sign) { left, right -> left * right }
}
internal fun luMatrixInv(
lu: MutableStructure2D<Double>,
pivots: MutableStructure1D<Int>,
invMatrix: MutableStructure2D<Double>
) {
val m = lu.shape[0]
for (j in 0 until m) {
for (i in 0 until m) {
if (pivots[i] == j) {
invMatrix[i, j] = 1.0
}
for (k in 0 until i) {
invMatrix[i, j] -= lu[i, k] * invMatrix[k, j]
}
}
for (i in m - 1 downTo 0) {
for (k in i + 1 until m) {
invMatrix[i, j] -= lu[i, k] * invMatrix[k, j]
}
invMatrix[i, j] /= lu[i, i]
}
}
}
internal fun DoubleTensorAlgebra.qrHelper(
matrix: DoubleTensor,
q: DoubleTensor,
r: MutableStructure2D<Double>
) {
checkSquareMatrix(matrix.shape)
val n = matrix.shape[0]
val qM = q.as2D()
val matrixT = matrix.transpose(0, 1)
val qT = q.transpose(0, 1)
for (j in 0 until n) {
val v = matrixT[j]
val vv = v.as1D()
if (j > 0) {
for (i in 0 until j) {
r[i, j] = (qT[i] dot matrixT[j]).value()
for (k in 0 until n) {
val qTi = qT[i].as1D()
vv[k] = vv[k] - r[i, j] * qTi[k]
}
}
}
r[j, j] = DoubleTensorAlgebra { (v dot v).sqrt().value() }
for (i in 0 until n) {
qM[i, j] = vv[i] / r[j, j]
}
}
}
internal fun DoubleTensorAlgebra.svd1d(a: DoubleTensor, epsilon: Double = 1e-10): DoubleTensor {
val (n, m) = a.shape
var v: DoubleTensor
val b: DoubleTensor
if (n > m) {
b = a.transpose(0, 1).dot(a)
v = DoubleTensor(intArrayOf(m), getRandomUnitVector(m, 0))
} else {
b = a.dot(a.transpose(0, 1))
v = DoubleTensor(intArrayOf(n), getRandomUnitVector(n, 0))
}
var lastV: DoubleTensor
while (true) {
lastV = v
v = b.dot(lastV)
val norm = DoubleTensorAlgebra { (v dot v).sqrt().value() }
v = v.times(1.0 / norm)
if (abs(v.dot(lastV).value()) > 1 - epsilon) {
return v
}
}
}
internal fun DoubleTensorAlgebra.svdHelper(
matrix: DoubleTensor,
USV: Pair<BufferedTensor<Double>, Pair<BufferedTensor<Double>, BufferedTensor<Double>>>,
m: Int, n: Int, epsilon: Double
) {
val res = ArrayList<Triple<Double, DoubleTensor, DoubleTensor>>(0)
val (matrixU, SV) = USV
val (matrixS, matrixV) = SV
for (k in 0 until min(n, m)) {
var a = matrix.copy()
for ((singularValue, u, v) in res.slice(0 until k)) {
val outerProduct = DoubleArray(u.shape[0] * v.shape[0])
for (i in 0 until u.shape[0]) {
for (j in 0 until v.shape[0]) {
outerProduct[i * v.shape[0] + j] = u[i].value() * v[j].value()
}
}
a = a - singularValue.times(DoubleTensor(intArrayOf(u.shape[0], v.shape[0]), outerProduct))
}
var v: DoubleTensor
var u: DoubleTensor
var norm: Double
if (n > m) {
v = svd1d(a, epsilon)
u = matrix.dot(v)
norm = DoubleTensorAlgebra { (u dot u).sqrt().value() }
u = u.times(1.0 / norm)
} else {
u = svd1d(a, epsilon)
v = matrix.transpose(0, 1).dot(u)
norm = DoubleTensorAlgebra { (v dot v).sqrt().value() }
v = v.times(1.0 / norm)
}
res.add(Triple(norm, u, v))
}
val s = res.map { it.first }.toDoubleArray()
val uBuffer = res.map { it.second }.flatMap { it.mutableBuffer.array().toList() }.toDoubleArray()
val vBuffer = res.map { it.third }.flatMap { it.mutableBuffer.array().toList() }.toDoubleArray()
for (i in uBuffer.indices) {
matrixU.mutableBuffer.array()[matrixU.bufferStart + i] = uBuffer[i]
}
for (i in s.indices) {
matrixS.mutableBuffer.array()[matrixS.bufferStart + i] = s[i]
}
for (i in vBuffer.indices) {
matrixV.mutableBuffer.array()[matrixV.bufferStart + i] = vBuffer[i]
}
}
internal fun cleanSymHelper(matrix: MutableStructure2D<Double>, n: Int) {
for (i in 0 until n)
for (j in 0 until n) {
if (i == j) {
matrix[i, j] = sign(matrix[i, j])
} else {
matrix[i, j] = 0.0
}
}
}

View File

@ -0,0 +1,44 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.tensors.core.internal
import space.kscience.kmath.nd.MutableBufferND
import space.kscience.kmath.structures.asMutableBuffer
import space.kscience.kmath.tensors.api.Tensor
import space.kscience.kmath.tensors.core.BufferedTensor
import space.kscience.kmath.tensors.core.DoubleTensor
import space.kscience.kmath.tensors.core.IntTensor
internal fun BufferedTensor<Int>.asTensor(): IntTensor =
IntTensor(this.shape, this.mutableBuffer.array(), this.bufferStart)
internal fun BufferedTensor<Double>.asTensor(): DoubleTensor =
DoubleTensor(this.shape, this.mutableBuffer.array(), this.bufferStart)
internal fun <T> Tensor<T>.copyToBufferedTensor(): BufferedTensor<T> =
BufferedTensor(
this.shape,
TensorLinearStructure(this.shape).indices().map(this::get).toMutableList().asMutableBuffer(), 0
)
internal fun <T> Tensor<T>.toBufferedTensor(): BufferedTensor<T> = when (this) {
is BufferedTensor<T> -> this
is MutableBufferND<T> -> if (this.strides.strides contentEquals TensorLinearStructure(this.shape).strides)
BufferedTensor(this.shape, this.mutableBuffer, 0) else this.copyToBufferedTensor()
else -> this.copyToBufferedTensor()
}
internal val Tensor<Double>.tensor: DoubleTensor
get() = when (this) {
is DoubleTensor -> this
else -> this.toBufferedTensor().asTensor()
}
internal val Tensor<Int>.tensor: IntTensor
get() = when (this) {
is IntTensor -> this
else -> this.toBufferedTensor().asTensor()
}

View File

@ -0,0 +1,124 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.tensors.core.internal
import space.kscience.kmath.nd.as1D
import space.kscience.kmath.samplers.GaussianSampler
import space.kscience.kmath.stat.RandomGenerator
import space.kscience.kmath.structures.*
import space.kscience.kmath.tensors.core.BufferedTensor
import space.kscience.kmath.tensors.core.DoubleTensor
import kotlin.math.*
/**
* Returns a reference to [IntArray] containing all of the elements of this [Buffer] or copy the data.
*/
internal fun Buffer<Int>.array(): IntArray = when (this) {
is IntBuffer -> array
else -> this.toIntArray()
}
/**
* Returns a reference to [DoubleArray] containing all of the elements of this [Buffer] or copy the data.
*/
internal fun Buffer<Double>.array(): DoubleArray = when (this) {
is DoubleBuffer -> array
else -> this.toDoubleArray()
}
internal fun getRandomNormals(n: Int, seed: Long): DoubleArray {
val distribution = GaussianSampler(0.0, 1.0)
val generator = RandomGenerator.default(seed)
return distribution.sample(generator).nextBufferBlocking(n).toDoubleArray()
}
internal fun getRandomUnitVector(n: Int, seed: Long): DoubleArray {
val unnorm = getRandomNormals(n, seed)
val norm = sqrt(unnorm.sumOf { it * it })
return unnorm.map { it / norm }.toDoubleArray()
}
internal fun minusIndexFrom(n: Int, i: Int): Int = if (i >= 0) i else {
val ii = n + i
check(ii >= 0) {
"Out of bound index $i for tensor of dim $n"
}
ii
}
internal fun <T> BufferedTensor<T>.minusIndex(i: Int): Int = minusIndexFrom(this.dimension, i)
internal fun format(value: Double, digits: Int = 4): String = buildString {
val res = buildString {
val ten = 10.0
val approxOrder = if (value == 0.0) 0 else ceil(log10(abs(value))).toInt()
val order = if (
((value % ten) == 0.0) ||
(value == 1.0) ||
((1 / value) % ten == 0.0)
) approxOrder else approxOrder - 1
val lead = value / ten.pow(order)
if (value >= 0.0) append(' ')
append(round(lead * ten.pow(digits)) / ten.pow(digits))
when {
order == 0 -> Unit
order > 0 -> {
append("e+")
append(order)
}
else -> {
append('e')
append(order)
}
}
}
val fLength = digits + 6
append(res)
repeat(fLength - res.length) { append(' ') }
}
internal fun DoubleTensor.toPrettyString(): String = buildString {
var offset = 0
val shape = this@toPrettyString.shape
val linearStructure = this@toPrettyString.linearStructure
val vectorSize = shape.last()
append("DoubleTensor(\n")
var charOffset = 3
for (vector in vectorSequence()) {
repeat(charOffset) { append(' ') }
val index = linearStructure.index(offset)
for (ind in index.reversed()) {
if (ind != 0) {
break
}
append('[')
charOffset += 1
}
val values = vector.as1D().toMutableList().map(::format)
values.joinTo(this, separator = ", ")
append(']')
charOffset -= 1
index.reversed().zip(shape.reversed()).drop(1).forEach { (ind, maxInd) ->
if (ind != maxInd - 1) {
return@forEach
}
append(']')
charOffset -= 1
}
offset += vectorSize
if (this@toPrettyString.numElements == offset) {
break
}
append(",\n")
}
append("\n)")
}

View File

@ -0,0 +1,37 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.tensors.core
import space.kscience.kmath.tensors.api.Tensor
import space.kscience.kmath.tensors.core.internal.tensor
/**
* Casts [Tensor] of [Double] to [DoubleTensor]
*/
public fun Tensor<Double>.toDoubleTensor(): DoubleTensor = this.tensor
/**
* Casts [Tensor] of [Int] to [IntTensor]
*/
public fun Tensor<Int>.toIntTensor(): IntTensor = this.tensor
/**
* Returns [DoubleArray] of tensor elements
*/
public fun DoubleTensor.toDoubleArray(): DoubleArray {
return DoubleArray(numElements) { i ->
mutableBuffer[bufferStart + i]
}
}
/**
* Returns [IntArray] of tensor elements
*/
public fun IntTensor.toIntArray(): IntArray {
return IntArray(numElements) { i ->
mutableBuffer[bufferStart + i]
}
}

View File

@ -0,0 +1,105 @@
package space.kscience.kmath.tensors.core
import space.kscience.kmath.operations.invoke
import space.kscience.kmath.tensors.core.internal.*
import kotlin.test.Test
import kotlin.test.assertTrue
internal class TestBroadcasting {
@Test
fun testBroadcastShapes() = DoubleTensorAlgebra {
assertTrue(
broadcastShapes(
intArrayOf(2, 3), intArrayOf(1, 3), intArrayOf(1, 1, 1)
) contentEquals intArrayOf(1, 2, 3)
)
assertTrue(
broadcastShapes(
intArrayOf(6, 7), intArrayOf(5, 6, 1), intArrayOf(7), intArrayOf(5, 1, 7)
) contentEquals intArrayOf(5, 6, 7)
)
}
@Test
fun testBroadcastTo() = DoubleTensorAlgebra {
val tensor1 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
val tensor2 = fromArray(intArrayOf(1, 3), doubleArrayOf(10.0, 20.0, 30.0))
val res = broadcastTo(tensor2, tensor1.shape)
assertTrue(res.shape contentEquals intArrayOf(2, 3))
assertTrue(res.mutableBuffer.array() contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0))
}
@Test
fun testBroadcastTensors() = DoubleTensorAlgebra {
val tensor1 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
val tensor2 = fromArray(intArrayOf(1, 3), doubleArrayOf(10.0, 20.0, 30.0))
val tensor3 = fromArray(intArrayOf(1, 1, 1), doubleArrayOf(500.0))
val res = broadcastTensors(tensor1, tensor2, tensor3)
assertTrue(res[0].shape contentEquals intArrayOf(1, 2, 3))
assertTrue(res[1].shape contentEquals intArrayOf(1, 2, 3))
assertTrue(res[2].shape contentEquals intArrayOf(1, 2, 3))
assertTrue(res[0].mutableBuffer.array() contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
assertTrue(res[1].mutableBuffer.array() contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0))
assertTrue(res[2].mutableBuffer.array() contentEquals doubleArrayOf(500.0, 500.0, 500.0, 500.0, 500.0, 500.0))
}
@Test
fun testBroadcastOuterTensors() = DoubleTensorAlgebra {
val tensor1 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
val tensor2 = fromArray(intArrayOf(1, 3), doubleArrayOf(10.0, 20.0, 30.0))
val tensor3 = fromArray(intArrayOf(1, 1, 1), doubleArrayOf(500.0))
val res = broadcastOuterTensors(tensor1, tensor2, tensor3)
assertTrue(res[0].shape contentEquals intArrayOf(1, 2, 3))
assertTrue(res[1].shape contentEquals intArrayOf(1, 1, 3))
assertTrue(res[2].shape contentEquals intArrayOf(1, 1, 1))
assertTrue(res[0].mutableBuffer.array() contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
assertTrue(res[1].mutableBuffer.array() contentEquals doubleArrayOf(10.0, 20.0, 30.0))
assertTrue(res[2].mutableBuffer.array() contentEquals doubleArrayOf(500.0))
}
@Test
fun testBroadcastOuterTensorsShapes() = DoubleTensorAlgebra {
val tensor1 = fromArray(intArrayOf(2, 1, 3, 2, 3), DoubleArray(2 * 1 * 3 * 2 * 3) {0.0})
val tensor2 = fromArray(intArrayOf(4, 2, 5, 1, 3, 3), DoubleArray(4 * 2 * 5 * 1 * 3 * 3) {0.0})
val tensor3 = fromArray(intArrayOf(1, 1), doubleArrayOf(500.0))
val res = broadcastOuterTensors(tensor1, tensor2, tensor3)
assertTrue(res[0].shape contentEquals intArrayOf(4, 2, 5, 3, 2, 3))
assertTrue(res[1].shape contentEquals intArrayOf(4, 2, 5, 3, 3, 3))
assertTrue(res[2].shape contentEquals intArrayOf(4, 2, 5, 3, 1, 1))
}
@Test
fun testMinusTensor() = BroadcastDoubleTensorAlgebra.invoke {
val tensor1 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
val tensor2 = fromArray(intArrayOf(1, 3), doubleArrayOf(10.0, 20.0, 30.0))
val tensor3 = fromArray(intArrayOf(1, 1, 1), doubleArrayOf(500.0))
val tensor21 = tensor2 - tensor1
val tensor31 = tensor3 - tensor1
val tensor32 = tensor3 - tensor2
assertTrue(tensor21.shape contentEquals intArrayOf(2, 3))
assertTrue(tensor21.mutableBuffer.array() contentEquals doubleArrayOf(9.0, 18.0, 27.0, 6.0, 15.0, 24.0))
assertTrue(tensor31.shape contentEquals intArrayOf(1, 2, 3))
assertTrue(
tensor31.mutableBuffer.array()
contentEquals doubleArrayOf(499.0, 498.0, 497.0, 496.0, 495.0, 494.0)
)
assertTrue(tensor32.shape contentEquals intArrayOf(1, 1, 3))
assertTrue(tensor32.mutableBuffer.array() contentEquals doubleArrayOf(490.0, 480.0, 470.0))
}
}

View File

@ -0,0 +1,158 @@
package space.kscience.kmath.tensors.core
import space.kscience.kmath.operations.invoke
import kotlin.math.*
import kotlin.test.Test
import kotlin.test.assertTrue
internal class TestDoubleAnalyticTensorAlgebra {
val shape = intArrayOf(2, 1, 3, 2)
val buffer = doubleArrayOf(
27.1, 20.0, 19.84,
23.123, 3.0, 2.0,
3.23, 133.7, 25.3,
100.3, 11.0, 12.012
)
val tensor = DoubleTensor(shape, buffer)
fun DoubleArray.fmap(transform: (Double) -> Double): DoubleArray {
return this.map(transform).toDoubleArray()
}
fun expectedTensor(transform: (Double) -> Double): DoubleTensor {
return DoubleTensor(shape, buffer.fmap(transform))
}
@Test
fun testExp() = DoubleTensorAlgebra {
assertTrue { tensor.exp() eq expectedTensor(::exp) }
}
@Test
fun testLog() = DoubleTensorAlgebra {
assertTrue { tensor.ln() eq expectedTensor(::ln) }
}
@Test
fun testSqrt() = DoubleTensorAlgebra {
assertTrue { tensor.sqrt() eq expectedTensor(::sqrt) }
}
@Test
fun testCos() = DoubleTensorAlgebra {
assertTrue { tensor.cos() eq expectedTensor(::cos) }
}
@Test
fun testCosh() = DoubleTensorAlgebra {
assertTrue { tensor.cosh() eq expectedTensor(::cosh) }
}
@Test
fun testAcosh() = DoubleTensorAlgebra {
assertTrue { tensor.acosh() eq expectedTensor(::acosh) }
}
@Test
fun testSin() = DoubleTensorAlgebra {
assertTrue { tensor.sin() eq expectedTensor(::sin) }
}
@Test
fun testSinh() = DoubleTensorAlgebra {
assertTrue { tensor.sinh() eq expectedTensor(::sinh) }
}
@Test
fun testAsinh() = DoubleTensorAlgebra {
assertTrue { tensor.asinh() eq expectedTensor(::asinh) }
}
@Test
fun testTan() = DoubleTensorAlgebra {
assertTrue { tensor.tan() eq expectedTensor(::tan) }
}
@Test
fun testAtan() = DoubleTensorAlgebra {
assertTrue { tensor.atan() eq expectedTensor(::atan) }
}
@Test
fun testTanh() = DoubleTensorAlgebra {
assertTrue { tensor.tanh() eq expectedTensor(::tanh) }
}
@Test
fun testCeil() = DoubleTensorAlgebra {
assertTrue { tensor.ceil() eq expectedTensor(::ceil) }
}
@Test
fun testFloor() = DoubleTensorAlgebra {
assertTrue { tensor.floor() eq expectedTensor(::floor) }
}
val shape2 = intArrayOf(2, 2)
val buffer2 = doubleArrayOf(
1.0, 2.0,
-3.0, 4.0
)
val tensor2 = DoubleTensor(shape2, buffer2)
@Test
fun testMin() = DoubleTensorAlgebra {
assertTrue { tensor2.min() == -3.0 }
assertTrue { tensor2.min(0, true) eq fromArray(
intArrayOf(1, 2),
doubleArrayOf(-3.0, 2.0)
)}
assertTrue { tensor2.min(1, false) eq fromArray(
intArrayOf(2),
doubleArrayOf(1.0, -3.0)
)}
}
@Test
fun testMax() = DoubleTensorAlgebra {
assertTrue { tensor2.max() == 4.0 }
assertTrue { tensor2.max(0, true) eq fromArray(
intArrayOf(1, 2),
doubleArrayOf(1.0, 4.0)
)}
assertTrue { tensor2.max(1, false) eq fromArray(
intArrayOf(2),
doubleArrayOf(2.0, 4.0)
)}
}
@Test
fun testSum() = DoubleTensorAlgebra {
assertTrue { tensor2.sum() == 4.0 }
assertTrue { tensor2.sum(0, true) eq fromArray(
intArrayOf(1, 2),
doubleArrayOf(-2.0, 6.0)
)}
assertTrue { tensor2.sum(1, false) eq fromArray(
intArrayOf(2),
doubleArrayOf(3.0, 1.0)
)}
}
@Test
fun testMean() = DoubleTensorAlgebra {
assertTrue { tensor2.mean() == 1.0 }
assertTrue { tensor2.mean(0, true) eq fromArray(
intArrayOf(1, 2),
doubleArrayOf(-1.0, 3.0)
)}
assertTrue { tensor2.mean(1, false) eq fromArray(
intArrayOf(2),
doubleArrayOf(1.5, 0.5)
)}
}
}

View File

@ -0,0 +1,196 @@
package space.kscience.kmath.tensors.core
import space.kscience.kmath.operations.invoke
import space.kscience.kmath.tensors.core.internal.array
import space.kscience.kmath.tensors.core.internal.svd1d
import kotlin.math.abs
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertTrue
internal class TestDoubleLinearOpsTensorAlgebra {
@Test
fun testDetLU() = DoubleTensorAlgebra {
val tensor = fromArray(
intArrayOf(2, 2, 2),
doubleArrayOf(
1.0, 3.0,
1.0, 2.0,
1.5, 1.0,
10.0, 2.0
)
)
val expectedTensor = fromArray(
intArrayOf(2, 1),
doubleArrayOf(
-1.0,
-7.0
)
)
val detTensor = tensor.detLU()
assertTrue(detTensor.eq(expectedTensor))
}
@Test
fun testDet() = DoubleTensorAlgebra {
val expectedValue = 0.019827417
val m = fromArray(
intArrayOf(3, 3), doubleArrayOf(
2.1843, 1.4391, -0.4845,
1.4391, 1.7772, 0.4055,
-0.4845, 0.4055, 0.7519
)
)
assertTrue { abs(m.det().value() - expectedValue) < 1e-5 }
}
@Test
fun testDetSingle() = DoubleTensorAlgebra {
val expectedValue = 48.151623
val m = fromArray(
intArrayOf(1, 1), doubleArrayOf(
expectedValue
)
)
assertTrue { abs(m.det().value() - expectedValue) < 1e-5 }
}
@Test
fun testInvLU() = DoubleTensorAlgebra {
val tensor = fromArray(
intArrayOf(2, 2, 2),
doubleArrayOf(
1.0, 0.0,
0.0, 2.0,
1.0, 1.0,
1.0, 0.0
)
)
val expectedTensor = fromArray(
intArrayOf(2, 2, 2), doubleArrayOf(
1.0, 0.0,
0.0, 0.5,
0.0, 1.0,
1.0, -1.0
)
)
val invTensor = tensor.invLU()
assertTrue(invTensor.eq(expectedTensor))
}
@Test
fun testScalarProduct() = DoubleTensorAlgebra {
val a = fromArray(intArrayOf(3), doubleArrayOf(1.8, 2.5, 6.8))
val b = fromArray(intArrayOf(3), doubleArrayOf(5.5, 2.6, 6.4))
assertEquals(a.dot(b).value(), 59.92)
}
@Test
fun testQR() = DoubleTensorAlgebra {
val shape = intArrayOf(2, 2, 2)
val buffer = doubleArrayOf(
1.0, 3.0,
1.0, 2.0,
1.5, 1.0,
10.0, 2.0
)
val tensor = fromArray(shape, buffer)
val (q, r) = tensor.qr()
assertTrue { q.shape contentEquals shape }
assertTrue { r.shape contentEquals shape }
assertTrue((q dot r).eq(tensor))
}
@Test
fun testLU() = DoubleTensorAlgebra {
val shape = intArrayOf(2, 2, 2)
val buffer = doubleArrayOf(
1.0, 3.0,
1.0, 2.0,
1.5, 1.0,
10.0, 2.0
)
val tensor = fromArray(shape, buffer)
val (p, l, u) = tensor.lu()
assertTrue { p.shape contentEquals shape }
assertTrue { l.shape contentEquals shape }
assertTrue { u.shape contentEquals shape }
assertTrue((p dot tensor).eq(l dot u))
}
@Test
fun testCholesky() = DoubleTensorAlgebra {
val tensor = randomNormal(intArrayOf(2, 5, 5), 0)
val sigma = (tensor dot tensor.transpose()) + diagonalEmbedding(
fromArray(intArrayOf(2, 5), DoubleArray(10) { 0.1 })
)
val low = sigma.cholesky()
val sigmChol = low dot low.transpose()
assertTrue(sigma.eq(sigmChol))
}
@Test
fun testSVD1D() = DoubleTensorAlgebra {
val tensor2 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
val res = svd1d(tensor2)
assertTrue(res.shape contentEquals intArrayOf(2))
assertTrue { abs(abs(res.mutableBuffer.array()[res.bufferStart]) - 0.386) < 0.01 }
assertTrue { abs(abs(res.mutableBuffer.array()[res.bufferStart + 1]) - 0.922) < 0.01 }
}
@Test
fun testSVD() = DoubleTensorAlgebra{
testSVDFor(fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)))
testSVDFor(fromArray(intArrayOf(2, 2), doubleArrayOf(-1.0, 0.0, 239.0, 238.0)))
}
@Test
fun testBatchedSVD() = DoubleTensorAlgebra {
val tensor = randomNormal(intArrayOf(2, 5, 3), 0)
val (tensorU, tensorS, tensorV) = tensor.svd()
val tensorSVD = tensorU dot (diagonalEmbedding(tensorS) dot tensorV.transpose())
assertTrue(tensor.eq(tensorSVD))
}
@Test
fun testBatchedSymEig() = DoubleTensorAlgebra {
val tensor = randomNormal(shape = intArrayOf(2, 3, 3), 0)
val tensorSigma = tensor + tensor.transpose()
val (tensorS, tensorV) = tensorSigma.symEig()
val tensorSigmaCalc = tensorV dot (diagonalEmbedding(tensorS) dot tensorV.transpose())
assertTrue(tensorSigma.eq(tensorSigmaCalc))
}
}
private fun DoubleTensorAlgebra.testSVDFor(tensor: DoubleTensor, epsilon: Double = 1e-10): Unit {
val svd = tensor.svd()
val tensorSVD = svd.first
.dot(
diagonalEmbedding(svd.second)
.dot(svd.third.transpose())
)
assertTrue(tensor.eq(tensorSVD, epsilon))
}

View File

@ -0,0 +1,89 @@
package space.kscience.kmath.tensors.core
import space.kscience.kmath.nd.DefaultStrides
import space.kscience.kmath.nd.MutableBufferND
import space.kscience.kmath.nd.as1D
import space.kscience.kmath.nd.as2D
import space.kscience.kmath.operations.invoke
import space.kscience.kmath.structures.DoubleBuffer
import space.kscience.kmath.structures.toDoubleArray
import space.kscience.kmath.tensors.core.internal.array
import space.kscience.kmath.tensors.core.internal.asTensor
import space.kscience.kmath.tensors.core.internal.matrixSequence
import space.kscience.kmath.tensors.core.internal.toBufferedTensor
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertTrue
internal class TestDoubleTensor {
@Test
fun testValue() = DoubleTensorAlgebra {
val value = 12.5
val tensor = fromArray(intArrayOf(1), doubleArrayOf(value))
assertEquals(tensor.value(), value)
}
@Test
fun testStrides() = DoubleTensorAlgebra {
val tensor = fromArray(intArrayOf(2, 2), doubleArrayOf(3.5, 5.8, 58.4, 2.4))
assertEquals(tensor[intArrayOf(0, 1)], 5.8)
assertTrue(
tensor.elements().map { it.second }.toList().toDoubleArray() contentEquals tensor.mutableBuffer.toDoubleArray()
)
}
@Test
fun testGet() = DoubleTensorAlgebra {
val tensor = fromArray(intArrayOf(1, 2, 2), doubleArrayOf(3.5, 5.8, 58.4, 2.4))
val matrix = tensor[0].as2D()
assertEquals(matrix[0, 1], 5.8)
val vector = tensor[0][1].as1D()
assertEquals(vector[0], 58.4)
matrix[0, 1] = 77.89
assertEquals(tensor[intArrayOf(0, 0, 1)], 77.89)
vector[0] = 109.56
assertEquals(tensor[intArrayOf(0, 1, 0)], 109.56)
tensor.matrixSequence().forEach {
val a = it.asTensor()
val secondRow = a[1].as1D()
val secondColumn = a.transpose(0, 1)[1].as1D()
assertEquals(secondColumn[0], 77.89)
assertEquals(secondRow[1], secondColumn[1])
}
}
@Test
fun testNoBufferProtocol() {
// create buffer
val doubleArray = DoubleBuffer(doubleArrayOf(1.0, 2.0, 3.0))
// create ND buffers, no data is copied
val ndArray = MutableBufferND(DefaultStrides(intArrayOf(3)), doubleArray)
// map to tensors
val bufferedTensorArray = ndArray.toBufferedTensor() // strides are flipped so data copied
val tensorArray = bufferedTensorArray.asTensor() // data not contiguous so copied again
val tensorArrayPublic = ndArray.toDoubleTensor() // public API, data copied twice
val sharedTensorArray = tensorArrayPublic.toDoubleTensor() // no data copied by matching type
assertTrue(tensorArray.mutableBuffer.array() contentEquals sharedTensorArray.mutableBuffer.array())
tensorArray[intArrayOf(0)] = 55.9
assertEquals(tensorArrayPublic[intArrayOf(0)], 1.0)
tensorArrayPublic[intArrayOf(0)] = 55.9
assertEquals(sharedTensorArray[intArrayOf(0)], 55.9)
assertEquals(bufferedTensorArray[intArrayOf(0)], 1.0)
bufferedTensorArray[intArrayOf(0)] = 55.9
assertEquals(ndArray[intArrayOf(0)], 1.0)
}
}

View File

@ -0,0 +1,167 @@
package space.kscience.kmath.tensors.core
import space.kscience.kmath.operations.invoke
import space.kscience.kmath.tensors.core.internal.array
import kotlin.test.Test
import kotlin.test.assertFalse
import kotlin.test.assertTrue
internal class TestDoubleTensorAlgebra {
@Test
fun testDoublePlus() = DoubleTensorAlgebra {
val tensor = fromArray(intArrayOf(2), doubleArrayOf(1.0, 2.0))
val res = 10.0 + tensor
assertTrue(res.mutableBuffer.array() contentEquals doubleArrayOf(11.0, 12.0))
}
@Test
fun TestDoubleDiv() = DoubleTensorAlgebra {
val tensor = fromArray(intArrayOf(2), doubleArrayOf(2.0, 4.0))
val res = 2.0/tensor
assertTrue(res.mutableBuffer.array() contentEquals doubleArrayOf(1.0, 0.5))
}
@Test
fun testDivDouble() = DoubleTensorAlgebra {
val tensor = fromArray(intArrayOf(2), doubleArrayOf(10.0, 5.0))
val res = tensor / 2.5
assertTrue(res.mutableBuffer.array() contentEquals doubleArrayOf(4.0, 2.0))
}
@Test
fun testTranspose1x1() = DoubleTensorAlgebra {
val tensor = fromArray(intArrayOf(1), doubleArrayOf(0.0))
val res = tensor.transpose(0, 0)
assertTrue(res.mutableBuffer.array() contentEquals doubleArrayOf(0.0))
assertTrue(res.shape contentEquals intArrayOf(1))
}
@Test
fun testTranspose3x2() = DoubleTensorAlgebra {
val tensor = fromArray(intArrayOf(3, 2), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
val res = tensor.transpose(1, 0)
assertTrue(res.mutableBuffer.array() contentEquals doubleArrayOf(1.0, 3.0, 5.0, 2.0, 4.0, 6.0))
assertTrue(res.shape contentEquals intArrayOf(2, 3))
}
@Test
fun testTranspose1x2x3() = DoubleTensorAlgebra {
val tensor = fromArray(intArrayOf(1, 2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
val res01 = tensor.transpose(0, 1)
val res02 = tensor.transpose(-3, 2)
val res12 = tensor.transpose()
assertTrue(res01.shape contentEquals intArrayOf(2, 1, 3))
assertTrue(res02.shape contentEquals intArrayOf(3, 2, 1))
assertTrue(res12.shape contentEquals intArrayOf(1, 3, 2))
assertTrue(res01.mutableBuffer.array() contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
assertTrue(res02.mutableBuffer.array() contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
assertTrue(res12.mutableBuffer.array() contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
}
@Test
fun testLinearStructure() = DoubleTensorAlgebra {
val shape = intArrayOf(3)
val tensorA = full(value = -4.5, shape = shape)
val tensorB = full(value = 10.9, shape = shape)
val tensorC = full(value = 789.3, shape = shape)
val tensorD = full(value = -72.9, shape = shape)
val tensorE = full(value = 553.1, shape = shape)
val result = 15.8 * tensorA - 1.5 * tensorB * (-tensorD) + 0.02 * tensorC / tensorE - 39.4
val expected = fromArray(
shape,
(1..3).map {
15.8 * (-4.5) - 1.5 * 10.9 * 72.9 + 0.02 * 789.3 / 553.1 - 39.4
}.toDoubleArray()
)
val assignResult = zeros(shape)
tensorA *= 15.8
tensorB *= 1.5
tensorB *= -tensorD
tensorC *= 0.02
tensorC /= tensorE
assignResult += tensorA
assignResult -= tensorB
assignResult += tensorC
assignResult += -39.4
assertTrue(expected.mutableBuffer.array() contentEquals result.mutableBuffer.array())
assertTrue(expected.mutableBuffer.array() contentEquals assignResult.mutableBuffer.array())
}
@Test
fun testDot() = DoubleTensorAlgebra {
val tensor1 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
val tensor11 = fromArray(intArrayOf(3, 2), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
val tensor2 = fromArray(intArrayOf(3), doubleArrayOf(10.0, 20.0, 30.0))
val tensor3 = fromArray(intArrayOf(1, 1, 3), doubleArrayOf(-1.0, -2.0, -3.0))
val res12 = tensor1.dot(tensor2)
assertTrue(res12.mutableBuffer.array() contentEquals doubleArrayOf(140.0, 320.0))
assertTrue(res12.shape contentEquals intArrayOf(2))
val res32 = tensor3.dot(tensor2)
assertTrue(res32.mutableBuffer.array() contentEquals doubleArrayOf(-140.0))
assertTrue(res32.shape contentEquals intArrayOf(1, 1))
val res22 = tensor2.dot(tensor2)
assertTrue(res22.mutableBuffer.array() contentEquals doubleArrayOf(1400.0))
assertTrue(res22.shape contentEquals intArrayOf(1))
val res11 = tensor1.dot(tensor11)
assertTrue(res11.mutableBuffer.array() contentEquals doubleArrayOf(22.0, 28.0, 49.0, 64.0))
assertTrue(res11.shape contentEquals intArrayOf(2, 2))
}
@Test
fun testDiagonalEmbedding() = DoubleTensorAlgebra {
val tensor1 = fromArray(intArrayOf(3), doubleArrayOf(10.0, 20.0, 30.0))
val tensor2 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
val tensor3 = zeros(intArrayOf(2, 3, 4, 5))
assertTrue(diagonalEmbedding(tensor3, 0, 3, 4).shape contentEquals
intArrayOf(2, 3, 4, 5, 5))
assertTrue(diagonalEmbedding(tensor3, 1, 3, 4).shape contentEquals
intArrayOf(2, 3, 4, 6, 6))
assertTrue(diagonalEmbedding(tensor3, 2, 0, 3).shape contentEquals
intArrayOf(7, 2, 3, 7, 4))
val diagonal1 = diagonalEmbedding(tensor1, 0, 1, 0)
assertTrue(diagonal1.shape contentEquals intArrayOf(3, 3))
assertTrue(diagonal1.mutableBuffer.array() contentEquals
doubleArrayOf(10.0, 0.0, 0.0, 0.0, 20.0, 0.0, 0.0, 0.0, 30.0))
val diagonal1Offset = diagonalEmbedding(tensor1, 1, 1, 0)
assertTrue(diagonal1Offset.shape contentEquals intArrayOf(4, 4))
assertTrue(diagonal1Offset.mutableBuffer.array() contentEquals
doubleArrayOf(0.0, 0.0, 0.0, 0.0, 10.0, 0.0, 0.0, 0.0, 0.0, 20.0, 0.0, 0.0, 0.0, 0.0, 30.0, 0.0))
val diagonal2 = diagonalEmbedding(tensor2, 1, 0, 2)
assertTrue(diagonal2.shape contentEquals intArrayOf(4, 2, 4))
assertTrue(diagonal2.mutableBuffer.array() contentEquals
doubleArrayOf(
0.0, 1.0, 0.0, 0.0, 0.0, 4.0, 0.0, 0.0,
0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 5.0, 0.0,
0.0, 0.0, 0.0, 3.0, 0.0, 0.0, 0.0, 6.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0))
}
@Test
fun testEq() = DoubleTensorAlgebra {
val tensor1 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
val tensor2 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
val tensor3 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 5.0))
assertTrue(tensor1 eq tensor1)
assertTrue(tensor1 eq tensor2)
assertFalse(tensor1.eq(tensor3))
}
}

View File

@ -42,6 +42,7 @@ include(
":kmath-ast",
":kmath-ejml",
":kmath-kotlingrad",
":kmath-tensors",
":kmath-jupyter",
":examples",
":benchmarks"