diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Matrix.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Matrix.kt index 243eb2d6c..905263534 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Matrix.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Matrix.kt @@ -1,5 +1,6 @@ package scientifik.kmath.linear +import scientifik.kmath.histogram.Point import scientifik.kmath.operations.* import scientifik.kmath.structures.* @@ -37,19 +38,6 @@ abstract class MatrixSpace>(val rows: Int, val columns: Int return produce { i, j -> with(field) { a[i, j] * k } } } - /** - * Dot product. Throws exception on dimension mismatch - */ - fun multiply(a: Matrix, b: Matrix): Matrix { - if (a.rows != b.columns) { - //TODO replace by specific exception - error("Dimension mismatch in linear structure dot product: [${a.rows},${a.columns}]*[${b.rows},${b.columns}]") - } - return produceSpace(a.rows, b.columns).produce { i, j -> - (0 until a.columns).asSequence().map { k -> field.multiply(a[i, k], b[k, j]) }.reduce { first, second -> field.add(first, second) } - } - } - override fun equals(other: Any?): Boolean { if (this === other) return true if (other !is MatrixSpace<*,*>) return false @@ -69,8 +57,6 @@ abstract class MatrixSpace>(val rows: Int, val columns: Int } } -infix fun > Matrix.dot(b: Matrix): Matrix = this.context.multiply(this, b) - /** * A matrix-like structure */ @@ -138,6 +124,31 @@ interface Matrix> : SpaceElement, MatrixSpace> Matrix.dot(b: Matrix): Matrix { + if (columns != b.rows) { + //TODO replace by specific exception + error("Dimension mismatch in linear structure dot product: [$rows,$columns]*[${b.rows},${b.columns}]") + } + return context.produceSpace(rows, b.columns).produce { i, j -> + (0 until columns).asSequence().map { k -> context.field.multiply(this[i, k], b[k, j]) }.reduce { first, second -> context.field.add(first, second) } + } +} + +/** + * Matrix x Vector dot product. + */ +infix fun > Matrix.dot(b: Point): Matrix { + if (columns != b.size) { + //TODO replace by specific exception + error("Dimension mismatch in linear structure dot product: [$rows,$columns]*[${b.size},1]") + } + return context.produceSpace(rows, 1).produce { i, j -> + (0 until columns).asSequence().map { k -> context.field.multiply(this[i, k], b[k]) }.reduce { first, second -> context.field.add(first, second) } + } +}