Skip to content

Commit e06f0be

Browse files
authored
Multinomial logistic regression (#159)
* Generalized argmin param across dimensions * Wrote multinomial loss and gradient * Finished multi fitted model * Implement dot and norm on ArgminParams * Add test for multinomial loss and grad * Add multi logistic regression tests * Add docs
1 parent 992938e commit e06f0be

File tree

4 files changed

+704
-248
lines changed

4 files changed

+704
-248
lines changed

algorithms/linfa-logistic/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ categories = ["algorithms", "mathematics", "science"]
1616
[dependencies]
1717
ndarray = { version = "0.15", features = ["approx", "blas"] }
1818
ndarray-linalg = "0.14"
19+
ndarray-stats = "0.5.0"
1920
num-traits = "0.2"
2021
argmin = { version = "0.4.6", features = ["ndarrayl"] }
2122
serde = "1.0"
+27-18
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//! This module defines newtypes for ndarray's Array1.
1+
//! This module defines newtypes for ndarray's Array.
22
//!
33
//! This is necessary to be able to abstract over floats (f32 and f64) so that
44
//! the logistic regression code can be abstract in the float type it works
@@ -8,51 +8,60 @@
88
99
use crate::float::Float;
1010
use argmin::prelude::*;
11-
use ndarray::Array1;
11+
use ndarray::{Array, ArrayBase, Data, Dimension, Zip};
1212
use serde::{Deserialize, Serialize};
1313

14+
pub fn elem_dot<F: linfa::Float, A1: Data<Elem = F>, A2: Data<Elem = F>, D: Dimension>(
15+
a: &ArrayBase<A1, D>,
16+
b: &ArrayBase<A2, D>,
17+
) -> F {
18+
Zip::from(a)
19+
.and(b)
20+
.fold(F::zero(), |acc, &a, &b| acc + a * b)
21+
}
22+
1423
#[derive(Serialize, Clone, Deserialize, Debug, Default)]
15-
pub struct ArgminParam<F>(pub Array1<F>);
24+
pub struct ArgminParam<F, D: Dimension>(pub Array<F, D>);
1625

17-
impl<F> ArgminParam<F> {
26+
impl<F, D: Dimension> ArgminParam<F, D> {
1827
#[inline]
19-
pub fn as_array(&self) -> &Array1<F> {
28+
pub fn as_array(&self) -> &Array<F, D> {
2029
&self.0
2130
}
2231
}
2332

24-
impl<F: Float> ArgminSub<ArgminParam<F>, ArgminParam<F>> for ArgminParam<F> {
25-
fn sub(&self, other: &ArgminParam<F>) -> ArgminParam<F> {
33+
impl<F: Float, D: Dimension> ArgminSub<ArgminParam<F, D>, ArgminParam<F, D>> for ArgminParam<F, D> {
34+
fn sub(&self, other: &ArgminParam<F, D>) -> ArgminParam<F, D> {
2635
ArgminParam(&self.0 - &other.0)
2736
}
2837
}
2938

30-
impl<F: Float> ArgminAdd<ArgminParam<F>, ArgminParam<F>> for ArgminParam<F> {
31-
fn add(&self, other: &ArgminParam<F>) -> ArgminParam<F> {
39+
impl<F: Float, D: Dimension> ArgminAdd<ArgminParam<F, D>, ArgminParam<F, D>> for ArgminParam<F, D> {
40+
fn add(&self, other: &ArgminParam<F, D>) -> ArgminParam<F, D> {
3241
ArgminParam(&self.0 + &other.0)
3342
}
3443
}
3544

36-
impl<F: Float> ArgminDot<ArgminParam<F>, F> for ArgminParam<F> {
37-
fn dot(&self, other: &ArgminParam<F>) -> F {
38-
self.0.dot(&other.0)
45+
impl<F: Float, D: Dimension> ArgminDot<ArgminParam<F, D>, F> for ArgminParam<F, D> {
46+
fn dot(&self, other: &ArgminParam<F, D>) -> F {
47+
elem_dot(&self.0, &other.0)
3948
}
4049
}
4150

42-
impl<F: Float> ArgminNorm<F> for ArgminParam<F> {
51+
impl<F: Float, D: Dimension> ArgminNorm<F> for ArgminParam<F, D> {
4352
fn norm(&self) -> F {
44-
self.0.dot(&self.0)
53+
num_traits::Float::sqrt(elem_dot(&self.0, &self.0))
4554
}
4655
}
4756

48-
impl<F: Float> ArgminMul<F, ArgminParam<F>> for ArgminParam<F> {
49-
fn mul(&self, other: &F) -> ArgminParam<F> {
57+
impl<F: Float, D: Dimension> ArgminMul<F, ArgminParam<F, D>> for ArgminParam<F, D> {
58+
fn mul(&self, other: &F) -> ArgminParam<F, D> {
5059
ArgminParam(&self.0 * *other)
5160
}
5261
}
5362

54-
impl<F: Float> ArgminMul<ArgminParam<F>, ArgminParam<F>> for ArgminParam<F> {
55-
fn mul(&self, other: &ArgminParam<F>) -> ArgminParam<F> {
63+
impl<F: Float, D: Dimension> ArgminMul<ArgminParam<F, D>, ArgminParam<F, D>> for ArgminParam<F, D> {
64+
fn mul(&self, other: &ArgminParam<F, D>) -> ArgminParam<F, D> {
5665
ArgminParam(&self.0 * &other.0)
5766
}
5867
}

algorithms/linfa-logistic/src/float.rs

+7-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use crate::argmin_param::ArgminParam;
22
use argmin::prelude::{ArgminFloat, ArgminMul};
3-
use ndarray::NdFloat;
3+
use ndarray::{Dimension, Ix1, Ix2, NdFloat};
44
use ndarray_linalg::Lapack;
55
use num_traits::FromPrimitive;
66

@@ -13,21 +13,22 @@ pub trait Float:
1313
+ Default
1414
+ Clone
1515
+ FromPrimitive
16-
+ ArgminMul<ArgminParam<Self>, ArgminParam<Self>>
16+
+ ArgminMul<ArgminParam<Self, Ix1>, ArgminParam<Self, Ix1>>
17+
+ ArgminMul<ArgminParam<Self, Ix2>, ArgminParam<Self, Ix2>>
1718
+ linfa::Float
1819
{
1920
const POSITIVE_LABEL: Self;
2021
const NEGATIVE_LABEL: Self;
2122
}
2223

23-
impl ArgminMul<ArgminParam<Self>, ArgminParam<Self>> for f64 {
24-
fn mul(&self, other: &ArgminParam<Self>) -> ArgminParam<Self> {
24+
impl<D: Dimension> ArgminMul<ArgminParam<Self, D>, ArgminParam<Self, D>> for f64 {
25+
fn mul(&self, other: &ArgminParam<Self, D>) -> ArgminParam<Self, D> {
2526
ArgminParam(&other.0 * *self)
2627
}
2728
}
2829

29-
impl ArgminMul<ArgminParam<Self>, ArgminParam<Self>> for f32 {
30-
fn mul(&self, other: &ArgminParam<Self>) -> ArgminParam<Self> {
30+
impl<D: Dimension> ArgminMul<ArgminParam<Self, D>, ArgminParam<Self, D>> for f32 {
31+
fn mul(&self, other: &ArgminParam<Self, D>) -> ArgminParam<Self, D> {
3132
ArgminParam(&other.0 * *self)
3233
}
3334
}

0 commit comments

Comments
 (0)