Skip to content

Commit 8a4fbbd

Browse files
committed
bleugh
1 parent bdcd084 commit 8a4fbbd

File tree

5 files changed

+292
-109
lines changed

5 files changed

+292
-109
lines changed

geolearn_env.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,14 @@ dependencies:
4141
- optuna
4242
- plotly
4343
- openpyxl
44+
- skorch
4445

4546
- pip:
4647
- morphsnakes
4748
- xmltodict
4849
- simpledbf
4950
- pyfftw
5051
- phasepack
51-
5252
- SimpleCRF
5353
- git+https://github.com/Ciaran1981/geospatial-learn#egg=geospatial-learn
5454
#py

geospatial_learn/convutils.py

+38-48
Original file line numberDiff line numberDiff line change
@@ -14,26 +14,24 @@
1414
import numpy as np
1515
from skimage.exposure import rescale_intensity
1616
import os
17-
from glob2 import glob
17+
from glob import glob
1818
import matplotlib.pyplot as plt
1919
# Albumentations
2020
from collections import defaultdict
2121
import copy
2222
import random
2323
import albumentations as A
2424
from albumentations.pytorch import ToTensorV2
25-
import ternausnet.models
2625
from tqdm import tqdm
2726
import torch
2827
import torch.backends.cudnn as cudnn
2928
import torch.nn as nn
3029
import torch.optim
3130
from torch.utils.data import Dataset, DataLoader
3231
from torchvision.models import segmentation
33-
import gdal
32+
from osgeo import gdal
3433
import segmentation_models_pytorch as smp
3534
import skimage.morphology as skm
36-
import gdal
3735
import pandas as pd
3836
from mpl_toolkits.axes_grid1 import ImageGrid
3937
gdal.UseExceptions()
@@ -472,51 +470,43 @@ def create_model(params, activation, proc="cuda:0"):
472470

473471
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
474472

475-
if params["model"] == "UNet11" or params["model"] == "UNet16":
476-
model = getattr(ternausnet.models, params["model"])(pretrained=True)
477-
if torch.cuda.device_count() > 1:
478-
#consider also DistributedDataParallel
479-
model= nn.DataParallel(model)
480-
hrdWare = torch.device(proc)
481-
model = model.to(hrdWare)
482-
473+
474+
#Unet, UNet16, ULinknet, FPN, PSPNet,PAN, DeepLabV3 and DeepLabV3+
475+
if params["model"] == 'Unet':
476+
model = smp.Unet(encoder_name=params['encoder'],
477+
classes=params['classes'],in_channels=params['in_channels'],
478+
activation=activation)
479+
if params["model"] == 'Linknet':
480+
model = smp.Linknet(encoder_name=params['encoder'],
481+
classes=params['classes'],in_channels=params['in_channels'],
482+
activation=activation)
483+
if params["model"] == 'FPN':
484+
model = smp.FPN(encoder_name=params['encoder'],
485+
classes=params['classes'],in_channels=params['in_channels'],
486+
activation=activation)
487+
if params["model"] == 'PSPNet':
488+
model = smp.PSPNet(encoder_name=params['encoder'],
489+
classes=params['classes'],in_channels=params['in_channels'],
490+
activation=activation)
491+
if params["model"] == 'PAN':
492+
model = smp.PAN(encoder_name=params['encoder'],
493+
classes=params['classes'],in_channels=params['in_channels'],
494+
activation=activation)
495+
if params["model"] == 'DeepLabV3':
496+
model = smp.DeepLabV3(encoder_name=params['encoder'],
497+
classes=params['classes'],in_channels=params['in_channels'],
498+
activation=activation)
499+
if params["model"] == 'DeepLabV3+':
500+
model = smp.DeepLabV3(encoder_name=params['encoder'],
501+
classes=params['classes'],in_channels=params['in_channels'],
502+
activation=activation)
503+
if torch.cuda.device_count() > 1:
504+
#consider also DistributedDataParallel
505+
model= nn.DataParallel(model)
506+
model = model.to(device)
483507
else:
484-
#Unet, UNet16, ULinknet, FPN, PSPNet,PAN, DeepLabV3 and DeepLabV3+
485-
if params["model"] == 'Unet':
486-
model = smp.Unet(encoder_name=params['encoder'],
487-
classes=params['classes'],in_channels=params['in_channels'],
488-
activation=activation)
489-
if params["model"] == 'Linknet':
490-
model = smp.Linknet(encoder_name=params['encoder'],
491-
classes=params['classes'],in_channels=params['in_channels'],
492-
activation=activation)
493-
if params["model"] == 'FPN':
494-
model = smp.FPN(encoder_name=params['encoder'],
495-
classes=params['classes'],in_channels=params['in_channels'],
496-
activation=activation)
497-
if params["model"] == 'PSPNet':
498-
model = smp.PSPNet(encoder_name=params['encoder'],
499-
classes=params['classes'],in_channels=params['in_channels'],
500-
activation=activation)
501-
if params["model"] == 'PAN':
502-
model = smp.PAN(encoder_name=params['encoder'],
503-
classes=params['classes'],in_channels=params['in_channels'],
504-
activation=activation)
505-
if params["model"] == 'DeepLabV3':
506-
model = smp.DeepLabV3(encoder_name=params['encoder'],
507-
classes=params['classes'],in_channels=params['in_channels'],
508-
activation=activation)
509-
if params["model"] == 'DeepLabV3+':
510-
model = smp.DeepLabV3(encoder_name=params['encoder'],
511-
classes=params['classes'],in_channels=params['in_channels'],
512-
activation=activation)
513-
if torch.cuda.device_count() > 1:
514-
#consider also DistributedDataParallel
515-
model= nn.DataParallel(model)
516-
model = model.to(device)
517-
else:
518-
hrdWare = torch.device(proc)
519-
model = model.to(hrdWare)
508+
hrdWare = torch.device(proc)
509+
model = model.to(hrdWare)
520510

521511
return model
522512

geospatial_learn/learning.py

+86-58
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@
2727
import numpy as np
2828
from sklearn.pipeline import Pipeline
2929
from sklearn.model_selection import (StratifiedKFold, GroupKFold, KFold,
30-
train_test_split,GroupShuffleSplit, PredefinedSplit)
30+
train_test_split,GroupShuffleSplit,
31+
StratifiedGroupKFold,
32+
PredefinedSplit)
3133
from sklearn.ensemble import (RandomForestClassifier, ExtraTreesClassifier,
3234
GradientBoostingClassifier,RandomForestRegressor,
3335
GradientBoostingRegressor, ExtraTreesRegressor,
@@ -39,6 +41,7 @@
3941
from sklearn.preprocessing import (LabelEncoder, MaxAbsScaler, MinMaxScaler,
4042
Normalizer, PowerTransformer,StandardScaler,
4143
QuantileTransformer)
44+
from sklearn.svm import SVC, SVR, NuSVC, NuSVR, LinearSVC, LinearSVR
4245
from sklearn.feature_selection import VarianceThreshold, RFECV
4346
from sklearn.inspection import permutation_importance
4447
from sklearn import metrics
@@ -256,7 +259,7 @@ def objective(trial, X, y, cv, group, score=scr):
256259
print(f"\tBest value (rmse or r2): {study.best_value:.5f}")
257260
print(f"\tBest params:")
258261

259-
def _group_cv(X_train, y_train, group, test_size=0.2, cv=10):
262+
def _group_cv(X_train, y_train, group, test_size=0.2, cv=10, strat=False):
260263

261264
"""
262265
Return the splits and and vars for a group grid search
@@ -275,17 +278,27 @@ def _group_cv(X_train, y_train, group, test_size=0.2, cv=10):
275278
y_train = y_train[train_inds]
276279
group_trn = group[train_inds]
277280

278-
group_kfold = GroupKFold(n_splits=cv)
279-
# Create a nested list of train and test indices for each fold
280-
k_kfold = group_kfold.split(X_train, y_train, group_trn)
281+
if strat == True:
282+
group_kfold = StratifiedGroupKFold(n_splits=cv).split(X_train,
283+
y_train,
284+
group_trn)
285+
else:
286+
group_kfold = GroupKFold(n_splits=cv).split(X_train,
287+
y_train,
288+
group_trn)
289+
290+
# all this not required produces same as above - keep for ref though
291+
# # Create a nested list of train and test indices for each fold
292+
# k_kfold = group_kfold.split(X_train, y_train, groups=group_trn)
281293

282-
train_ind2, test_ind2 = [list(traintest) for traintest in zip(*k_kfold)]
294+
# train_ind2, test_ind2 = [list(traintest) for traintest in zip(*k_kfold)]
283295

284-
cv = [*zip(train_ind2, test_ind2)]
296+
# cv = [*zip(train_ind2, test_ind2)]
285297

286-
return X_train, y_train, X_test, y_test, cv
298+
return X_train, y_train, X_test, y_test, group_kfold
287299

288-
def rec_feat_sel(X_train, featnames, preproc=('scaler', None), clf='erf', group=None,
300+
def rec_feat_sel(X_train, featnames, preproc=('scaler', None), clf='erf',
301+
group=None,
289302
cv=5, params=None, cores=-1, strat=True,
290303
test_size=0.3, regress=False, return_test=True,
291304
scoring=None, class_names=None, save=True, cat_feat=None):
@@ -550,7 +563,10 @@ class names in order of their numercial equivalents
550563
# devices='0:1'),
551564
'lgbm': lgb.LGBMClassifier(random_state=0),
552565

553-
'hgb': HistGradientBoostingClassifier(random_state=0)}
566+
'hgb': HistGradientBoostingClassifier(random_state=0),
567+
'svm': SVC(),
568+
'nusvc': NuSVC(),
569+
'linsvc': LinearSVC()}
554570

555571
regdict = {'rf': RandomForestRegressor(random_state=0),
556572
'erf': ExtraTreesRegressor(random_state=0),
@@ -563,16 +579,19 @@ class names in order of their numercial equivalents
563579
# task_type="GPU",
564580
# devices='0:1'),
565581
'lgbm': lgb.LGBMRegressor(random_state=0),
566-
567-
'hgb': HistGradientBoostingRegressor(random_state=0)}
582+
'hgb': HistGradientBoostingRegressor(random_state=0),
583+
'svm': SVR(),
584+
'nusvc': NuSVR(),
585+
'linsvc': LinearSVR()}
568586

569587
if regress is True:
570588
model = regdict[clf]
571589
if scoring is None:
572590
scoring = 'r2'
573591
else:
574592
model = clfdict[clf]
575-
cv = StratifiedKFold(cv)
593+
if group is None:
594+
cv = StratifiedKFold(cv)
576595
if scoring is None:
577596
scoring = 'accuracy'
578597

@@ -600,25 +619,18 @@ class names in order of their numercial equivalents
600619

601620

602621
# this is not a good way to do this
603-
if group is not None:
622+
if regress == True:
623+
strat = False # failsafe
604624

625+
if group is not None: # becoming a mess
626+
605627
X_train, y_train, X_test, y_test, cv = _group_cv(X_train, y_train,
606-
group, test_size,
607-
cv)
608-
628+
group, test_size,
629+
cv, strat=strat)
609630
else:
610631
X_train, X_test, y_train, y_test = train_test_split(
611632
X_train, y_train, test_size=test_size, random_state=0)
612-
613-
#
614-
# if clf[0:4] == 'catb':
615-
# # Quick and quiet but can't enter the group cv indices or the sklearn
616-
# # pipe
617-
# ds = Pool(X_train, label=y_train)
618-
619-
# # fails at end saying
620-
# model.grid_search(param_grid, ds, cv=k_kfold)
621-
#CatBoostError: /src/catboost/catboost/private/libs/options/cross_validation_params.cpp:21: FoldCount is 0
633+
#cv = StratifiedKFold(cv)
622634

623635

624636
if pipe == 'default':
@@ -650,10 +662,8 @@ class names in order of their numercial equivalents
650662
grid = GridSearchCV(sk_pipe, param_grid=sclr,
651663
cv=cv, n_jobs=cores,
652664
scoring=scoring, verbose=1)
653-
654665

655-
656-
666+
657667
grid.fit(X_train, y_train)
658668

659669
joblib.dump(grid.best_estimator_, outModel)
@@ -667,12 +677,17 @@ class names in order of their numercial equivalents
667677
else:
668678
crDf = hp.plot_classif_report(y_test, testresult, target_names=class_names,
669679
save=outModel[:-3]+'._classif_report.png')
680+
681+
confmat = metrics.confusion_matrix(testresult, y_test, labels=class_names)
682+
disp = metrics.ConfusionMatrixDisplay(confusion_matrix=confmat,
683+
display_labels=class_names)
684+
disp.plot()
670685

671-
confmat = hp.plt_confmat(X_test, y_test, grid.best_estimator_,
672-
class_names=class_names,
673-
cmap=plt.cm.Blues,
674-
fmt="%d",
675-
save=outModel[:-3]+'_confmat.png')
686+
# confmat = hp.plt_confmat(X_test, y_test, grid.best_estimator_,
687+
# class_names=class_names,
688+
# cmap=plt.cm.Blues,
689+
# fmt="%d",
690+
# save=outModel[:-3]+'_confmat.png')
676691

677692
results = [grid, crDf, confmat]
678693

@@ -776,13 +791,26 @@ class names in order of their numercial equivalents
776791
# we only wish to predict really - but necessary
777792
# for sklearn model construct
778793
else:
779-
clfdict = {'rf': RandomForestClassifier, 'erf': ExtraTreesClassifier,
780-
'gb': GradientBoostingClassifier, 'xgb': XGBClassifier,
781-
'logit': LogisticRegression, 'hgb': HistGradientBoostingClassifier}
782-
783-
regdict = {'rf': RandomForestRegressor, 'erf': ExtraTreesRegressor,
784-
'gb': GradientBoostingRegressor, 'xgb': XGBRegressor,
785-
'hgb': HistGradientBoostingRegressor}
794+
clfdict = {'rf': RandomForestClassifier(random_state=0),
795+
'erf': ExtraTreesClassifier(random_state=0),
796+
'gb': GradientBoostingClassifier(random_state=0),
797+
'xgb': XGBClassifier(random_state=0),
798+
'logit': LogisticRegression(),
799+
'lgbm': lgb.LGBMClassifier(random_state=0),
800+
'hgb': HistGradientBoostingClassifier(random_state=0),
801+
'svm': SVC(),
802+
'nusvc': NuSVC(),
803+
'linsvc': LinearSVC()}
804+
805+
regdict = {'rf': RandomForestRegressor(random_state=0),
806+
'erf': ExtraTreesRegressor(random_state=0),
807+
'gb': GradientBoostingRegressor(random_state=0),
808+
'xgb': XGBRegressor(random_state=0),
809+
'lgbm': lgb.LGBMRegressor(random_state=0),
810+
'hgb': HistGradientBoostingRegressor(random_state=0),
811+
'svm': SVR(),
812+
'nusvc': NuSVR(),
813+
'linsvc': LinearSVR()}
786814

787815
if mtype == 'regress':
788816
# won't accept the dict even with the ** to unpack it
@@ -840,23 +868,18 @@ def regression_results(y_true, y_pred):
840868
print('r2: ', round(r2,4))
841869
print('MAE: ', round(mean_absolute_error,4))
842870
print('MSE: ', round(mse,4))
871+
print('MedianAE', round(median_absolute_error, 4))
843872
print('RMSE: ', round(np.sqrt(mse),4))
844-
#TODO add when sklearn updated
845-
# display = metrics.PredictionErrorDisplay.from_predictions(
846-
# y_true=y,
847-
# y_pred=y_pred,
848-
# kind="actual_vs_predicted",
849-
# ax=ax,
850-
# scatter_kwargs={"alpha": 0.2, "color": "tab:blue"},
851-
# line_kwargs={"color": "tab:red"},
852-
# )
853-
# print(grid.best_params_)
854-
# print(grid.best_estimator_)
855-
# print(grid.oob_score_)
856-
857-
# plt.plot(est_range, grid_mean_scores)
858-
# plt.xlabel('no of estimators')
859-
# plt.ylabel('Cross validated accuracy')
873+
#TODO add when sklearn updated
874+
display = metrics.PredictionErrorDisplay.from_predictions(
875+
y_true=y_true,
876+
y_pred=y_pred,
877+
kind="actual_vs_predicted",
878+
#ax=ax,
879+
scatter_kwargs={"alpha": 0.2, "color": "tab:blue"},
880+
line_kwargs={"color": "tab:red"},
881+
)
882+
860883

861884

862885
def RF_oob_opt(model, X_train, min_est, max_est, step, group=None,
@@ -1103,6 +1126,11 @@ def plot_feat_importance_permutation(modelPth, featureNames, X_test, y_test,
11031126
featureNames : list of strings
11041127
a list of feature names
11051128
1129+
Returns
1130+
-------
1131+
1132+
pandas df of importances
1133+
11061134
"""
11071135

11081136
if modelPth is not str:

0 commit comments

Comments
 (0)