Skip to content
Merged
Next Next commit
TransferForestClassifier
  • Loading branch information
atiqm committed Mar 23, 2022
commit bf27e4f8ac9fa9fa53f33ed6712f443193660c54
29 changes: 26 additions & 3 deletions adapt/_tree_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,31 @@
import copy
import numpy as np


def _bootstrap_(size,class_wise=False,y=None):
if class_wise:
if y is None:
print("Error : need labels to apply class wise bootstrap.")
else:
inds = []
oob_inds = []
classes_ = set(y)
ind_classes_ = np.zeros(len(classes_),dtype=object)

for j,c in enumerate(classes_):
ind_classes_[j] = np.where(y==c)[0]
s = ind_classes_[j].size
inds += list(np.random.choice(ind_classes_[j], s, replace=True))
oob_inds += list(set(ind_classes_[j]) - set(inds))

inds,oob_inds = np.array(inds),np.array(oob_inds)
else:
inds = np.random.choice(np.arange(size), size, replace=True)
oob_inds = set(np.arange(size)) - set(inds)
oob_inds = np.array(list(oob_inds))

return inds, oob_inds

def depth_tree(dt,node=0):

if dt.tree_.feature[node] == -2:
Expand Down Expand Up @@ -464,8 +489,7 @@ def coherent_new_split(phi,th,rule):
return 0,1
else:
return 1,0



def all_coherent_splits(rule,all_splits):

inds = np.zeros(all_splits.shape[0],dtype=bool)
Expand Down Expand Up @@ -534,4 +558,3 @@ def bounds_rule(rule,n_features):

return bound_infs,bound_sups


Loading