From f15ac203236355117e72f81159af7a496f5c6706 Mon Sep 17 00:00:00 2001 From: Roland Grinis Date: Mon, 1 Mar 2021 22:52:58 +0000 Subject: [PATCH] Cannot afford to inherit from RingWithNumbers --- .../kmath/tensors/RealTensorAlgebra.kt | 35 +++++++++---------- .../kscience/kmath/tensors/TensorAlgebra.kt | 9 ++--- 2 files changed, 21 insertions(+), 23 deletions(-) diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/RealTensorAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/RealTensorAlgebra.kt index 0a28ace4f..6023d2b72 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/RealTensorAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/RealTensorAlgebra.kt @@ -19,25 +19,6 @@ public class RealTensor( public class RealTensorAlgebra : TensorPartialDivisionAlgebra { - override fun add(a: RealTensor, b: RealTensor): RealTensor { - TODO("Not yet implemented") - } - - override fun multiply(a: RealTensor, k: Number): RealTensor { - TODO("Not yet implemented") - } - - override val zero: RealTensor - get() = TODO("Not yet implemented") - - override fun multiply(a: RealTensor, b: RealTensor): RealTensor { - TODO("Not yet implemented") - } - - override val one: RealTensor - get() = TODO("Not yet implemented") - - override fun Double.plus(other: RealTensor): RealTensor { val n = other.buffer.size val arr = other.buffer.array @@ -51,6 +32,10 @@ public class RealTensorAlgebra : TensorPartialDivisionAlgebra>: RingWithNumbers { +public interface TensorAlgebra>{ public operator fun T.plus(other: TensorType): TensorType public operator fun TensorType.plus(value: T): TensorType + public operator fun TensorType.plus(other: TensorType): TensorType public operator fun TensorType.plusAssign(value: T): Unit public operator fun TensorType.plusAssign(other: TensorType): Unit public operator fun T.minus(other: TensorType): TensorType public operator fun TensorType.minus(value: T): TensorType + public operator fun TensorType.minus(other: TensorType): TensorType public operator fun TensorType.minusAssign(value: T): Unit public operator fun TensorType.minusAssign(other: TensorType): Unit public operator fun T.times(other: TensorType): TensorType public operator fun TensorType.times(value: T): TensorType + public operator fun TensorType.times(other: TensorType): TensorType public operator fun TensorType.timesAssign(value: T): Unit public operator fun TensorType.timesAssign(other: TensorType): Unit + public operator fun TensorType.unaryMinus(): TensorType //https://pytorch.org/docs/stable/generated/torch.matmul.html