From bc8ffd02c2c513781f6b72dbea0dc3416b273875 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Mon, 11 Apr 2022 17:27:11 +0300 Subject: [PATCH] Change tensor algebra operations to prefix notations --- .../kscience/kmath/tensors/neuralNetwork.kt | 4 +- .../kscience/kmath/nd4j/Nd4jTensorAlgebra.kt | 35 ++--- .../tensors/api/AnalyticTensorAlgebra.kt | 143 ++++++++++-------- .../kmath/tensors/core/DoubleTensorAlgebra.kt | 28 ++-- .../core/TestDoubleAnalyticTensorAlgebra.kt | 22 +-- 5 files changed, 125 insertions(+), 107 deletions(-) diff --git a/examples/src/main/kotlin/space/kscience/kmath/tensors/neuralNetwork.kt b/examples/src/main/kotlin/space/kscience/kmath/tensors/neuralNetwork.kt index 5c41ab0f1..66044eabe 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/tensors/neuralNetwork.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/tensors/neuralNetwork.kt @@ -48,7 +48,7 @@ fun reluDer(x: DoubleTensor): DoubleTensor = DoubleTensorAlgebra { class ReLU : Activation(::relu, ::reluDer) fun sigmoid(x: DoubleTensor): DoubleTensor = DoubleTensorAlgebra { - 1.0 / (1.0 + (-x).exp()) + 1.0 / (1.0 + exp(-x)) } fun sigmoidDer(x: DoubleTensor): DoubleTensor = DoubleTensorAlgebra { @@ -116,7 +116,7 @@ class NeuralNetwork(private val layers: List) { onesForAnswers[intArrayOf(index, label)] = 1.0 } - val softmaxValue = yPred.exp() / yPred.exp().sum(dim = 1, keepDim = true) + val softmaxValue = exp(yPred) / exp(yPred).sum(dim = 1, keepDim = true) (-onesForAnswers + softmaxValue) / (yPred.shape[0].toDouble()) } diff --git a/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jTensorAlgebra.kt b/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jTensorAlgebra.kt index ceb384f0d..084d8353d 100644 --- a/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jTensorAlgebra.kt +++ b/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jTensorAlgebra.kt @@ -114,27 +114,29 @@ public sealed interface Nd4jTensorAlgebra> : AnalyticTe override fun StructureND.mean(dim: Int, keepDim: Boolean): Nd4jArrayStructure = ndArray.mean(keepDim, dim).wrap() - override fun StructureND.exp(): Nd4jArrayStructure = Transforms.exp(ndArray).wrap() - override fun StructureND.ln(): Nd4jArrayStructure = Transforms.log(ndArray).wrap() override fun StructureND.sqrt(): Nd4jArrayStructure = Transforms.sqrt(ndArray).wrap() - override fun StructureND.cos(): Nd4jArrayStructure = Transforms.cos(ndArray).wrap() - override fun StructureND.acos(): Nd4jArrayStructure = Transforms.acos(ndArray).wrap() - override fun StructureND.cosh(): Nd4jArrayStructure = Transforms.cosh(ndArray).wrap() - override fun StructureND.acosh(): Nd4jArrayStructure = - Nd4j.getExecutioner().exec(ACosh(ndArray, ndArray.ulike())).wrap() + override fun exp(arg: StructureND): Nd4jArrayStructure = Transforms.exp(arg.ndArray).wrap() + override fun ln(arg: StructureND): Nd4jArrayStructure = Transforms.log(arg.ndArray).wrap() + override fun cos(arg: StructureND): Nd4jArrayStructure = Transforms.cos(arg.ndArray).wrap() + override fun acos(arg: StructureND): Nd4jArrayStructure = Transforms.acos(arg.ndArray).wrap() + override fun cosh(arg: StructureND): Nd4jArrayStructure = Transforms.cosh(arg.ndArray).wrap() - override fun StructureND.sin(): Nd4jArrayStructure = Transforms.sin(ndArray).wrap() - override fun StructureND.asin(): Nd4jArrayStructure = Transforms.asin(ndArray).wrap() - override fun StructureND.sinh(): Tensor = Transforms.sinh(ndArray).wrap() + override fun acosh(arg: StructureND): Nd4jArrayStructure = + Nd4j.getExecutioner().exec(ACosh(arg.ndArray, arg.ndArray.ulike())).wrap() - override fun StructureND.asinh(): Nd4jArrayStructure = - Nd4j.getExecutioner().exec(ASinh(ndArray, ndArray.ulike())).wrap() + override fun sin(arg: StructureND): Nd4jArrayStructure = Transforms.sin(arg.ndArray).wrap() + override fun asin(arg: StructureND): Nd4jArrayStructure = Transforms.asin(arg.ndArray).wrap() + override fun sinh(arg: StructureND): Tensor = Transforms.sinh(arg.ndArray).wrap() + + override fun asinh(arg: StructureND): Nd4jArrayStructure = + Nd4j.getExecutioner().exec(ASinh(arg.ndArray, arg.ndArray.ulike())).wrap() + + override fun tan(arg: StructureND): Nd4jArrayStructure = Transforms.tan(arg.ndArray).wrap() + override fun atan(arg: StructureND): Nd4jArrayStructure = Transforms.atan(arg.ndArray).wrap() + override fun tanh(arg: StructureND): Nd4jArrayStructure = Transforms.tanh(arg.ndArray).wrap() + override fun atanh(arg: StructureND): Nd4jArrayStructure = Transforms.atanh(arg.ndArray).wrap() - override fun StructureND.tan(): Nd4jArrayStructure = Transforms.tan(ndArray).wrap() - override fun StructureND.atan(): Nd4jArrayStructure = Transforms.atan(ndArray).wrap() - override fun StructureND.tanh(): Nd4jArrayStructure = Transforms.tanh(ndArray).wrap() - override fun StructureND.atanh(): Nd4jArrayStructure = Transforms.atanh(ndArray).wrap() override fun StructureND.ceil(): Nd4jArrayStructure = Transforms.ceil(ndArray).wrap() override fun StructureND.floor(): Nd4jArrayStructure = Transforms.floor(ndArray).wrap() override fun StructureND.std(dim: Int, keepDim: Boolean): Nd4jArrayStructure = @@ -178,7 +180,6 @@ public object DoubleNd4jTensorAlgebra : Nd4jTensorAlgebra { return array.wrap() } - @OptIn(PerformancePitfall::class) override val StructureND.ndArray: INDArray get() = when (this) { diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/AnalyticTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/AnalyticTensorAlgebra.kt index b32fcf608..5232b1b8c 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/AnalyticTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/AnalyticTensorAlgebra.kt @@ -72,78 +72,95 @@ public interface AnalyticTensorAlgebra> : */ public fun StructureND.variance(dim: Int, keepDim: Boolean): Tensor - //For information: https://pytorch.org/docs/stable/generated/torch.exp.html - public fun StructureND.exp(): Tensor + override fun sin(arg: StructureND): Tensor - //For information: https://pytorch.org/docs/stable/generated/torch.log.html - public fun StructureND.ln(): Tensor + override fun cos(arg: StructureND): Tensor + + override fun asin(arg: StructureND): Tensor + + override fun acos(arg: StructureND): Tensor + + override fun tan(arg: StructureND): Tensor + + override fun atan(arg: StructureND): Tensor + + override fun exp(arg: StructureND): Tensor + + override fun ln(arg: StructureND): Tensor + + override fun sinh(arg: StructureND): Tensor + + override fun cosh(arg: StructureND): Tensor + + override fun tanh(arg: StructureND): Tensor + + override fun asinh(arg: StructureND): Tensor + + override fun acosh(arg: StructureND): Tensor + + override fun atanh(arg: StructureND): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.sqrt.html public fun StructureND.sqrt(): Tensor - //For information: https://pytorch.org/docs/stable/generated/torch.acos.html#torch.cos - public fun StructureND.cos(): Tensor - - //For information: https://pytorch.org/docs/stable/generated/torch.acos.html#torch.acos - public fun StructureND.acos(): Tensor - - //For information: https://pytorch.org/docs/stable/generated/torch.acosh.html#torch.cosh - public fun StructureND.cosh(): Tensor - - //For information: https://pytorch.org/docs/stable/generated/torch.acosh.html#torch.acosh - public fun StructureND.acosh(): Tensor - - //For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.sin - public fun StructureND.sin(): Tensor - - //For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.asin - public fun StructureND.asin(): Tensor - - //For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.sinh - public fun StructureND.sinh(): Tensor - - //For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.asinh - public fun StructureND.asinh(): Tensor - - //For information: https://pytorch.org/docs/stable/generated/torch.atan.html#torch.tan - public fun StructureND.tan(): Tensor - - //https://pytorch.org/docs/stable/generated/torch.atan.html#torch.atan - public fun StructureND.atan(): Tensor - - //For information: https://pytorch.org/docs/stable/generated/torch.atanh.html#torch.tanh - public fun StructureND.tanh(): Tensor - - //For information: https://pytorch.org/docs/stable/generated/torch.atanh.html#torch.atanh - public fun StructureND.atanh(): Tensor - //For information: https://pytorch.org/docs/stable/generated/torch.ceil.html#torch.ceil public fun StructureND.ceil(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.floor.html#torch.floor public fun StructureND.floor(): Tensor - override fun sin(arg: StructureND): Tensor = arg.sin() - - override fun cos(arg: StructureND): Tensor = arg.cos() - - override fun asin(arg: StructureND): Tensor = arg.asin() - - override fun acos(arg: StructureND): Tensor = arg.acos() - - override fun atan(arg: StructureND): Tensor = arg.atan() - - override fun exp(arg: StructureND): Tensor = arg.exp() - - override fun ln(arg: StructureND): Tensor = arg.ln() - - override fun sinh(arg: StructureND): Tensor = arg.sinh() - - override fun cosh(arg: StructureND): Tensor = arg.cosh() - - override fun asinh(arg: StructureND): Tensor = arg.asinh() - - override fun acosh(arg: StructureND): Tensor = arg.acosh() - - override fun atanh(arg: StructureND): Tensor = arg.atanh() +// //For information: https://pytorch.org/docs/stable/generated/torch.exp.html +// public fun StructureND.exp(): Tensor = exp(this) +// +// //For information: https://pytorch.org/docs/stable/generated/torch.log.html +// @JvmName("lnChain") +// public fun StructureND.ln(): Tensor = ln(this) +// +// //For information: https://pytorch.org/docs/stable/generated/torch.acos.html#torch.cos +// @JvmName("cosChain") +// public fun StructureND.cos(): Tensor = cos(this) +// +// //For information: https://pytorch.org/docs/stable/generated/torch.acos.html#torch.acos +// @JvmName("acosChain") +// public fun StructureND.acos(): Tensor = acos(this) +// +// //For information: https://pytorch.org/docs/stable/generated/torch.acosh.html#torch.cosh +// @JvmName("coshChain") +// public fun StructureND.cosh(): Tensor = cosh(this) +// +// //For information: https://pytorch.org/docs/stable/generated/torch.acosh.html#torch.acosh +// @JvmName("acoshChain") +// public fun StructureND.acosh(): Tensor = acosh(this) +// +// //For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.sin +// @JvmName("sinChain") +// public fun StructureND.sin(): Tensor = sin(this) +// +// //For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.asin +// @JvmName("asinChain") +// public fun StructureND.asin(): Tensor = asin(this) +// +// //For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.sinh +// @JvmName("sinhChain") +// public fun StructureND.sinh(): Tensor = sinh(this) +// +// //For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.asinh +// @JvmName("asinhChain") +// public fun StructureND.asinh(): Tensor = asinh(this) +// +// //For information: https://pytorch.org/docs/stable/generated/torch.atan.html#torch.tan +// @JvmName("tanChain") +// public fun StructureND.tan(): Tensor = tan(this) +// +// //https://pytorch.org/docs/stable/generated/torch.atan.html#torch.atan +// @JvmName("atanChain") +// public fun StructureND.atan(): Tensor = atan(this) +// +// //For information: https://pytorch.org/docs/stable/generated/torch.atanh.html#torch.tanh +// @JvmName("tanhChain") +// public fun StructureND.tanh(): Tensor = tanh(this) +// +// //For information: https://pytorch.org/docs/stable/generated/torch.atanh.html#torch.atanh +// @JvmName("atanhChain") +// public fun StructureND.atanh(): Tensor = atanh(this) } \ No newline at end of file diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt index e9dc34748..e722d2b2b 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt @@ -685,35 +685,35 @@ public open class DoubleTensorAlgebra : return resTensor } - override fun StructureND.exp(): DoubleTensor = tensor.map { exp(it) } + override fun exp(arg: StructureND): DoubleTensor = arg.tensor.map { exp(it) } - override fun StructureND.ln(): DoubleTensor = tensor.map { ln(it) } + override fun ln(arg: StructureND): DoubleTensor = arg.tensor.map { ln(it) } override fun StructureND.sqrt(): DoubleTensor = tensor.map { sqrt(it) } - override fun StructureND.cos(): DoubleTensor = tensor.map { cos(it) } + override fun cos(arg: StructureND): DoubleTensor = arg.tensor.map { cos(it) } - override fun StructureND.acos(): DoubleTensor = tensor.map { acos(it) } + override fun acos(arg: StructureND): DoubleTensor = arg.tensor.map { acos(it) } - override fun StructureND.cosh(): DoubleTensor = tensor.map { cosh(it) } + override fun cosh(arg: StructureND): DoubleTensor = arg.tensor.map { cosh(it) } - override fun StructureND.acosh(): DoubleTensor = tensor.map { acosh(it) } + override fun acosh(arg: StructureND): DoubleTensor = arg.tensor.map { acosh(it) } - override fun StructureND.sin(): DoubleTensor = tensor.map { sin(it) } + override fun sin(arg: StructureND): DoubleTensor = arg.tensor.map { sin(it) } - override fun StructureND.asin(): DoubleTensor = tensor.map { asin(it) } + override fun asin(arg: StructureND): DoubleTensor = arg.tensor.map { asin(it) } - override fun StructureND.sinh(): DoubleTensor = tensor.map { sinh(it) } + override fun sinh(arg: StructureND): DoubleTensor = arg.tensor.map { sinh(it) } - override fun StructureND.asinh(): DoubleTensor = tensor.map { asinh(it) } + override fun asinh(arg: StructureND): DoubleTensor = arg.tensor.map { asinh(it) } - override fun StructureND.tan(): DoubleTensor = tensor.map { tan(it) } + override fun tan(arg: StructureND): DoubleTensor = arg.tensor.map { tan(it) } - override fun StructureND.atan(): DoubleTensor = tensor.map { atan(it) } + override fun atan(arg: StructureND): DoubleTensor = arg.tensor.map { atan(it) } - override fun StructureND.tanh(): DoubleTensor = tensor.map { tanh(it) } + override fun tanh(arg: StructureND): DoubleTensor = arg.tensor.map { tanh(it) } - override fun StructureND.atanh(): DoubleTensor = tensor.map { atanh(it) } + override fun atanh(arg: StructureND): DoubleTensor = arg.tensor.map { atanh(it) } override fun StructureND.ceil(): DoubleTensor = tensor.map { ceil(it) } diff --git a/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleAnalyticTensorAlgebra.kt b/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleAnalyticTensorAlgebra.kt index 1e21379b4..71639b5a1 100644 --- a/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleAnalyticTensorAlgebra.kt +++ b/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleAnalyticTensorAlgebra.kt @@ -32,12 +32,12 @@ internal class TestDoubleAnalyticTensorAlgebra { @Test fun testExp() = DoubleTensorAlgebra { - assertTrue { tensor.exp() eq expectedTensor(::exp) } + assertTrue { exp(tensor) eq expectedTensor(::exp) } } @Test fun testLog() = DoubleTensorAlgebra { - assertTrue { tensor.ln() eq expectedTensor(::ln) } + assertTrue { ln(tensor) eq expectedTensor(::ln) } } @Test @@ -47,48 +47,48 @@ internal class TestDoubleAnalyticTensorAlgebra { @Test fun testCos() = DoubleTensorAlgebra { - assertTrue { tensor.cos() eq expectedTensor(::cos) } + assertTrue { cos(tensor) eq expectedTensor(::cos) } } @Test fun testCosh() = DoubleTensorAlgebra { - assertTrue { tensor.cosh() eq expectedTensor(::cosh) } + assertTrue { cosh(tensor) eq expectedTensor(::cosh) } } @Test fun testAcosh() = DoubleTensorAlgebra { - assertTrue { tensor.acosh() eq expectedTensor(::acosh) } + assertTrue { acosh(tensor) eq expectedTensor(::acosh) } } @Test fun testSin() = DoubleTensorAlgebra { - assertTrue { tensor.sin() eq expectedTensor(::sin) } + assertTrue { sin(tensor) eq expectedTensor(::sin) } } @Test fun testSinh() = DoubleTensorAlgebra { - assertTrue { tensor.sinh() eq expectedTensor(::sinh) } + assertTrue { sinh(tensor) eq expectedTensor(::sinh) } } @Test fun testAsinh() = DoubleTensorAlgebra { - assertTrue { tensor.asinh() eq expectedTensor(::asinh) } + assertTrue { asinh(tensor) eq expectedTensor(::asinh) } } @Test fun testTan() = DoubleTensorAlgebra { - assertTrue { tensor.tan() eq expectedTensor(::tan) } + assertTrue { tan(tensor) eq expectedTensor(::tan) } } @Test fun testAtan() = DoubleTensorAlgebra { - assertTrue { tensor.atan() eq expectedTensor(::atan) } + assertTrue { atan(tensor) eq expectedTensor(::atan) } } @Test fun testTanh() = DoubleTensorAlgebra { - assertTrue { tensor.tanh() eq expectedTensor(::tanh) } + assertTrue { tanh(tensor) eq expectedTensor(::tanh) } } @Test -- 2.34.1