Skip to content

Commit 01f912a

Browse files
bytesnakeLorenz SchmidtYuhanLiin
authored
Introduce checked parameters to all algorithms (#158)
* Introduce param guards to linfa * Add param guards to linfa-bayes + documentation * Add param guards to linfa-elasticnet + docs * Add remaining changes * Reset logistic regression * Add ParamGuard to HierarchicalCluster * Fix PLS errors * Add ParamGuard to logistic regression * Fixed logistic regression tests * Fix Clippy warnings * Tweak logistic regression * Fix elasticnet doctests * Fix doctests in linfa-bayes * Fix Gaussian mixture example * Fix tsne examples * Fix SVM examples * Simplify reduction transform impl return type * Resolve transform double Result issue using marker traits Co-authored-by: Lorenz Schmidt <[email protected]> Co-authored-by: YuhanLiin <[email protected]>
1 parent e06f0be commit 01f912a

File tree

109 files changed

+2981
-2192
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

109 files changed

+2981
-2192
lines changed

.github/workflows/checking.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ jobs:
1111
fail-fast: false
1212
matrix:
1313
toolchain:
14-
- 1.51.0
14+
- 1.54.0
1515
- stable
1616
experimental: [false]
1717
include:
@@ -37,7 +37,7 @@ jobs:
3737
args: --workspace
3838

3939
- name: Run cargo check (with serde)
40-
if: ${{ matrix.toolchain != '1.51.0' }} # The following args don't work on older versions of Cargo
40+
if: ${{ matrix.toolchain != '1.54.0' }} # The following args don't work on older versions of Cargo
4141
uses: actions-rs/cargo@v1
4242
with:
4343
command: check

.github/workflows/codequality.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ jobs:
1111
strategy:
1212
matrix:
1313
toolchain:
14-
- 1.51.0
14+
- 1.54.0
1515
- stable
1616

1717
steps:

.github/workflows/testing.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ jobs:
1414
fail-fast: false
1515
matrix:
1616
toolchain:
17-
- 1.51.0
17+
- 1.54.0
1818
- stable
1919

2020
steps:

Cargo.toml

+1-15
Original file line numberDiff line numberDiff line change
@@ -74,21 +74,7 @@ linfa-datasets = { path = "datasets", features = ["winequality", "iris", "diabet
7474

7575
[workspace]
7676
members = [
77-
"algorithms/linfa-clustering",
78-
"algorithms/linfa-reduction",
79-
"algorithms/linfa-kernel",
80-
"algorithms/linfa-linear",
81-
"algorithms/linfa-logistic",
82-
"algorithms/linfa-trees",
83-
"algorithms/linfa-svm",
84-
"algorithms/linfa-hierarchical",
85-
"algorithms/linfa-ica",
86-
"algorithms/linfa-bayes",
87-
"algorithms/linfa-elasticnet",
88-
"algorithms/linfa-pls",
89-
"algorithms/linfa-tsne",
90-
"algorithms/linfa-preprocessing",
91-
"algorithms/linfa-nn",
77+
"algorithms/*",
9278
"datasets",
9379
]
9480

algorithms/linfa-bayes/README.md

+37-8
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Naive Bayes
22

3-
`linfa-bayes` aims to provide pure Rust implementations of Naive Bayes algorithms.
3+
`linfa-bayes` provides pure Rust implementations of Naive Bayes algorithms for the Linfa toolkit.
44

55
## The Big Picture
66

@@ -14,14 +14,43 @@
1414

1515
## Examples
1616

17-
There is an usage example in the `examples/` directory. To run, use:
17+
You can find an example in the `examples/` directory. To run, use:
1818

1919
```bash
20-
$ cargo run --example winequality
20+
$ cargo run --example winequality --release
2121
```
2222

23-
## License
24-
Dual-licensed to be compatible with the Rust project.
25-
26-
Licensed under the Apache License, Version 2.0 <http://www.apache.org/licenses/LICENSE-2.0> or the MIT license <http://opensource.org/licenses/MIT>, at your option. This file may not be copied, modified, or distributed except according to those terms.
27-
23+
<details>
24+
<summary style="cursor: pointer; display:list-item;">
25+
Show source code
26+
</summary>
27+
28+
```rust, no_run
29+
use linfa::metrics::ToConfusionMatrix;
30+
use linfa::traits::{Fit, Predict};
31+
use linfa_bayes::{GaussianNb, Result};
32+
33+
// Read in the dataset and convert targets to binary data
34+
let (train, valid) = linfa_datasets::winequality()
35+
.map_targets(|x| if *x > 6 { "good" } else { "bad" })
36+
.split_with_ratio(0.9);
37+
38+
// Train the model
39+
let model = GaussianNb::params().fit(&train)?;
40+
41+
// Predict the validation dataset
42+
let pred = model.predict(&valid);
43+
44+
// Construct confusion matrix
45+
let cm = pred.confusion_matrix(&valid)?;
46+
47+
// classes | bad | good
48+
// bad | 130 | 12
49+
// good | 7 | 10
50+
//
51+
// accuracy 0.8805031, MCC 0.45080978
52+
println!("{:?}", cm);
53+
println!("accuracy {}, MCC {}", cm.accuracy(), cm.mcc());
54+
# Result::Ok(())
55+
```
56+
</details>

algorithms/linfa-bayes/examples/winequality.rs

+7-7
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,25 @@
11
use linfa::metrics::ToConfusionMatrix;
22
use linfa::traits::{Fit, Predict};
3-
use linfa_bayes::{GaussianNbParams, Result};
3+
use linfa_bayes::{GaussianNb, Result};
44

55
fn main() -> Result<()> {
6-
// Read in the dataset and convert continuous target into categorical
6+
// Read in the dataset and convert targets to binary data
77
let (train, valid) = linfa_datasets::winequality()
8-
.map_targets(|x| if *x > 6 { 1 } else { 0 })
8+
.map_targets(|x| if *x > 6 { "good" } else { "bad" })
99
.split_with_ratio(0.9);
1010

1111
// Train the model
12-
let model = GaussianNbParams::params().fit(&train.view())?;
12+
let model = GaussianNb::params().fit(&train)?;
1313

1414
// Predict the validation dataset
1515
let pred = model.predict(&valid);
1616

1717
// Construct confusion matrix
1818
let cm = pred.confusion_matrix(&valid)?;
1919

20-
// classes | 1 | 0
21-
// 1 | 10 | 12
22-
// 0 | 7 | 130
20+
// classes | bad | good
21+
// bad | 130 | 12
22+
// good | 7 | 10
2323
//
2424
// accuracy 0.8805031, MCC 0.45080978
2525
println!("{:?}", cm);

algorithms/linfa-bayes/src/error.rs

+7-3
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
11
use ndarray_stats::errors::MinMaxError;
22
use thiserror::Error;
33

4-
pub type Result<T> = std::result::Result<T, BayesError>;
4+
/// Simplified `Result` using [`NaiveBayesError`](crate::NaiveBayesError) as error type
5+
pub type Result<T> = std::result::Result<T, NaiveBayesError>;
56

6-
/// An error when using a GaussianNB classifier
7+
/// Error variants from hyper-parameter construction or model estimation
78
#[derive(Error, Debug)]
8-
pub enum BayesError {
9+
pub enum NaiveBayesError {
910
/// Error when performing Max operation on data
1011
#[error("invalid statistical operation {0}")]
1112
Stats(#[from] MinMaxError),
13+
/// Invalid smoothing parameter
14+
#[error("invalid smoothing parameter {0}")]
15+
InvalidSmoothing(f64),
1216
#[error(transparent)]
1317
BaseCrate(#[from] linfa::Error),
1418
}

0 commit comments

Comments
 (0)