Skip to content

Commit 857a8c6

Browse files
committed
Fix reference to InputValidator
1 parent 80fc8f2 commit 857a8c6

File tree

2 files changed

+6
-11
lines changed

2 files changed

+6
-11
lines changed

autosklearn/automl.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1741,15 +1741,15 @@ def score(self, X: SUPPORTED_FEAT_TYPES, y: SUPPORTED_TARGET_TYPES) -> float:
17411741
check_is_fitted(self)
17421742

17431743
prediction = self.predict(X)
1744-
y = self.InputValidator.target_validator.transform(y)
1744+
y = self.input_validator.target_validator.transform(y)
17451745

17461746
# Encode the prediction using the input validator
17471747
# We train autosklearn with a encoded version of y,
17481748
# which is decoded by predict().
17491749
# Above call to validate() encodes the y given for score()
17501750
# Below call encodes the prediction, so we compare in the
17511751
# same representation domain
1752-
prediction = self.InputValidator.target_validator.transform(prediction)
1752+
prediction = self.input_validator.target_validator.transform(prediction)
17531753

17541754
return compute_single_metric(
17551755
solution=y,
@@ -2267,16 +2267,15 @@ def predict(
22672267
n_jobs: int = 1,
22682268
) -> np.ndarray:
22692269
check_is_fitted(self)
2270-
assert self.InputValidator is not None
2271-
22722270
probabilities = self.predict_proba(X, batch_size=batch_size, n_jobs=n_jobs)
2271+
validator = self.input_validator
22732272

2274-
if self.InputValidator.target_validator.is_single_column_target():
2273+
if validator.target_validator.is_single_column_target():
22752274
predicted_indexes = np.argmax(probabilities, axis=1)
22762275
else:
22772276
predicted_indexes = (probabilities > 0.5).astype(int)
22782277

2279-
return self.InputValidator.target_validator.inverse_transform(predicted_indexes)
2278+
return validator.target_validator.inverse_transform(predicted_indexes)
22802279

22812280
def predict_proba(
22822281
self,

autosklearn/estimators.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from scipy.sparse import spmatrix
1414
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
1515
from sklearn.ensemble import VotingClassifier, VotingRegressor
16-
from sklearn.exceptions import NotFittedError
1716
from sklearn.model_selection._split import (
1817
BaseCrossValidator,
1918
BaseShuffleSplit,
@@ -1522,10 +1521,7 @@ def classes_(self) -> np.ndarray:
15221521
np.ndarray
15231522
Class labels seen during fit
15241523
"""
1525-
if self.automl.InputValidator is None:
1526-
raise NotFittedError("Please call fit first")
1527-
1528-
return self.automl.InputValidator.target_validator.classes_
1524+
return self.automl.input_validator.classes_
15291525

15301526
def predict_proba(
15311527
self,

0 commit comments

Comments
 (0)