diff --git a/src/lib.rs b/src/lib.rs index 8f441cd..e8813fb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -29,4 +29,12 @@ mod tests { assert_eq!(a.det(), Ok(30)); assert!(b.det().is_err()); } + + #[test] + fn zero_one_test() { + let a = Matrix::from(vec![vec![0, 0, 0], vec![0, 0, 0]]).unwrap(); + let b = Matrix::from(vec![vec![1, 0], vec![0, 1]]).unwrap(); + assert_eq!(Matrix::::zero(2, 3), a); + assert_eq!(Matrix::::identity(2), b); + } } diff --git a/src/matrix.rs b/src/matrix.rs index 19ed212..8251e53 100644 --- a/src/matrix.rs +++ b/src/matrix.rs @@ -1,4 +1,7 @@ -use num::{traits::Zero, Integer}; +use num::{ + traits::{One, Zero}, + Integer, +}; use std::{ fmt::{self, Debug, Display, Formatter}, ops::{Add, Mul, Sub}, @@ -10,7 +13,7 @@ pub struct Matrix { entries: Vec>, } -impl Matrix { +impl Matrix { pub fn from(entries: Vec>) -> Result, &'static str> { let mut equal_rows = true; let row_len = entries[0].len(); @@ -93,6 +96,7 @@ impl Matrix { T: Copy, T: Mul, T: Sub, + T: Zero, { if self.is_square() { let out = if self.width() == 1 { @@ -114,6 +118,41 @@ impl Matrix { Err("Provided matrix isn't square.") } } + + pub fn zero(height: usize, width: usize) -> Self + where + T: Zero, + { + let mut out = Vec::new(); + for _ in 0..height { + let mut new_row = Vec::new(); + for _ in 0..width { + new_row.push(T::zero()); + } + out.push(new_row); + } + Matrix { entries: out } + } + + pub fn identity(size: usize) -> Self + where + T: Zero, + T: One, + { + let mut out = Vec::new(); + for i in 0..size { + let mut new_row = Vec::new(); + for j in 0..size { + if i == j { + new_row.push(T::one()); + } else { + new_row.push(T::zero()); + } + } + out.push(new_row); + } + Matrix { entries: out } + } } impl Display for Matrix {