resolve conflict
This commit is contained in:
commit
6e85d496f2
@ -2654,10 +2654,6 @@ public final class space/kscience/kmath/tensors/LinearOpsTensorAlgebra$DefaultIm
|
|||||||
public static synthetic fun symEig$default (Lspace/kscience/kmath/tensors/LinearOpsTensorAlgebra;Lspace/kscience/kmath/nd/MutableStructureND;ZILjava/lang/Object;)Lkotlin/Pair;
|
public static synthetic fun symEig$default (Lspace/kscience/kmath/tensors/LinearOpsTensorAlgebra;Lspace/kscience/kmath/nd/MutableStructureND;ZILjava/lang/Object;)Lkotlin/Pair;
|
||||||
}
|
}
|
||||||
|
|
||||||
public abstract interface class space/kscience/kmath/tensors/ReduceOpsTensorAlgebra : space/kscience/kmath/tensors/TensorAlgebra {
|
|
||||||
public abstract fun value (Lspace/kscience/kmath/nd/MutableStructureND;)Ljava/lang/Object;
|
|
||||||
}
|
|
||||||
|
|
||||||
public abstract interface class space/kscience/kmath/tensors/TensorAlgebra {
|
public abstract interface class space/kscience/kmath/tensors/TensorAlgebra {
|
||||||
public abstract fun copy (Lspace/kscience/kmath/nd/MutableStructureND;)Lspace/kscience/kmath/nd/MutableStructureND;
|
public abstract fun copy (Lspace/kscience/kmath/nd/MutableStructureND;)Lspace/kscience/kmath/nd/MutableStructureND;
|
||||||
public abstract fun diagonalEmbedding (Lspace/kscience/kmath/nd/MutableStructureND;III)Lspace/kscience/kmath/nd/MutableStructureND;
|
public abstract fun diagonalEmbedding (Lspace/kscience/kmath/nd/MutableStructureND;III)Lspace/kscience/kmath/nd/MutableStructureND;
|
||||||
@ -2686,6 +2682,7 @@ public abstract interface class space/kscience/kmath/tensors/TensorAlgebra {
|
|||||||
public abstract fun timesAssign (Lspace/kscience/kmath/nd/MutableStructureND;Lspace/kscience/kmath/nd/MutableStructureND;)V
|
public abstract fun timesAssign (Lspace/kscience/kmath/nd/MutableStructureND;Lspace/kscience/kmath/nd/MutableStructureND;)V
|
||||||
public abstract fun transpose (Lspace/kscience/kmath/nd/MutableStructureND;II)Lspace/kscience/kmath/nd/MutableStructureND;
|
public abstract fun transpose (Lspace/kscience/kmath/nd/MutableStructureND;II)Lspace/kscience/kmath/nd/MutableStructureND;
|
||||||
public abstract fun unaryMinus (Lspace/kscience/kmath/nd/MutableStructureND;)Lspace/kscience/kmath/nd/MutableStructureND;
|
public abstract fun unaryMinus (Lspace/kscience/kmath/nd/MutableStructureND;)Lspace/kscience/kmath/nd/MutableStructureND;
|
||||||
|
public abstract fun value (Lspace/kscience/kmath/nd/MutableStructureND;)Ljava/lang/Object;
|
||||||
public abstract fun view (Lspace/kscience/kmath/nd/MutableStructureND;[I)Lspace/kscience/kmath/nd/MutableStructureND;
|
public abstract fun view (Lspace/kscience/kmath/nd/MutableStructureND;[I)Lspace/kscience/kmath/nd/MutableStructureND;
|
||||||
public abstract fun viewAs (Lspace/kscience/kmath/nd/MutableStructureND;Lspace/kscience/kmath/nd/MutableStructureND;)Lspace/kscience/kmath/nd/MutableStructureND;
|
public abstract fun viewAs (Lspace/kscience/kmath/nd/MutableStructureND;Lspace/kscience/kmath/nd/MutableStructureND;)Lspace/kscience/kmath/nd/MutableStructureND;
|
||||||
public abstract fun zeroesLike (Lspace/kscience/kmath/nd/MutableStructureND;)Lspace/kscience/kmath/nd/MutableStructureND;
|
public abstract fun zeroesLike (Lspace/kscience/kmath/nd/MutableStructureND;)Lspace/kscience/kmath/nd/MutableStructureND;
|
||||||
@ -2731,8 +2728,6 @@ public class space/kscience/kmath/tensors/core/BufferedTensor : space/kscience/k
|
|||||||
public fun <init> ([ILspace/kscience/kmath/structures/MutableBuffer;I)V
|
public fun <init> ([ILspace/kscience/kmath/structures/MutableBuffer;I)V
|
||||||
public fun elements ()Lkotlin/sequences/Sequence;
|
public fun elements ()Lkotlin/sequences/Sequence;
|
||||||
public fun equals (Ljava/lang/Object;)Z
|
public fun equals (Ljava/lang/Object;)Z
|
||||||
public final fun forEachMatrix (Lkotlin/jvm/functions/Function1;)V
|
|
||||||
public final fun forEachVector (Lkotlin/jvm/functions/Function1;)V
|
|
||||||
public fun get ([I)Ljava/lang/Object;
|
public fun get ([I)Ljava/lang/Object;
|
||||||
public final fun getBuffer ()Lspace/kscience/kmath/structures/MutableBuffer;
|
public final fun getBuffer ()Lspace/kscience/kmath/structures/MutableBuffer;
|
||||||
public fun getDimension ()I
|
public fun getDimension ()I
|
||||||
@ -2740,37 +2735,7 @@ public class space/kscience/kmath/tensors/core/BufferedTensor : space/kscience/k
|
|||||||
public final fun getNumel ()I
|
public final fun getNumel ()I
|
||||||
public fun getShape ()[I
|
public fun getShape ()[I
|
||||||
public fun hashCode ()I
|
public fun hashCode ()I
|
||||||
public final fun matrixSequence ()Lkotlin/sequences/Sequence;
|
|
||||||
public fun set ([ILjava/lang/Object;)V
|
public fun set ([ILjava/lang/Object;)V
|
||||||
public final fun vectorSequence ()Lkotlin/sequences/Sequence;
|
|
||||||
}
|
|
||||||
|
|
||||||
public final class space/kscience/kmath/tensors/core/BufferedTensor1D : space/kscience/kmath/tensors/core/BufferedTensor, space/kscience/kmath/nd/MutableStructure1D {
|
|
||||||
public fun copy ()Lspace/kscience/kmath/structures/MutableBuffer;
|
|
||||||
public fun get (I)Ljava/lang/Object;
|
|
||||||
public fun get ([I)Ljava/lang/Object;
|
|
||||||
public fun getDimension ()I
|
|
||||||
public fun getSize ()I
|
|
||||||
public fun iterator ()Ljava/util/Iterator;
|
|
||||||
public fun set (ILjava/lang/Object;)V
|
|
||||||
public fun set ([ILjava/lang/Object;)V
|
|
||||||
}
|
|
||||||
|
|
||||||
public final class space/kscience/kmath/tensors/core/BufferedTensor2D : space/kscience/kmath/tensors/core/BufferedTensor, space/kscience/kmath/nd/MutableStructure2D {
|
|
||||||
public fun elements ()Lkotlin/sequences/Sequence;
|
|
||||||
public fun get (II)Ljava/lang/Object;
|
|
||||||
public fun get ([I)Ljava/lang/Object;
|
|
||||||
public fun getColNum ()I
|
|
||||||
public fun getColumns ()Ljava/util/List;
|
|
||||||
public fun getRowNum ()I
|
|
||||||
public fun getRows ()Ljava/util/List;
|
|
||||||
public fun getShape ()[I
|
|
||||||
public fun set (IILjava/lang/Object;)V
|
|
||||||
}
|
|
||||||
|
|
||||||
public final class space/kscience/kmath/tensors/core/BufferedTensorKt {
|
|
||||||
public static final fun as1D (Lspace/kscience/kmath/tensors/core/BufferedTensor;)Lspace/kscience/kmath/tensors/core/BufferedTensor1D;
|
|
||||||
public static final fun as2D (Lspace/kscience/kmath/tensors/core/BufferedTensor;)Lspace/kscience/kmath/tensors/core/BufferedTensor2D;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public final class space/kscience/kmath/tensors/core/DoubleAnalyticTensorAlgebra : space/kscience/kmath/tensors/core/DoubleTensorAlgebra, space/kscience/kmath/tensors/AnalyticTensorAlgebra {
|
public final class space/kscience/kmath/tensors/core/DoubleAnalyticTensorAlgebra : space/kscience/kmath/tensors/core/DoubleTensorAlgebra, space/kscience/kmath/tensors/AnalyticTensorAlgebra {
|
||||||
@ -2841,16 +2806,6 @@ public final class space/kscience/kmath/tensors/core/DoubleLinearOpsTensorAlgebr
|
|||||||
public static final fun DoubleLinearOpsTensorAlgebra (Lkotlin/jvm/functions/Function1;)Ljava/lang/Object;
|
public static final fun DoubleLinearOpsTensorAlgebra (Lkotlin/jvm/functions/Function1;)Ljava/lang/Object;
|
||||||
}
|
}
|
||||||
|
|
||||||
public final class space/kscience/kmath/tensors/core/DoubleReduceOpsTensorAlgebra : space/kscience/kmath/tensors/core/DoubleTensorAlgebra, space/kscience/kmath/tensors/ReduceOpsTensorAlgebra {
|
|
||||||
public fun <init> ()V
|
|
||||||
public synthetic fun value (Lspace/kscience/kmath/nd/MutableStructureND;)Ljava/lang/Object;
|
|
||||||
public fun value (Lspace/kscience/kmath/tensors/core/DoubleTensor;)Ljava/lang/Double;
|
|
||||||
}
|
|
||||||
|
|
||||||
public final class space/kscience/kmath/tensors/core/DoubleReduceOpsTensorAlgebraKt {
|
|
||||||
public static final fun DoubleReduceOpsTensorAlgebra (Lkotlin/jvm/functions/Function1;)Ljava/lang/Object;
|
|
||||||
}
|
|
||||||
|
|
||||||
public final class space/kscience/kmath/tensors/core/DoubleTensor : space/kscience/kmath/tensors/core/BufferedTensor {
|
public final class space/kscience/kmath/tensors/core/DoubleTensor : space/kscience/kmath/tensors/core/BufferedTensor {
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2876,7 +2831,6 @@ public class space/kscience/kmath/tensors/core/DoubleTensorAlgebra : space/kscie
|
|||||||
public synthetic fun eq (Lspace/kscience/kmath/nd/MutableStructureND;Lspace/kscience/kmath/nd/MutableStructureND;Ljava/lang/Object;)Z
|
public synthetic fun eq (Lspace/kscience/kmath/nd/MutableStructureND;Lspace/kscience/kmath/nd/MutableStructureND;Ljava/lang/Object;)Z
|
||||||
public final fun eq (Lspace/kscience/kmath/tensors/core/DoubleTensor;Lspace/kscience/kmath/tensors/core/DoubleTensor;)Z
|
public final fun eq (Lspace/kscience/kmath/tensors/core/DoubleTensor;Lspace/kscience/kmath/tensors/core/DoubleTensor;)Z
|
||||||
public fun eq (Lspace/kscience/kmath/tensors/core/DoubleTensor;Lspace/kscience/kmath/tensors/core/DoubleTensor;D)Z
|
public fun eq (Lspace/kscience/kmath/tensors/core/DoubleTensor;Lspace/kscience/kmath/tensors/core/DoubleTensor;D)Z
|
||||||
public final fun eq (Lspace/kscience/kmath/tensors/core/DoubleTensor;Lspace/kscience/kmath/tensors/core/DoubleTensor;Lkotlin/jvm/functions/Function2;)Z
|
|
||||||
public synthetic fun eye (I)Lspace/kscience/kmath/nd/MutableStructureND;
|
public synthetic fun eye (I)Lspace/kscience/kmath/nd/MutableStructureND;
|
||||||
public fun eye (I)Lspace/kscience/kmath/tensors/core/DoubleTensor;
|
public fun eye (I)Lspace/kscience/kmath/tensors/core/DoubleTensor;
|
||||||
public final fun fromArray ([I[D)Lspace/kscience/kmath/tensors/core/DoubleTensor;
|
public final fun fromArray ([I[D)Lspace/kscience/kmath/tensors/core/DoubleTensor;
|
||||||
@ -2925,6 +2879,8 @@ public class space/kscience/kmath/tensors/core/DoubleTensorAlgebra : space/kscie
|
|||||||
public fun transpose (Lspace/kscience/kmath/tensors/core/DoubleTensor;II)Lspace/kscience/kmath/tensors/core/DoubleTensor;
|
public fun transpose (Lspace/kscience/kmath/tensors/core/DoubleTensor;II)Lspace/kscience/kmath/tensors/core/DoubleTensor;
|
||||||
public synthetic fun unaryMinus (Lspace/kscience/kmath/nd/MutableStructureND;)Lspace/kscience/kmath/nd/MutableStructureND;
|
public synthetic fun unaryMinus (Lspace/kscience/kmath/nd/MutableStructureND;)Lspace/kscience/kmath/nd/MutableStructureND;
|
||||||
public fun unaryMinus (Lspace/kscience/kmath/tensors/core/DoubleTensor;)Lspace/kscience/kmath/tensors/core/DoubleTensor;
|
public fun unaryMinus (Lspace/kscience/kmath/tensors/core/DoubleTensor;)Lspace/kscience/kmath/tensors/core/DoubleTensor;
|
||||||
|
public synthetic fun value (Lspace/kscience/kmath/nd/MutableStructureND;)Ljava/lang/Object;
|
||||||
|
public fun value (Lspace/kscience/kmath/tensors/core/DoubleTensor;)Ljava/lang/Double;
|
||||||
public synthetic fun view (Lspace/kscience/kmath/nd/MutableStructureND;[I)Lspace/kscience/kmath/nd/MutableStructureND;
|
public synthetic fun view (Lspace/kscience/kmath/nd/MutableStructureND;[I)Lspace/kscience/kmath/nd/MutableStructureND;
|
||||||
public fun view (Lspace/kscience/kmath/tensors/core/DoubleTensor;[I)Lspace/kscience/kmath/tensors/core/DoubleTensor;
|
public fun view (Lspace/kscience/kmath/tensors/core/DoubleTensor;[I)Lspace/kscience/kmath/tensors/core/DoubleTensor;
|
||||||
public synthetic fun viewAs (Lspace/kscience/kmath/nd/MutableStructureND;Lspace/kscience/kmath/nd/MutableStructureND;)Lspace/kscience/kmath/nd/MutableStructureND;
|
public synthetic fun viewAs (Lspace/kscience/kmath/nd/MutableStructureND;Lspace/kscience/kmath/nd/MutableStructureND;)Lspace/kscience/kmath/nd/MutableStructureND;
|
||||||
|
@ -1,7 +0,0 @@
|
|||||||
package space.kscience.kmath.tensors
|
|
||||||
|
|
||||||
public interface ReduceOpsTensorAlgebra<T, TensorType : TensorStructure<T>> :
|
|
||||||
TensorAlgebra<T, TensorType> {
|
|
||||||
public fun TensorType.value(): T
|
|
||||||
|
|
||||||
}
|
|
@ -5,6 +5,8 @@ import space.kscience.kmath.tensors.core.DoubleTensor
|
|||||||
// https://proofwiki.org/wiki/Definition:Algebra_over_Ring
|
// https://proofwiki.org/wiki/Definition:Algebra_over_Ring
|
||||||
public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
|
public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
|
||||||
|
|
||||||
|
public fun TensorType.value(): T
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.full.html
|
//https://pytorch.org/docs/stable/generated/torch.full.html
|
||||||
public fun full(value: T, shape: IntArray): TensorType
|
public fun full(value: T, shape: IntArray): TensorType
|
||||||
|
|
||||||
|
@ -35,35 +35,34 @@ public open class BufferedTensor<T>(
|
|||||||
|
|
||||||
override fun hashCode(): Int = 0
|
override fun hashCode(): Int = 0
|
||||||
|
|
||||||
public fun vectorSequence(): Sequence<BufferedTensor1D<T>> = sequence {
|
internal fun vectorSequence(): Sequence<BufferedTensor<T>> = sequence {
|
||||||
check(shape.size >= 1) { "todo" }
|
|
||||||
val n = shape.size
|
val n = shape.size
|
||||||
val vectorOffset = shape[n - 1]
|
val vectorOffset = shape[n - 1]
|
||||||
val vectorShape = intArrayOf(shape.last())
|
val vectorShape = intArrayOf(shape.last())
|
||||||
for (offset in 0 until numel step vectorOffset) {
|
for (offset in 0 until numel step vectorOffset) {
|
||||||
val vector = BufferedTensor<T>(vectorShape, buffer, offset).as1D()
|
val vector = BufferedTensor(vectorShape, buffer, offset)
|
||||||
yield(vector)
|
yield(vector)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun matrixSequence(): Sequence<BufferedTensor2D<T>> = sequence {
|
internal fun matrixSequence(): Sequence<BufferedTensor<T>> = sequence {
|
||||||
check(shape.size >= 2) { "todo" }
|
check(shape.size >= 2) { "todo" }
|
||||||
val n = shape.size
|
val n = shape.size
|
||||||
val matrixOffset = shape[n - 1] * shape[n - 2]
|
val matrixOffset = shape[n - 1] * shape[n - 2]
|
||||||
val matrixShape = intArrayOf(shape[n - 2], shape[n - 1]) //todo better way?
|
val matrixShape = intArrayOf(shape[n - 2], shape[n - 1])
|
||||||
for (offset in 0 until numel step matrixOffset) {
|
for (offset in 0 until numel step matrixOffset) {
|
||||||
val matrix = BufferedTensor<T>(matrixShape, buffer, offset).as2D()
|
val matrix = BufferedTensor(matrixShape, buffer, offset)
|
||||||
yield(matrix)
|
yield(matrix)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public inline fun forEachVector(vectorAction: (BufferedTensor1D<T>) -> Unit): Unit {
|
internal inline fun forEachVector(vectorAction: (BufferedTensor<T>) -> Unit): Unit {
|
||||||
for (vector in vectorSequence()) {
|
for (vector in vectorSequence()) {
|
||||||
vectorAction(vector)
|
vectorAction(vector)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public inline fun forEachMatrix(matrixAction: (BufferedTensor2D<T>) -> Unit): Unit {
|
internal inline fun forEachMatrix(matrixAction: (BufferedTensor<T>) -> Unit): Unit {
|
||||||
for (matrix in matrixSequence()) {
|
for (matrix in matrixSequence()) {
|
||||||
matrixAction(matrix)
|
matrixAction(matrix)
|
||||||
}
|
}
|
||||||
@ -71,7 +70,6 @@ public open class BufferedTensor<T>(
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public class IntTensor internal constructor(
|
public class IntTensor internal constructor(
|
||||||
shape: IntArray,
|
shape: IntArray,
|
||||||
buffer: IntArray,
|
buffer: IntArray,
|
||||||
@ -112,90 +110,7 @@ public class DoubleTensor internal constructor(
|
|||||||
this(bufferedTensor.shape, bufferedTensor.buffer.array(), bufferedTensor.bufferStart)
|
this(bufferedTensor.shape, bufferedTensor.buffer.array(), bufferedTensor.bufferStart)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
internal fun BufferedTensor<Int>.asTensor(): IntTensor = IntTensor(this)
|
||||||
public class BufferedTensor2D<T> internal constructor(
|
internal fun BufferedTensor<Long>.asTensor(): LongTensor = LongTensor(this)
|
||||||
private val tensor: BufferedTensor<T>,
|
internal fun BufferedTensor<Float>.asTensor(): FloatTensor = FloatTensor(this)
|
||||||
) : BufferedTensor<T>(tensor), MutableStructure2D<T> {
|
internal fun BufferedTensor<Double>.asTensor(): DoubleTensor = DoubleTensor(this)
|
||||||
init {
|
|
||||||
check(shape.size == 2) {
|
|
||||||
"Shape ${shape.toList()} not compatible with DoubleTensor2D"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
override val shape: IntArray
|
|
||||||
get() = tensor.shape
|
|
||||||
|
|
||||||
override val rowNum: Int
|
|
||||||
get() = shape[0]
|
|
||||||
override val colNum: Int
|
|
||||||
get() = shape[1]
|
|
||||||
|
|
||||||
override fun get(i: Int, j: Int): T = tensor[intArrayOf(i, j)]
|
|
||||||
|
|
||||||
override fun get(index: IntArray): T = tensor[index]
|
|
||||||
|
|
||||||
override fun elements(): Sequence<Pair<IntArray, T>> = tensor.elements()
|
|
||||||
|
|
||||||
override fun set(i: Int, j: Int, value: T) {
|
|
||||||
tensor[intArrayOf(i, j)] = value
|
|
||||||
}
|
|
||||||
|
|
||||||
override val rows: List<BufferedTensor1D<T>>
|
|
||||||
get() = List(rowNum) { i ->
|
|
||||||
BufferedTensor1D(
|
|
||||||
BufferedTensor(
|
|
||||||
shape = intArrayOf(colNum),
|
|
||||||
buffer = VirtualMutableBuffer(colNum) { j -> get(i, j) },
|
|
||||||
bufferStart = 0
|
|
||||||
)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
override val columns: List<BufferedTensor1D<T>>
|
|
||||||
get() = List(colNum) { j ->
|
|
||||||
BufferedTensor1D(
|
|
||||||
BufferedTensor(
|
|
||||||
shape = intArrayOf(rowNum),
|
|
||||||
buffer = VirtualMutableBuffer(rowNum) { i -> get(i, j) },
|
|
||||||
bufferStart = 0
|
|
||||||
)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public class BufferedTensor1D<T> internal constructor(
|
|
||||||
private val tensor: BufferedTensor<T>
|
|
||||||
) : BufferedTensor<T>(tensor), MutableStructure1D<T> {
|
|
||||||
init {
|
|
||||||
check(shape.size == 1) {
|
|
||||||
"Shape ${shape.toList()} not compatible with DoubleTensor1D"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun get(index: IntArray): T = tensor[index]
|
|
||||||
|
|
||||||
override fun set(index: IntArray, value: T) {
|
|
||||||
tensor[index] = value
|
|
||||||
}
|
|
||||||
|
|
||||||
override val size: Int
|
|
||||||
get() = tensor.linearStructure.size
|
|
||||||
|
|
||||||
override fun get(index: Int): T = tensor[intArrayOf(index)]
|
|
||||||
|
|
||||||
override fun set(index: Int, value: T) {
|
|
||||||
tensor[intArrayOf(index)] = value
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun copy(): MutableBuffer<T> = tensor.buffer.copy()
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
internal fun BufferedTensor<Int>.asIntTensor(): IntTensor = IntTensor(this)
|
|
||||||
internal fun BufferedTensor<Long>.asLongTensor(): LongTensor = LongTensor(this)
|
|
||||||
internal fun BufferedTensor<Float>.asFloatTensor(): FloatTensor = FloatTensor(this)
|
|
||||||
internal fun BufferedTensor<Double>.asDoubleTensor(): DoubleTensor = DoubleTensor(this)
|
|
||||||
|
|
||||||
|
|
||||||
public fun <T> BufferedTensor<T>.as2D(): BufferedTensor2D<T> = BufferedTensor2D(this)
|
|
||||||
public fun <T> BufferedTensor<T>.as1D(): BufferedTensor1D<T> = BufferedTensor1D(this)
|
|
||||||
|
@ -1,5 +1,9 @@
|
|||||||
package space.kscience.kmath.tensors.core
|
package space.kscience.kmath.tensors.core
|
||||||
|
|
||||||
|
import space.kscience.kmath.nd.MutableStructure1D
|
||||||
|
import space.kscience.kmath.nd.MutableStructure2D
|
||||||
|
import space.kscience.kmath.nd.as1D
|
||||||
|
import space.kscience.kmath.nd.as2D
|
||||||
import space.kscience.kmath.tensors.LinearOpsTensorAlgebra
|
import space.kscience.kmath.tensors.LinearOpsTensorAlgebra
|
||||||
import kotlin.math.sqrt
|
import kotlin.math.sqrt
|
||||||
|
|
||||||
@ -11,23 +15,7 @@ public class DoubleLinearOpsTensorAlgebra :
|
|||||||
|
|
||||||
override fun DoubleTensor.det(): DoubleTensor = detLU()
|
override fun DoubleTensor.det(): DoubleTensor = detLU()
|
||||||
|
|
||||||
override fun DoubleTensor.lu(): Pair<DoubleTensor, IntTensor> {
|
private inline fun luHelper(lu: MutableStructure2D<Double>, pivots: MutableStructure1D<Int>, m: Int) {
|
||||||
|
|
||||||
checkSquareMatrix(shape)
|
|
||||||
|
|
||||||
val luTensor = copy()
|
|
||||||
|
|
||||||
val n = shape.size
|
|
||||||
val m = shape.last()
|
|
||||||
val pivotsShape = IntArray(n - 1) { i -> shape[i] }
|
|
||||||
pivotsShape[n - 2] = m + 1
|
|
||||||
|
|
||||||
val pivotsTensor = IntTensor(
|
|
||||||
pivotsShape,
|
|
||||||
IntArray(pivotsShape.reduce(Int::times)) { 0 }
|
|
||||||
)
|
|
||||||
|
|
||||||
for ((lu, pivots) in luTensor.matrixSequence().zip(pivotsTensor.vectorSequence())){
|
|
||||||
for (row in 0 until m) pivots[row] = row
|
for (row in 0 until m) pivots[row] = row
|
||||||
|
|
||||||
for (i in 0 until m) {
|
for (i in 0 until m) {
|
||||||
@ -69,29 +57,45 @@ public class DoubleLinearOpsTensorAlgebra :
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override fun DoubleTensor.lu(): Pair<DoubleTensor, IntTensor> {
|
||||||
|
|
||||||
|
checkSquareMatrix(shape)
|
||||||
|
|
||||||
|
val luTensor = copy()
|
||||||
|
|
||||||
|
val n = shape.size
|
||||||
|
val m = shape.last()
|
||||||
|
val pivotsShape = IntArray(n - 1) { i -> shape[i] }
|
||||||
|
pivotsShape[n - 2] = m + 1
|
||||||
|
|
||||||
|
val pivotsTensor = IntTensor(
|
||||||
|
pivotsShape,
|
||||||
|
IntArray(pivotsShape.reduce(Int::times)) { 0 }
|
||||||
|
)
|
||||||
|
|
||||||
|
for ((lu, pivots) in luTensor.matrixSequence().zip(pivotsTensor.vectorSequence()))
|
||||||
|
luHelper(lu.as2D(), pivots.as1D(), m)
|
||||||
|
|
||||||
return Pair(luTensor, pivotsTensor)
|
return Pair(luTensor, pivotsTensor)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun luPivot(luTensor: DoubleTensor, pivotsTensor: IntTensor): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
|
private inline fun pivInit(
|
||||||
//todo checks
|
p: MutableStructure2D<Double>,
|
||||||
checkSquareMatrix(luTensor.shape)
|
pivot: MutableStructure1D<Int>,
|
||||||
check(luTensor.shape.dropLast(1).toIntArray() contentEquals pivotsTensor.shape) { "Bed shapes (("} //todo rewrite
|
n: Int
|
||||||
|
) {
|
||||||
val n = luTensor.shape.last()
|
for (i in 0 until n) {
|
||||||
val pTensor = luTensor.zeroesLike()
|
|
||||||
for ((p, pivot) in pTensor.matrixSequence().zip(pivotsTensor.vectorSequence())){
|
|
||||||
for (i in 0 until n){
|
|
||||||
p[i, pivot[i]] = 1.0
|
p[i, pivot[i]] = 1.0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
val lTensor = luTensor.zeroesLike()
|
private inline fun luPivotHelper(
|
||||||
val uTensor = luTensor.zeroesLike()
|
l: MutableStructure2D<Double>,
|
||||||
|
u: MutableStructure2D<Double>,
|
||||||
for ((pairLU, lu) in lTensor.matrixSequence().zip(uTensor.matrixSequence()).zip(luTensor.matrixSequence())){
|
lu: MutableStructure2D<Double>,
|
||||||
val (l, u) = pairLU
|
n: Int
|
||||||
|
) {
|
||||||
for (i in 0 until n) {
|
for (i in 0 until n) {
|
||||||
for (j in 0 until n) {
|
for (j in 0 until n) {
|
||||||
if (i == j) {
|
if (i == j) {
|
||||||
@ -107,18 +111,39 @@ public class DoubleLinearOpsTensorAlgebra :
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override fun luPivot(
|
||||||
|
luTensor: DoubleTensor,
|
||||||
|
pivotsTensor: IntTensor
|
||||||
|
): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
|
||||||
|
//todo checks
|
||||||
|
checkSquareMatrix(luTensor.shape)
|
||||||
|
check(
|
||||||
|
luTensor.shape.dropLast(1).toIntArray() contentEquals pivotsTensor.shape
|
||||||
|
) { "Bed shapes ((" } //todo rewrite
|
||||||
|
|
||||||
|
val n = luTensor.shape.last()
|
||||||
|
val pTensor = luTensor.zeroesLike()
|
||||||
|
for ((p, pivot) in pTensor.matrixSequence().zip(pivotsTensor.vectorSequence()))
|
||||||
|
pivInit(p.as2D(), pivot.as1D(), n)
|
||||||
|
|
||||||
|
val lTensor = luTensor.zeroesLike()
|
||||||
|
val uTensor = luTensor.zeroesLike()
|
||||||
|
|
||||||
|
for ((pairLU, lu) in lTensor.matrixSequence().zip(uTensor.matrixSequence())
|
||||||
|
.zip(luTensor.matrixSequence())) {
|
||||||
|
val (l, u) = pairLU
|
||||||
|
luPivotHelper(l.as2D(), u.as2D(), lu.as2D(), n)
|
||||||
|
}
|
||||||
|
|
||||||
return Triple(pTensor, lTensor, uTensor)
|
return Triple(pTensor, lTensor, uTensor)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.cholesky(): DoubleTensor {
|
private inline fun choleskyHelper(
|
||||||
// todo checks
|
a: MutableStructure2D<Double>,
|
||||||
checkSquareMatrix(shape)
|
l: MutableStructure2D<Double>,
|
||||||
|
n: Int
|
||||||
val n = shape.last()
|
) {
|
||||||
val lTensor = zeroesLike()
|
|
||||||
|
|
||||||
for ((a, l) in this.matrixSequence().zip(lTensor.matrixSequence())) {
|
|
||||||
for (i in 0 until n) {
|
for (i in 0 until n) {
|
||||||
for (j in 0 until i) {
|
for (j in 0 until i) {
|
||||||
var h = a[i, j]
|
var h = a[i, j]
|
||||||
@ -135,6 +160,16 @@ public class DoubleLinearOpsTensorAlgebra :
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override fun DoubleTensor.cholesky(): DoubleTensor {
|
||||||
|
// todo checks
|
||||||
|
checkSquareMatrix(shape)
|
||||||
|
|
||||||
|
val n = shape.last()
|
||||||
|
val lTensor = zeroesLike()
|
||||||
|
|
||||||
|
for ((a, l) in this.matrixSequence().zip(lTensor.matrixSequence()))
|
||||||
|
for (i in 0 until n) choleskyHelper(a.as2D(), l.as2D(), n)
|
||||||
|
|
||||||
return lTensor
|
return lTensor
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -158,9 +193,11 @@ public class DoubleLinearOpsTensorAlgebra :
|
|||||||
TODO("ANDREI")
|
TODO("ANDREI")
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun luMatrixDet(lu: BufferedTensor2D<Double>, pivots: BufferedTensor1D<Int>): Double {
|
private fun luMatrixDet(luTensor: MutableStructure2D<Double>, pivotsTensor: MutableStructure1D<Int>): Double {
|
||||||
|
val lu = luTensor.as2D()
|
||||||
|
val pivots = pivotsTensor.as1D()
|
||||||
val m = lu.shape[0]
|
val m = lu.shape[0]
|
||||||
val sign = if((pivots[m] - m) % 2 == 0) 1.0 else -1.0
|
val sign = if ((pivots[m] - m) % 2 == 0) 1.0 else -1.0
|
||||||
return (0 until m).asSequence().map { lu[it, it] }.fold(sign) { left, right -> left * right }
|
return (0 until m).asSequence().map { lu[it, it] }.fold(sign) { left, right -> left * right }
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -177,28 +214,27 @@ public class DoubleLinearOpsTensorAlgebra :
|
|||||||
resBuffer
|
resBuffer
|
||||||
)
|
)
|
||||||
|
|
||||||
luTensor.matrixSequence().zip(pivotsTensor.vectorSequence()).forEachIndexed { index, (luMatrix, pivots) ->
|
luTensor.matrixSequence().zip(pivotsTensor.vectorSequence()).forEachIndexed { index, (lu, pivots) ->
|
||||||
resBuffer[index] = luMatrixDet(luMatrix, pivots)
|
resBuffer[index] = luMatrixDet(lu.as2D(), pivots.as1D())
|
||||||
}
|
}
|
||||||
|
|
||||||
return detTensor
|
return detTensor
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun luMatrixInv(
|
private fun luMatrixInv(
|
||||||
lu: BufferedTensor2D<Double>,
|
lu: MutableStructure2D<Double>,
|
||||||
pivots: BufferedTensor1D<Int>,
|
pivots: MutableStructure1D<Int>,
|
||||||
invMatrix : BufferedTensor2D<Double>
|
invMatrix: MutableStructure2D<Double>
|
||||||
): Unit {
|
) {
|
||||||
//todo check
|
|
||||||
val m = lu.shape[0]
|
val m = lu.shape[0]
|
||||||
|
|
||||||
for (j in 0 until m) {
|
for (j in 0 until m) {
|
||||||
for (i in 0 until m) {
|
for (i in 0 until m) {
|
||||||
if (pivots[i] == j){
|
if (pivots[i] == j) {
|
||||||
invMatrix[i, j] = 1.0
|
invMatrix[i, j] = 1.0
|
||||||
}
|
}
|
||||||
|
|
||||||
for (k in 0 until i){
|
for (k in 0 until i) {
|
||||||
invMatrix[i, j] -= lu[i, k] * invMatrix[k, j]
|
invMatrix[i, j] -= lu[i, k] * invMatrix[k, j]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -214,13 +250,12 @@ public class DoubleLinearOpsTensorAlgebra :
|
|||||||
|
|
||||||
public fun DoubleTensor.invLU(): DoubleTensor {
|
public fun DoubleTensor.invLU(): DoubleTensor {
|
||||||
val (luTensor, pivotsTensor) = lu()
|
val (luTensor, pivotsTensor) = lu()
|
||||||
val n = shape.size
|
|
||||||
val invTensor = luTensor.zeroesLike()
|
val invTensor = luTensor.zeroesLike()
|
||||||
|
|
||||||
val seq = luTensor.matrixSequence().zip(pivotsTensor.vectorSequence()).zip(invTensor.matrixSequence())
|
val seq = luTensor.matrixSequence().zip(pivotsTensor.vectorSequence()).zip(invTensor.matrixSequence())
|
||||||
for ((luP, invMatrix) in seq) {
|
for ((luP, invMatrix) in seq) {
|
||||||
val (lu, pivots) = luP
|
val (lu, pivots) = luP
|
||||||
luMatrixInv(lu, pivots, invMatrix)
|
luMatrixInv(lu.as2D(), pivots.as1D(), invMatrix.as2D())
|
||||||
}
|
}
|
||||||
|
|
||||||
return invTensor
|
return invTensor
|
||||||
|
@ -1,18 +0,0 @@
|
|||||||
package space.kscience.kmath.tensors.core
|
|
||||||
|
|
||||||
import space.kscience.kmath.tensors.ReduceOpsTensorAlgebra
|
|
||||||
|
|
||||||
public class DoubleReduceOpsTensorAlgebra:
|
|
||||||
DoubleTensorAlgebra(),
|
|
||||||
ReduceOpsTensorAlgebra<Double, DoubleTensor> {
|
|
||||||
|
|
||||||
override fun DoubleTensor.value(): Double {
|
|
||||||
check(this.shape contentEquals intArrayOf(1)) {
|
|
||||||
"Inconsistent value for tensor of shape ${shape.toList()}"
|
|
||||||
}
|
|
||||||
return this.buffer.array()[this.bufferStart]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public inline fun <R> DoubleReduceOpsTensorAlgebra(block: DoubleReduceOpsTensorAlgebra.() -> R): R =
|
|
||||||
DoubleReduceOpsTensorAlgebra().block()
|
|
@ -1,13 +1,19 @@
|
|||||||
package space.kscience.kmath.tensors.core
|
package space.kscience.kmath.tensors.core
|
||||||
|
|
||||||
import space.kscience.kmath.linear.Matrix
|
|
||||||
import space.kscience.kmath.nd.MutableStructure2D
|
import space.kscience.kmath.nd.MutableStructure2D
|
||||||
import space.kscience.kmath.nd.Structure2D
|
import space.kscience.kmath.nd.as2D
|
||||||
import space.kscience.kmath.tensors.TensorPartialDivisionAlgebra
|
import space.kscience.kmath.tensors.TensorPartialDivisionAlgebra
|
||||||
import kotlin.math.abs
|
import kotlin.math.abs
|
||||||
|
|
||||||
public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, DoubleTensor> {
|
public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, DoubleTensor> {
|
||||||
|
|
||||||
|
override fun DoubleTensor.value(): Double {
|
||||||
|
check(this.shape contentEquals intArrayOf(1)) {
|
||||||
|
"Inconsistent value for tensor of shape ${shape.toList()}"
|
||||||
|
}
|
||||||
|
return this.buffer.array()[this.bufferStart]
|
||||||
|
}
|
||||||
|
|
||||||
public fun fromArray(shape: IntArray, buffer: DoubleArray): DoubleTensor {
|
public fun fromArray(shape: IntArray, buffer: DoubleArray): DoubleTensor {
|
||||||
checkEmptyShape(shape)
|
checkEmptyShape(shape)
|
||||||
checkEmptyDoubleBuffer(buffer)
|
checkEmptyDoubleBuffer(buffer)
|
||||||
@ -224,6 +230,23 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
|||||||
return this.view(other.shape)
|
return this.view(other.shape)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private inline fun dotHelper(
|
||||||
|
a: MutableStructure2D<Double>,
|
||||||
|
b: MutableStructure2D<Double>,
|
||||||
|
res: MutableStructure2D<Double>,
|
||||||
|
l: Int, m: Int, n: Int
|
||||||
|
) {
|
||||||
|
for (i in 0 until l) {
|
||||||
|
for (j in 0 until n) {
|
||||||
|
var curr = 0.0
|
||||||
|
for (k in 0 until m) {
|
||||||
|
curr += a[i, k] * b[k, j]
|
||||||
|
}
|
||||||
|
res[i, j] = curr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.dot(other: DoubleTensor): DoubleTensor {
|
override fun DoubleTensor.dot(other: DoubleTensor): DoubleTensor {
|
||||||
if (this.shape.size == 1 && other.shape.size == 1) {
|
if (this.shape.size == 1 && other.shape.size == 1) {
|
||||||
return DoubleTensor(intArrayOf(1), doubleArrayOf(this.times(other).buffer.array().sum()))
|
return DoubleTensor(intArrayOf(1), doubleArrayOf(this.times(other).buffer.array().sum()))
|
||||||
@ -240,7 +263,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
|||||||
}
|
}
|
||||||
if (other.shape.size == 1) {
|
if (other.shape.size == 1) {
|
||||||
lastDim = true
|
lastDim = true
|
||||||
newOther = other.view(other.shape + intArrayOf(1) )
|
newOther = other.view(other.shape + intArrayOf(1))
|
||||||
}
|
}
|
||||||
|
|
||||||
val broadcastTensors = broadcastOuterTensors(newThis, newOther)
|
val broadcastTensors = broadcastOuterTensors(newThis, newOther)
|
||||||
@ -248,7 +271,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
|||||||
newOther = broadcastTensors[1]
|
newOther = broadcastTensors[1]
|
||||||
|
|
||||||
val l = newThis.shape[newThis.shape.size - 2]
|
val l = newThis.shape[newThis.shape.size - 2]
|
||||||
val m1= newThis.shape[newThis.shape.size - 1]
|
val m1 = newThis.shape[newThis.shape.size - 1]
|
||||||
val m2 = newOther.shape[newOther.shape.size - 2]
|
val m2 = newOther.shape[newOther.shape.size - 2]
|
||||||
val n = newOther.shape[newOther.shape.size - 1]
|
val n = newOther.shape[newOther.shape.size - 1]
|
||||||
if (m1 != m2) {
|
if (m1 != m2) {
|
||||||
@ -262,21 +285,14 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
|||||||
|
|
||||||
for ((res, ab) in resTensor.matrixSequence().zip(newThis.matrixSequence().zip(newOther.matrixSequence()))) {
|
for ((res, ab) in resTensor.matrixSequence().zip(newThis.matrixSequence().zip(newOther.matrixSequence()))) {
|
||||||
val (a, b) = ab
|
val (a, b) = ab
|
||||||
|
dotHelper(a.as2D(), b.as2D(), res.as2D(), l, m, n)
|
||||||
for (i in 0 until l) {
|
|
||||||
for (j in 0 until n) {
|
|
||||||
var curr = 0.0
|
|
||||||
for (k in 0 until m) {
|
|
||||||
curr += a[i, k] * b[k, j]
|
|
||||||
}
|
|
||||||
res[i, j] = curr
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (penultimateDim) {
|
if (penultimateDim) {
|
||||||
return resTensor.view(resTensor.shape.dropLast(2).toIntArray() +
|
return resTensor.view(
|
||||||
intArrayOf(resTensor.shape.last()))
|
resTensor.shape.dropLast(2).toIntArray() +
|
||||||
|
intArrayOf(resTensor.shape.last())
|
||||||
|
)
|
||||||
}
|
}
|
||||||
if (lastDim) {
|
if (lastDim) {
|
||||||
return resTensor.view(resTensor.shape.dropLast(1).toIntArray())
|
return resTensor.view(resTensor.shape.dropLast(1).toIntArray())
|
||||||
@ -307,15 +323,11 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
|||||||
|
|
||||||
public fun DoubleTensor.eq(other: DoubleTensor): Boolean = this.eq(other, 1e-5)
|
public fun DoubleTensor.eq(other: DoubleTensor): Boolean = this.eq(other, 1e-5)
|
||||||
|
|
||||||
public fun DoubleTensor.contentEquals(other: DoubleTensor, eqFunction: (Double, Double) -> Boolean): Boolean {
|
public fun DoubleTensor.contentEquals(other: DoubleTensor, eqFunction: (Double, Double) -> Boolean): Boolean =
|
||||||
if (!(this.shape contentEquals other.shape)) {
|
this.eq(other, eqFunction)
|
||||||
return false
|
|
||||||
}
|
|
||||||
return this.eq(other, eqFunction)
|
|
||||||
}
|
|
||||||
|
|
||||||
public fun DoubleTensor.eq(other: DoubleTensor, eqFunction: (Double, Double) -> Boolean): Boolean {
|
private fun DoubleTensor.eq(other: DoubleTensor, eqFunction: (Double, Double) -> Boolean): Boolean {
|
||||||
// todo broadcasting checking
|
checkShapesCompatible(this, other)
|
||||||
val n = this.linearStructure.size
|
val n = this.linearStructure.size
|
||||||
if (n != other.linearStructure.size) {
|
if (n != other.linearStructure.size) {
|
||||||
return false
|
return false
|
||||||
|
@ -11,14 +11,14 @@ internal inline fun <T, TensorType : TensorStructure<T>,
|
|||||||
"Illegal empty shape provided"
|
"Illegal empty shape provided"
|
||||||
}
|
}
|
||||||
|
|
||||||
internal inline fun < TensorType : TensorStructure<Double>,
|
internal inline fun <TensorType : TensorStructure<Double>,
|
||||||
TorchTensorAlgebraType : TensorAlgebra<Double, TensorType>>
|
TorchTensorAlgebraType : TensorAlgebra<Double, TensorType>>
|
||||||
TorchTensorAlgebraType.checkEmptyDoubleBuffer(buffer: DoubleArray): Unit =
|
TorchTensorAlgebraType.checkEmptyDoubleBuffer(buffer: DoubleArray): Unit =
|
||||||
check(buffer.isNotEmpty()) {
|
check(buffer.isNotEmpty()) {
|
||||||
"Illegal empty buffer provided"
|
"Illegal empty buffer provided"
|
||||||
}
|
}
|
||||||
|
|
||||||
internal inline fun < TensorType : TensorStructure<Double>,
|
internal inline fun <TensorType : TensorStructure<Double>,
|
||||||
TorchTensorAlgebraType : TensorAlgebra<Double, TensorType>>
|
TorchTensorAlgebraType : TensorAlgebra<Double, TensorType>>
|
||||||
TorchTensorAlgebraType.checkBufferShapeConsistency(shape: IntArray, buffer: DoubleArray): Unit =
|
TorchTensorAlgebraType.checkBufferShapeConsistency(shape: IntArray, buffer: DoubleArray): Unit =
|
||||||
check(buffer.size == shape.reduce(Int::times)) {
|
check(buffer.size == shape.reduce(Int::times)) {
|
||||||
|
@ -72,11 +72,8 @@ class TestDoubleLinearOpsTensorAlgebra {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testScalarProduct() = DoubleLinearOpsTensorAlgebra {
|
fun testScalarProduct() = DoubleLinearOpsTensorAlgebra {
|
||||||
val a = fromArray(intArrayOf(3), doubleArrayOf(1.8,2.5, 6.8))
|
val a = fromArray(intArrayOf(3), doubleArrayOf(1.8, 2.5, 6.8))
|
||||||
val b = fromArray(intArrayOf(3), doubleArrayOf(5.5,2.6, 6.4))
|
val b = fromArray(intArrayOf(3), doubleArrayOf(5.5, 2.6, 6.4))
|
||||||
DoubleReduceOpsTensorAlgebra {
|
|
||||||
assertEquals(a.dot(b).value(), 59.92)
|
assertEquals(a.dot(b).value(), 59.92)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
package space.kscience.kmath.tensors.core
|
package space.kscience.kmath.tensors.core
|
||||||
|
|
||||||
|
import space.kscience.kmath.nd.as1D
|
||||||
|
import space.kscience.kmath.nd.as2D
|
||||||
import space.kscience.kmath.structures.toDoubleArray
|
import space.kscience.kmath.structures.toDoubleArray
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
@ -8,7 +10,7 @@ import kotlin.test.assertTrue
|
|||||||
class TestDoubleTensor {
|
class TestDoubleTensor {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun valueTest() = DoubleReduceOpsTensorAlgebra {
|
fun valueTest() = DoubleTensorAlgebra {
|
||||||
val value = 12.5
|
val value = 12.5
|
||||||
val tensor = fromArray(intArrayOf(1), doubleArrayOf(value))
|
val tensor = fromArray(intArrayOf(1), doubleArrayOf(value))
|
||||||
assertEquals(tensor.value(), value)
|
assertEquals(tensor.value(), value)
|
||||||
@ -35,5 +37,13 @@ class TestDoubleTensor {
|
|||||||
|
|
||||||
vector[0] = 109.56
|
vector[0] = 109.56
|
||||||
assertEquals(tensor[intArrayOf(0,1,0)], 109.56)
|
assertEquals(tensor[intArrayOf(0,1,0)], 109.56)
|
||||||
|
|
||||||
|
tensor.matrixSequence().forEach {
|
||||||
|
val a = it.asTensor()
|
||||||
|
val secondRow = a[1].as1D()
|
||||||
|
val secondColumn = a.transpose(0,1)[1].as1D()
|
||||||
|
assertEquals(secondColumn[0], 77.89)
|
||||||
|
assertEquals(secondRow[1], secondColumn[1])
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user