Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add test tradaboost
  • Loading branch information
antoinedemathelin committed Apr 11, 2022
commit eea283cd61ed86a899c6c0454a373882a31e70d3
9 changes: 6 additions & 3 deletions adapt/instance_based/_tradaboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def fit(self, X, y, Xt=None, yt=None,
Xt, yt = self._get_target_data(Xt, yt)
Xt, yt = check_arrays(Xt, yt, accept_sparse=True)

if isinstance(self, TrAdaBoost):
if not isinstance(self, TrAdaBoostR2) and isinstance(self.estimator, BaseEstimator):
self.label_encoder_ = LabelEncoder()
ys = self.label_encoder_.fit_transform(ys)
yt = self.label_encoder_.transform(yt)
Expand Down Expand Up @@ -454,7 +454,10 @@ def predict(self, X):
predictions.append(y_pred)
predictions = np.stack(predictions, -1)
weighted_vote = predictions.dot(weights).argmax(1)
return self.label_encoder_.inverse_transform(weighted_vote)
if hasattr(self, "label_encoder_"):
return self.label_encoder_.inverse_transform(weighted_vote)
else:
return weighted_vote


def predict_weights(self, domain="src"):
Expand Down Expand Up @@ -951,7 +954,7 @@ def func(x):
def _cross_val_score(self, Xs, ys, Xt, yt,
sample_weight_src, sample_weight_tgt,
**fit_params):
if len(Xt) >= self.cv:
if Xt.shape[0] >= self.cv:
cv = self.cv
else:
cv = Xt.shape[0]
Expand Down
16 changes: 16 additions & 0 deletions src_docs/_static/css/custom.css
Original file line number Diff line number Diff line change
Expand Up @@ -251,3 +251,19 @@ img.map-adapt {
#selecting-the-right-domain-adaptation-model {
padding-bottom: 600px;
}


blockquote {
border-left: 5px solid #D3D3D3;
padding: 0 1em;
}


div.alert.alert-block.alert-info {
background: #e7f2fa;
padding: 12px;
margin-bottom: 12px;
border-top-color: #6ab0de;
border-top-width: 12px;
border-top-style: solid;
}
10 changes: 9 additions & 1 deletion src_docs/_templates/layout.html
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
{%- if pathto("examples/Regression") == "#" %}{% set Regression = "" %}{% else %}{% set Regression = "#" %}{% endif %}
{%- if pathto("examples/sample_bias") == "#" %}{% set sample_bias = "" %}{% else %}{% set sample_bias = "#" %}{% endif %}
{%- if pathto("examples/Multi_fidelity") == "#" %}{% set Multi_fidelity = "" %}{% else %}{% set Multi_fidelity = "#" %}{% endif %}
{%- if pathto("examples/Rotation") == "#" %}{% set Rotation = "" %}{% else %}{% set Rotation = "#" %}{% endif %}
{%- if pathto("examples/Rotation") == "#" %}{% set Rotation = "" %}{% else %}{% set Rotation = "#" %}{% endif %}
{%- if pathto("examples/tradaboost_experiments") == "#" %}{% set tradaboost_experiments = "" %}{% else %}{% set tradaboost_experiments = "#" %}{% endif %}

{% block menu %}
<p class="caption" role="heading"><span class="caption-text">Installation</span></p>
Expand Down Expand Up @@ -150,5 +151,12 @@
<li class="toctree-l2"><a class="reference internal" href="{{ pathto("examples/Multi_fidelity") }}{{ Multi_fidelity }}{{ "RegularTransferNN" }}">RegularTransferNN</a></li>
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="{{ pathto("examples/tradaboost_experiments") }}">TrAdaBoost Experiments</a><ul>
<li class="toctree-l2"><a class="reference internal" href="{{ pathto("examples/tradaboost_experiments") }}{{ tradaboost_experiments }}{{ "Mushrooms" }}">Mushrooms</a></li>
<li class="toctree-l2"><a class="reference internal" href="{{ pathto("examples/tradaboost_experiments") }}{{ tradaboost_experiments }}{{ "20-NewsGroup" }}">20-NewsGroup</a></li>
</ul>
</li>


</ul>
{% endblock %}
37 changes: 21 additions & 16 deletions src_docs/examples/tradaboost_experiments.ipynb
Original file line number Diff line number Diff line change
@@ -1,15 +1,32 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "8ea3c629-ffd7-48f6-9ff1-96a43d120c9f",
"metadata": {},
"source": [
"# Reproduction of the TrAdaBoost experiments"
]
},
{
"cell_type": "markdown",
"id": "3aedc081-d6af-45dc-a9e1-bcd71e83f90b",
"metadata": {},
"source": [
"<div class=\"btn btn-notebook\" role=\"button\">\n",
" <img src=\"../_static/images/github_logo_32px.png\"> [View on GitHub](https://github.com/adapt-python/notebooks/blob/d0364973c642ea4880756cef4e9f2ee8bb5e8495/Two_moons.ipynb)\n",
"</div>"
]
},
{
"cell_type": "markdown",
"id": "a22504a0-5ff7-498e-bc35-a6c101926204",
"metadata": {},
"source": [
"# Reproduction of the TrAdaBoost experiments\n",
"\n",
"The purpose of this example is to reproduce the results obtained in the paper [Boosting for Transfer Learning (2007)](https://cse.hkust.edu.hk/~qyang/Docs/2007/tradaboost.pdf). In this work, the authors developed a transfer algorithm called TrAdaBoost dedicated for [supervised domain adaptation](https://adapt-python.github.io/adapt/map.html). You can find more details about this algorithm [here](https://adapt-python.github.io/adapt/generated/adapt.instance_based.TrAdaBoost.html). The goal of this algorithm is to combine a source dataset with many labeled instances to a target dataset with few labels in order to learn a good model on the target domain.\n",
"\n",
"We try to reproduce the two following exepriments:\n",
"\n",
"- Mushrooms\n",
"- 20newsgroups\n",
"\n"
Expand Down Expand Up @@ -314,8 +331,7 @@
"metadata": {},
"source": [
"<div class=\"alert alert-block alert-info\">\n",
"<b>Note:</b> When looking at the number of instances in each category of the *stalk-shape* attribute, it seems that the authors inversed the source data set with the target one in the text above. Indeed, when looking at Table 1 in the paper, the number of source instances should be 4608 which corresponds to the <b>tapering</b> class and not the <b>enlarging</b> one.</div>\n",
"\n"
"**Note:** When looking at the number of instances in each category of the *stalk-shape* attribute, it seems that the authors inversed the source data set with the target one in the text above. Indeed, when looking at Table 1 in the paper, the number of source instances should be 4608 which corresponds to the **tapering** class and not the **enlarging** one.</div>"
]
},
{
Expand Down Expand Up @@ -552,7 +568,7 @@
"id": "babd1cce-c39e-4516-9a10-9d7e9f00f190",
"metadata": {},
"source": [
"## 20 NewsGroup experiments"
"## 20 NewsGroup"
]
},
{
Expand Down Expand Up @@ -641,17 +657,6 @@
"We conduct the three proposed experiments \"rec vs talk\", \"rec vs sci\" and \"sci vs talk\". We set the number of TrAdaBoost estimators to 10 instead of 100. We found that using 100 estimators give poor results for TrAdaBoost."
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "1d47ec68-f638-42aa-9c11-c37caa61fe14",
"metadata": {},
"outputs": [],
"source": [
"# source_sci = ['sci.crypt', 'sci.electronics']\n",
"# target_sci = ['sci.med', 'sci.space']"
]
},
{
"cell_type": "markdown",
"id": "054fa3be-c83e-4c64-a3ff-58c51ea397fe",
Expand Down
28 changes: 26 additions & 2 deletions tests/test_tradaboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

import copy
import numpy as np
from sklearn.linear_model import LinearRegression, LogisticRegression, Ridge
import scipy
from sklearn.linear_model import LinearRegression, LogisticRegression, Ridge, RidgeClassifier
from sklearn.metrics import r2_score, accuracy_score
import tensorflow as tf

Expand Down Expand Up @@ -184,4 +185,27 @@ def test_tradaboost_lr():
model.fit(Xs, ys_classif)
err2 = model.estimator_errors_

assert np.sum(err1) > 10 * np.sum(err2)
assert np.sum(err1) > 5 * np.sum(err2)


def test_tradaboost_sparse_matrix():
X = scipy.sparse.csr_matrix(np.eye(200))
y = np.random.randn(100)
yc = np.random.choice(["e", "p"], 100)
Xt = X[:100]
Xs = X[100:]

model = TrAdaBoost(RidgeClassifier(), Xt=Xt[:10], yt=yc[:10])
model.fit(Xs, yc)
model.score(Xt, yc)
model.predict(Xs)

model = TrAdaBoostR2(Ridge(), Xt=Xt[:10], yt=y[:10])
model.fit(Xs, y)
model.score(Xt, y)
model.predict(Xs)

model = TwoStageTrAdaBoostR2(Ridge(), Xt=Xt[:10], yt=y[:10], n_estimators=3)
model.fit(Xs, y)
model.score(Xt, y)
model.predict(Xs)