1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
use crate::Error;
use ndarray::{Array1, Array2, Axis};
pub fn solve_by_thomas(matrix: &Array2<f64>, b: &Array1<f64>) -> Result<Vec<f64>, Error> {
if !matrix.is_square() || matrix.len_of(Axis(0)) != b.len() {
return Err(Error::WrongDims)
}
let mut solution = vec![0_f64; b.len()];
let mut c = Array1::from_elem(b.len() - 1, 0_f64);
let mut d = Array1::from_elem(b.len(), 0_f64);
c[0] = matrix[[0, 1]] / matrix[[0, 0]];
d[0] = b[0] / matrix[[0, 0]];
for i in 1..b.len() - 1 {
c[i] = matrix[[i, i + 1]] / (matrix[[i, i]] - matrix[[i, i - 1]] * c[i - 1]);
d[i] = (b[i] - matrix[[i, i - 1]] * d[i - 1])
/ (matrix[[i, i]] - matrix[[i, i - 1]] * c[i - 1]);
}
d[b.len() - 1] = (b[b.len() - 1] - matrix[[b.len() - 1, b.len() - 2]] * d[b.len() - 2])
/ (matrix[[b.len() - 1, b.len() - 1]]
- matrix[[b.len() - 1, b.len() - 2]] * c[b.len() - 2]);
solution[b.len() - 1] = d[b.len() - 1];
for i in (0..b.len() - 1).rev() {
solution[i] = d[i] - c[i] * solution[i + 1];
}
Ok(solution)
}
#[cfg(test)]
mod test {
use ndarray::{Array2, Array1};
use super::solve_by_thomas;
#[test]
fn solve_3x3() {
let matrix: Array2<f64> = Array2::from(vec![[1.,2.,0.],[1.,1.,2.],[0.,2.,1.]]);
let b: Array1<f64> = Array1::from(vec![1.,0.,0.]);
let res = solve_by_thomas(&matrix, &b).unwrap();
assert!(res[0] <= 0.7 && res[0] >= 0.5);
assert!(res[1] <= 0.3 && res[1] >= 0.1);
assert!(res[2] <= -0.3 && res[2] >= -0.5);
}
#[test]
fn solve_5x5() {
let matrix: Array2<f64> = Array2::from(vec![[1.,2.,0.,0.,0.],
[2.,1.,1.,0.,0.],[0.,1.,2.,1.,0.],[0.,0.,2.,2.,1.],[0.,0.,0.,1.,2.]]);
let b: Array1<f64> = Array1::from(vec![1.,0.,0.,0.,0.]);
let res = solve_by_thomas(&matrix, &b).unwrap();
println!("{:?}",res);
assert!(res[0] <= 0.13 && res[0] >= 0.09);
assert!(res[1] <= 0.46 && res[1] >= 0.43);
assert!(res[2] <= -0.60 && res[2] >= -0.70);
assert!(res[3] <= 0.90 && res[3] >= 0.86);
assert!(res[4] <= -0.42 && res[4] >= -0.46);
}
}