commit f0c10ec7d0e1e97594d895dce0c914c1607f652a Author: SinTan1729 Date: Wed May 24 21:46:25 2023 -0500 Initial commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..4fffb2f --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +/target +/Cargo.lock diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..a14c311 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,8 @@ +[package] +name = "matrix" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..5ebe3ea --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,22 @@ +pub mod matrix; + +#[cfg(test)] +mod tests { + use super::*; + use matrix::Matrix; + #[test] + fn mul_test() { + let a = Matrix::from(vec![vec![1, 2, 4], vec![3, 4, 9]]).unwrap(); + let b = Matrix::from(vec![vec![1, 2], vec![2, 3], vec![5, 1]]).unwrap(); + let c = Matrix::from(vec![vec![25, 12], vec![56, 27]]).unwrap(); + assert_eq!(a * b, c); + } + + #[test] + fn add_test() { + let a = Matrix::from(vec![vec![1, 2, 3], vec![0, 1, 2]]).unwrap(); + let b = Matrix::from(vec![vec![0, 0, 1], vec![2, 1, 3]]).unwrap(); + let c = Matrix::from(vec![vec![1, 2, 4], vec![2, 2, 5]]).unwrap(); + assert_eq!(a + b, c); + } +} diff --git a/src/matrix.rs b/src/matrix.rs new file mode 100644 index 0000000..562d82d --- /dev/null +++ b/src/matrix.rs @@ -0,0 +1,112 @@ +use std::{ + fmt::{self, Debug, Display, Formatter}, + ops::{Add, Mul}, + result::Result, +}; + +#[derive(PartialEq, Debug)] +pub struct Matrix { + entries: Vec>, +} + +impl Matrix { + pub fn from(entries: Vec>) -> Result, &'static str> { + let mut equal_rows = true; + let row_len = entries[0].len(); + for row in &entries { + if row_len != row.len() { + equal_rows = false; + break; + } + } + if equal_rows { + Ok(Matrix { entries }) + } else { + Err("Unequal rows.") + } + } + + pub fn height(&self) -> usize { + self.entries.len() + } + + pub fn width(&self) -> usize { + self.entries[0].len() + } + + pub fn transpose(&self) -> Matrix + where + T: Copy, + { + let mut out = Vec::new(); + for i in 0..self.width() { + let mut column = Vec::new(); + for row in &self.entries { + column.push(row[i]); + } + out.push(column) + } + Matrix { entries: out } + } + + pub fn rows(&self) -> Vec> + where + T: Copy, + { + self.entries.clone() + } + + pub fn columns(&self) -> Vec> + where + T: Copy, + { + self.transpose().entries + } +} + +impl Display for Matrix { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "{:?}", self.entries) + } +} + +impl + Add + Copy> Mul for Matrix { + type Output = Self; + fn mul(self, other: Self) -> Self { + let width = self.width(); + if width != other.height() { + panic!("Row length of first matrix must be same as column length of second matrix."); + } else { + let mut out = Vec::new(); + for row in self.rows() { + let mut new_row = Vec::new(); + for col in other.columns() { + let mut prod = row[0] * col[0]; + for i in 1..width { + prod = prod + (row[i] * col[i]); + } + new_row.push(prod) + } + out.push(new_row); + } + Matrix { entries: out } + } + } +} + +impl + Mul + Copy> Add for Matrix { + type Output = Self; + fn add(self, other: Self) -> Self { + if self.height() == other.height() && self.width() == other.width() { + let mut out = self.entries.clone(); + for (i, row) in self.rows().iter().enumerate() { + for (j, entry) in other.rows()[i].iter().enumerate() { + out[i][j] = row[j] + *entry; + } + } + Matrix { entries: out } + } else { + panic!("Both matrices must be of same dimensions."); + } + } +}