Tensor algebra generified
This commit is contained in:
parent
4db7398a28
commit
46e7da9ae0
@ -81,8 +81,8 @@ public sealed class DoubleFieldOpsND : BufferedFieldOpsND<Double, DoubleField>(D
|
|||||||
|
|
||||||
override fun StructureND<Double>.unaryMinus(): DoubleBufferND = mapInline(toBufferND()) { -it }
|
override fun StructureND<Double>.unaryMinus(): DoubleBufferND = mapInline(toBufferND()) { -it }
|
||||||
|
|
||||||
override fun StructureND<Double>.div(other: StructureND<Double>): DoubleBufferND =
|
override fun StructureND<Double>.div(arg: StructureND<Double>): DoubleBufferND =
|
||||||
zipInline(toBufferND(), other.toBufferND()) { l, r -> l / r }
|
zipInline(toBufferND(), arg.toBufferND()) { l, r -> l / r }
|
||||||
|
|
||||||
override fun divide(left: StructureND<Double>, right: StructureND<Double>): DoubleBufferND =
|
override fun divide(left: StructureND<Double>, right: StructureND<Double>): DoubleBufferND =
|
||||||
zipInline(left.toBufferND(), right.toBufferND()) { l: Double, r: Double -> l / r }
|
zipInline(left.toBufferND(), right.toBufferND()) { l: Double, r: Double -> l / r }
|
||||||
@ -101,8 +101,8 @@ public sealed class DoubleFieldOpsND : BufferedFieldOpsND<Double, DoubleField>(D
|
|||||||
override fun StructureND<Double>.minus(arg: StructureND<Double>): DoubleBufferND =
|
override fun StructureND<Double>.minus(arg: StructureND<Double>): DoubleBufferND =
|
||||||
zipInline(toBufferND(), arg.toBufferND()) { l: Double, r: Double -> l - r }
|
zipInline(toBufferND(), arg.toBufferND()) { l: Double, r: Double -> l - r }
|
||||||
|
|
||||||
override fun StructureND<Double>.times(other: StructureND<Double>): DoubleBufferND =
|
override fun StructureND<Double>.times(arg: StructureND<Double>): DoubleBufferND =
|
||||||
zipInline(toBufferND(), other.toBufferND()) { l: Double, r: Double -> l * r }
|
zipInline(toBufferND(), arg.toBufferND()) { l: Double, r: Double -> l * r }
|
||||||
|
|
||||||
override fun StructureND<Double>.times(k: Number): DoubleBufferND =
|
override fun StructureND<Double>.times(k: Number): DoubleBufferND =
|
||||||
mapInline(toBufferND()) { it * k.toDouble() }
|
mapInline(toBufferND()) { it * k.toDouble() }
|
||||||
|
@ -270,10 +270,10 @@ public interface FieldOps<T> : RingOps<T> {
|
|||||||
* Division of two elements.
|
* Division of two elements.
|
||||||
*
|
*
|
||||||
* @receiver the dividend.
|
* @receiver the dividend.
|
||||||
* @param other the divisor.
|
* @param arg the divisor.
|
||||||
* @return the quotient.
|
* @return the quotient.
|
||||||
*/
|
*/
|
||||||
public operator fun T.div(other: T): T = divide(this, other)
|
public operator fun T.div(arg: T): T = divide(this, arg)
|
||||||
|
|
||||||
override fun binaryOperationFunction(operation: String): (left: T, right: T) -> T = when (operation) {
|
override fun binaryOperationFunction(operation: String): (left: T, right: T) -> T = when (operation) {
|
||||||
DIV_OPERATION -> ::divide
|
DIV_OPERATION -> ::divide
|
||||||
|
@ -34,18 +34,18 @@ public abstract class DoubleBufferOps : ExtendedFieldOps<Buffer<Double>>, Norm<B
|
|||||||
} else DoubleBuffer(DoubleArray(left.size) { left[it] + right[it] })
|
} else DoubleBuffer(DoubleArray(left.size) { left[it] + right[it] })
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun Buffer<Double>.plus(other: Buffer<Double>): DoubleBuffer = add(this, other)
|
override fun Buffer<Double>.plus(arg: Buffer<Double>): DoubleBuffer = add(this, arg)
|
||||||
|
|
||||||
override fun Buffer<Double>.minus(other: Buffer<Double>): DoubleBuffer {
|
override fun Buffer<Double>.minus(arg: Buffer<Double>): DoubleBuffer {
|
||||||
require(other.size == this.size) {
|
require(arg.size == this.size) {
|
||||||
"The size of the first buffer ${this.size} should be the same as for second one: ${other.size} "
|
"The size of the first buffer ${this.size} should be the same as for second one: ${arg.size} "
|
||||||
}
|
}
|
||||||
|
|
||||||
return if (this is DoubleBuffer && other is DoubleBuffer) {
|
return if (this is DoubleBuffer && arg is DoubleBuffer) {
|
||||||
val aArray = this.array
|
val aArray = this.array
|
||||||
val bArray = other.array
|
val bArray = arg.array
|
||||||
DoubleBuffer(DoubleArray(this.size) { aArray[it] - bArray[it] })
|
DoubleBuffer(DoubleArray(this.size) { aArray[it] - bArray[it] })
|
||||||
} else DoubleBuffer(DoubleArray(this.size) { this[it] - other[it] })
|
} else DoubleBuffer(DoubleArray(this.size) { this[it] - arg[it] })
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
|
@ -102,10 +102,10 @@ public object DoubleField : ExtendedField<Double>, Norm<Double, Double>, ScaleOp
|
|||||||
override inline fun norm(arg: Double): Double = abs(arg)
|
override inline fun norm(arg: Double): Double = abs(arg)
|
||||||
|
|
||||||
override inline fun Double.unaryMinus(): Double = -this
|
override inline fun Double.unaryMinus(): Double = -this
|
||||||
override inline fun Double.plus(other: Double): Double = this + other
|
override inline fun Double.plus(arg: Double): Double = this + arg
|
||||||
override inline fun Double.minus(other: Double): Double = this - other
|
override inline fun Double.minus(arg: Double): Double = this - arg
|
||||||
override inline fun Double.times(other: Double): Double = this * other
|
override inline fun Double.times(arg: Double): Double = this * arg
|
||||||
override inline fun Double.div(other: Double): Double = this / other
|
override inline fun Double.div(arg: Double): Double = this / arg
|
||||||
}
|
}
|
||||||
|
|
||||||
public val Double.Companion.algebra: DoubleField get() = DoubleField
|
public val Double.Companion.algebra: DoubleField get() = DoubleField
|
||||||
@ -155,10 +155,10 @@ public object FloatField : ExtendedField<Float>, Norm<Float, Float> {
|
|||||||
override inline fun norm(arg: Float): Float = abs(arg)
|
override inline fun norm(arg: Float): Float = abs(arg)
|
||||||
|
|
||||||
override inline fun Float.unaryMinus(): Float = -this
|
override inline fun Float.unaryMinus(): Float = -this
|
||||||
override inline fun Float.plus(other: Float): Float = this + other
|
override inline fun Float.plus(arg: Float): Float = this + arg
|
||||||
override inline fun Float.minus(other: Float): Float = this - other
|
override inline fun Float.minus(arg: Float): Float = this - arg
|
||||||
override inline fun Float.times(other: Float): Float = this * other
|
override inline fun Float.times(arg: Float): Float = this * arg
|
||||||
override inline fun Float.div(other: Float): Float = this / other
|
override inline fun Float.div(arg: Float): Float = this / arg
|
||||||
}
|
}
|
||||||
|
|
||||||
public val Float.Companion.algebra: FloatField get() = FloatField
|
public val Float.Companion.algebra: FloatField get() = FloatField
|
||||||
@ -180,9 +180,9 @@ public object IntRing : Ring<Int>, Norm<Int, Int>, NumericAlgebra<Int> {
|
|||||||
override inline fun norm(arg: Int): Int = abs(arg)
|
override inline fun norm(arg: Int): Int = abs(arg)
|
||||||
|
|
||||||
override inline fun Int.unaryMinus(): Int = -this
|
override inline fun Int.unaryMinus(): Int = -this
|
||||||
override inline fun Int.plus(other: Int): Int = this + other
|
override inline fun Int.plus(arg: Int): Int = this + arg
|
||||||
override inline fun Int.minus(other: Int): Int = this - other
|
override inline fun Int.minus(arg: Int): Int = this - arg
|
||||||
override inline fun Int.times(other: Int): Int = this * other
|
override inline fun Int.times(arg: Int): Int = this * arg
|
||||||
}
|
}
|
||||||
|
|
||||||
public val Int.Companion.algebra: IntRing get() = IntRing
|
public val Int.Companion.algebra: IntRing get() = IntRing
|
||||||
@ -204,9 +204,9 @@ public object ShortRing : Ring<Short>, Norm<Short, Short>, NumericAlgebra<Short>
|
|||||||
override fun norm(arg: Short): Short = if (arg > 0) arg else (-arg).toShort()
|
override fun norm(arg: Short): Short = if (arg > 0) arg else (-arg).toShort()
|
||||||
|
|
||||||
override inline fun Short.unaryMinus(): Short = (-this).toShort()
|
override inline fun Short.unaryMinus(): Short = (-this).toShort()
|
||||||
override inline fun Short.plus(other: Short): Short = (this + other).toShort()
|
override inline fun Short.plus(arg: Short): Short = (this + arg).toShort()
|
||||||
override inline fun Short.minus(other: Short): Short = (this - other).toShort()
|
override inline fun Short.minus(arg: Short): Short = (this - arg).toShort()
|
||||||
override inline fun Short.times(other: Short): Short = (this * other).toShort()
|
override inline fun Short.times(arg: Short): Short = (this * arg).toShort()
|
||||||
}
|
}
|
||||||
|
|
||||||
public val Short.Companion.algebra: ShortRing get() = ShortRing
|
public val Short.Companion.algebra: ShortRing get() = ShortRing
|
||||||
@ -230,7 +230,7 @@ public object ByteRing : Ring<Byte>, Norm<Byte, Byte>, NumericAlgebra<Byte> {
|
|||||||
override inline fun Byte.unaryMinus(): Byte = (-this).toByte()
|
override inline fun Byte.unaryMinus(): Byte = (-this).toByte()
|
||||||
override inline fun Byte.plus(arg: Byte): Byte = (this + arg).toByte()
|
override inline fun Byte.plus(arg: Byte): Byte = (this + arg).toByte()
|
||||||
override inline fun Byte.minus(arg: Byte): Byte = (this - arg).toByte()
|
override inline fun Byte.minus(arg: Byte): Byte = (this - arg).toByte()
|
||||||
override inline fun Byte.times(other: Byte): Byte = (this * other).toByte()
|
override inline fun Byte.times(arg: Byte): Byte = (this * arg).toByte()
|
||||||
}
|
}
|
||||||
|
|
||||||
public val Byte.Companion.algebra: ByteRing get() = ByteRing
|
public val Byte.Companion.algebra: ByteRing get() = ByteRing
|
||||||
@ -252,9 +252,9 @@ public object LongRing : Ring<Long>, Norm<Long, Long>, NumericAlgebra<Long> {
|
|||||||
override fun norm(arg: Long): Long = abs(arg)
|
override fun norm(arg: Long): Long = abs(arg)
|
||||||
|
|
||||||
override inline fun Long.unaryMinus(): Long = (-this)
|
override inline fun Long.unaryMinus(): Long = (-this)
|
||||||
override inline fun Long.plus(other: Long): Long = (this + other)
|
override inline fun Long.plus(arg: Long): Long = (this + arg)
|
||||||
override inline fun Long.minus(other: Long): Long = (this - other)
|
override inline fun Long.minus(arg: Long): Long = (this - arg)
|
||||||
override inline fun Long.times(other: Long): Long = (this * other)
|
override inline fun Long.times(arg: Long): Long = (this * arg)
|
||||||
}
|
}
|
||||||
|
|
||||||
public val Long.Companion.algebra: LongRing get() = LongRing
|
public val Long.Companion.algebra: LongRing get() = LongRing
|
||||||
|
@ -59,8 +59,8 @@ public object JafamaDoubleField : ExtendedField<Double>, Norm<Double, Double>, S
|
|||||||
override inline fun Double.unaryMinus(): Double = -this
|
override inline fun Double.unaryMinus(): Double = -this
|
||||||
override inline fun Double.plus(arg: Double): Double = this + arg
|
override inline fun Double.plus(arg: Double): Double = this + arg
|
||||||
override inline fun Double.minus(arg: Double): Double = this - arg
|
override inline fun Double.minus(arg: Double): Double = this - arg
|
||||||
override inline fun Double.times(other: Double): Double = this * other
|
override inline fun Double.times(arg: Double): Double = this * arg
|
||||||
override inline fun Double.div(other: Double): Double = this / other
|
override inline fun Double.div(arg: Double): Double = this / arg
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -108,8 +108,8 @@ public object StrictJafamaDoubleField : ExtendedField<Double>, Norm<Double, Doub
|
|||||||
override inline fun norm(arg: Double): Double = StrictFastMath.abs(arg)
|
override inline fun norm(arg: Double): Double = StrictFastMath.abs(arg)
|
||||||
|
|
||||||
override inline fun Double.unaryMinus(): Double = -this
|
override inline fun Double.unaryMinus(): Double = -this
|
||||||
override inline fun Double.plus(other: Double): Double = this + other
|
override inline fun Double.plus(arg: Double): Double = this + arg
|
||||||
override inline fun Double.minus(other: Double): Double = this - other
|
override inline fun Double.minus(arg: Double): Double = this - arg
|
||||||
override inline fun Double.times(other: Double): Double = this * other
|
override inline fun Double.times(arg: Double): Double = this * arg
|
||||||
override inline fun Double.div(other: Double): Double = this / other
|
override inline fun Double.div(arg: Double): Double = this / arg
|
||||||
}
|
}
|
||||||
|
@ -294,11 +294,11 @@ public abstract class MultikDivisionTensorAlgebra<T, A : Field<T>>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun Tensor<T>.divAssign(other: StructureND<T>) {
|
override fun Tensor<T>.divAssign(arg: StructureND<T>) {
|
||||||
if (this is MultikTensor) {
|
if (this is MultikTensor) {
|
||||||
array.divAssign(other.asMultik().array)
|
array.divAssign(arg.asMultik().array)
|
||||||
} else {
|
} else {
|
||||||
mapInPlace { index, t -> elementAlgebra.divide(t, other[index]) }
|
mapInPlace { index, t -> elementAlgebra.divide(t, arg[index]) }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -148,8 +148,8 @@ public sealed interface Nd4jTensorAlgebra<T : Number, A : Field<T>> : AnalyticTe
|
|||||||
ndArray.divi(value)
|
ndArray.divi(value)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun Tensor<T>.divAssign(other: StructureND<T>) {
|
override fun Tensor<T>.divAssign(arg: StructureND<T>) {
|
||||||
ndArray.divi(other.ndArray)
|
ndArray.divi(arg.ndArray)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun StructureND<T>.variance(dim: Int, keepDim: Boolean): Nd4jArrayStructure<T> =
|
override fun StructureND<T>.variance(dim: Int, keepDim: Boolean): Nd4jArrayStructure<T> =
|
||||||
|
@ -53,9 +53,9 @@ public interface TensorPartialDivisionAlgebra<T, A : Field<T>> : TensorAlgebra<T
|
|||||||
public operator fun Tensor<T>.divAssign(value: T)
|
public operator fun Tensor<T>.divAssign(value: T)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Each element of this tensor is divided by each element of the [other] tensor.
|
* Each element of this tensor is divided by each element of the [arg] tensor.
|
||||||
*
|
*
|
||||||
* @param other tensor to be divided by.
|
* @param arg tensor to be divided by.
|
||||||
*/
|
*/
|
||||||
public operator fun Tensor<T>.divAssign(other: StructureND<T>)
|
public operator fun Tensor<T>.divAssign(arg: StructureND<T>)
|
||||||
}
|
}
|
||||||
|
@ -74,8 +74,8 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun StructureND<Double>.div(other: StructureND<Double>): DoubleTensor {
|
override fun StructureND<Double>.div(arg: StructureND<Double>): DoubleTensor {
|
||||||
val broadcast = broadcastTensors(tensor, other.tensor)
|
val broadcast = broadcastTensors(tensor, arg.tensor)
|
||||||
val newThis = broadcast[0]
|
val newThis = broadcast[0]
|
||||||
val newOther = broadcast[1]
|
val newOther = broadcast[1]
|
||||||
val resBuffer = DoubleArray(newThis.indices.linearSize) { i ->
|
val resBuffer = DoubleArray(newThis.indices.linearSize) { i ->
|
||||||
@ -85,8 +85,8 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
|
|||||||
return DoubleTensor(newThis.shape, resBuffer)
|
return DoubleTensor(newThis.shape, resBuffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun Tensor<Double>.divAssign(other: StructureND<Double>) {
|
override fun Tensor<Double>.divAssign(arg: StructureND<Double>) {
|
||||||
val newOther = broadcastTo(other.tensor, tensor.shape)
|
val newOther = broadcastTo(arg.tensor, tensor.shape)
|
||||||
for (i in 0 until tensor.indices.linearSize) {
|
for (i in 0 until tensor.indices.linearSize) {
|
||||||
tensor.mutableBuffer.array()[tensor.bufferStart + i] /=
|
tensor.mutableBuffer.array()[tensor.bufferStart + i] /=
|
||||||
newOther.mutableBuffer.array()[tensor.bufferStart + i]
|
newOther.mutableBuffer.array()[tensor.bufferStart + i]
|
||||||
|
@ -314,11 +314,11 @@ public open class DoubleTensorAlgebra :
|
|||||||
return DoubleTensor(shape, resBuffer)
|
return DoubleTensor(shape, resBuffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun StructureND<Double>.div(other: StructureND<Double>): DoubleTensor {
|
override fun StructureND<Double>.div(arg: StructureND<Double>): DoubleTensor {
|
||||||
checkShapesCompatible(tensor, other)
|
checkShapesCompatible(tensor, arg)
|
||||||
val resBuffer = DoubleArray(tensor.numElements) { i ->
|
val resBuffer = DoubleArray(tensor.numElements) { i ->
|
||||||
tensor.mutableBuffer.array()[other.tensor.bufferStart + i] /
|
tensor.mutableBuffer.array()[arg.tensor.bufferStart + i] /
|
||||||
other.tensor.mutableBuffer.array()[other.tensor.bufferStart + i]
|
arg.tensor.mutableBuffer.array()[arg.tensor.bufferStart + i]
|
||||||
}
|
}
|
||||||
return DoubleTensor(tensor.shape, resBuffer)
|
return DoubleTensor(tensor.shape, resBuffer)
|
||||||
}
|
}
|
||||||
@ -329,11 +329,11 @@ public open class DoubleTensorAlgebra :
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun Tensor<Double>.divAssign(other: StructureND<Double>) {
|
override fun Tensor<Double>.divAssign(arg: StructureND<Double>) {
|
||||||
checkShapesCompatible(tensor, other)
|
checkShapesCompatible(tensor, arg)
|
||||||
for (i in 0 until tensor.numElements) {
|
for (i in 0 until tensor.numElements) {
|
||||||
tensor.mutableBuffer.array()[tensor.bufferStart + i] /=
|
tensor.mutableBuffer.array()[tensor.bufferStart + i] /=
|
||||||
other.tensor.mutableBuffer.array()[tensor.bufferStart + i]
|
arg.tensor.mutableBuffer.array()[tensor.bufferStart + i]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user