Merge remote-tracking branch 'ups/feature/tensor-algebra' into andrew

This commit is contained in:
Andrei Kislitsyn 2021-03-26 17:43:20 +03:00
commit d6a1bee93b
6 changed files with 60 additions and 8 deletions

View File

@ -984,6 +984,8 @@ public final class space/kscience/kmath/nd/MutableStructure1D$DefaultImpls {
} }
public abstract interface class space/kscience/kmath/nd/MutableStructure2D : space/kscience/kmath/nd/MutableStructureND, space/kscience/kmath/nd/Structure2D { public abstract interface class space/kscience/kmath/nd/MutableStructure2D : space/kscience/kmath/nd/MutableStructureND, space/kscience/kmath/nd/Structure2D {
public abstract fun getColumns ()Ljava/util/List;
public abstract fun getRows ()Ljava/util/List;
public abstract fun set (IILjava/lang/Object;)V public abstract fun set (IILjava/lang/Object;)V
} }
@ -2608,6 +2610,15 @@ public final class space/kscience/kmath/structures/VirtualBuffer : space/kscienc
public fun iterator ()Ljava/util/Iterator; public fun iterator ()Ljava/util/Iterator;
} }
public final class space/kscience/kmath/structures/VirtualMutableBuffer : space/kscience/kmath/structures/MutableBuffer {
public fun <init> (ILkotlin/jvm/functions/Function1;)V
public fun copy ()Lspace/kscience/kmath/structures/MutableBuffer;
public fun get (I)Ljava/lang/Object;
public fun getSize ()I
public fun iterator ()Ljava/util/Iterator;
public fun set (ILjava/lang/Object;)V
}
public abstract interface class space/kscience/kmath/tensors/AnalyticTensorAlgebra : space/kscience/kmath/tensors/OrderedTensorAlgebra, space/kscience/kmath/tensors/TensorPartialDivisionAlgebra { public abstract interface class space/kscience/kmath/tensors/AnalyticTensorAlgebra : space/kscience/kmath/tensors/OrderedTensorAlgebra, space/kscience/kmath/tensors/TensorPartialDivisionAlgebra {
public abstract fun acos (Lspace/kscience/kmath/nd/MutableStructureND;)Lspace/kscience/kmath/nd/MutableStructureND; public abstract fun acos (Lspace/kscience/kmath/nd/MutableStructureND;)Lspace/kscience/kmath/nd/MutableStructureND;
public abstract fun acosh (Lspace/kscience/kmath/nd/MutableStructureND;)Lspace/kscience/kmath/nd/MutableStructureND; public abstract fun acosh (Lspace/kscience/kmath/nd/MutableStructureND;)Lspace/kscience/kmath/nd/MutableStructureND;
@ -2709,6 +2720,7 @@ public abstract interface class space/kscience/kmath/tensors/TensorAlgebra {
public abstract fun cumsum (Lspace/kscience/kmath/nd/MutableStructureND;I)Lspace/kscience/kmath/nd/MutableStructureND; public abstract fun cumsum (Lspace/kscience/kmath/nd/MutableStructureND;I)Lspace/kscience/kmath/nd/MutableStructureND;
public abstract fun diagonalEmbedding (Lspace/kscience/kmath/nd/MutableStructureND;III)Lspace/kscience/kmath/nd/MutableStructureND; public abstract fun diagonalEmbedding (Lspace/kscience/kmath/nd/MutableStructureND;III)Lspace/kscience/kmath/nd/MutableStructureND;
public abstract fun dot (Lspace/kscience/kmath/nd/MutableStructureND;Lspace/kscience/kmath/nd/MutableStructureND;)Lspace/kscience/kmath/nd/MutableStructureND; public abstract fun dot (Lspace/kscience/kmath/nd/MutableStructureND;Lspace/kscience/kmath/nd/MutableStructureND;)Lspace/kscience/kmath/nd/MutableStructureND;
public abstract fun eq (Lspace/kscience/kmath/nd/MutableStructureND;Lspace/kscience/kmath/nd/MutableStructureND;Ljava/lang/Object;)Z
public abstract fun eye (I)Lspace/kscience/kmath/nd/MutableStructureND; public abstract fun eye (I)Lspace/kscience/kmath/nd/MutableStructureND;
public abstract fun flatten (Lspace/kscience/kmath/nd/MutableStructureND;II)Lspace/kscience/kmath/nd/MutableStructureND; public abstract fun flatten (Lspace/kscience/kmath/nd/MutableStructureND;II)Lspace/kscience/kmath/nd/MutableStructureND;
public abstract fun full (Ljava/lang/Object;[I)Lspace/kscience/kmath/nd/MutableStructureND; public abstract fun full (Ljava/lang/Object;[I)Lspace/kscience/kmath/nd/MutableStructureND;
@ -2970,9 +2982,10 @@ public class space/kscience/kmath/tensors/core/DoubleTensorAlgebra : space/kscie
public fun divAssign (Lspace/kscience/kmath/tensors/core/DoubleTensor;Lspace/kscience/kmath/tensors/core/DoubleTensor;)V public fun divAssign (Lspace/kscience/kmath/tensors/core/DoubleTensor;Lspace/kscience/kmath/tensors/core/DoubleTensor;)V
public synthetic fun dot (Lspace/kscience/kmath/nd/MutableStructureND;Lspace/kscience/kmath/nd/MutableStructureND;)Lspace/kscience/kmath/nd/MutableStructureND; public synthetic fun dot (Lspace/kscience/kmath/nd/MutableStructureND;Lspace/kscience/kmath/nd/MutableStructureND;)Lspace/kscience/kmath/nd/MutableStructureND;
public fun dot (Lspace/kscience/kmath/tensors/core/DoubleTensor;Lspace/kscience/kmath/tensors/core/DoubleTensor;)Lspace/kscience/kmath/tensors/core/DoubleTensor; public fun dot (Lspace/kscience/kmath/tensors/core/DoubleTensor;Lspace/kscience/kmath/tensors/core/DoubleTensor;)Lspace/kscience/kmath/tensors/core/DoubleTensor;
public final fun eq (Lspace/kscience/kmath/tensors/core/DoubleTensor;Lspace/kscience/kmath/tensors/core/DoubleTensor;D)Z public synthetic fun eq (Lspace/kscience/kmath/nd/MutableStructureND;Lspace/kscience/kmath/nd/MutableStructureND;Ljava/lang/Object;)Z
public final fun eq (Lspace/kscience/kmath/tensors/core/DoubleTensor;Lspace/kscience/kmath/tensors/core/DoubleTensor;)Z
public fun eq (Lspace/kscience/kmath/tensors/core/DoubleTensor;Lspace/kscience/kmath/tensors/core/DoubleTensor;D)Z
public final fun eq (Lspace/kscience/kmath/tensors/core/DoubleTensor;Lspace/kscience/kmath/tensors/core/DoubleTensor;Lkotlin/jvm/functions/Function2;)Z public final fun eq (Lspace/kscience/kmath/tensors/core/DoubleTensor;Lspace/kscience/kmath/tensors/core/DoubleTensor;Lkotlin/jvm/functions/Function2;)Z
public static synthetic fun eq$default (Lspace/kscience/kmath/tensors/core/DoubleTensorAlgebra;Lspace/kscience/kmath/tensors/core/DoubleTensor;Lspace/kscience/kmath/tensors/core/DoubleTensor;DILjava/lang/Object;)Z
public synthetic fun eye (I)Lspace/kscience/kmath/nd/MutableStructureND; public synthetic fun eye (I)Lspace/kscience/kmath/nd/MutableStructureND;
public fun eye (I)Lspace/kscience/kmath/tensors/core/DoubleTensor; public fun eye (I)Lspace/kscience/kmath/tensors/core/DoubleTensor;
public synthetic fun flatten (Lspace/kscience/kmath/nd/MutableStructureND;II)Lspace/kscience/kmath/nd/MutableStructureND; public synthetic fun flatten (Lspace/kscience/kmath/nd/MutableStructureND;II)Lspace/kscience/kmath/nd/MutableStructureND;

View File

@ -14,6 +14,7 @@ import kotlin.reflect.KClass
* @param T the type of items. * @param T the type of items.
*/ */
public typealias Matrix<T> = Structure2D<T> public typealias Matrix<T> = Structure2D<T>
public typealias MutableMatrix<T> = MutableStructure2D<T>
/** /**
* Alias or using [Buffer] as a point/vector in a many-dimensional space. * Alias or using [Buffer] as a point/vector in a many-dimensional space.

View File

@ -2,7 +2,9 @@ package space.kscience.kmath.nd
import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.structures.MutableBuffer
import space.kscience.kmath.structures.VirtualBuffer import space.kscience.kmath.structures.VirtualBuffer
import space.kscience.kmath.structures.VirtualMutableBuffer
import kotlin.reflect.KClass import kotlin.reflect.KClass
/** /**
@ -69,6 +71,18 @@ public interface MutableStructure2D<T> : Structure2D<T>, MutableStructureND<T> {
* @param value the value. * @param value the value.
*/ */
public operator fun set(i: Int, j: Int, value: T) public operator fun set(i: Int, j: Int, value: T)
/**
* The buffer of rows of this structure. It gets elements from the structure dynamically.
*/
override val rows: List<MutableBuffer<T>>
get() = List(rowNum) { i -> VirtualMutableBuffer(colNum) { j -> get(i, j) } }
/**
* The buffer of columns of this structure. It gets elements from the structure dynamically.
*/
override val columns: List<MutableBuffer<T>>
get() = List(colNum) { j -> VirtualMutableBuffer(rowNum) { i -> get(i, j) } }
} }
/** /**

View File

@ -223,11 +223,7 @@ public inline class MutableListBuffer<T>(public val list: MutableList<T>) : Muta
} }
/** /**
<<<<<<< HEAD
* Returns an [MutableListBuffer] that wraps the original list. * Returns an [MutableListBuffer] that wraps the original list.
=======
* Returns an [ListBuffer] that wraps the original list.
>>>>>>> dev
*/ */
public fun <T> MutableList<T>.asMutableBuffer(): MutableListBuffer<T> = MutableListBuffer(this) public fun <T> MutableList<T>.asMutableBuffer(): MutableListBuffer<T> = MutableListBuffer(this)
@ -286,6 +282,24 @@ public class VirtualBuffer<T>(override val size: Int, private val generator: (In
override operator fun iterator(): Iterator<T> = (0 until size).asSequence().map(generator).iterator() override operator fun iterator(): Iterator<T> = (0 until size).asSequence().map(generator).iterator()
} }
public class VirtualMutableBuffer<T>(override val size: Int, private val generator: (Int) -> T) : MutableBuffer<T> {
private val bufferHolder: MutableListBuffer<T> = (0 until size).map(generator).toMutableList().asMutableBuffer()
override operator fun get(index: Int): T {
if (index < 0 || index >= size) throw IndexOutOfBoundsException("Expected index from 0 to ${size - 1}, but found $index")
return bufferHolder[index]
}
override operator fun iterator(): Iterator<T> = bufferHolder.iterator()
override fun set(index: Int, value: T) {
bufferHolder[index] = value
}
override fun copy(): MutableBuffer<T> = bufferHolder.copy()
}
/** /**
* Convert this buffer to read-only buffer. * Convert this buffer to read-only buffer.
*/ */

View File

@ -9,8 +9,8 @@ public interface TensorPartialDivisionAlgebra<T, TensorType : TensorStructure<T>
public operator fun TensorType.divAssign(other: TensorType) public operator fun TensorType.divAssign(other: TensorType)
//https://pytorch.org/docs/stable/generated/torch.mean.html#torch.mean //https://pytorch.org/docs/stable/generated/torch.mean.html#torch.mean
public fun TensorType.mean(dim: Int, keepDim: Boolean): TensorType public fun TensorType.mean(dim: Int = 0, keepDim: Boolean = false): TensorType
//https://pytorch.org/docs/stable/generated/torch.var.html#torch.var //https://pytorch.org/docs/stable/generated/torch.var.html#torch.var
public fun TensorType.variance(dim: Int, unbiased: Boolean, keepDim: Boolean): TensorType public fun TensorType.variance(dim: Int = 0, unbiased: Boolean = true, keepDim: Boolean = false): TensorType
} }

View File

@ -70,4 +70,14 @@ class TestDoubleLinearOpsTensorAlgebra {
assertTrue { invTensor.shape contentEquals expectedShape } assertTrue { invTensor.shape contentEquals expectedShape }
assertTrue { invTensor.buffer.array().epsEqual(expectedBuffer) } assertTrue { invTensor.buffer.array().epsEqual(expectedBuffer) }
} }
@Test
fun testScalarProduct() = DoubleLinearOpsTensorAlgebra {
val a = fromArray(intArrayOf(3), doubleArrayOf(1.8,2.5, 6.8))
val b = fromArray(intArrayOf(3), doubleArrayOf(5.5,2.6, 6.4))
DoubleReduceOpsTensorAlgebra {
assertEquals(a.dot(b).value(), 59.92)
}
}
} }