matrix-basic/src/lib.rs

486 lines
16 KiB
Rust
Raw Normal View History

2023-05-25 20:20:31 -05:00
//! This is a crate for very basic matrix operations
2023-05-27 01:02:06 -05:00
//! with any type that implement [`Add`], [`Sub`], [`Mul`],
//! [`Zero`] and [`Copy`]. Additional properties might be
2023-05-25 20:20:31 -05:00
//! needed for certain operations.
2023-05-25 20:40:09 -05:00
//! I created it mostly to learn using generic types
//! and traits.
2023-05-25 20:20:31 -05:00
//!
//! Sayantan Santra (2023)
2023-05-25 20:28:36 -05:00
use num::{
traits::{One, Zero},
Integer,
};
use std::{
fmt::{self, Debug, Display, Formatter},
2023-05-26 01:18:52 -05:00
ops::{Add, Div, Mul, Neg, Sub},
2023-05-25 20:28:36 -05:00
result::Result,
};
mod tests;
/// A generic matrix struct (over any type with addition, substraction
/// and multiplication defined on it).
/// Look at [`from`](Self::from()) to see examples.
#[derive(PartialEq, Debug, Clone)]
2023-05-27 01:02:06 -05:00
pub struct Matrix<T: Mul<Output = T> + Add<Output = T> + Sub<Output = T> + Zero + Copy> {
2023-05-25 20:28:36 -05:00
entries: Vec<Vec<T>>,
}
2023-05-27 01:02:06 -05:00
impl<T: Mul<Output = T> + Add<Output = T> + Sub<Output = T> + Zero + Copy> Matrix<T> {
2023-05-25 20:28:36 -05:00
/// Creates a matrix from given 2D "array" in a `Vec<Vec<T>>` form.
2023-05-25 20:40:09 -05:00
/// It'll throw an error if all the given rows aren't of the same size.
2023-05-25 20:28:36 -05:00
/// # Example
/// ```
2023-05-25 21:06:16 -05:00
/// use matrix_basic::Matrix;
2023-05-25 20:28:36 -05:00
/// let m = Matrix::from(vec![vec![1,2,3], vec![4,5,6]]);
/// ```
/// will create the following matrix:
/// ⌈1,2,3⌉
/// ⌊4,5,6⌋
pub fn from(entries: Vec<Vec<T>>) -> Result<Matrix<T>, &'static str> {
let mut equal_rows = true;
let row_len = entries[0].len();
for row in &entries {
if row_len != row.len() {
equal_rows = false;
break;
}
}
if equal_rows {
Ok(Matrix { entries })
} else {
Err("Unequal rows.")
}
}
2023-05-26 17:33:51 -05:00
/// Returns the height of a matrix.
2023-05-25 20:28:36 -05:00
pub fn height(&self) -> usize {
self.entries.len()
}
2023-05-26 17:33:51 -05:00
/// Returns the width of a matrix.
2023-05-25 20:28:36 -05:00
pub fn width(&self) -> usize {
self.entries[0].len()
}
2023-05-26 17:33:51 -05:00
/// Returns the transpose of a matrix.
2023-05-27 01:02:06 -05:00
pub fn transpose(&self) -> Self {
2023-05-25 20:28:36 -05:00
let mut out = Vec::new();
for i in 0..self.width() {
let mut column = Vec::new();
for row in &self.entries {
column.push(row[i]);
}
out.push(column)
}
Matrix { entries: out }
}
2023-05-26 17:33:51 -05:00
/// Returns a reference to the rows of a matrix as `&Vec<Vec<T>>`.
2023-05-25 20:28:36 -05:00
pub fn rows(&self) -> &Vec<Vec<T>> {
&self.entries
}
/// Return the columns of a matrix as `Vec<Vec<T>>`.
2023-05-27 01:02:06 -05:00
pub fn columns(&self) -> Vec<Vec<T>> {
2023-05-25 20:28:36 -05:00
self.transpose().entries
}
/// Return true if a matrix is square and false otherwise.
pub fn is_square(&self) -> bool {
self.height() == self.width()
}
2023-05-26 17:33:51 -05:00
/// Returns a matrix after removing the provided row and column from it.
2023-05-25 20:28:36 -05:00
/// Note: Row and column numbers are 0-indexed.
/// # Example
/// ```
2023-05-25 21:06:16 -05:00
/// use matrix_basic::Matrix;
2023-05-25 20:28:36 -05:00
/// let m = Matrix::from(vec![vec![1,2,3],vec![4,5,6]]).unwrap();
/// let n = Matrix::from(vec![vec![5,6]]).unwrap();
/// assert_eq!(m.submatrix(0,0),n);
/// ```
2023-05-27 01:02:06 -05:00
pub fn submatrix(&self, row: usize, col: usize) -> Self {
2023-05-25 20:28:36 -05:00
let mut out = Vec::new();
for (m, row_iter) in self.entries.iter().enumerate() {
if m == row {
continue;
}
let mut new_row = Vec::new();
for (n, entry) in row_iter.iter().enumerate() {
if n != col {
new_row.push(*entry);
}
}
out.push(new_row);
}
Matrix { entries: out }
}
2023-05-27 01:09:52 -05:00
/// Returns the determinant of a square matrix.
/// This uses basic recursive algorithm using cofactor-minor.
2023-05-25 22:59:01 -05:00
/// See [`det_in_field`](Self::det_in_field()) for faster determinant calculation in fields.
2023-05-25 20:40:09 -05:00
/// It'll throw an error if the provided matrix isn't square.
/// # Example
/// ```
2023-05-25 21:06:16 -05:00
/// use matrix_basic::Matrix;
2023-05-25 20:40:09 -05:00
/// let m = Matrix::from(vec![vec![1,2],vec![3,4]]).unwrap();
/// assert_eq!(m.det(),Ok(-2));
/// ```
2023-05-27 01:02:06 -05:00
pub fn det(&self) -> Result<T, &'static str> {
2023-05-25 20:28:36 -05:00
if self.is_square() {
2023-05-25 20:40:09 -05:00
// It's a recursive algorithm using minors.
2023-05-25 22:59:01 -05:00
// TODO: Implement a faster algorithm.
2023-05-25 20:28:36 -05:00
let out = if self.width() == 1 {
self.entries[0][0]
} else {
2023-05-25 20:40:09 -05:00
// Add the minors multiplied by cofactors.
2023-05-25 20:28:36 -05:00
let n = 0..self.width();
let mut out = T::zero();
for i in n {
if i.is_even() {
out = out + (self.entries[0][i] * self.submatrix(0, i).det().unwrap());
} else {
out = out - (self.entries[0][i] * self.submatrix(0, i).det().unwrap());
}
}
out
};
Ok(out)
} else {
Err("Provided matrix isn't square.")
}
}
2023-05-26 17:33:51 -05:00
/// Returns the determinant of a square matrix over a field i.e. needs [`One`] and [`Div`] traits.
2023-05-25 22:59:01 -05:00
/// See [`det`](Self::det()) for determinants in rings.
/// This method uses row reduction as is much faster.
/// It'll throw an error if the provided matrix isn't square.
/// # Example
/// ```
/// use matrix_basic::Matrix;
2023-05-26 00:06:41 -05:00
/// let m = Matrix::from(vec![vec![1.0,2.0],vec![3.0,4.0]]).unwrap();
/// assert_eq!(m.det(),Ok(-2.0));
2023-05-25 22:59:01 -05:00
/// ```
pub fn det_in_field(&self) -> Result<T, &'static str>
where
T: One,
T: PartialEq,
T: Div<Output = T>,
{
if self.is_square() {
// Cloning is necessary as we'll be doing row operations on it.
let mut rows = self.entries.clone();
let mut multiplier = T::one();
2023-05-26 00:43:33 -05:00
let h = self.height();
let w = self.width();
for i in 0..h {
2023-05-25 22:59:01 -05:00
// First check if the row has diagonal element 0, if yes, then swap.
if rows[i][i] == T::zero() {
let mut zero_column = true;
2023-05-26 00:43:33 -05:00
for j in (i + 1)..h {
2023-05-25 22:59:01 -05:00
if rows[j][i] != T::zero() {
rows.swap(i, j);
multiplier = T::zero() - multiplier;
zero_column = false;
break;
}
}
if zero_column {
return Ok(T::zero());
}
}
2023-05-26 00:43:33 -05:00
for j in (i + 1)..h {
2023-05-26 00:06:41 -05:00
let ratio = rows[j][i] / rows[i][i];
2023-05-26 00:43:33 -05:00
for k in i..w {
2023-05-26 00:06:41 -05:00
rows[j][k] = rows[j][k] - rows[i][k] * ratio;
2023-05-25 22:59:01 -05:00
}
}
}
for (i, row) in rows.iter().enumerate() {
multiplier = multiplier * row[i];
}
Ok(multiplier)
} else {
Err("Provided matrix isn't square.")
}
}
2023-05-27 01:09:52 -05:00
/// Returns the row echelon form of a matrix over a field i.e. needs the [`Div`] trait.
2023-05-26 00:06:41 -05:00
/// # Example
/// ```
/// use matrix_basic::Matrix;
/// let m = Matrix::from(vec![vec![1.0,2.0,3.0],vec![3.0,4.0,5.0]]).unwrap();
/// let n = Matrix::from(vec![vec![1.0,2.0,3.0], vec![0.0,-2.0,-4.0]]).unwrap();
/// assert_eq!(m.row_echelon(),n);
/// ```
pub fn row_echelon(&self) -> Self
where
T: PartialEq,
T: Div<Output = T>,
{
// Cloning is necessary as we'll be doing row operations on it.
let mut rows = self.entries.clone();
let mut offset = 0;
2023-05-26 00:43:33 -05:00
let h = self.height();
let w = self.width();
for i in 0..h {
2023-05-26 00:06:41 -05:00
// Check if all the rows below are 0
if i + offset >= self.width() {
break;
}
// First check if the row has diagonal element 0, if yes, then swap.
if rows[i][i + offset] == T::zero() {
let mut zero_column = true;
2023-05-26 00:43:33 -05:00
for j in (i + 1)..h {
2023-05-26 00:06:41 -05:00
if rows[j][i + offset] != T::zero() {
rows.swap(i, j);
zero_column = false;
break;
}
}
if zero_column {
offset += 1;
}
}
2023-05-26 00:43:33 -05:00
for j in (i + 1)..h {
2023-05-26 00:06:41 -05:00
let ratio = rows[j][i + offset] / rows[i][i + offset];
2023-05-26 00:43:33 -05:00
for k in (i + offset)..w {
2023-05-26 00:06:41 -05:00
rows[j][k] = rows[j][k] - rows[i][k] * ratio;
}
}
}
Matrix { entries: rows }
}
2023-05-27 01:09:52 -05:00
/// Returns the column echelon form of a matrix over a field i.e. needs the [`Div`] trait.
2023-05-26 00:43:33 -05:00
/// It's just the transpose of the row echelon form of the transpose.
/// See [`row_echelon`](Self::row_echelon()) and [`transpose`](Self::transpose()).
pub fn column_echelon(&self) -> Self
where
T: PartialEq,
T: Div<Output = T>,
{
self.transpose().row_echelon().transpose()
}
2023-05-27 01:09:52 -05:00
/// Returns the reduced row echelon form of a matrix over a field i.e. needs the `Div`] trait.
2023-05-26 00:43:33 -05:00
/// # Example
/// ```
/// use matrix_basic::Matrix;
/// let m = Matrix::from(vec![vec![1.0,2.0,3.0],vec![3.0,4.0,5.0]]).unwrap();
/// let n = Matrix::from(vec![vec![1.0,2.0,3.0], vec![0.0,1.0,2.0]]).unwrap();
/// assert_eq!(m.reduced_row_echelon(),n);
/// ```
pub fn reduced_row_echelon(&self) -> Self
where
T: PartialEq,
T: Div<Output = T>,
{
let mut echelon = self.row_echelon();
let mut offset = 0;
for row in &mut echelon.entries {
while row[offset] == T::zero() {
offset += 1;
}
let divisor = row[offset];
for entry in row.iter_mut().skip(offset) {
*entry = *entry / divisor;
}
offset += 1;
}
echelon
}
2023-05-25 20:28:36 -05:00
/// Creates a zero matrix of a given size.
2023-05-27 01:02:06 -05:00
pub fn zero(height: usize, width: usize) -> Self {
2023-05-25 20:28:36 -05:00
let mut out = Vec::new();
for _ in 0..height {
let mut new_row = Vec::new();
for _ in 0..width {
new_row.push(T::zero());
}
out.push(new_row);
}
Matrix { entries: out }
}
/// Creates an identity matrix of a given size.
2023-05-27 01:09:52 -05:00
/// It needs the [`One`] trait.
2023-05-25 20:28:36 -05:00
pub fn identity(size: usize) -> Self
where
T: One,
{
let mut out = Vec::new();
for i in 0..size {
let mut new_row = Vec::new();
for j in 0..size {
if i == j {
new_row.push(T::one());
} else {
new_row.push(T::zero());
}
}
out.push(new_row);
}
Matrix { entries: out }
}
2023-05-26 00:43:33 -05:00
2023-05-27 01:09:52 -05:00
/// Returns the trace of a square matrix.
/// It'll throw an error if the provided matrix isn't square.
/// # Example
/// ```
/// use matrix_basic::Matrix;
/// let m = Matrix::from(vec![vec![1,2],vec![3,4]]).unwrap();
/// assert_eq!(m.det(),Ok(-2));
/// ```
pub fn trace(self) -> Result<T, &'static str> {
if self.is_square() {
let mut out = self.entries[0][0];
for i in 1..self.height() {
out = out + self.entries[i][i];
}
Ok(out)
} else {
Err("Provided matrix isn't square.")
}
}
2023-05-26 00:43:33 -05:00
// TODO: Canonical forms, eigenvalues, eigenvectors etc.
2023-05-25 20:28:36 -05:00
}
2023-05-27 01:02:06 -05:00
impl<T: Debug + Mul<Output = T> + Add<Output = T> + Sub<Output = T> + Zero + Copy> Display
for Matrix<T>
{
2023-05-25 20:28:36 -05:00
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(f, "{:?}", self.entries)
}
}
2023-05-27 01:02:06 -05:00
impl<T: Mul<Output = T> + Add<Output = T> + Sub<Output = T> + Zero + Copy + Copy + Zero> Mul
for Matrix<T>
{
2023-05-25 22:59:01 -05:00
// TODO: Implement a faster algorithm.
2023-05-25 20:28:36 -05:00
type Output = Self;
2023-05-26 01:18:52 -05:00
fn mul(self, other: Self) -> Self::Output {
2023-05-25 20:28:36 -05:00
let width = self.width();
if width != other.height() {
panic!("Row length of first matrix must be same as column length of second matrix.");
} else {
let mut out = Vec::new();
for row in self.rows() {
let mut new_row = Vec::new();
for col in other.columns() {
let mut prod = row[0] * col[0];
for i in 1..width {
prod = prod + (row[i] * col[i]);
}
new_row.push(prod)
}
out.push(new_row);
}
Matrix { entries: out }
}
}
}
2023-05-27 01:02:06 -05:00
impl<T: Mul<Output = T> + Add<Output = T> + Sub<Output = T> + Zero + Copy + Copy + Zero> Add
for Matrix<T>
{
2023-05-25 20:28:36 -05:00
type Output = Self;
2023-05-26 01:18:52 -05:00
fn add(self, other: Self) -> Self::Output {
2023-05-25 20:28:36 -05:00
if self.height() == other.height() && self.width() == other.width() {
let mut out = self.entries.clone();
for (i, row) in self.rows().iter().enumerate() {
for (j, entry) in other.rows()[i].iter().enumerate() {
out[i][j] = row[j] + *entry;
}
}
Matrix { entries: out }
} else {
panic!("Both matrices must be of same dimensions.");
}
}
}
2023-05-27 01:02:06 -05:00
impl<
T: Mul<Output = T> + Add<Output = T> + Sub<Output = T> + Zero + Copy + Copy + Neg<Output = T>,
> Neg for Matrix<T>
{
2023-05-25 20:28:36 -05:00
type Output = Self;
2023-05-26 01:18:52 -05:00
fn neg(self) -> Self::Output {
let mut out = self;
for row in &mut out.entries {
for entry in row {
*entry = -*entry;
2023-05-25 20:28:36 -05:00
}
2023-05-26 01:18:52 -05:00
}
out
}
}
2023-05-27 01:02:06 -05:00
impl<
T: Mul<Output = T>
+ Add<Output = T>
+ Sub<Output = T>
+ Zero
+ Copy
+ Copy
+ Zero
+ Neg<Output = T>,
> Sub for Matrix<T>
{
2023-05-26 01:18:52 -05:00
type Output = Self;
fn sub(self, other: Self) -> Self::Output {
if self.height() == other.height() && self.width() == other.width() {
self + -other
2023-05-25 20:28:36 -05:00
} else {
panic!("Both matrices must be of same dimensions.");
}
2023-05-25 20:07:10 -05:00
}
2023-05-24 21:46:25 -05:00
}
2023-05-27 00:44:36 -05:00
/// Trait for conversion between matrices of different types.
/// It only has a `convert_to()` method.
/// This is needed since negative trait bound are not supported in stable Rust
/// yet, so we'll have a conflict trying to implement [`From`].
/// I plan to change this to the default From trait as soon as some sort
/// of specialization system is implemented.
/// You can track this issue [here](https://github.com/rust-lang/rust/issues/42721).
2023-05-27 01:02:06 -05:00
pub trait MatrixInto<T: Mul<Output = T> + Add<Output = T> + Sub<Output = T> + Zero + Copy> {
2023-05-27 00:44:36 -05:00
/// Method for converting a matrix into a matrix of type `Matrix<T>`
fn matrix_into(self) -> Matrix<T>;
}
/// Blanket implementation of MatrixInto for converting `Matrix<S>` to `Matrix<T>` whenever
/// `S` implements `Into<T>`.
/// # Example
/// ```
/// use matrix_basic::Matrix;
/// use matrix_basic::MatrixInto;
///
/// let a = Matrix::from(vec![vec![1, 2, 3], vec![0, 1, 2]]).unwrap();
/// let b = Matrix::from(vec![vec![1.0, 2.0, 3.0], vec![0.0, 1.0, 2.0]]).unwrap();
/// let c: Matrix<f64> = a.matrix_into();
///
/// assert_eq!(c, b);
/// ```
2023-05-27 01:02:06 -05:00
impl<
T: Mul<Output = T> + Add<Output = T> + Sub<Output = T> + Zero + Copy,
S: Mul<Output = S> + Add<Output = S> + Sub<Output = S> + Zero + Copy + Into<T>,
> MatrixInto<T> for Matrix<S>
{
2023-05-27 00:44:36 -05:00
fn matrix_into(self) -> Matrix<T> {
let mut out = Vec::new();
for row in self.entries {
let mut new_row: Vec<T> = Vec::new();
for entry in row {
new_row.push(entry.into());
}
out.push(new_row)
}
Matrix { entries: out }
}
}