diff --git a/Cargo.toml b/Cargo.toml index a14c311..fb8d6d6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,3 +6,4 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +num = "0.4.0" diff --git a/src/lib.rs b/src/lib.rs index 5ebe3ea..5a1062c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -19,4 +19,12 @@ mod tests { let c = Matrix::from(vec![vec![1, 2, 4], vec![2, 2, 5]]).unwrap(); assert_eq!(a + b, c); } + + #[test] + fn det_test() { + let a = Matrix::from(vec![vec![1, 2, 0], vec![0, 3, 5], vec![0, 0, 10]]).unwrap(); + let b = Matrix::from(vec![vec![1, 2, 0], vec![0, 3, 5]]).unwrap(); + assert_eq!(a.det(), Ok(30)); + assert!(b.det().is_err()); + } } diff --git a/src/matrix.rs b/src/matrix.rs index 562d82d..0b8dc19 100644 --- a/src/matrix.rs +++ b/src/matrix.rs @@ -1,15 +1,16 @@ +use num::{traits::Zero, Integer}; use std::{ fmt::{self, Debug, Display, Formatter}, - ops::{Add, Mul}, + ops::{Add, Mul, Sub}, result::Result, }; #[derive(PartialEq, Debug)] -pub struct Matrix { +pub struct Matrix { entries: Vec>, } -impl Matrix { +impl Matrix { pub fn from(entries: Vec>) -> Result, &'static str> { let mut equal_rows = true; let row_len = entries[0].len(); @@ -34,7 +35,7 @@ impl Matrix { self.entries[0].len() } - pub fn transpose(&self) -> Matrix + pub fn transpose(&self) -> Self where T: Copy, { @@ -62,15 +63,66 @@ impl Matrix { { self.transpose().entries } + + pub fn is_square(&self) -> bool { + self.height() == self.width() + } + + pub fn submatrix(&self, i: usize, j: usize) -> Self + where + T: Copy, + { + let mut out = Vec::new(); + for (m, row) in self.rows().iter().enumerate() { + if m == i { + continue; + } + let mut new_row = Vec::new(); + for (n, entry) in row.iter().enumerate() { + if n != j { + new_row.push(*entry); + } + } + out.push(new_row); + } + Matrix { entries: out } + } + + pub fn det(&self) -> Result + where + T: Copy, + T: Mul, + T: Sub, + { + if self.is_square() { + let out = if self.width() == 1 { + self.entries[0][0] + } else { + let n = 0..self.width(); + let mut out = T::zero(); + for i in n { + if i.is_even() { + out = out + (self.entries[0][i] * self.submatrix(0, i).det().unwrap()); + } else { + out = out - (self.entries[0][i] * self.submatrix(0, i).det().unwrap()); + } + } + out + }; + Ok(out) + } else { + Err("Provided matrix isn't square.") + } + } } -impl Display for Matrix { +impl Display for Matrix { fn fmt(&self, f: &mut Formatter) -> fmt::Result { write!(f, "{:?}", self.entries) } } -impl + Add + Copy> Mul for Matrix { +impl + Add + Sub + Copy + Zero> Mul for Matrix { type Output = Self; fn mul(self, other: Self) -> Self { let width = self.width(); @@ -94,7 +146,7 @@ impl + Add + Copy> Mul for Matrix { } } -impl + Mul + Copy> Add for Matrix { +impl + Sub + Mul + Copy + Zero> Add for Matrix { type Output = Self; fn add(self, other: Self) -> Self { if self.height() == other.height() && self.width() == other.width() {