new: Added diagonal_matrix method

This commit is contained in:
Sayantan Santra 2023-05-27 01:28:01 -05:00
parent 4a26f0cf71
commit 56222e04f1
Signed by: SinTan1729
GPG key ID: EB3E68BFBA25C85F
2 changed files with 31 additions and 20 deletions

View file

@ -312,19 +312,11 @@ impl<T: Mul<Output = T> + Add<Output = T> + Sub<Output = T> + Zero + Copy> Matri
where where
T: One, T: One,
{ {
let mut out = Vec::new(); let mut out = Matrix::zero(size, size);
for i in 0..size { for (i, row) in out.entries.iter_mut().enumerate() {
let mut new_row = Vec::new(); row[i] = T::one();
for j in 0..size {
if i == j {
new_row.push(T::one());
} else {
new_row.push(T::zero());
} }
} out
out.push(new_row);
}
Matrix { entries: out }
} }
/// Returns the trace of a square matrix. /// Returns the trace of a square matrix.
@ -333,7 +325,7 @@ impl<T: Mul<Output = T> + Add<Output = T> + Sub<Output = T> + Zero + Copy> Matri
/// ``` /// ```
/// use matrix_basic::Matrix; /// use matrix_basic::Matrix;
/// let m = Matrix::from(vec![vec![1,2], vec![3,4]]).unwrap(); /// let m = Matrix::from(vec![vec![1,2], vec![3,4]]).unwrap();
/// assert_eq!(m.det(),Ok(-2)); /// assert_eq!(m.trace(),Ok(5));
/// ``` /// ```
pub fn trace(self) -> Result<T, &'static str> { pub fn trace(self) -> Result<T, &'static str> {
if self.is_square() { if self.is_square() {
@ -347,6 +339,24 @@ impl<T: Mul<Output = T> + Add<Output = T> + Sub<Output = T> + Zero + Copy> Matri
} }
} }
/// Returns a diagonal matrix with a given diagonal.
/// # Example
/// ```
/// use matrix_basic::Matrix;
/// let m = Matrix::diagonal_matrix(vec![1,2,3]);
/// let n = Matrix::from(vec![vec![1,0,0], vec![0,2,0], vec![0,0,3]]).unwrap();
///
/// assert_eq!(m,n);
/// ```
pub fn diagonal_matrix(diag: Vec<T>) -> Self {
let size = diag.len();
let mut out = Matrix::zero(size, size);
for (i, row) in out.entries.iter_mut().enumerate() {
row[i] = diag[i];
}
out
}
// TODO: Canonical forms, eigenvalues, eigenvectors etc. // TODO: Canonical forms, eigenvalues, eigenvectors etc.
} }

View file

@ -41,12 +41,13 @@ fn det_trace_test() {
} }
#[test] #[test]
fn zero_one_test() { fn zero_one_diag_test() {
let a = Matrix::from(vec![vec![0, 0, 0], vec![0, 0, 0]]).unwrap(); let a = Matrix::from(vec![vec![0, 0, 0], vec![0, 0, 0]]).unwrap();
let b = Matrix::from(vec![vec![1, 0], vec![0, 1]]).unwrap(); let b = Matrix::from(vec![vec![1, 0], vec![0, 1]]).unwrap();
assert_eq!(Matrix::<i32>::zero(2, 3), a); assert_eq!(Matrix::<i32>::zero(2, 3), a);
assert_eq!(Matrix::<i32>::identity(2), b); assert_eq!(Matrix::<i32>::identity(2), b);
assert_eq!(Matrix::diagonal_matrix(vec![1, 1]), b);
} }
#[test] #[test]