diff --git a/src/lib.rs b/src/lib.rs index 3701b70..d3a52bd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,7 +13,7 @@ use num::{ }; use std::{ fmt::{self, Debug, Display, Formatter}, - ops::{Add, Div, Mul, Sub}, + ops::{Add, Div, Mul, Neg, Sub}, result::Result, }; @@ -377,7 +377,7 @@ impl Display for Matrix { impl + Add + Sub + Copy + Zero> Mul for Matrix { // TODO: Implement a faster algorithm. type Output = Self; - fn mul(self, other: Self) -> Self { + fn mul(self, other: Self) -> Self::Output { let width = self.width(); if width != other.height() { panic!("Row length of first matrix must be same as column length of second matrix."); @@ -401,7 +401,7 @@ impl + Add + Sub + Copy + Zero> Mul for Matrix { impl + Sub + Mul + Copy + Zero> Add for Matrix { type Output = Self; - fn add(self, other: Self) -> Self { + fn add(self, other: Self) -> Self::Output { if self.height() == other.height() && self.width() == other.width() { let mut out = self.entries.clone(); for (i, row) in self.rows().iter().enumerate() { @@ -416,17 +416,24 @@ impl + Sub + Mul + Copy + Zero> Add for Matrix { } } -impl + Mul + Copy + Zero> Sub for Matrix { +impl + Mul + Copy + Neg> Neg for Matrix { type Output = Self; - fn sub(self, other: Self) -> Self { - if self.height() == other.height() && self.width() == other.width() { - let mut out = self.entries.clone(); - for (i, row) in self.rows().iter().enumerate() { - for (j, entry) in other.rows()[i].iter().enumerate() { - out[i][j] = row[j] - *entry; - } + fn neg(self) -> Self::Output { + let mut out = self; + for row in &mut out.entries { + for entry in row { + *entry = -*entry; } - Matrix { entries: out } + } + out + } +} + +impl + Mul + Copy + Zero + Neg> Sub for Matrix { + type Output = Self; + fn sub(self, other: Self) -> Self::Output { + if self.height() == other.height() && self.width() == other.width() { + self + -other } else { panic!("Both matrices must be of same dimensions."); } diff --git a/src/tests.rs b/src/tests.rs index 0f4b0d0..608c3ec 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -14,8 +14,10 @@ fn add_sub_test() { let b = Matrix::from(vec![vec![0, 0, 1], vec![2, 1, 3]]).unwrap(); let c = Matrix::from(vec![vec![1, 2, 4], vec![2, 2, 5]]).unwrap(); let d = Matrix::from(vec![vec![1, 2, 2], vec![-2, 0, -1]]).unwrap(); + let e = Matrix::from(vec![vec![-1, -2, -4], vec![-2, -2, -5]]).unwrap(); assert_eq!(a.clone() + b.clone(), c); assert_eq!(a - b, d); + assert_eq!(-c, e); } #[test]