From c199221b214f194a73e6b02dafa51115a85757cf Mon Sep 17 00:00:00 2001 From: Can Date: Mon, 10 Feb 2025 14:03:41 +0100 Subject: [PATCH 01/13] add support for graphical lasso and adaptive (reweighted) graphical lasso --- examples/plot_graphical_lasso.py | 113 ++++++++++++++++++ skglm/estimators.py | 189 +++++++++++++++++++++++++++++-- skglm/tests/test_estimators.py | 129 ++++++++++++++++++++- skglm/utils/data.py | 20 ++++ 4 files changed, 434 insertions(+), 17 deletions(-) create mode 100644 examples/plot_graphical_lasso.py diff --git a/examples/plot_graphical_lasso.py b/examples/plot_graphical_lasso.py new file mode 100644 index 000000000..85d4c0e17 --- /dev/null +++ b/examples/plot_graphical_lasso.py @@ -0,0 +1,113 @@ +import numpy as np +from numpy.linalg import norm +import matplotlib.pyplot as plt +from sklearn.metrics import f1_score +from sklearn.datasets import make_sparse_spd_matrix +from sklearn.covariance import GraphicalLasso as skGraphicalLasso + +from skglm.estimators import GraphicalLasso, AdaptiveGraphicalLasso +from skglm.utils.data import generate_GraphicalLasso_data + +# Data +p = 20 +n = 100 +S, Theta_true, alpha_max = generate_GraphicalLasso_data(n, p) + +alphas = alpha_max*np.geomspace(1, 1e-3, num=30) + + +penalties = [ + "L1", + "R-L1", +] + +models_tol = 1e-4 +models = [ + GraphicalLasso(algo="mazumder", + warm_start=True, tol=models_tol), + AdaptiveGraphicalLasso(warm_start=True, n_reweights=10, tol=models_tol), + +] + +my_glasso_nmses = {penalty: [] for penalty in penalties} +my_glasso_f1_scores = {penalty: [] for penalty in penalties} + +sk_glasso_nmses = [] +sk_glasso_f1_scores = [] + + +for i, (penalty, model) in enumerate(zip(penalties, models)): + print(penalty) + for alpha_idx, alpha in enumerate(alphas): + print(f"======= alpha {alpha_idx+1}/{len(alphas)} =======") + model.alpha = alpha + model.fit(S) + Theta = model.precision_ + + my_nmse = norm(Theta - Theta_true)**2 / norm(Theta_true)**2 + + my_f1_score = f1_score(Theta.flatten() != 0., + Theta_true.flatten() != 0.) + print(f"NMSE: {my_nmse:.3f}") + print(f"F1 : {my_f1_score:.3f}") + + my_glasso_nmses[penalty].append(my_nmse) + my_glasso_f1_scores[penalty].append(my_f1_score) + + +plt.close('all') +fig, ax = plt.subplots(2, 1, sharex=True, figsize=( + [12.6, 4.63]), layout="constrained") +cmap = plt.get_cmap("tab10") +for i, penalty in enumerate(penalties): + + ax[0].semilogx(alphas/alpha_max, + my_glasso_nmses[penalty], + color=cmap(i), + linewidth=2., + label=penalty) + min_nmse = np.argmin(my_glasso_nmses[penalty]) + ax[0].vlines( + x=alphas[min_nmse] / alphas[0], + ymin=0, + ymax=np.min(my_glasso_nmses[penalty]), + linestyle='--', + color=cmap(i)) + line0 = ax[0].plot( + [alphas[min_nmse] / alphas[0]], + 0, + clip_on=False, + marker='X', + color=cmap(i), + markersize=12) + + ax[1].semilogx(alphas/alpha_max, + my_glasso_f1_scores[penalty], + linewidth=2., + color=cmap(i)) + max_f1 = np.argmax(my_glasso_f1_scores[penalty]) + ax[1].vlines( + x=alphas[max_f1] / alphas[0], + ymin=0, + ymax=np.max(my_glasso_f1_scores[penalty]), + linestyle='--', + color=cmap(i)) + line1 = ax[1].plot( + [alphas[max_f1] / alphas[0]], + 0, + clip_on=False, + marker='X', + markersize=12, + color=cmap(i)) + + +ax[0].set_title(f"{p=},{n=}", fontsize=18) +ax[0].set_ylabel("NMSE", fontsize=18) +ax[1].set_ylabel("F1 score", fontsize=18) +ax[1].set_xlabel(f"$\lambda / \lambda_\mathrm{{max}}$", fontsize=18) + +ax[0].legend(fontsize=14) +ax[0].grid(which='both', alpha=0.9) +ax[1].grid(which='both', alpha=0.9) +# plt.show(block=False) +plt.show() diff --git a/skglm/estimators.py b/skglm/estimators.py index a785fd5c5..0bbfa5253 100644 --- a/skglm/estimators.py +++ b/skglm/estimators.py @@ -21,7 +21,7 @@ from skglm.utils.jit_compilation import compiled_clone from skglm.solvers import AndersonCD, MultiTaskBCD, GroupBCD from skglm.datafits import (Cox, Quadratic, Logistic, QuadraticSVC, - QuadraticMultiTask, QuadraticGroup,) + QuadraticMultiTask, QuadraticGroup, QuadraticHessian) from skglm.penalties import (L1, WeightedL1, L1_plus_L2, L2, WeightedGroupL2, MCPenalty, WeightedMCPenalty, IndicatorBox, L2_1) from skglm.utils.data import grp_converter @@ -126,7 +126,8 @@ def _glm_fit(X, y, model, datafit, penalty, solver): w = np.zeros(n_features + fit_intercept, dtype=X_.dtype) Xw = np.zeros(n_samples, dtype=X_.dtype) else: # multitask - w = np.zeros((n_features + fit_intercept, y.shape[1]), dtype=X_.dtype) + w = np.zeros((n_features + fit_intercept, + y.shape[1]), dtype=X_.dtype) Xw = np.zeros(y.shape, dtype=X_.dtype) # check consistency of weights for WeightedL1 @@ -576,7 +577,8 @@ def path(self, X, y, alphas, coef_init=None, return_n_iter=True, **params): raise ValueError("The number of weights must match the number of \ features. Got %s, expected %s." % ( len(weights), X.shape[1])) - penalty = compiled_clone(WeightedL1(self.alpha, weights, self.positive)) + penalty = compiled_clone(WeightedL1( + self.alpha, weights, self.positive)) datafit = compiled_clone(Quadratic(), to_float32=X.dtype == np.float32) solver = AndersonCD( self.max_iter, self.max_epochs, self.p0, tol=self.tol, @@ -601,7 +603,8 @@ def fit(self, X, y): Fitted estimator. """ if self.weights is None: - warnings.warn('Weights are not provided, fitting with Lasso penalty') + warnings.warn( + 'Weights are not provided, fitting with Lasso penalty') penalty = L1(self.alpha, self.positive) else: penalty = WeightedL1(self.alpha, self.weights, self.positive) @@ -734,7 +737,8 @@ def path(self, X, y, alphas, coef_init=None, return_n_iter=True, **params): The number of iterations along the path. If return_n_iter is set to ``True``. """ - penalty = compiled_clone(L1_plus_L2(self.alpha, self.l1_ratio, self.positive)) + penalty = compiled_clone(L1_plus_L2( + self.alpha, self.l1_ratio, self.positive)) datafit = compiled_clone(Quadratic(), to_float32=X.dtype == np.float32) solver = AndersonCD( self.max_iter, self.max_epochs, self.p0, tol=self.tol, @@ -912,7 +916,8 @@ def path(self, X, y, alphas, coef_init=None, return_n_iter=True, **params): f"Got {len(self.weights)}, expected {X.shape[1]}." ) penalty = compiled_clone( - WeightedMCPenalty(self.alpha, self.gamma, self.weights, self.positive) + WeightedMCPenalty(self.alpha, self.gamma, + self.weights, self.positive) ) datafit = compiled_clone(Quadratic(), to_float32=X.dtype == np.float32) solver = AndersonCD( @@ -1307,7 +1312,8 @@ def fit(self, X, y): # copy/paste from https://github.com/scikit-learn/scikit-learn/blob/ \ # 23ff51c07ebc03c866984e93c921a8993e96d1f9/sklearn/utils/ \ # estimator_checks.py#L3886 - raise ValueError("requires y to be passed, but the target y is None") + raise ValueError( + "requires y to be passed, but the target y is None") y = check_array( y, accept_sparse=False, @@ -1322,7 +1328,8 @@ def fit(self, X, y): f"two columns. Got one column.\nAssuming that `y` " "is the vector of times and there is no censoring." ) - y = np.column_stack((y, np.ones_like(y))).astype(X.dtype, order="F") + y = np.column_stack((y, np.ones_like(y))).astype( + X.dtype, order="F") elif y.shape[1] > 2: raise ValueError( f"{repr(self)} requires the vector of response `y` to have " @@ -1347,7 +1354,8 @@ def fit(self, X, y): # init solver if self.l1_ratio == 0.: - solver = LBFGS(max_iter=self.max_iter, tol=self.tol, verbose=self.verbose) + solver = LBFGS(max_iter=self.max_iter, + tol=self.tol, verbose=self.verbose) else: solver = ProxNewton( max_iter=self.max_iter, tol=self.tol, verbose=self.verbose, @@ -1485,7 +1493,8 @@ def fit(self, X, Y): if not self.warm_start or not hasattr(self, "coef_"): self.coef_ = None - datafit_jit = compiled_clone(QuadraticMultiTask(), X.dtype == np.float32) + datafit_jit = compiled_clone( + QuadraticMultiTask(), X.dtype == np.float32) penalty_jit = compiled_clone(L2_1(self.alpha), X.dtype == np.float32) solver = MultiTaskBCD( @@ -1540,7 +1549,8 @@ def path(self, X, Y, alphas, coef_init=None, return_n_iter=False, **params): The number of iterations along the path. If return_n_iter is set to ``True``. """ - datafit = compiled_clone(QuadraticMultiTask(), to_float32=X.dtype == np.float32) + datafit = compiled_clone(QuadraticMultiTask(), + to_float32=X.dtype == np.float32) penalty = compiled_clone(L2_1(self.alpha)) solver = MultiTaskBCD( self.max_iter, self.max_epochs, self.p0, tol=self.tol, @@ -1664,7 +1674,8 @@ def fit(self, X, y): "The total number of group members must equal the number of features. " f"Got {n_features}, expected {X.shape[1]}.") - weights = np.ones(len(group_sizes)) if self.weights is None else self.weights + weights = np.ones( + len(group_sizes)) if self.weights is None else self.weights group_penalty = WeightedGroupL2(alpha=self.alpha, grp_ptr=grp_ptr, grp_indices=grp_indices, weights=weights, positive=self.positive) @@ -1675,3 +1686,157 @@ def fit(self, X, y): verbose=self.verbose) return _glm_fit(X, y, self, quad_group, group_penalty, solver) + + +class GraphicalLasso(): + def __init__(self, + alpha=1., + weights=None, + algo="banerjee", + max_iter=1000, + tol=1e-8, + warm_start=False, + ): + self.alpha = alpha + self.weights = weights + self.algo = algo + self.max_iter = max_iter + self.tol = tol + self.warm_start = warm_start + + def fit(self, S): + p = S.shape[-1] + indices = np.arange(p) + + if self.weights is None: + Weights = np.ones((p, p)) + else: + Weights = self.weights + if not np.allclose(Weights, Weights.T): + raise ValueError("Weights should be symmetric.") + + if self.warm_start and hasattr(self, "precision_"): + if self.algo == "banerjee": + raise ValueError( + "Banerjee does not support warm start for now.") + Theta = self.precision_ + W = self.covariance_ + else: + W = S.copy() # + alpha*np.eye(p) + Theta = np.linalg.pinv(W, hermitian=True) + + datafit = compiled_clone(QuadraticHessian()) + penalty = compiled_clone( + WeightedL1(alpha=self.alpha, weights=Weights[0, :-1])) + + solver = AndersonCD(warm_start=True, + fit_intercept=False, + ws_strategy="fixpoint") + + for it in range(self.max_iter): + Theta_old = Theta.copy() + for col in range(p): + indices_minus_col = np.concatenate( + [indices[:col], indices[col + 1:]]) + _11 = indices_minus_col[:, None], indices_minus_col[None] + _12 = indices_minus_col, col + _21 = col, indices_minus_col + _22 = col, col + + W_11 = W[_11] + w_12 = W[_12] + w_22 = W[_22] + s_12 = S[_12] + s_22 = S[_22] + + penalty.weights = Weights[_12] + + if self.algo == "banerjee": + w_init = Theta[_12]/Theta[_22] + Xw_init = W_11 @ w_init + Q = W_11 + elif self.algo == "mazumder": + inv_Theta_11 = W_11 - np.outer(w_12, w_12)/w_22 + Q = inv_Theta_11 + w_init = Theta[_12] * w_22 + Xw_init = inv_Theta_11 @ w_init + else: + raise ValueError(f"Unsupported algo {self.algo}") + + beta, _, _ = solver._solve( + Q, + s_12, + datafit, + penalty, + w_init=w_init, + Xw_init=Xw_init, + ) + + if self.algo == "banerjee": + w_12 = -W_11 @ beta + W[_12] = w_12 + W[_21] = w_12 + Theta[_22] = 1/(s_22 + beta @ w_12) + Theta[_12] = beta*Theta[_22] + else: # mazumder + theta_12 = beta / s_22 + theta_22 = 1/s_22 + theta_12 @ inv_Theta_11 @ theta_12 + + Theta[_12] = theta_12 + Theta[_21] = theta_12 + Theta[_22] = theta_22 + + w_22 = 1/(theta_22 - theta_12 @ inv_Theta_11 @ theta_12) + w_12 = -w_22*inv_Theta_11 @ theta_12 + W_11 = inv_Theta_11 + np.outer(w_12, w_12)/w_22 + W[_11] = W_11 + W[_12] = w_12 + W[_21] = w_12 + W[_22] = w_22 + + if np.linalg.norm(Theta - Theta_old) < self.tol: + print(f"Weighted Glasso converged at CD epoch {it + 1}") + break + else: + print(f"Not converged at epoch {it + 1}, " + f"diff={np.linalg.norm(Theta - Theta_old):.2e}") + self.precision_, self.covariance_ = Theta, W + self.n_iter_ = it + 1 + + return self + + +class AdaptiveGraphicalLasso(): + def __init__( + self, + alpha=1., + n_reweights=5, + max_iter=1000, + tol=1e-8, + warm_start=False, + # verbose=False, + ): + self.alpha = alpha + self.n_reweights = n_reweights + self.max_iter = max_iter + self.tol = tol + self.warm_start = warm_start + + def fit(self, S): + glasso = GraphicalLasso( + alpha=self.alpha, algo="mazumder", max_iter=self.max_iter, + tol=self.tol, warm_start=True) + Weights = np.ones(S.shape) + self.n_iter_ = [] + for it in range(self.n_reweights): + glasso.weights = Weights + glasso.fit(S) + Theta = glasso.precision_ + Weights = 1/(np.abs(Theta) + 1e-10) + self.n_iter_.append(glasso.n_iter_) + # TODO print losses for original problem? + glasso.covariance_ = np.linalg.pinv(Theta, hermitian=True) + self.precision_ = glasso.precision_ + self.covariance_ = glasso.covariance_ + + return self diff --git a/skglm/tests/test_estimators.py b/skglm/tests/test_estimators.py index ec7536f19..366e63be9 100644 --- a/skglm/tests/test_estimators.py +++ b/skglm/tests/test_estimators.py @@ -14,19 +14,23 @@ from sklearn.linear_model import ElasticNet as ElasticNet_sklearn from sklearn.linear_model import LogisticRegression as LogReg_sklearn from sklearn.linear_model import MultiTaskLasso as MultiTaskLasso_sklearn +from sklearn.covariance import GraphicalLasso as GraphicalLasso_sklearn from sklearn.model_selection import GridSearchCV from sklearn.svm import LinearSVC as LinearSVC_sklearn from sklearn.utils.estimator_checks import check_estimator +from sklearn.utils import check_random_state from skglm.utils.data import (make_correlated_data, make_dummy_survival_data, _alpha_max_group_lasso, grp_converter) from skglm.estimators import ( GeneralizedLinearEstimator, Lasso, MultiTaskLasso, WeightedLasso, ElasticNet, - MCPRegression, SparseLogisticRegression, LinearSVC, GroupLasso, CoxEstimator) + MCPRegression, SparseLogisticRegression, LinearSVC, GroupLasso, CoxEstimator, GraphicalLasso, + AdaptiveGraphicalLasso) from skglm.datafits import Logistic, Quadratic, QuadraticSVC, QuadraticMultiTask, Cox from skglm.penalties import L1, IndicatorBox, L1_plus_L2, MCPenalty, WeightedL1, SLOPE from skglm.solvers import AndersonCD, FISTA, ProxNewton from skglm.utils.jit_compilation import compiled_clone +from skglm.utils.data import generate_GraphicalLasso_data n_samples = 50 n_tasks = 9 @@ -118,7 +122,8 @@ def test_estimator(estimator_name, X, fit_intercept, positive): pytest.xfail("Intercept is not supported for SVC.") if positive and estimator_name not in ( "Lasso", "ElasticNet", "wLasso", "MCP", "wMCP", "GroupLasso"): - pytest.xfail("`positive` option is only supported by L1, L1_plus_L2 and wL1.") + pytest.xfail( + "`positive` option is only supported by L1, L1_plus_L2 and wL1.") estimator_sk = clone(dict_estimators_sk[estimator_name]) estimator_ours = clone(dict_estimators_ours[estimator_name]) @@ -364,7 +369,8 @@ def test_equivalence_cox_SLOPE_cox_L1(use_efron, issparse): w, *_ = solver.solve(X, y, datafit, penalty) method = 'efron' if use_efron else 'breslow' - estimator = CoxEstimator(alpha, l1_ratio=1., method=method, tol=1e-9).fit(X, y) + estimator = CoxEstimator( + alpha, l1_ratio=1., method=method, tol=1e-9).fit(X, y) np.testing.assert_allclose(w, estimator.coef_, atol=1e-5) @@ -510,7 +516,8 @@ def test_grid_search(estimator_name): for attr in res_attr: np.testing.assert_allclose(sk_clf.cv_results_[attr], ours_clf.cv_results_[attr], rtol=1e-3) - np.testing.assert_allclose(sk_clf.best_score_, ours_clf.best_score_, rtol=1e-3) + np.testing.assert_allclose( + sk_clf.best_score_, ours_clf.best_score_, rtol=1e-3) np.testing.assert_allclose(sk_clf.best_params_["alpha"], ours_clf.best_params_["alpha"], rtol=1e-3) @@ -607,7 +614,8 @@ def test_SparseLogReg_elasticnet(X, l1_ratio): estimator_ours = clone(dict_estimators_ours['LogisticRegression']) estimator_sk.set_params(fit_intercept=True, solver='saga', penalty='elasticnet', l1_ratio=l1_ratio, max_iter=10_000) - estimator_ours.set_params(fit_intercept=True, l1_ratio=l1_ratio, max_iter=10_000) + estimator_ours.set_params( + fit_intercept=True, l1_ratio=l1_ratio, max_iter=10_000) estimator_sk.fit(X, y) estimator_ours.fit(X, y) @@ -619,6 +627,117 @@ def test_SparseLogReg_elasticnet(X, l1_ratio): np.testing.assert_allclose( estimator_sk.intercept_, estimator_ours.intercept_, rtol=1e-4) +# Graphical Lasso tests + + +def test_GraphicalLasso_equivalence_sklearn(): + S, _, lmbd_max = generate_GraphicalLasso_data(200, 50) + alpha = lmbd_max / 5 + + model_sk = GraphicalLasso_sklearn( + alpha=alpha, covariance="precomputed", tol=1e-10) + model_sk.fit(S) + + for algo in ("banerjee", "mazumder"): + model = GraphicalLasso( + alpha=alpha, + warm_start=False, + max_iter=1000, + tol=1e-14, + algo=algo, + ).fit(S) + + np.testing.assert_allclose( + model.precision_, model_sk.precision_, atol=2e-4) + np.testing.assert_allclose( + model.covariance_, model_sk.covariance_, atol=2e-4) + + # check that we did not mess up lambda: + np.testing.assert_array_less(S.shape[0] + 1, (model.precision_ != 0).sum()) + + +def test_GraphicalLasso_warm_start(): + S, _, lmbd_max = generate_GraphicalLasso_data(200, 50) + + alpha = lmbd_max / 5 + + model = GraphicalLasso( + alpha=alpha, + warm_start=True, + max_iter=1000, + tol=1e-14, + algo="mazumder", + ).fit(S) + np.testing.assert_array_less(1, model.n_iter_) + + model.fit(S) + np.testing.assert_equal(model.n_iter_, 1) + + model.algo = "banerjee" + with pytest.raises(ValueError, match="does not support"): + model.fit(S) + + +def test_GraphicalLasso_weights(): + S, _, lmbd_max = generate_GraphicalLasso_data(200, 50) + + alpha = lmbd_max / 10 + + model = GraphicalLasso( + alpha=alpha, + warm_start=False, + max_iter=2000, + tol=1e-14, + algo="mazumder", + ).fit(S) + prec = model.precision_.copy() + + scal = 2. + model.weights = np.full(S.shape, scal) + model.alpha /= scal + model.fit(S) + np.testing.assert_allclose(prec, model.precision_) + + mask = np.random.randn(*S.shape) > 0 + mask = mask + mask.T + mask.flat[::S.shape[0] + 1] = 0 + weights = mask.astype(float) + model.weights = weights + model.fit(S) + np.testing.assert_array_less(1e-4, np.abs(model.precision_[~mask])) + + +def test_GraphicalLasso_adaptive(): + S, _, lmbd_max = generate_GraphicalLasso_data(200, 50) + + alpha = lmbd_max / 10 + tol = 1e-14 + model = GraphicalLasso( + alpha=alpha, + warm_start=True, + max_iter=1000, + tol=tol, + algo="mazumder", + ).fit(S) + n_iter = [model.n_iter_] + Theta1 = model.precision_ + weights = 1 / (np.abs(Theta1) + 1e-10) + model.weights = weights + + model.fit(S) + n_iter.append(model.n_iter_) + print("ada:") + + # TODO test more than 2 reweightings? + model_a = AdaptiveGraphicalLasso( + alpha=alpha, n_reweights=2, tol=tol).fit(S) + + np.testing.assert_allclose(model_a.precision_, model.precision_) + np.testing.assert_allclose(model_a.n_iter_, n_iter) + + # support is decreasing: + assert not np.any(model_a.precision_[Theta1 == 0]) + if __name__ == "__main__": pass diff --git a/skglm/utils/data.py b/skglm/utils/data.py index 8f51e04d8..7fdd1ef02 100644 --- a/skglm/utils/data.py +++ b/skglm/utils/data.py @@ -3,6 +3,7 @@ from numpy.linalg import norm from sklearn.utils import check_random_state from sklearn.preprocessing import StandardScaler +from sklearn.datasets import make_sparse_spd_matrix def make_correlated_data( @@ -252,3 +253,22 @@ def _alpha_max_group_lasso(X, y, grp_indices, grp_ptr, weights): norm(X[:, grp_g_indices].T @ y) / (n_samples * weights[g]) ) return alpha_max + + +def generate_GraphicalLasso_data(n_samples, n_features): + rng = check_random_state(0) + Theta_true = make_sparse_spd_matrix( + n_features, alpha=0.9, random_state=rng) + Theta_true += 0.1*np.eye(n_features) + Sigma_true = np.linalg.pinv(Theta_true, hermitian=True) + X = rng.multivariate_normal( + mean=np.zeros(n_features), + cov=Sigma_true, + size=n_samples, + ) + S = np.cov(X, bias=True, rowvar=False) + S_cpy = np.copy(S) + np.fill_diagonal(S_cpy, 0.) + lmbd_max = np.max(np.abs(S_cpy)) + + return S, Theta_true, lmbd_max From 5696d40160ac6a1579615762a291b57552766e93 Mon Sep 17 00:00:00 2001 From: Can Date: Thu, 20 Mar 2025 11:37:11 +0100 Subject: [PATCH 02/13] Update estimators to use barebones solver, make it at least as fast as sklearn --- skglm/estimators.py | 185 +++++++++++++++++++++------------ skglm/solvers/gram_cd.py | 36 ++++++- skglm/tests/test_estimators.py | 23 ++-- 3 files changed, 166 insertions(+), 78 deletions(-) diff --git a/skglm/estimators.py b/skglm/estimators.py index 0bbfa5253..e9c62b272 100644 --- a/skglm/estimators.py +++ b/skglm/estimators.py @@ -2,6 +2,7 @@ import warnings import numpy as np +from scipy.linalg import pinvh from scipy.sparse import issparse from scipy.special import expit from numbers import Integral, Real @@ -25,6 +26,10 @@ from skglm.penalties import (L1, WeightedL1, L1_plus_L2, L2, WeightedGroupL2, MCPenalty, WeightedMCPenalty, IndicatorBox, L2_1) from skglm.utils.data import grp_converter +from skglm.utils.prox_funcs import ST_vec + +from numba import njit +from skglm.solvers.gram_cd import barebones_cd_gram def _glm_fit(X, y, model, datafit, penalty, solver): @@ -1687,15 +1692,24 @@ def fit(self, X, y): return _glm_fit(X, y, self, quad_group, group_penalty, solver) +#################### +# WIP Graphical Lasso +#################### + class GraphicalLasso(): + """ A first-order BCD Graphical Lasso solver implementing the GLasso algorithm + described in Friedman et al., 2008 and the P-GLasso algorithm described in + Mazumder et al., 2012.""" + def __init__(self, alpha=1., weights=None, - algo="banerjee", - max_iter=1000, + algo="dual", + max_iter=100, tol=1e-8, warm_start=False, + inner_tol=1e-4, ): self.alpha = alpha self.weights = weights @@ -1703,6 +1717,7 @@ def __init__(self, self.max_iter = max_iter self.tol = tol self.warm_start = warm_start + self.inner_tol = inner_tol def fit(self, S): p = S.shape[-1] @@ -1716,90 +1731,110 @@ def fit(self, S): raise ValueError("Weights should be symmetric.") if self.warm_start and hasattr(self, "precision_"): - if self.algo == "banerjee": + if self.algo == "dual": raise ValueError( - "Banerjee does not support warm start for now.") + "dual does not support warm start for now.") Theta = self.precision_ W = self.covariance_ - else: - W = S.copy() # + alpha*np.eye(p) - Theta = np.linalg.pinv(W, hermitian=True) - datafit = compiled_clone(QuadraticHessian()) - penalty = compiled_clone( - WeightedL1(alpha=self.alpha, weights=Weights[0, :-1])) + else: + W = S.copy() + W *= 0.95 + diagonal = S.flat[:: p + 1] + W.flat[:: p + 1] = diagonal + Theta = pinvh(W) - solver = AndersonCD(warm_start=True, - fit_intercept=False, - ws_strategy="fixpoint") + W_11 = np.copy(W[1:, 1:], order="C") + eps = np.finfo(np.float64).eps + it = 0 + Theta_old = Theta.copy() for it in range(self.max_iter): Theta_old = Theta.copy() + for col in range(p): - indices_minus_col = np.concatenate( - [indices[:col], indices[col + 1:]]) - _11 = indices_minus_col[:, None], indices_minus_col[None] - _12 = indices_minus_col, col - _21 = col, indices_minus_col - _22 = col, col - - W_11 = W[_11] - w_12 = W[_12] - w_22 = W[_22] - s_12 = S[_12] - s_22 = S[_22] - - penalty.weights = Weights[_12] - - if self.algo == "banerjee": - w_init = Theta[_12]/Theta[_22] - Xw_init = W_11 @ w_init + if self.algo == "primal": + indices_minus_col = np.concatenate( + [indices[:col], indices[col + 1:]]) + _11 = indices_minus_col[:, None], indices_minus_col[None] + _12 = indices_minus_col, col + _21 = col, indices_minus_col + _22 = col, col + + elif self.algo == "dual": + if col > 0: + di = col - 1 + W_11[di] = W[di][indices != col] + W_11[:, di] = W[:, di][indices != col] + else: + W_11[:] = W[1:, 1:] + + s_12 = S[col, indices != col] + + if self.algo == "dual": + beta_init = (Theta[indices != col, col] / + (Theta[col, col] + 1000 * eps)) Q = W_11 - elif self.algo == "mazumder": - inv_Theta_11 = W_11 - np.outer(w_12, w_12)/w_22 + + elif self.algo == "primal": + inv_Theta_11 = (W[_11] - + np.outer(W[_12], + W[_12])/W[_22]) Q = inv_Theta_11 - w_init = Theta[_12] * w_22 - Xw_init = inv_Theta_11 @ w_init + beta_init = Theta[indices != col, col] * S[col, col] else: raise ValueError(f"Unsupported algo {self.algo}") - beta, _, _ = solver._solve( + beta = barebones_cd_gram( Q, s_12, - datafit, - penalty, - w_init=w_init, - Xw_init=Xw_init, + x=beta_init, + alpha=self.alpha, + weights=Weights[indices != col, col], + tol=self.inner_tol, + max_iter=self.max_iter, ) - if self.algo == "banerjee": - w_12 = -W_11 @ beta - W[_12] = w_12 - W[_21] = w_12 - Theta[_22] = 1/(s_22 + beta @ w_12) - Theta[_12] = beta*Theta[_22] - else: # mazumder - theta_12 = beta / s_22 - theta_22 = 1/s_22 + theta_12 @ inv_Theta_11 @ theta_12 - - Theta[_12] = theta_12 - Theta[_21] = theta_12 - Theta[_22] = theta_22 - - w_22 = 1/(theta_22 - theta_12 @ inv_Theta_11 @ theta_12) - w_12 = -w_22*inv_Theta_11 @ theta_12 - W_11 = inv_Theta_11 + np.outer(w_12, w_12)/w_22 - W[_11] = W_11 - W[_12] = w_12 - W[_21] = w_12 - W[_22] = w_22 + if self.algo == "dual": + w_12 = -np.dot(W_11, beta) + W[col, indices != col] = w_12 + W[indices != col, col] = w_12 + + Theta[col, col] = 1 / \ + (W[col, col] + np.dot(beta, w_12)) + Theta[indices != col, col] = beta*Theta[col, col] + Theta[col, indices != col] = beta*Theta[col, col] + + else: # primal + Theta[indices != col, col] = beta / S[col, col] + Theta[col, indices != col] = beta / S[col, col] + Theta[col, col] = (1/S[col, col] + + Theta[col, indices != col] @ + inv_Theta_11 @ + Theta[indices != col, col]) + W[col, col] = (1/(Theta[col, col] - + Theta[indices != col, col] @ + inv_Theta_11 @ + Theta[indices != col, col])) + W[indices != col, col] = (-W[col, col] * + inv_Theta_11 @ + Theta[indices != col, col]) + W[col, indices != col] = (-W[col, col] * + inv_Theta_11 @ + Theta[indices != col, col]) + # Maybe W_11 can be done smarter ? + W[_11] = (inv_Theta_11 + + np.outer(W[indices != col, col], + W[indices != col, col])/W[col, col]) if np.linalg.norm(Theta - Theta_old) < self.tol: print(f"Weighted Glasso converged at CD epoch {it + 1}") break else: - print(f"Not converged at epoch {it + 1}, " - f"diff={np.linalg.norm(Theta - Theta_old):.2e}") + print( + f"Not converged at epoch {it + 1}, " + f"diff={np.linalg.norm(Theta - Theta_old):.2e}" + ) self.precision_, self.covariance_ = Theta, W self.n_iter_ = it + 1 @@ -1810,13 +1845,14 @@ class AdaptiveGraphicalLasso(): def __init__( self, alpha=1., + strategy="log", n_reweights=5, max_iter=1000, tol=1e-8, warm_start=False, - # verbose=False, ): self.alpha = alpha + self.strategy = strategy self.n_reweights = n_reweights self.max_iter = max_iter self.tol = tol @@ -1824,19 +1860,32 @@ def __init__( def fit(self, S): glasso = GraphicalLasso( - alpha=self.alpha, algo="mazumder", max_iter=self.max_iter, - tol=self.tol, warm_start=True) + alpha=self.alpha, + algo="primal", + max_iter=self.max_iter, + tol=self.tol, + warm_start=True) Weights = np.ones(S.shape) self.n_iter_ = [] for it in range(self.n_reweights): glasso.weights = Weights glasso.fit(S) Theta = glasso.precision_ - Weights = 1/(np.abs(Theta) + 1e-10) + if self.strategy == "log": + Weights = 1/(np.abs(Theta) + 1e-10) + elif self.strategy == "sqrt": + Weights = 1/(2*np.sqrt(np.abs(Theta)) + 1e-10) + elif self.strategy == "mcp": + gamma = 3. + Weights = np.zeros_like(Theta) + Weights[np.abs(Theta) < gamma*self.alpha] = (self.alpha - + np.abs(Theta[np.abs(Theta) < gamma*self.alpha])/gamma) + else: + raise ValueError(f"Unknown strategy {self.strategy}") + self.n_iter_.append(glasso.n_iter_) # TODO print losses for original problem? glasso.covariance_ = np.linalg.pinv(Theta, hermitian=True) self.precision_ = glasso.precision_ self.covariance_ = glasso.covariance_ - return self diff --git a/skglm/solvers/gram_cd.py b/skglm/solvers/gram_cd.py index 9ecf42bfb..ad639ee06 100644 --- a/skglm/solvers/gram_cd.py +++ b/skglm/solvers/gram_cd.py @@ -5,6 +5,7 @@ from skglm.solvers.base import BaseSolver from skglm.utils.anderson import AndersonAcceleration +from skglm.utils.prox_funcs import ST_vec class GramCD(BaseSolver): @@ -118,7 +119,8 @@ def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): # perform Anderson extrapolation if self.use_acc: - w_acc, grad_acc, is_extrapolated = accelerator.extrapolate(w, grad) + w_acc, grad_acc, is_extrapolated = accelerator.extrapolate( + w, grad) if is_extrapolated: # omit constant term for comparison @@ -165,3 +167,35 @@ def _gram_cd_epoch(scaled_gram, w, grad, penalty, greedy_cd): grad += (w[j] - old_w_j) * scaled_gram[:, j] return penalty.subdiff_distance(w, grad, all_features) + + +@njit +def barebones_cd_gram(H, q, x, alpha, weights, max_iter=100, tol=1e-4): + """ + Solve min .5 * x.T H x + q.T @ x + alpha * norm(x, 1). + + H must be symmetric. + """ + dim = H.shape[0] + lc = np.zeros(dim) + for j in range(dim): + lc[j] = H[j, j] + + # Hx = H @ x + Hx = np.dot(H, x) + for _ in range(max_iter): + max_delta = 0 # max coeff change + + for j in range(dim): + x_j_prev = x[j] + x[j] = ST_vec(x[j] - (Hx[j] + q[j]) / lc[j], + alpha*weights[j] / lc[j]) + + max_delta = max(max_delta, np.abs(x_j_prev - x[j])) + + if x_j_prev != x[j]: + Hx += (x[j] - x_j_prev) * H[j] + if max_delta <= tol: + break + + return x diff --git a/skglm/tests/test_estimators.py b/skglm/tests/test_estimators.py index 366e63be9..c2a01995c 100644 --- a/skglm/tests/test_estimators.py +++ b/skglm/tests/test_estimators.py @@ -627,8 +627,10 @@ def test_SparseLogReg_elasticnet(X, l1_ratio): np.testing.assert_allclose( estimator_sk.intercept_, estimator_ours.intercept_, rtol=1e-4) -# Graphical Lasso tests +####################### +# WIP Graphical Lasso tests +####################### def test_GraphicalLasso_equivalence_sklearn(): S, _, lmbd_max = generate_GraphicalLasso_data(200, 50) @@ -638,7 +640,7 @@ def test_GraphicalLasso_equivalence_sklearn(): alpha=alpha, covariance="precomputed", tol=1e-10) model_sk.fit(S) - for algo in ("banerjee", "mazumder"): + for algo in ("primal", "dual"): model = GraphicalLasso( alpha=alpha, warm_start=False, @@ -648,9 +650,9 @@ def test_GraphicalLasso_equivalence_sklearn(): ).fit(S) np.testing.assert_allclose( - model.precision_, model_sk.precision_, atol=2e-4) + model.precision_, model_sk.precision_, atol=1e-4) np.testing.assert_allclose( - model.covariance_, model_sk.covariance_, atol=2e-4) + model.covariance_, model_sk.covariance_, atol=1e-4) # check that we did not mess up lambda: np.testing.assert_array_less(S.shape[0] + 1, (model.precision_ != 0).sum()) @@ -666,14 +668,14 @@ def test_GraphicalLasso_warm_start(): warm_start=True, max_iter=1000, tol=1e-14, - algo="mazumder", + algo="primal", ).fit(S) np.testing.assert_array_less(1, model.n_iter_) model.fit(S) np.testing.assert_equal(model.n_iter_, 1) - model.algo = "banerjee" + model.algo = "dual" with pytest.raises(ValueError, match="does not support"): model.fit(S) @@ -688,7 +690,7 @@ def test_GraphicalLasso_weights(): warm_start=False, max_iter=2000, tol=1e-14, - algo="mazumder", + algo="primal", ).fit(S) prec = model.precision_.copy() @@ -717,10 +719,11 @@ def test_GraphicalLasso_adaptive(): warm_start=True, max_iter=1000, tol=tol, - algo="mazumder", + algo="primal", ).fit(S) n_iter = [model.n_iter_] Theta1 = model.precision_ + # TODO test the other strategies weights = 1 / (np.abs(Theta1) + 1e-10) model.weights = weights @@ -730,7 +733,9 @@ def test_GraphicalLasso_adaptive(): # TODO test more than 2 reweightings? model_a = AdaptiveGraphicalLasso( - alpha=alpha, n_reweights=2, tol=tol).fit(S) + alpha=alpha, + n_reweights=2, + tol=tol).fit(S) np.testing.assert_allclose(model_a.precision_, model.precision_) np.testing.assert_allclose(model_a.n_iter_, n_iter) From 3e3df79db0876af9d6fbe1a6bfd459fdd89d9d59 Mon Sep 17 00:00:00 2001 From: Can Date: Fri, 21 Mar 2025 10:05:54 +0100 Subject: [PATCH 03/13] add Reweighted GLasso regularization path example --- examples/plot_reweighted_glasso_reg_path.py | 143 ++++++++++++++++++++ 1 file changed, 143 insertions(+) create mode 100644 examples/plot_reweighted_glasso_reg_path.py diff --git a/examples/plot_reweighted_glasso_reg_path.py b/examples/plot_reweighted_glasso_reg_path.py new file mode 100644 index 000000000..595de1b17 --- /dev/null +++ b/examples/plot_reweighted_glasso_reg_path.py @@ -0,0 +1,143 @@ +import numpy as np +from numpy.linalg import norm +import matplotlib.pyplot as plt +from sklearn.metrics import f1_score +from sklearn.datasets import make_sparse_spd_matrix +from sklearn.utils import check_random_state + +from skglm.estimators import GraphicalLasso +from skglm.estimators import AdaptiveGraphicalLasso + +# Data +p = 100 +n = 1000 +rng = check_random_state(0) +Theta_true = make_sparse_spd_matrix( + p, + alpha=0.9, + random_state=rng) + +Theta_true += 0.1*np.eye(p) +Sigma_true = np.linalg.pinv(Theta_true, hermitian=True) +X = rng.multivariate_normal( + mean=np.zeros(p), + cov=Sigma_true, + size=n, +) + +S = np.cov(X, bias=True, rowvar=False) +S_cpy = np.copy(S) +np.fill_diagonal(S_cpy, 0.) +alpha_max = np.max(np.abs(S_cpy)) + +alphas = alpha_max*np.geomspace(1, 1e-4, num=10) + + +penalties = [ + "L1", + "R-L1 (log)", + "R-L1 (L0.5)", + "R-L1 (MCP)", +] +n_reweights = 5 +models_tol = 1e-4 +models = [ + GraphicalLasso(algo="primal", + warm_start=True, + tol=models_tol), + AdaptiveGraphicalLasso(warm_start=True, + strategy="log", + n_reweights=n_reweights, + tol=models_tol), + AdaptiveGraphicalLasso(warm_start=True, + strategy="sqrt", + n_reweights=n_reweights, + tol=models_tol), + AdaptiveGraphicalLasso(warm_start=True, + strategy="mcp", + n_reweights=n_reweights, + tol=models_tol), +] + +my_glasso_nmses = {penalty: [] for penalty in penalties} +my_glasso_f1_scores = {penalty: [] for penalty in penalties} + +sk_glasso_nmses = [] +sk_glasso_f1_scores = [] + + +for i, (penalty, model) in enumerate(zip(penalties, models)): + print(penalty) + for alpha_idx, alpha in enumerate(alphas): + print(f"======= alpha {alpha_idx+1}/{len(alphas)} =======") + model.alpha = alpha + model.fit(S) + Theta = model.precision_ + + my_nmse = norm(Theta - Theta_true)**2 / norm(Theta_true)**2 + + my_f1_score = f1_score(Theta.flatten() != 0., + Theta_true.flatten() != 0.) + print(f"NMSE: {my_nmse:.3f}") + print(f"F1 : {my_f1_score:.3f}") + + my_glasso_nmses[penalty].append(my_nmse) + my_glasso_f1_scores[penalty].append(my_f1_score) + + +plt.close('all') +fig, ax = plt.subplots(2, 1, sharex=True, figsize=( + [6.11, 3.91]), layout="constrained") +cmap = plt.get_cmap("tab10") +for i, penalty in enumerate(penalties): + + ax[0].semilogx(alphas/alpha_max, + my_glasso_nmses[penalty], + color=cmap(i), + linewidth=2., + label=penalty) + min_nmse = np.argmin(my_glasso_nmses[penalty]) + ax[0].vlines( + x=alphas[min_nmse] / alphas[0], + ymin=0, + ymax=np.min(my_glasso_nmses[penalty]), + linestyle='--', + color=cmap(i)) + line0 = ax[0].plot( + [alphas[min_nmse] / alphas[0]], + 0, + clip_on=False, + marker='X', + color=cmap(i), + markersize=12) + + ax[1].semilogx(alphas/alpha_max, + my_glasso_f1_scores[penalty], + linewidth=2., + color=cmap(i)) + max_f1 = np.argmax(my_glasso_f1_scores[penalty]) + ax[1].vlines( + x=alphas[max_f1] / alphas[0], + ymin=0, + ymax=np.max(my_glasso_f1_scores[penalty]), + linestyle='--', + color=cmap(i)) + line1 = ax[1].plot( + [alphas[max_f1] / alphas[0]], + 0, + clip_on=False, + marker='X', + markersize=12, + color=cmap(i)) + + +ax[0].set_title(f"{p=},{n=}", fontsize=18) +ax[0].set_ylabel("NMSE", fontsize=18) +ax[1].set_ylabel("F1 score", fontsize=18) +ax[1].set_xlabel(f"$\lambda / \lambda_\mathrm{{max}}$", fontsize=18) + +ax[0].legend(fontsize=14) +ax[0].grid(which='both', alpha=0.9) +ax[1].grid(which='both', alpha=0.9) +# plt.savefig(f"./non_convex_p{p}_n{n}.pdf") +plt.show(block=False) From ca6960f03f915673513c611f784ca79f1e6e3b03 Mon Sep 17 00:00:00 2001 From: Can Date: Wed, 2 Apr 2025 16:11:21 +0200 Subject: [PATCH 04/13] fix issues in glasso reg path example --- examples/plot_graphical_lasso.py | 113 ---------------- examples/plot_reweighted_glasso_reg_path.py | 139 ++++++++------------ 2 files changed, 58 insertions(+), 194 deletions(-) delete mode 100644 examples/plot_graphical_lasso.py diff --git a/examples/plot_graphical_lasso.py b/examples/plot_graphical_lasso.py deleted file mode 100644 index 85d4c0e17..000000000 --- a/examples/plot_graphical_lasso.py +++ /dev/null @@ -1,113 +0,0 @@ -import numpy as np -from numpy.linalg import norm -import matplotlib.pyplot as plt -from sklearn.metrics import f1_score -from sklearn.datasets import make_sparse_spd_matrix -from sklearn.covariance import GraphicalLasso as skGraphicalLasso - -from skglm.estimators import GraphicalLasso, AdaptiveGraphicalLasso -from skglm.utils.data import generate_GraphicalLasso_data - -# Data -p = 20 -n = 100 -S, Theta_true, alpha_max = generate_GraphicalLasso_data(n, p) - -alphas = alpha_max*np.geomspace(1, 1e-3, num=30) - - -penalties = [ - "L1", - "R-L1", -] - -models_tol = 1e-4 -models = [ - GraphicalLasso(algo="mazumder", - warm_start=True, tol=models_tol), - AdaptiveGraphicalLasso(warm_start=True, n_reweights=10, tol=models_tol), - -] - -my_glasso_nmses = {penalty: [] for penalty in penalties} -my_glasso_f1_scores = {penalty: [] for penalty in penalties} - -sk_glasso_nmses = [] -sk_glasso_f1_scores = [] - - -for i, (penalty, model) in enumerate(zip(penalties, models)): - print(penalty) - for alpha_idx, alpha in enumerate(alphas): - print(f"======= alpha {alpha_idx+1}/{len(alphas)} =======") - model.alpha = alpha - model.fit(S) - Theta = model.precision_ - - my_nmse = norm(Theta - Theta_true)**2 / norm(Theta_true)**2 - - my_f1_score = f1_score(Theta.flatten() != 0., - Theta_true.flatten() != 0.) - print(f"NMSE: {my_nmse:.3f}") - print(f"F1 : {my_f1_score:.3f}") - - my_glasso_nmses[penalty].append(my_nmse) - my_glasso_f1_scores[penalty].append(my_f1_score) - - -plt.close('all') -fig, ax = plt.subplots(2, 1, sharex=True, figsize=( - [12.6, 4.63]), layout="constrained") -cmap = plt.get_cmap("tab10") -for i, penalty in enumerate(penalties): - - ax[0].semilogx(alphas/alpha_max, - my_glasso_nmses[penalty], - color=cmap(i), - linewidth=2., - label=penalty) - min_nmse = np.argmin(my_glasso_nmses[penalty]) - ax[0].vlines( - x=alphas[min_nmse] / alphas[0], - ymin=0, - ymax=np.min(my_glasso_nmses[penalty]), - linestyle='--', - color=cmap(i)) - line0 = ax[0].plot( - [alphas[min_nmse] / alphas[0]], - 0, - clip_on=False, - marker='X', - color=cmap(i), - markersize=12) - - ax[1].semilogx(alphas/alpha_max, - my_glasso_f1_scores[penalty], - linewidth=2., - color=cmap(i)) - max_f1 = np.argmax(my_glasso_f1_scores[penalty]) - ax[1].vlines( - x=alphas[max_f1] / alphas[0], - ymin=0, - ymax=np.max(my_glasso_f1_scores[penalty]), - linestyle='--', - color=cmap(i)) - line1 = ax[1].plot( - [alphas[max_f1] / alphas[0]], - 0, - clip_on=False, - marker='X', - markersize=12, - color=cmap(i)) - - -ax[0].set_title(f"{p=},{n=}", fontsize=18) -ax[0].set_ylabel("NMSE", fontsize=18) -ax[1].set_ylabel("F1 score", fontsize=18) -ax[1].set_xlabel(f"$\lambda / \lambda_\mathrm{{max}}$", fontsize=18) - -ax[0].legend(fontsize=14) -ax[0].grid(which='both', alpha=0.9) -ax[1].grid(which='both', alpha=0.9) -# plt.show(block=False) -plt.show() diff --git a/examples/plot_reweighted_glasso_reg_path.py b/examples/plot_reweighted_glasso_reg_path.py index 595de1b17..fd80c3225 100644 --- a/examples/plot_reweighted_glasso_reg_path.py +++ b/examples/plot_reweighted_glasso_reg_path.py @@ -1,43 +1,33 @@ +# Authors: Can Pouliquen +# Mathurin Massias +""" +======================================================================= +Regularization paths for the Graphical Lasso and its Adaptive variation +======================================================================= +Highlight the importance of using non-convex regularization for improved performance, +solved using the reweighting strategy. +""" + import numpy as np from numpy.linalg import norm import matplotlib.pyplot as plt from sklearn.metrics import f1_score -from sklearn.datasets import make_sparse_spd_matrix -from sklearn.utils import check_random_state +from skglm.utils.data import generate_GraphicalLasso_data from skglm.estimators import GraphicalLasso from skglm.estimators import AdaptiveGraphicalLasso -# Data + p = 100 n = 1000 -rng = check_random_state(0) -Theta_true = make_sparse_spd_matrix( - p, - alpha=0.9, - random_state=rng) - -Theta_true += 0.1*np.eye(p) -Sigma_true = np.linalg.pinv(Theta_true, hermitian=True) -X = rng.multivariate_normal( - mean=np.zeros(p), - cov=Sigma_true, - size=n, -) - -S = np.cov(X, bias=True, rowvar=False) -S_cpy = np.copy(S) -np.fill_diagonal(S_cpy, 0.) -alpha_max = np.max(np.abs(S_cpy)) - +S, Theta_true, alpha_max = generate_GraphicalLasso_data(n, p) alphas = alpha_max*np.geomspace(1, 1e-4, num=10) - penalties = [ "L1", - "R-L1 (log)", - "R-L1 (L0.5)", - "R-L1 (MCP)", + "Log", + "L0.5", + "MCP", ] n_reweights = 5 models_tol = 1e-4 @@ -67,9 +57,8 @@ for i, (penalty, model) in enumerate(zip(penalties, models)): - print(penalty) for alpha_idx, alpha in enumerate(alphas): - print(f"======= alpha {alpha_idx+1}/{len(alphas)} =======") + print(f"======= {penalty} penalty, alpha {alpha_idx+1}/{len(alphas)} =======") model.alpha = alpha model.fit(S) Theta = model.precision_ @@ -78,66 +67,54 @@ my_f1_score = f1_score(Theta.flatten() != 0., Theta_true.flatten() != 0.) - print(f"NMSE: {my_nmse:.3f}") - print(f"F1 : {my_f1_score:.3f}") my_glasso_nmses[penalty].append(my_nmse) my_glasso_f1_scores[penalty].append(my_f1_score) plt.close('all') -fig, ax = plt.subplots(2, 1, sharex=True, figsize=( - [6.11, 3.91]), layout="constrained") +fig, axarr = plt.subplots(2, 1, sharex=True, figsize=([6.11, 3.91]), + layout="constrained") cmap = plt.get_cmap("tab10") for i, penalty in enumerate(penalties): - ax[0].semilogx(alphas/alpha_max, - my_glasso_nmses[penalty], - color=cmap(i), - linewidth=2., - label=penalty) - min_nmse = np.argmin(my_glasso_nmses[penalty]) - ax[0].vlines( - x=alphas[min_nmse] / alphas[0], - ymin=0, - ymax=np.min(my_glasso_nmses[penalty]), - linestyle='--', - color=cmap(i)) - line0 = ax[0].plot( - [alphas[min_nmse] / alphas[0]], - 0, - clip_on=False, - marker='X', - color=cmap(i), - markersize=12) - - ax[1].semilogx(alphas/alpha_max, - my_glasso_f1_scores[penalty], - linewidth=2., - color=cmap(i)) - max_f1 = np.argmax(my_glasso_f1_scores[penalty]) - ax[1].vlines( - x=alphas[max_f1] / alphas[0], - ymin=0, - ymax=np.max(my_glasso_f1_scores[penalty]), - linestyle='--', - color=cmap(i)) - line1 = ax[1].plot( - [alphas[max_f1] / alphas[0]], - 0, - clip_on=False, - marker='X', - markersize=12, - color=cmap(i)) - - -ax[0].set_title(f"{p=},{n=}", fontsize=18) -ax[0].set_ylabel("NMSE", fontsize=18) -ax[1].set_ylabel("F1 score", fontsize=18) -ax[1].set_xlabel(f"$\lambda / \lambda_\mathrm{{max}}$", fontsize=18) - -ax[0].legend(fontsize=14) -ax[0].grid(which='both', alpha=0.9) -ax[1].grid(which='both', alpha=0.9) -# plt.savefig(f"./non_convex_p{p}_n{n}.pdf") + for j, ax in enumerate(axarr): + + if j == 0: + metric = my_glasso_nmses + best_idx = np.argmin(metric[penalty]) + ystop = np.min(metric[penalty]) + else: + metric = my_glasso_f1_scores + best_idx = np.argmax(metric[penalty]) + ystop = np.max(metric[penalty]) + + ax.semilogx(alphas/alpha_max, + metric[penalty], + color=cmap(i), + linewidth=2., + label=penalty) + + ax.vlines( + x=alphas[best_idx] / alphas[0], + ymin=0, + ymax=ystop, + linestyle='--', + color=cmap(i)) + line = ax.plot( + [alphas[best_idx] / alphas[0]], + 0, + clip_on=False, + marker='X', + color=cmap(i), + markersize=12) + + ax.grid(which='both', alpha=0.9) + +axarr[0].legend(fontsize=14) +axarr[0].set_title(f"{p=},{n=}", fontsize=18) +axarr[0].set_ylabel("NMSE", fontsize=18) +axarr[1].set_ylabel("F1 score", fontsize=18) +axarr[1].set_xlabel(r"$\lambda / \lambda_\mathrm{{max}}$", fontsize=18) + plt.show(block=False) From 4538002c853e4c83705d996ea021200c95217754 Mon Sep 17 00:00:00 2001 From: Can Date: Wed, 2 Apr 2025 17:48:07 +0200 Subject: [PATCH 05/13] fix glasso solver issues, move estimator to own file, create dedicated tests file --- examples/plot_reweighted_glasso_reg_path.py | 7 +- skglm/covariance.py | 218 ++++++++++++++++++ skglm/estimators.py | 241 ++------------------ skglm/solvers/gram_cd.py | 14 +- skglm/tests/test_covariance.py | 119 ++++++++++ skglm/tests/test_estimators.py | 134 +---------- skglm/utils/data.py | 2 +- 7 files changed, 363 insertions(+), 372 deletions(-) create mode 100644 skglm/covariance.py create mode 100644 skglm/tests/test_covariance.py diff --git a/examples/plot_reweighted_glasso_reg_path.py b/examples/plot_reweighted_glasso_reg_path.py index fd80c3225..6dd755083 100644 --- a/examples/plot_reweighted_glasso_reg_path.py +++ b/examples/plot_reweighted_glasso_reg_path.py @@ -13,14 +13,13 @@ import matplotlib.pyplot as plt from sklearn.metrics import f1_score -from skglm.utils.data import generate_GraphicalLasso_data -from skglm.estimators import GraphicalLasso -from skglm.estimators import AdaptiveGraphicalLasso +from skglm.covariance import GraphicalLasso, AdaptiveGraphicalLasso +from skglm.utils.data import make_dummy_covariance_data p = 100 n = 1000 -S, Theta_true, alpha_max = generate_GraphicalLasso_data(n, p) +S, Theta_true, alpha_max = make_dummy_covariance_data(n, p) alphas = alpha_max*np.geomspace(1, 1e-4, num=10) penalties = [ diff --git a/skglm/covariance.py b/skglm/covariance.py new file mode 100644 index 000000000..9294a220b --- /dev/null +++ b/skglm/covariance.py @@ -0,0 +1,218 @@ +# License: BSD 3 clause + +import numpy as np +from scipy.linalg import pinvh + +from skglm.solvers.gram_cd import barebones_cd_gram + + +class GraphicalLasso(): + """ A first-order BCD Graphical Lasso solver implementing the GLasso algorithm + described in Friedman et al., 2008 and the P-GLasso algorithm described in + Mazumder et al., 2012.""" + + def __init__(self, + alpha=1., + weights=None, + algo="dual", + max_iter=100, + tol=1e-8, + warm_start=False, + inner_tol=1e-4, + verbose=False + ): + self.alpha = alpha + self.weights = weights + self.algo = algo + self.max_iter = max_iter + self.tol = tol + self.warm_start = warm_start + self.inner_tol = inner_tol + self.verbose = verbose + + def fit(self, S): + p = S.shape[-1] + indices = np.arange(p) + + if self.weights is None: + Weights = np.ones((p, p)) + else: + Weights = self.weights + if not np.allclose(Weights, Weights.T): + raise ValueError("Weights should be symmetric.") + + if self.warm_start and hasattr(self, "precision_"): + if self.algo == "dual": + raise ValueError( + "dual does not support warm start for now.") + Theta = self.precision_ + W = self.covariance_ + else: + W = S.copy() + W *= 0.95 + diagonal = S.flat[:: p + 1] + W.flat[:: p + 1] = diagonal + Theta = pinvh(W) + + W_11 = np.copy(W[1:, 1:], order="C") + eps = np.finfo(np.float64).eps + it = 0 + Theta_old = Theta.copy() + + for it in range(self.max_iter): + Theta_old = Theta.copy() + + for col in range(p): + if self.algo == "primal": + indices_minus_col = np.concatenate( + [indices[:col], indices[col + 1:]]) + _11 = indices_minus_col[:, None], indices_minus_col[None] + _12 = indices_minus_col, col + _22 = col, col + + elif self.algo == "dual": + if col > 0: + di = col - 1 + W_11[di] = W[di][indices != col] + W_11[:, di] = W[:, di][indices != col] + else: + W_11[:] = W[1:, 1:] + + s_12 = S[col, indices != col] + + if self.algo == "dual": + beta_init = (Theta[indices != col, col] / + (Theta[col, col] + 1000 * eps)) + Q = W_11 + + elif self.algo == "primal": + inv_Theta_11 = (W[_11] - + np.outer(W[_12], + W[_12])/W[_22]) + Q = inv_Theta_11 + beta_init = Theta[indices != col, col] * S[col, col] + else: + raise ValueError(f"Unsupported algo {self.algo}") + + beta = barebones_cd_gram( + Q, + s_12, + x=beta_init, + alpha=self.alpha, + weights=Weights[indices != col, col], + tol=self.inner_tol, + max_iter=self.max_iter, + ) + + if self.algo == "dual": + w_12 = -np.dot(W_11, beta) + W[col, indices != col] = w_12 + W[indices != col, col] = w_12 + + Theta[col, col] = 1 / \ + (W[col, col] + np.dot(beta, w_12)) + Theta[indices != col, col] = beta*Theta[col, col] + Theta[col, indices != col] = beta*Theta[col, col] + + else: # primal + s_22 = S[col, col] + + # Updating Theta + theta_12 = beta / s_22 + Theta[indices != col, col] = theta_12 + Theta[col, indices != col] = theta_12 + Theta[col, col] = (1/s_22 + + theta_12 @ + inv_Theta_11 @ + theta_12) + theta_22 = Theta[col, col] + + # Updating W + W[col, col] = (1/(theta_22 - + theta_12 @ + inv_Theta_11 @ + theta_12)) + w_22 = W[col, col] + + w_12 = (-w_22 * inv_Theta_11 @ theta_12) + W[indices != col, col] = w_12 + W[col, indices != col] = w_12 + + # Maybe W_11 can be done smarter ? + W[_11] = (inv_Theta_11 + + np.outer(w_12, + w_12)/w_22) + + if np.linalg.norm(Theta - Theta_old) < self.tol: + if self.verbose: + print(f"Weighted Glasso converged at CD epoch {it + 1}") + break + else: + if self.verbose: + print( + f"Not converged at epoch {it + 1}, " + f"diff={np.linalg.norm(Theta - Theta_old):.2e}" + ) + self.precision_, self.covariance_ = Theta, W + self.n_iter_ = it + 1 + + return self + + +class AdaptiveGraphicalLasso(): + """ An adaptive version of the Graphical Lasso that solves non-convex penalty + variations using the reweighting strategy from Candès et al., 2007.""" + + def __init__( + self, + alpha=1., + strategy="log", + n_reweights=5, + max_iter=1000, + tol=1e-8, + warm_start=False, + ): + self.alpha = alpha + self.strategy = strategy + self.n_reweights = n_reweights + self.max_iter = max_iter + self.tol = tol + self.warm_start = warm_start + + def fit(self, S): + glasso = GraphicalLasso( + alpha=self.alpha, + algo="primal", + max_iter=self.max_iter, + tol=self.tol, + warm_start=True) + Weights = np.ones(S.shape) + self.n_iter_ = [] + for it in range(self.n_reweights): + glasso.weights = Weights + glasso.fit(S) + Theta = glasso.precision_ + Weights = update_weights(Theta, self.alpha, strategy=self.strategy) + self.n_iter_.append(glasso.n_iter_) + # TODO print losses for original problem? + glasso.covariance_ = np.linalg.pinv(Theta, hermitian=True) + self.precision_ = glasso.precision_ + self.covariance_ = glasso.covariance_ + return self + + +def update_weights(Theta, alpha, strategy="log"): + if strategy == "log": + return 1/(np.abs(Theta) + 1e-10) + elif strategy == "sqrt": + return 1/(2*np.sqrt(np.abs(Theta)) + 1e-10) + elif strategy == "mcp": + gamma = 3. + Weights = np.zeros_like(Theta) + Weights[np.abs(Theta) + < gamma*alpha] = (alpha - + np.abs(Theta[np.abs(Theta) + < gamma*alpha])/gamma) + return Weights + else: + raise ValueError(f"Unknown strategy {strategy}") diff --git a/skglm/estimators.py b/skglm/estimators.py index c7749dd8c..0da3be721 100644 --- a/skglm/estimators.py +++ b/skglm/estimators.py @@ -2,11 +2,9 @@ import warnings import numpy as np -from scipy.linalg import pinvh from scipy.sparse import issparse from scipy.special import expit from numbers import Integral, Real -from skglm.solvers import ProxNewton, LBFGS from sklearn.utils.validation import (check_is_fitted, check_array, check_consistent_length) @@ -20,16 +18,12 @@ from sklearn.multiclass import OneVsRestClassifier, check_classification_targets from skglm.utils.jit_compilation import compiled_clone -from skglm.solvers import AndersonCD, MultiTaskBCD, GroupBCD +from skglm.solvers import AndersonCD, MultiTaskBCD, GroupBCD, ProxNewton, LBFGS from skglm.datafits import (Cox, Quadratic, Logistic, QuadraticSVC, - QuadraticMultiTask, QuadraticGroup, QuadraticHessian) + QuadraticMultiTask, QuadraticGroup,) from skglm.penalties import (L1, WeightedL1, L1_plus_L2, L2, WeightedGroupL2, MCPenalty, WeightedMCPenalty, IndicatorBox, L2_1) from skglm.utils.data import grp_converter -from skglm.utils.prox_funcs import ST_vec - -from numba import njit -from skglm.solvers.gram_cd import barebones_cd_gram def _glm_fit(X, y, model, datafit, penalty, solver): @@ -131,8 +125,7 @@ def _glm_fit(X, y, model, datafit, penalty, solver): w = np.zeros(n_features + fit_intercept, dtype=X_.dtype) Xw = np.zeros(n_samples, dtype=X_.dtype) else: # multitask - w = np.zeros((n_features + fit_intercept, - y.shape[1]), dtype=X_.dtype) + w = np.zeros((n_features + fit_intercept, y.shape[1]), dtype=X_.dtype) Xw = np.zeros(y.shape, dtype=X_.dtype) # check consistency of weights for WeightedL1 @@ -582,8 +575,7 @@ def path(self, X, y, alphas, coef_init=None, return_n_iter=True, **params): raise ValueError("The number of weights must match the number of \ features. Got %s, expected %s." % ( len(weights), X.shape[1])) - penalty = compiled_clone(WeightedL1( - self.alpha, weights, self.positive)) + penalty = compiled_clone(WeightedL1(self.alpha, weights, self.positive)) datafit = compiled_clone(Quadratic(), to_float32=X.dtype == np.float32) solver = AndersonCD( self.max_iter, self.max_epochs, self.p0, tol=self.tol, @@ -608,8 +600,7 @@ def fit(self, X, y): Fitted estimator. """ if self.weights is None: - warnings.warn( - 'Weights are not provided, fitting with Lasso penalty') + warnings.warn('Weights are not provided, fitting with Lasso penalty') penalty = L1(self.alpha, self.positive) else: penalty = WeightedL1(self.alpha, self.weights, self.positive) @@ -742,8 +733,7 @@ def path(self, X, y, alphas, coef_init=None, return_n_iter=True, **params): The number of iterations along the path. If return_n_iter is set to ``True``. """ - penalty = compiled_clone(L1_plus_L2( - self.alpha, self.l1_ratio, self.positive)) + penalty = compiled_clone(L1_plus_L2(self.alpha, self.l1_ratio, self.positive)) datafit = compiled_clone(Quadratic(), to_float32=X.dtype == np.float32) solver = AndersonCD( self.max_iter, self.max_epochs, self.p0, tol=self.tol, @@ -921,8 +911,7 @@ def path(self, X, y, alphas, coef_init=None, return_n_iter=True, **params): f"Got {len(self.weights)}, expected {X.shape[1]}." ) penalty = compiled_clone( - WeightedMCPenalty(self.alpha, self.gamma, - self.weights, self.positive) + WeightedMCPenalty(self.alpha, self.gamma, self.weights, self.positive) ) datafit = compiled_clone(Quadratic(), to_float32=X.dtype == np.float32) solver = AndersonCD( @@ -1325,8 +1314,7 @@ def fit(self, X, y): # copy/paste from https://github.com/scikit-learn/scikit-learn/blob/ \ # 23ff51c07ebc03c866984e93c921a8993e96d1f9/sklearn/utils/ \ # estimator_checks.py#L3886 - raise ValueError( - "requires y to be passed, but the target y is None") + raise ValueError("requires y to be passed, but the target y is None") y = check_array( y, accept_sparse=False, @@ -1341,8 +1329,7 @@ def fit(self, X, y): f"two columns. Got one column.\nAssuming that `y` " "is the vector of times and there is no censoring." ) - y = np.column_stack((y, np.ones_like(y))).astype( - X.dtype, order="F") + y = np.column_stack((y, np.ones_like(y))).astype(X.dtype, order="F") elif y.shape[1] > 2: raise ValueError( f"{repr(self)} requires the vector of response `y` to have " @@ -1367,8 +1354,7 @@ def fit(self, X, y): # init solver if self.l1_ratio == 0.: - solver = LBFGS(max_iter=self.max_iter, - tol=self.tol, verbose=self.verbose) + solver = LBFGS(max_iter=self.max_iter, tol=self.tol, verbose=self.verbose) else: solver = ProxNewton( max_iter=self.max_iter, tol=self.tol, verbose=self.verbose, @@ -1506,8 +1492,7 @@ def fit(self, X, Y): if not self.warm_start or not hasattr(self, "coef_"): self.coef_ = None - datafit_jit = compiled_clone( - QuadraticMultiTask(), X.dtype == np.float32) + datafit_jit = compiled_clone(QuadraticMultiTask(), X.dtype == np.float32) penalty_jit = compiled_clone(L2_1(self.alpha), X.dtype == np.float32) solver = MultiTaskBCD( @@ -1562,8 +1547,7 @@ def path(self, X, Y, alphas, coef_init=None, return_n_iter=False, **params): The number of iterations along the path. If return_n_iter is set to ``True``. """ - datafit = compiled_clone(QuadraticMultiTask(), - to_float32=X.dtype == np.float32) + datafit = compiled_clone(QuadraticMultiTask(), to_float32=X.dtype == np.float32) penalty = compiled_clone(L2_1(self.alpha)) solver = MultiTaskBCD( self.max_iter, self.max_epochs, self.p0, tol=self.tol, @@ -1687,8 +1671,7 @@ def fit(self, X, y): "The total number of group members must equal the number of features. " f"Got {n_features}, expected {X.shape[1]}.") - weights = np.ones( - len(group_sizes)) if self.weights is None else self.weights + weights = np.ones(len(group_sizes)) if self.weights is None else self.weights group_penalty = WeightedGroupL2(alpha=self.alpha, grp_ptr=grp_ptr, grp_indices=grp_indices, weights=weights, positive=self.positive) @@ -1699,201 +1682,3 @@ def fit(self, X, y): verbose=self.verbose) return _glm_fit(X, y, self, quad_group, group_penalty, solver) - -#################### -# WIP Graphical Lasso -#################### - - -class GraphicalLasso(): - """ A first-order BCD Graphical Lasso solver implementing the GLasso algorithm - described in Friedman et al., 2008 and the P-GLasso algorithm described in - Mazumder et al., 2012.""" - - def __init__(self, - alpha=1., - weights=None, - algo="dual", - max_iter=100, - tol=1e-8, - warm_start=False, - inner_tol=1e-4, - ): - self.alpha = alpha - self.weights = weights - self.algo = algo - self.max_iter = max_iter - self.tol = tol - self.warm_start = warm_start - self.inner_tol = inner_tol - - def fit(self, S): - p = S.shape[-1] - indices = np.arange(p) - - if self.weights is None: - Weights = np.ones((p, p)) - else: - Weights = self.weights - if not np.allclose(Weights, Weights.T): - raise ValueError("Weights should be symmetric.") - - if self.warm_start and hasattr(self, "precision_"): - if self.algo == "dual": - raise ValueError( - "dual does not support warm start for now.") - Theta = self.precision_ - W = self.covariance_ - - else: - W = S.copy() - W *= 0.95 - diagonal = S.flat[:: p + 1] - W.flat[:: p + 1] = diagonal - Theta = pinvh(W) - - W_11 = np.copy(W[1:, 1:], order="C") - eps = np.finfo(np.float64).eps - it = 0 - Theta_old = Theta.copy() - - for it in range(self.max_iter): - Theta_old = Theta.copy() - - for col in range(p): - if self.algo == "primal": - indices_minus_col = np.concatenate( - [indices[:col], indices[col + 1:]]) - _11 = indices_minus_col[:, None], indices_minus_col[None] - _12 = indices_minus_col, col - _21 = col, indices_minus_col - _22 = col, col - - elif self.algo == "dual": - if col > 0: - di = col - 1 - W_11[di] = W[di][indices != col] - W_11[:, di] = W[:, di][indices != col] - else: - W_11[:] = W[1:, 1:] - - s_12 = S[col, indices != col] - - if self.algo == "dual": - beta_init = (Theta[indices != col, col] / - (Theta[col, col] + 1000 * eps)) - Q = W_11 - - elif self.algo == "primal": - inv_Theta_11 = (W[_11] - - np.outer(W[_12], - W[_12])/W[_22]) - Q = inv_Theta_11 - beta_init = Theta[indices != col, col] * S[col, col] - else: - raise ValueError(f"Unsupported algo {self.algo}") - - beta = barebones_cd_gram( - Q, - s_12, - x=beta_init, - alpha=self.alpha, - weights=Weights[indices != col, col], - tol=self.inner_tol, - max_iter=self.max_iter, - ) - - if self.algo == "dual": - w_12 = -np.dot(W_11, beta) - W[col, indices != col] = w_12 - W[indices != col, col] = w_12 - - Theta[col, col] = 1 / \ - (W[col, col] + np.dot(beta, w_12)) - Theta[indices != col, col] = beta*Theta[col, col] - Theta[col, indices != col] = beta*Theta[col, col] - - else: # primal - Theta[indices != col, col] = beta / S[col, col] - Theta[col, indices != col] = beta / S[col, col] - Theta[col, col] = (1/S[col, col] + - Theta[col, indices != col] @ - inv_Theta_11 @ - Theta[indices != col, col]) - W[col, col] = (1/(Theta[col, col] - - Theta[indices != col, col] @ - inv_Theta_11 @ - Theta[indices != col, col])) - W[indices != col, col] = (-W[col, col] * - inv_Theta_11 @ - Theta[indices != col, col]) - W[col, indices != col] = (-W[col, col] * - inv_Theta_11 @ - Theta[indices != col, col]) - # Maybe W_11 can be done smarter ? - W[_11] = (inv_Theta_11 + - np.outer(W[indices != col, col], - W[indices != col, col])/W[col, col]) - - if np.linalg.norm(Theta - Theta_old) < self.tol: - print(f"Weighted Glasso converged at CD epoch {it + 1}") - break - else: - print( - f"Not converged at epoch {it + 1}, " - f"diff={np.linalg.norm(Theta - Theta_old):.2e}" - ) - self.precision_, self.covariance_ = Theta, W - self.n_iter_ = it + 1 - - return self - - -class AdaptiveGraphicalLasso(): - def __init__( - self, - alpha=1., - strategy="log", - n_reweights=5, - max_iter=1000, - tol=1e-8, - warm_start=False, - ): - self.alpha = alpha - self.strategy = strategy - self.n_reweights = n_reweights - self.max_iter = max_iter - self.tol = tol - self.warm_start = warm_start - - def fit(self, S): - glasso = GraphicalLasso( - alpha=self.alpha, - algo="primal", - max_iter=self.max_iter, - tol=self.tol, - warm_start=True) - Weights = np.ones(S.shape) - self.n_iter_ = [] - for it in range(self.n_reweights): - glasso.weights = Weights - glasso.fit(S) - Theta = glasso.precision_ - if self.strategy == "log": - Weights = 1/(np.abs(Theta) + 1e-10) - elif self.strategy == "sqrt": - Weights = 1/(2*np.sqrt(np.abs(Theta)) + 1e-10) - elif self.strategy == "mcp": - gamma = 3. - Weights = np.zeros_like(Theta) - Weights[np.abs(Theta) < gamma*self.alpha] = (self.alpha - - np.abs(Theta[np.abs(Theta) < gamma*self.alpha])/gamma) - else: - raise ValueError(f"Unknown strategy {self.strategy}") - - self.n_iter_.append(glasso.n_iter_) - # TODO print losses for original problem? - glasso.covariance_ = np.linalg.pinv(Theta, hermitian=True) - self.precision_ = glasso.precision_ - self.covariance_ = glasso.covariance_ - return self diff --git a/skglm/solvers/gram_cd.py b/skglm/solvers/gram_cd.py index ad639ee06..436f37a26 100644 --- a/skglm/solvers/gram_cd.py +++ b/skglm/solvers/gram_cd.py @@ -5,7 +5,7 @@ from skglm.solvers.base import BaseSolver from skglm.utils.anderson import AndersonAcceleration -from skglm.utils.prox_funcs import ST_vec +from skglm.utils.prox_funcs import ST class GramCD(BaseSolver): @@ -119,8 +119,7 @@ def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): # perform Anderson extrapolation if self.use_acc: - w_acc, grad_acc, is_extrapolated = accelerator.extrapolate( - w, grad) + w_acc, grad_acc, is_extrapolated = accelerator.extrapolate(w, grad) if is_extrapolated: # omit constant term for comparison @@ -180,19 +179,14 @@ def barebones_cd_gram(H, q, x, alpha, weights, max_iter=100, tol=1e-4): lc = np.zeros(dim) for j in range(dim): lc[j] = H[j, j] - - # Hx = H @ x Hx = np.dot(H, x) + for _ in range(max_iter): max_delta = 0 # max coeff change - for j in range(dim): x_j_prev = x[j] - x[j] = ST_vec(x[j] - (Hx[j] + q[j]) / lc[j], - alpha*weights[j] / lc[j]) - + x[j] = ST(x[j] - (Hx[j] + q[j]) / lc[j], alpha*weights[j] / lc[j]) max_delta = max(max_delta, np.abs(x_j_prev - x[j])) - if x_j_prev != x[j]: Hx += (x[j] - x_j_prev) * H[j] if max_delta <= tol: diff --git a/skglm/tests/test_covariance.py b/skglm/tests/test_covariance.py new file mode 100644 index 000000000..d83076bf9 --- /dev/null +++ b/skglm/tests/test_covariance.py @@ -0,0 +1,119 @@ +import numpy as np +import pytest + +from sklearn.covariance import GraphicalLasso as GraphicalLasso_sklearn + +from skglm.covariance import GraphicalLasso, AdaptiveGraphicalLasso +from skglm.utils.data import make_dummy_covariance_data + + +def test_GraphicalLasso_equivalence_sklearn(): + S, _, lmbd_max = make_dummy_covariance_data(200, 50) + alpha = lmbd_max / 5 + + model_sk = GraphicalLasso_sklearn( + alpha=alpha, covariance="precomputed", tol=1e-10) + model_sk.fit(S) + + for algo in ("primal", "dual"): + model = GraphicalLasso( + alpha=alpha, + warm_start=False, + max_iter=1000, + tol=1e-14, + algo=algo, + ).fit(S) + + np.testing.assert_allclose( + model.precision_, model_sk.precision_, atol=1e-4) + np.testing.assert_allclose( + model.covariance_, model_sk.covariance_, atol=1e-4) + + # check that we did not mess up lambda: + np.testing.assert_array_less(S.shape[0] + 1, (model.precision_ != 0).sum()) + + +def test_GraphicalLasso_warm_start(): + S, _, lmbd_max = make_dummy_covariance_data(200, 50) + + alpha = lmbd_max / 5 + + model = GraphicalLasso( + alpha=alpha, + warm_start=True, + max_iter=1000, + tol=1e-14, + algo="primal", + ).fit(S) + np.testing.assert_array_less(1, model.n_iter_) + + model.fit(S) + np.testing.assert_equal(model.n_iter_, 1) + + model.algo = "dual" + with pytest.raises(ValueError, match="does not support"): + model.fit(S) + + +def test_GraphicalLasso_weights(): + S, _, lmbd_max = make_dummy_covariance_data(200, 50) + + alpha = lmbd_max / 10 + + model = GraphicalLasso( + alpha=alpha, + warm_start=False, + max_iter=2000, + tol=1e-14, + algo="primal", + ).fit(S) + prec = model.precision_.copy() + + scal = 2. + model.weights = np.full(S.shape, scal) + model.alpha /= scal + model.fit(S) + np.testing.assert_allclose(prec, model.precision_) + + mask = np.random.randn(*S.shape) > 0 + mask = mask + mask.T + mask.flat[::S.shape[0] + 1] = 0 + weights = mask.astype(float) + model.weights = weights + model.fit(S) + np.testing.assert_array_less(1e-4, np.abs(model.precision_[~mask])) + + +def test_GraphicalLasso_adaptive(): + S, _, lmbd_max = make_dummy_covariance_data(200, 50) + + alpha = lmbd_max / 10 + tol = 1e-14 + model = GraphicalLasso( + alpha=alpha, + warm_start=True, + max_iter=1000, + tol=tol, + algo="primal", + ).fit(S) + n_iter = [model.n_iter_] + Theta1 = model.precision_ + # TODO test the other strategies + weights = 1 / (np.abs(Theta1) + 1e-10) + model.weights = weights + + model.fit(S) + n_iter.append(model.n_iter_) + print("ada:") + + # TODO test more than 2 reweightings? + model_a = AdaptiveGraphicalLasso( + alpha=alpha, + n_reweights=2, + tol=tol).fit(S) + + np.testing.assert_allclose(model_a.precision_, model.precision_) + np.testing.assert_allclose(model_a.n_iter_, n_iter) + + # support is decreasing: + assert not np.any(model_a.precision_[Theta1 == 0]) diff --git a/skglm/tests/test_estimators.py b/skglm/tests/test_estimators.py index c2a01995c..ec7536f19 100644 --- a/skglm/tests/test_estimators.py +++ b/skglm/tests/test_estimators.py @@ -14,23 +14,19 @@ from sklearn.linear_model import ElasticNet as ElasticNet_sklearn from sklearn.linear_model import LogisticRegression as LogReg_sklearn from sklearn.linear_model import MultiTaskLasso as MultiTaskLasso_sklearn -from sklearn.covariance import GraphicalLasso as GraphicalLasso_sklearn from sklearn.model_selection import GridSearchCV from sklearn.svm import LinearSVC as LinearSVC_sklearn from sklearn.utils.estimator_checks import check_estimator -from sklearn.utils import check_random_state from skglm.utils.data import (make_correlated_data, make_dummy_survival_data, _alpha_max_group_lasso, grp_converter) from skglm.estimators import ( GeneralizedLinearEstimator, Lasso, MultiTaskLasso, WeightedLasso, ElasticNet, - MCPRegression, SparseLogisticRegression, LinearSVC, GroupLasso, CoxEstimator, GraphicalLasso, - AdaptiveGraphicalLasso) + MCPRegression, SparseLogisticRegression, LinearSVC, GroupLasso, CoxEstimator) from skglm.datafits import Logistic, Quadratic, QuadraticSVC, QuadraticMultiTask, Cox from skglm.penalties import L1, IndicatorBox, L1_plus_L2, MCPenalty, WeightedL1, SLOPE from skglm.solvers import AndersonCD, FISTA, ProxNewton from skglm.utils.jit_compilation import compiled_clone -from skglm.utils.data import generate_GraphicalLasso_data n_samples = 50 n_tasks = 9 @@ -122,8 +118,7 @@ def test_estimator(estimator_name, X, fit_intercept, positive): pytest.xfail("Intercept is not supported for SVC.") if positive and estimator_name not in ( "Lasso", "ElasticNet", "wLasso", "MCP", "wMCP", "GroupLasso"): - pytest.xfail( - "`positive` option is only supported by L1, L1_plus_L2 and wL1.") + pytest.xfail("`positive` option is only supported by L1, L1_plus_L2 and wL1.") estimator_sk = clone(dict_estimators_sk[estimator_name]) estimator_ours = clone(dict_estimators_ours[estimator_name]) @@ -369,8 +364,7 @@ def test_equivalence_cox_SLOPE_cox_L1(use_efron, issparse): w, *_ = solver.solve(X, y, datafit, penalty) method = 'efron' if use_efron else 'breslow' - estimator = CoxEstimator( - alpha, l1_ratio=1., method=method, tol=1e-9).fit(X, y) + estimator = CoxEstimator(alpha, l1_ratio=1., method=method, tol=1e-9).fit(X, y) np.testing.assert_allclose(w, estimator.coef_, atol=1e-5) @@ -516,8 +510,7 @@ def test_grid_search(estimator_name): for attr in res_attr: np.testing.assert_allclose(sk_clf.cv_results_[attr], ours_clf.cv_results_[attr], rtol=1e-3) - np.testing.assert_allclose( - sk_clf.best_score_, ours_clf.best_score_, rtol=1e-3) + np.testing.assert_allclose(sk_clf.best_score_, ours_clf.best_score_, rtol=1e-3) np.testing.assert_allclose(sk_clf.best_params_["alpha"], ours_clf.best_params_["alpha"], rtol=1e-3) @@ -614,8 +607,7 @@ def test_SparseLogReg_elasticnet(X, l1_ratio): estimator_ours = clone(dict_estimators_ours['LogisticRegression']) estimator_sk.set_params(fit_intercept=True, solver='saga', penalty='elasticnet', l1_ratio=l1_ratio, max_iter=10_000) - estimator_ours.set_params( - fit_intercept=True, l1_ratio=l1_ratio, max_iter=10_000) + estimator_ours.set_params(fit_intercept=True, l1_ratio=l1_ratio, max_iter=10_000) estimator_sk.fit(X, y) estimator_ours.fit(X, y) @@ -628,121 +620,5 @@ def test_SparseLogReg_elasticnet(X, l1_ratio): estimator_sk.intercept_, estimator_ours.intercept_, rtol=1e-4) -####################### -# WIP Graphical Lasso tests -####################### - -def test_GraphicalLasso_equivalence_sklearn(): - S, _, lmbd_max = generate_GraphicalLasso_data(200, 50) - alpha = lmbd_max / 5 - - model_sk = GraphicalLasso_sklearn( - alpha=alpha, covariance="precomputed", tol=1e-10) - model_sk.fit(S) - - for algo in ("primal", "dual"): - model = GraphicalLasso( - alpha=alpha, - warm_start=False, - max_iter=1000, - tol=1e-14, - algo=algo, - ).fit(S) - - np.testing.assert_allclose( - model.precision_, model_sk.precision_, atol=1e-4) - np.testing.assert_allclose( - model.covariance_, model_sk.covariance_, atol=1e-4) - - # check that we did not mess up lambda: - np.testing.assert_array_less(S.shape[0] + 1, (model.precision_ != 0).sum()) - - -def test_GraphicalLasso_warm_start(): - S, _, lmbd_max = generate_GraphicalLasso_data(200, 50) - - alpha = lmbd_max / 5 - - model = GraphicalLasso( - alpha=alpha, - warm_start=True, - max_iter=1000, - tol=1e-14, - algo="primal", - ).fit(S) - np.testing.assert_array_less(1, model.n_iter_) - - model.fit(S) - np.testing.assert_equal(model.n_iter_, 1) - - model.algo = "dual" - with pytest.raises(ValueError, match="does not support"): - model.fit(S) - - -def test_GraphicalLasso_weights(): - S, _, lmbd_max = generate_GraphicalLasso_data(200, 50) - - alpha = lmbd_max / 10 - - model = GraphicalLasso( - alpha=alpha, - warm_start=False, - max_iter=2000, - tol=1e-14, - algo="primal", - ).fit(S) - prec = model.precision_.copy() - - scal = 2. - model.weights = np.full(S.shape, scal) - model.alpha /= scal - model.fit(S) - np.testing.assert_allclose(prec, model.precision_) - - mask = np.random.randn(*S.shape) > 0 - mask = mask + mask.T - mask.flat[::S.shape[0] + 1] = 0 - weights = mask.astype(float) - model.weights = weights - model.fit(S) - np.testing.assert_array_less(1e-4, np.abs(model.precision_[~mask])) - - -def test_GraphicalLasso_adaptive(): - S, _, lmbd_max = generate_GraphicalLasso_data(200, 50) - - alpha = lmbd_max / 10 - tol = 1e-14 - model = GraphicalLasso( - alpha=alpha, - warm_start=True, - max_iter=1000, - tol=tol, - algo="primal", - ).fit(S) - n_iter = [model.n_iter_] - Theta1 = model.precision_ - # TODO test the other strategies - weights = 1 / (np.abs(Theta1) + 1e-10) - model.weights = weights - - model.fit(S) - n_iter.append(model.n_iter_) - print("ada:") - - # TODO test more than 2 reweightings? - model_a = AdaptiveGraphicalLasso( - alpha=alpha, - n_reweights=2, - tol=tol).fit(S) - - np.testing.assert_allclose(model_a.precision_, model.precision_) - np.testing.assert_allclose(model_a.n_iter_, n_iter) - - # support is decreasing: - assert not np.any(model_a.precision_[Theta1 == 0]) - - if __name__ == "__main__": pass diff --git a/skglm/utils/data.py b/skglm/utils/data.py index 7fdd1ef02..20601116f 100644 --- a/skglm/utils/data.py +++ b/skglm/utils/data.py @@ -255,7 +255,7 @@ def _alpha_max_group_lasso(X, y, grp_indices, grp_ptr, weights): return alpha_max -def generate_GraphicalLasso_data(n_samples, n_features): +def make_dummy_covariance_data(n_samples, n_features): rng = check_random_state(0) Theta_true = make_sparse_spd_matrix( n_features, alpha=0.9, random_state=rng) From a4ea3fd6689ec6409fbf6e1fee1c0ec684359188 Mon Sep 17 00:00:00 2001 From: Can Date: Wed, 2 Apr 2025 17:53:08 +0200 Subject: [PATCH 06/13] remove snakecase function names in test functions --- skglm/tests/test_covariance.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/skglm/tests/test_covariance.py b/skglm/tests/test_covariance.py index d83076bf9..3fb6cd26d 100644 --- a/skglm/tests/test_covariance.py +++ b/skglm/tests/test_covariance.py @@ -7,7 +7,7 @@ from skglm.utils.data import make_dummy_covariance_data -def test_GraphicalLasso_equivalence_sklearn(): +def test_glasso_equivalence_sklearn(): S, _, lmbd_max = make_dummy_covariance_data(200, 50) alpha = lmbd_max / 5 @@ -33,7 +33,7 @@ def test_GraphicalLasso_equivalence_sklearn(): np.testing.assert_array_less(S.shape[0] + 1, (model.precision_ != 0).sum()) -def test_GraphicalLasso_warm_start(): +def test_glasso_warm_start(): S, _, lmbd_max = make_dummy_covariance_data(200, 50) alpha = lmbd_max / 5 @@ -55,7 +55,7 @@ def test_GraphicalLasso_warm_start(): model.fit(S) -def test_GraphicalLasso_weights(): +def test_glasso_weights(): S, _, lmbd_max = make_dummy_covariance_data(200, 50) alpha = lmbd_max / 10 @@ -84,7 +84,7 @@ def test_GraphicalLasso_weights(): np.testing.assert_array_less(1e-4, np.abs(model.precision_[~mask])) -def test_GraphicalLasso_adaptive(): +def test_glasso_adaptive(): S, _, lmbd_max = make_dummy_covariance_data(200, 50) alpha = lmbd_max / 10 From 3148357048136e30c72364f5659aa8fd84f97dbd Mon Sep 17 00:00:00 2001 From: floriankozikowski Date: Fri, 11 Apr 2025 18:23:20 +0200 Subject: [PATCH 07/13] adjust weight updates, still need to test --- skglm/covariance.py | 130 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 130 insertions(+) diff --git a/skglm/covariance.py b/skglm/covariance.py index 9294a220b..b8aa79b98 100644 --- a/skglm/covariance.py +++ b/skglm/covariance.py @@ -1,9 +1,12 @@ # License: BSD 3 clause +from skglm.utils.data import make_dummy_covariance_data +import matplotlib.pyplot as plt import numpy as np from scipy.linalg import pinvh from skglm.solvers.gram_cd import barebones_cd_gram +from skglm.penalties import L0_5 class GraphicalLasso(): @@ -159,6 +162,94 @@ def fit(self, S): return self +# class AdaptiveGraphicalLasso(): +# """ An adaptive version of the Graphical Lasso that solves non-convex penalty +# variations using the reweighting strategy from Candès et al., 2007.""" + +# def __init__( +# self, +# alpha=1., +# # strategy="log", +# n_reweights=5, +# max_iter=1000, +# tol=1e-8, +# warm_start=False, +# penalty=L0_5(1.), +# ): +# self.alpha = alpha +# # self.strategy = strategy # we can remove this param. it if not used elsewhere +# self.n_reweights = n_reweights +# self.max_iter = max_iter +# self.tol = tol +# self.warm_start = warm_start +# self.penalty = penalty + +# def fit(self, S): +# """ Fit the AdaptiveGraphicalLasso model on the empirical covariance matrix S.""" +# glasso = GraphicalLasso( +# alpha=self.alpha, +# algo="primal", +# max_iter=self.max_iter, +# tol=self.tol, +# warm_start=True) +# Weights = np.ones(S.shape) +# self.n_iter_ = [] +# for it in range(self.n_reweights): +# glasso.weights = Weights +# glasso.fit(S) +# Theta = glasso.precision_ + +# Weights = abs(self.penalty.derivative(Theta)) + +# self.n_iter_.append(glasso.n_iter_) +# # TODO print losses for original problem? +# glasso.covariance_ = np.linalg.pinv(Theta, hermitian=True) +# self.precision_ = glasso.precision_ +# self.covariance_ = glasso.covariance_ +# return self + + +# if __name__ == "__main__": +# import matplotlib.pyplot as plt +# from skglm.utils.data import make_dummy_covariance_data +# from skglm.penalties import L1 # Import L0_5 + +# # Define the dimensions for the dummy data +# n = 100 # number of samples +# p = 20 # number of features + +# # Create dummy covariance data +# S, Theta_true, alpha_max = make_dummy_covariance_data(n, p) + +# # Compute the true covariance matrix as the pseudoinverse of the true precision matrix +# true_covariance = np.linalg.pinv(Theta_true, hermitian=True) + +# # Instantiate the AdaptiveGraphicalLasso model with L0_5 penalty +# model = AdaptiveGraphicalLasso( +# # Pass L0_5 object +# alpha=alpha_max * 0.1, n_reweights=5, tol=1e-8, warm_start=True, penalty=L0_5(1.)) + +# # Fit the model on the empirical covariance matrix S +# model.fit(S) + +# # Compute normalized mean squared error (NMSE) between the true and estimated covariance matrices +# nmse = np.linalg.norm(model.covariance_ - true_covariance)**2 / \ +# np.linalg.norm(true_covariance)**2 +# print("Normalized MSE (NMSE): {:.3e}".format(nmse)) + +# fig, axes = plt.subplots(1, 2, figsize=(12, 5)) + +# im0 = axes[0].imshow(true_covariance, cmap="hot", interpolation="nearest") +# axes[0].set_title("True Covariance Matrix") +# plt.colorbar(im0, ax=axes[0]) + +# im1 = axes[1].imshow(model.covariance_, cmap="hot", interpolation="nearest") +# axes[1].set_title("Estimated Covariance Matrix\n(Adaptive Graphical Lasso)") +# plt.colorbar(im1, ax=axes[1]) + +# plt.tight_layout() +# plt.show() + class AdaptiveGraphicalLasso(): """ An adaptive version of the Graphical Lasso that solves non-convex penalty variations using the reweighting strategy from Candès et al., 2007.""" @@ -216,3 +307,42 @@ def update_weights(Theta, alpha, strategy="log"): return Weights else: raise ValueError(f"Unknown strategy {strategy}") + + +if __name__ == "__main__": + + # Define the dimensions for the dummy data + n = 100 # number of samples + p = 20 # number of features + + # Create dummy covariance data + S, Theta_true, alpha_max = make_dummy_covariance_data(n, p) + + # Compute the true covariance matrix as the pseudoinverse of the true precision matrix + true_covariance = np.linalg.pinv(Theta_true, hermitian=True) + + # Instantiate the AdaptiveGraphicalLasso model with L0_5 penalty + model = AdaptiveGraphicalLasso( + # Pass L0_5 object + alpha=alpha_max * 0.1, n_reweights=5, tol=1e-8, warm_start=True) + + # Fit the model on the empirical covariance matrix S + model.fit(S) + + # Compute normalized mean squared error (NMSE) between the true and estimated covariance matrices + nmse = np.linalg.norm(model.covariance_ - true_covariance)**2 / \ + np.linalg.norm(true_covariance)**2 + print("Normalized MSE (NMSE): {:.3e}".format(nmse)) + + fig, axes = plt.subplots(1, 2, figsize=(12, 5)) + + im0 = axes[0].imshow(true_covariance, cmap="hot", interpolation="nearest") + axes[0].set_title("True Covariance Matrix") + plt.colorbar(im0, ax=axes[0]) + + im1 = axes[1].imshow(model.covariance_, cmap="hot", interpolation="nearest") + axes[1].set_title("Estimated Covariance Matrix\n(Adaptive Graphical Lasso)") + plt.colorbar(im1, ax=axes[1]) + + plt.tight_layout() + plt.show() From 082c00d27d9bb6cb81c09e238ec43f749e0e4a84 Mon Sep 17 00:00:00 2001 From: floriankozikowski Date: Mon, 14 Apr 2025 16:37:08 +0200 Subject: [PATCH 08/13] trying out different update methods, no success so far --- skglm/covariance.py | 279 ++++++++++++++++++++++++++------------------ 1 file changed, 163 insertions(+), 116 deletions(-) diff --git a/skglm/covariance.py b/skglm/covariance.py index b8aa79b98..d1304c7f7 100644 --- a/skglm/covariance.py +++ b/skglm/covariance.py @@ -1,5 +1,10 @@ # License: BSD 3 clause +from mpl_toolkits.axes_grid1 import make_axes_locatable +import matplotlib.ticker as mticker +import matplotlib.colors as mcolors +from skglm.penalties.separable import LogSumPenalty +from sklearn.datasets import make_sparse_spd_matrix from skglm.utils.data import make_dummy_covariance_data import matplotlib.pyplot as plt import numpy as np @@ -162,95 +167,66 @@ def fit(self, S): return self -# class AdaptiveGraphicalLasso(): -# """ An adaptive version of the Graphical Lasso that solves non-convex penalty -# variations using the reweighting strategy from Candès et al., 2007.""" - -# def __init__( -# self, -# alpha=1., -# # strategy="log", -# n_reweights=5, -# max_iter=1000, -# tol=1e-8, -# warm_start=False, -# penalty=L0_5(1.), -# ): -# self.alpha = alpha -# # self.strategy = strategy # we can remove this param. it if not used elsewhere -# self.n_reweights = n_reweights -# self.max_iter = max_iter -# self.tol = tol -# self.warm_start = warm_start -# self.penalty = penalty - -# def fit(self, S): -# """ Fit the AdaptiveGraphicalLasso model on the empirical covariance matrix S.""" -# glasso = GraphicalLasso( -# alpha=self.alpha, -# algo="primal", -# max_iter=self.max_iter, -# tol=self.tol, -# warm_start=True) -# Weights = np.ones(S.shape) -# self.n_iter_ = [] -# for it in range(self.n_reweights): -# glasso.weights = Weights -# glasso.fit(S) -# Theta = glasso.precision_ - -# Weights = abs(self.penalty.derivative(Theta)) - -# self.n_iter_.append(glasso.n_iter_) -# # TODO print losses for original problem? -# glasso.covariance_ = np.linalg.pinv(Theta, hermitian=True) -# self.precision_ = glasso.precision_ -# self.covariance_ = glasso.covariance_ -# return self - - -# if __name__ == "__main__": -# import matplotlib.pyplot as plt -# from skglm.utils.data import make_dummy_covariance_data -# from skglm.penalties import L1 # Import L0_5 - -# # Define the dimensions for the dummy data -# n = 100 # number of samples -# p = 20 # number of features - -# # Create dummy covariance data -# S, Theta_true, alpha_max = make_dummy_covariance_data(n, p) - -# # Compute the true covariance matrix as the pseudoinverse of the true precision matrix -# true_covariance = np.linalg.pinv(Theta_true, hermitian=True) - -# # Instantiate the AdaptiveGraphicalLasso model with L0_5 penalty -# model = AdaptiveGraphicalLasso( -# # Pass L0_5 object -# alpha=alpha_max * 0.1, n_reweights=5, tol=1e-8, warm_start=True, penalty=L0_5(1.)) - -# # Fit the model on the empirical covariance matrix S -# model.fit(S) - -# # Compute normalized mean squared error (NMSE) between the true and estimated covariance matrices -# nmse = np.linalg.norm(model.covariance_ - true_covariance)**2 / \ -# np.linalg.norm(true_covariance)**2 -# print("Normalized MSE (NMSE): {:.3e}".format(nmse)) - -# fig, axes = plt.subplots(1, 2, figsize=(12, 5)) - -# im0 = axes[0].imshow(true_covariance, cmap="hot", interpolation="nearest") -# axes[0].set_title("True Covariance Matrix") -# plt.colorbar(im0, ax=axes[0]) - -# im1 = axes[1].imshow(model.covariance_, cmap="hot", interpolation="nearest") -# axes[1].set_title("Estimated Covariance Matrix\n(Adaptive Graphical Lasso)") -# plt.colorbar(im1, ax=axes[1]) - -# plt.tight_layout() -# plt.show() - -class AdaptiveGraphicalLasso(): +class AdaptiveGraphicalLassoPenalty(): + """ An adaptive version of the Graphical Lasso that solves non-convex penalty + variations using the reweighting strategy from Candès et al., 2007.""" + + def __init__( + self, + alpha=1., + # strategy="log", + n_reweights=5, + max_iter=1000, + tol=1e-8, + warm_start=False, + penalty=L0_5(1.), + ): + self.alpha = alpha + # self.strategy = strategy # we can remove this param. it if not used elsewhere + self.n_reweights = n_reweights + self.max_iter = max_iter + self.tol = tol + self.warm_start = warm_start + self.penalty = penalty + + def fit(self, S): + """ Fit the AdaptiveGraphicalLasso model on the empirical covariance matrix S.""" + glasso = GraphicalLasso( + alpha=self.alpha, + algo="primal", + max_iter=self.max_iter, + tol=self.tol, + warm_start=True) + Weights = np.ones(S.shape) + self.n_iter_ = [] + for it in range(self.n_reweights): + glasso.weights = Weights + glasso.fit(S) + Theta = glasso.precision_ + + Weights = abs(self.penalty.derivative(Theta)) + # Theta = (Theta + Theta.T) / 2 + # Weights = (Weights + Weights.T) / 2 + # np.fill_diagonal(Weights, 0) + + print( + f"Min/Max Weights after penalty derivative: {Weights.min():.2e}, {Weights.max():.2e}") + + self.n_iter_.append(glasso.n_iter_) + # TODO print losses for original problem? + + glasso.covariance_ = np.linalg.pinv(Theta, hermitian=True) + self.precision_ = glasso.precision_ + self.covariance_ = glasso.covariance_ + if not np.isclose(self.alpha, self.penalty.alpha): + print( + f"Alpha mismatch: GLasso alpha = {self.alpha}, Penalty alpha = {self.penalty.alpha}") + else: + print(f"Alpha values match: {self.alpha}") + return self + + +class AdaptiveGraphicalLassoStrategy(): """ An adaptive version of the Graphical Lasso that solves non-convex penalty variations using the reweighting strategy from Candès et al., 2007.""" @@ -309,40 +285,111 @@ def update_weights(Theta, alpha, strategy="log"): raise ValueError(f"Unknown strategy {strategy}") -if __name__ == "__main__": - - # Define the dimensions for the dummy data - n = 100 # number of samples - p = 20 # number of features +# Testing - # Create dummy covariance data - S, Theta_true, alpha_max = make_dummy_covariance_data(n, p) +def frobenius_norm_diff(A, B): + """Relative Frobenius norm difference between A and B.""" + return np.linalg.norm(A - B, ord='fro') / np.linalg.norm(B, ord='fro') - # Compute the true covariance matrix as the pseudoinverse of the true precision matrix - true_covariance = np.linalg.pinv(Theta_true, hermitian=True) - # Instantiate the AdaptiveGraphicalLasso model with L0_5 penalty - model = AdaptiveGraphicalLasso( - # Pass L0_5 object - alpha=alpha_max * 0.1, n_reweights=5, tol=1e-8, warm_start=True) +def generate_problem(dim=20, n_samples=100, seed=42): + """Generate data from a known sparse precision matrix.""" + np.random.seed(seed) - # Fit the model on the empirical covariance matrix S - model.fit(S) + # Ground-truth sparse precision matrix (positive definite) + Theta_true = make_sparse_spd_matrix(n_dim=dim, alpha=0.95, smallest_coef=0.1) + Sigma_true = np.linalg.inv(Theta_true) - # Compute normalized mean squared error (NMSE) between the true and estimated covariance matrices - nmse = np.linalg.norm(model.covariance_ - true_covariance)**2 / \ - np.linalg.norm(true_covariance)**2 - print("Normalized MSE (NMSE): {:.3e}".format(nmse)) + # Sample from multivariate normal to get empirical covariance matrix S + X = np.random.multivariate_normal(np.zeros(dim), Sigma_true, size=n_samples) + S = np.cov(X, rowvar=False) - fig, axes = plt.subplots(1, 2, figsize=(12, 5)) + return S, Theta_true - im0 = axes[0].imshow(true_covariance, cmap="hot", interpolation="nearest") - axes[0].set_title("True Covariance Matrix") - plt.colorbar(im0, ax=axes[0]) - im1 = axes[1].imshow(model.covariance_, cmap="hot", interpolation="nearest") - axes[1].set_title("Estimated Covariance Matrix\n(Adaptive Graphical Lasso)") - plt.colorbar(im1, ax=axes[1]) +if __name__ == "__main__": + # Set test parameters + dim = 20 + alpha = 0.1 + n_reweights = 5 + seed = 42 + + # Get empirical covariance and ground truth + S, Theta_true = generate_problem(dim=dim, seed=seed) + + # Define non-convex penalty — this is consistent with 'log' strategy + penalty = LogSumPenalty(alpha=alpha, eps=1e-10) + + # Fit new penalty-based model + model_penalty = AdaptiveGraphicalLassoPenalty( + alpha=alpha, + penalty=penalty, + n_reweights=n_reweights, + ) + model_penalty.fit(S) + + # Fit old strategy-based model + model_strategy = AdaptiveGraphicalLassoStrategy( + alpha=alpha, + strategy="log", + n_reweights=n_reweights, + ) + model_strategy.fit(S) + + # Extract precision matrices + Theta_penalty = model_penalty.precision_ + Theta_strategy = model_strategy.precision_ + + # Compare the two estimated models + rel_diff_between_models = frobenius_norm_diff(Theta_penalty, Theta_strategy) + print( + f"\n Frobenius norm relative difference between models: {rel_diff_between_models:.2e}") + print(" Matrices are close?", np.allclose( + Theta_penalty, Theta_strategy, atol=1e-4)) + + # Compare both to ground truth + rel_diff_penalty_vs_true = frobenius_norm_diff(Theta_penalty, Theta_true) + rel_diff_strategy_vs_true = frobenius_norm_diff(Theta_strategy, Theta_true) + + print( + f"\n Penalty vs true Θ: Frobenius norm diff = {rel_diff_penalty_vs_true:.2e}") + print( + f"Strategy vs true Θ: Frobenius norm diff = {rel_diff_strategy_vs_true:.2e}") + + print("\nTrue precision matrix:\n", Theta_true) + print("\nPenalty-based estimate:\n", Theta_penalty) + print("\nStrategy-based estimate:\n", Theta_strategy) + + # Visualization + n_features = Theta_true.shape[0] + + plt.close('all') + cmap = plt.cm.bwr + + matrices = [Theta_true, Theta_penalty, Theta_strategy] + titles = [r"$\Theta_{\mathrm{True}}$", + r"$\Theta_{\mathrm{Penalty}}$", r"$\Theta_{\mathrm{Strategy}}$"] + + fig, ax = plt.subplots(3, 1, layout="constrained", + figsize=(4.42, 9.33)) + + vmax = max(np.max(mat) for mat in matrices) / 2 + vmin = min(np.min(mat) for mat in matrices) / 2 + norm = mcolors.TwoSlopeNorm(vmin=vmin, vmax=vmax, vcenter=0) + + for i in range(3): + im = ax[i].imshow(matrices[i], cmap=cmap, norm=norm) + sparsity = 100 * (1 - np.count_nonzero(matrices[i]) / (n_features**2)) + ax[i].set_title(f"{titles[i]}\nsparsity = {sparsity:.2f}%", fontsize=12) + ax[i].set_xticks([]) + ax[i].set_yticks([]) + + divider = make_axes_locatable(ax[i]) + cax = divider.append_axes("right", size="3%", pad=0.05) + cbar = fig.colorbar(im, cax=cax, orientation='vertical') + ticks_loc = cbar.ax.get_yticks().tolist() + cbar.ax.yaxis.set_major_locator(mticker.FixedLocator(ticks_loc)) + cbar.ax.set_yticklabels([f'{i:.0e}' for i in cbar.get_ticks()]) + cbar.ax.tick_params(labelsize=10) - plt.tight_layout() plt.show() From a4e192f2152a69a394be0725d736bd7f34c0900c Mon Sep 17 00:00:00 2001 From: floriankozikowski Date: Tue, 22 Apr 2025 16:27:05 +0200 Subject: [PATCH 09/13] empty From 33ca52cbf0999f79f653766484584605b66a95c3 Mon Sep 17 00:00:00 2001 From: floriankozikowski Date: Tue, 22 Apr 2025 16:42:28 +0200 Subject: [PATCH 10/13] add explicit handling of zeros in penalties.derivative, works now, todos: review math, check if it works with all other penalties and clean up code to one version if everything is correct --- skglm/covariance.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/skglm/covariance.py b/skglm/covariance.py index d1304c7f7..62a530e04 100644 --- a/skglm/covariance.py +++ b/skglm/covariance.py @@ -204,10 +204,12 @@ def fit(self, S): glasso.fit(S) Theta = glasso.precision_ - Weights = abs(self.penalty.derivative(Theta)) - # Theta = (Theta + Theta.T) / 2 - # Weights = (Weights + Weights.T) / 2 - # np.fill_diagonal(Weights, 0) + Theta_sym = (Theta + Theta.T) / 2 + Weights = np.where( + Theta_sym == 0, + 1 / self.penalty.eps, + np.abs(self.penalty.derivative(Theta_sym)) + ) print( f"Min/Max Weights after penalty derivative: {Weights.min():.2e}, {Weights.max():.2e}") From fc7000923f92f6fecbee68e72da7d57e96d1f78d Mon Sep 17 00:00:00 2001 From: floriankozikowski Date: Tue, 22 Apr 2025 16:54:21 +0200 Subject: [PATCH 11/13] leave original name of old strategy based version, so tests dont fail --- skglm/covariance.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skglm/covariance.py b/skglm/covariance.py index 62a530e04..c485b4e18 100644 --- a/skglm/covariance.py +++ b/skglm/covariance.py @@ -228,7 +228,7 @@ def fit(self, S): return self -class AdaptiveGraphicalLassoStrategy(): +class AdaptiveGraphicalLasso(): """ An adaptive version of the Graphical Lasso that solves non-convex penalty variations using the reweighting strategy from Candès et al., 2007.""" From a927e501b51ddac7df6a758ae70d7b1891920790 Mon Sep 17 00:00:00 2001 From: floriankozikowski Date: Tue, 22 Apr 2025 17:08:17 +0200 Subject: [PATCH 12/13] fix minor name dependency --- skglm/covariance.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skglm/covariance.py b/skglm/covariance.py index c485b4e18..3d889347f 100644 --- a/skglm/covariance.py +++ b/skglm/covariance.py @@ -331,7 +331,7 @@ def generate_problem(dim=20, n_samples=100, seed=42): model_penalty.fit(S) # Fit old strategy-based model - model_strategy = AdaptiveGraphicalLassoStrategy( + model_strategy = AdaptiveGraphicalLasso( alpha=alpha, strategy="log", n_reweights=n_reweights, From b62990d5f3fd699d64ff82cc693c7d0caf597c52 Mon Sep 17 00:00:00 2001 From: mathurinm Date: Wed, 30 Apr 2025 14:58:16 +0200 Subject: [PATCH 13/13] ci trigger