-
Notifications
You must be signed in to change notification settings - Fork 197
Expand file tree
/
Copy pathmodel_training.py
More file actions
151 lines (112 loc) · 4.69 KB
/
Copy pathmodel_training.py
File metadata and controls
151 lines (112 loc) · 4.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from sklearn import base, datasets, linear_model, metrics, svm, utils
from sklearn.model_selection import train_test_split
from hamilton import function_modifiers
@function_modifiers.config.when(data_loader="iris")
def data__iris() -> utils.Bunch:
return datasets.load_digits()
@function_modifiers.config.when(data_loader="digits")
def data__digits() -> utils.Bunch:
return datasets.load_digits()
def target(data: utils.Bunch) -> np.ndarray:
return data.target
def target_names(data: utils.Bunch) -> np.ndarray:
return data.target_names
def feature_matrix(data: utils.Bunch) -> np.ndarray:
return data.data
@function_modifiers.config.when(clf="svm")
def prefit_clf__svm(gamma: float = 0.001) -> base.ClassifierMixin:
"""Returns an unfitted SVM classifier object.
:param gamma: ...
:return:
"""
return svm.SVC(gamma=gamma)
@function_modifiers.config.when(clf="logistic")
def prefit_clf__logreg(penalty: str) -> base.ClassifierMixin:
"""Returns an unfitted Logistic Regression classifier object.
:param penalty:
:return:
"""
return linear_model.LogisticRegression(penalty)
@function_modifiers.extract_fields(
{"X_train": np.ndarray, "X_test": np.ndarray, "y_train": np.ndarray, "y_test": np.ndarray}
)
def train_test_split_func(
feature_matrix: np.ndarray,
target: np.ndarray,
test_size_fraction: float,
shuffle_train_test_split: bool,
) -> dict[str, np.ndarray]:
"""Function that creates the training & test splits.
It this then extracted out into constituent components and used downstream.
:param feature_matrix:
:param target:
:param test_size_fraction:
:param shuffle_train_test_split:
:return:
"""
X_train, X_test, y_train, y_test = train_test_split(
feature_matrix, target, test_size=test_size_fraction, shuffle=shuffle_train_test_split
)
return {"X_train": X_train, "X_test": X_test, "y_train": y_train, "y_test": y_test}
def y_test_with_labels(y_test: np.ndarray, target_names: np.ndarray) -> np.ndarray:
"""Adds labels to the target output."""
return np.array([target_names[idx] for idx in y_test])
def fit_clf(
prefit_clf: base.ClassifierMixin, X_train: np.ndarray, y_train: np.ndarray
) -> base.ClassifierMixin:
"""Calls fit on the classifier object; it mutates it."""
prefit_clf.fit(X_train, y_train)
return prefit_clf
def predicted_output(fit_clf: base.ClassifierMixin, X_test: np.ndarray) -> np.ndarray:
"""Exercised the fit classifier to perform a prediction."""
return fit_clf.predict(X_test)
def predicted_output_with_labels(
predicted_output: np.ndarray, target_names: np.ndarray
) -> np.ndarray:
"""Replaces the predictions with the desired labels."""
return np.array([target_names[idx] for idx in predicted_output])
def classification_report(
predicted_output_with_labels: np.ndarray, y_test_with_labels: np.ndarray
) -> str:
"""Returns a classification report."""
return metrics.classification_report(y_test_with_labels, predicted_output_with_labels)
def confusion_matrix(
predicted_output_with_labels: np.ndarray, y_test_with_labels: np.ndarray
) -> np.ndarray:
"""Returns a confusion matrix report."""
return metrics.confusion_matrix(y_test_with_labels, predicted_output_with_labels)
def confusion_matrix_figure(confusion_matrix: np.ndarray, target_names: np.ndarray) -> go.Figure:
"""Create a plotly interactive heatmap of the confusion matrix"""
class_indices = np.arange(len(target_names))
return px.imshow(
confusion_matrix,
x=class_indices,
y=class_indices,
labels=dict(
x="Predicted labels",
y="True labels",
color="Count",
),
)
def model_parameters(fit_clf: base.ClassifierMixin) -> dict:
"""Returns a dictionary of model parameters."""
return fit_clf.get_params()