@@ -465,7 +465,8 @@ class names in order of their numercial equivalents
465
465
def create_model (X_train , outModel , clf = 'erf' , group = None , random = False ,
466
466
cv = 5 , params = None , pipe = 'default' , cores = - 1 , strat = True ,
467
467
test_size = 0.3 , regress = False , return_test = True ,
468
- scoring = None , class_names = None , save = True , cat_feat = None ):
468
+ scoring = None , class_names = None , save = True , cat_feat = None ,
469
+ plot = True ):
469
470
470
471
"""
471
472
Brute force or random model creating using scikit learn.
@@ -679,24 +680,25 @@ class names in order of their numercial equivalents
679
680
testresult = grid .best_estimator_ .predict (X_test )
680
681
681
682
if regress == True :
682
- regrslt = regression_results (y_test , testresult )
683
+ regrslt = regression_results (y_test , testresult , plot = plot )
683
684
results = [grid ]
684
685
685
686
else :
686
687
crDf = hp .plot_classif_report (y_test , testresult , target_names = class_names ,
687
688
save = outModel [:- 3 ]+ '._classif_report.png' )
688
689
689
690
confmat = metrics .confusion_matrix (testresult , y_test , labels = class_names )
690
- disp = metrics .ConfusionMatrixDisplay (confusion_matrix = confmat ,
691
- display_labels = class_names )
692
- disp .plot ()
691
+
692
+ if plot == True :
693
+ disp = metrics .ConfusionMatrixDisplay (confusion_matrix = confmat ,
694
+ display_labels = class_names )
695
+ disp .plot ()
693
696
694
697
# confmat = hp.plt_confmat(X_test, y_test, grid.best_estimator_,
695
698
# class_names=class_names,
696
699
# cmap=plt.cm.Blues,
697
700
# fmt="%d",
698
- # save=outModel[:-3]+'_confmat.png')
699
-
701
+ # save=outModel[:-3]+'_confmat.png')
700
702
results = [grid , crDf , confmat ]
701
703
702
704
if return_test == True :
@@ -861,7 +863,7 @@ class names in order of their numercial equivalents
861
863
return comb , X_test , y_test
862
864
863
865
864
- def regression_results (y_true , y_pred ):
866
+ def regression_results (y_true , y_pred , plot = True ):
865
867
866
868
# Regression metrics
867
869
explained_variance = metrics .explained_variance_score (y_true , y_pred )
@@ -878,15 +880,16 @@ def regression_results(y_true, y_pred):
878
880
print ('MSE: ' , round (mse ,4 ))
879
881
print ('MedianAE' , round (median_absolute_error , 4 ))
880
882
print ('RMSE: ' , round (np .sqrt (mse ),4 ))
881
- #TODO add when sklearn updated
882
- display = metrics .PredictionErrorDisplay .from_predictions (
883
- y_true = y_true ,
884
- y_pred = y_pred ,
885
- kind = "actual_vs_predicted" ,
886
- #ax=ax,
887
- scatter_kwargs = {"alpha" : 0.2 , "color" : "tab:blue" },
888
- line_kwargs = {"color" : "tab:red" },
889
- )
883
+ #TODO add when sklearn updated
884
+ if plot == True :
885
+ display = metrics .PredictionErrorDisplay .from_predictions (
886
+ y_true = y_true ,
887
+ y_pred = y_pred ,
888
+ kind = "actual_vs_predicted" ,
889
+ #ax=ax,
890
+ scatter_kwargs = {"alpha" : 0.2 , "color" : "tab:blue" },
891
+ line_kwargs = {"color" : "tab:red" },
892
+ )
890
893
891
894
892
895
0 commit comments