From 6520436401cbf55d2656841830d3674fabf76847 Mon Sep 17 00:00:00 2001 From: minsuk-heo Date: Sat, 9 Jun 2018 18:35:47 -0700 Subject: [PATCH 1/8] w2v added --- data_science/nlp/word2vec_tensorflow.ipynb | 693 +++++++++++++++++++++ 1 file changed, 693 insertions(+) create mode 100644 data_science/nlp/word2vec_tensorflow.ipynb diff --git a/data_science/nlp/word2vec_tensorflow.ipynb b/data_science/nlp/word2vec_tensorflow.ipynb new file mode 100644 index 0000000..0521874 --- /dev/null +++ b/data_science/nlp/word2vec_tensorflow.ipynb @@ -0,0 +1,693 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Word2Vec\n", + "here I implement word2vec with very simple example using tensorflow \n", + "word2vec is vector representation for words with similarity" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Collect Data\n", + "we will use only 10 sentences to create word vectors" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [], + "source": [ + "corpus = ['king is a strong man', \n", + " 'queen is a wise woman', \n", + " 'boy is a young man',\n", + " 'girl is a young woman',\n", + " 'prince is a young king',\n", + " 'princess is a young queen',\n", + " 'man is strong', \n", + " 'woman is pretty',\n", + " 'prince is a boy will be king',\n", + " 'princess is a girl will be queen']" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Remove stop words\n", + "In order for efficiency of creating word vector, we will remove commonly used words" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [], + "source": [ + "def remove_stop_words(corpus):\n", + " stop_words = ['is', 'a', 'will', 'be']\n", + " results = []\n", + " for text in corpus:\n", + " tmp = text.split(' ')\n", + " for stop_word in stop_words:\n", + " if stop_word in tmp:\n", + " tmp.remove(stop_word)\n", + " results.append(\" \".join(tmp))\n", + " \n", + " return results" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [], + "source": [ + "corpus = remove_stop_words(corpus)" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [], + "source": [ + "words = []\n", + "for text in corpus:\n", + " for word in text.split(' '):\n", + " words.append(word)\n", + "\n", + "words = set(words)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "here we have word set by which we will have word vector" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'boy',\n", + " 'girl',\n", + " 'king',\n", + " 'man',\n", + " 'pretty',\n", + " 'prince',\n", + " 'princess',\n", + " 'queen',\n", + " 'strong',\n", + " 'wise',\n", + " 'woman',\n", + " 'young'}" + ] + }, + "execution_count": 37, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "words" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# data generation\n", + "we will generate label for each word using skip gram. " + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [], + "source": [ + "word2int = {}\n", + "\n", + "for i,word in enumerate(words):\n", + " word2int[word] = i\n", + "\n", + "sentences = []\n", + "for sentence in corpus:\n", + " sentences.append(sentence.split())\n", + " \n", + "WINDOW_SIZE = 2\n", + "\n", + "data = []\n", + "for sentence in sentences:\n", + " for idx, word in enumerate(sentence):\n", + " for neighbor in sentence[max(idx - WINDOW_SIZE, 0) : min(idx + WINDOW_SIZE, len(sentence)) + 1] : \n", + " if neighbor != word:\n", + " data.append([word, neighbor])" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "king strong man\n", + "queen wise woman\n", + "boy young man\n", + "girl young woman\n", + "prince young king\n", + "princess young queen\n", + "man strong\n", + "woman pretty\n", + "prince boy king\n", + "princess girl queen\n" + ] + } + ], + "source": [ + "import pandas as pd\n", + "for text in corpus:\n", + " print(text)\n", + "\n", + "df = pd.DataFrame(data, columns = ['input', 'label'])" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
inputlabel
0kingstrong
1kingman
2strongking
3strongman
4manking
5manstrong
6queenwise
7queenwoman
8wisequeen
9wisewoman
\n", + "
" + ], + "text/plain": [ + " input label\n", + "0 king strong\n", + "1 king man\n", + "2 strong king\n", + "3 strong man\n", + "4 man king\n", + "5 man strong\n", + "6 queen wise\n", + "7 queen woman\n", + "8 wise queen\n", + "9 wise woman" + ] + }, + "execution_count": 40, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df.head(10)" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(52, 2)" + ] + }, + "execution_count": 41, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'boy': 2,\n", + " 'girl': 10,\n", + " 'king': 7,\n", + " 'man': 5,\n", + " 'pretty': 9,\n", + " 'prince': 1,\n", + " 'princess': 11,\n", + " 'queen': 6,\n", + " 'strong': 4,\n", + " 'wise': 3,\n", + " 'woman': 8,\n", + " 'young': 0}" + ] + }, + "execution_count": 42, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "word2int" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Define Tensorflow Graph" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [], + "source": [ + "import tensorflow as tf\n", + "import numpy as np\n", + "\n", + "ONE_HOT_DIM = len(words)\n", + "\n", + "# function to convert numbers to one hot vectors\n", + "def to_one_hot_encoding(data_point_index):\n", + " one_hot_encoding = np.zeros(ONE_HOT_DIM)\n", + " one_hot_encoding[data_point_index] = 1\n", + " return one_hot_encoding\n", + "\n", + "X = [] # input word\n", + "Y = [] # target word\n", + "\n", + "for x, y in zip(df['input'], df['label']):\n", + " X.append(to_one_hot_encoding(word2int[ x ]))\n", + " Y.append(to_one_hot_encoding(word2int[ y ]))\n", + "\n", + "# convert them to numpy arrays\n", + "X_train = np.asarray(X)\n", + "Y_train = np.asarray(Y)\n", + "\n", + "# making placeholders for X_train and Y_train\n", + "x = tf.placeholder(tf.float32, shape=(None, ONE_HOT_DIM))\n", + "y_label = tf.placeholder(tf.float32, shape=(None, ONE_HOT_DIM))\n", + "\n", + "# word embedding will be 2 dimension for 2d visualization\n", + "EMBEDDING_DIM = 2 \n", + "\n", + "# hidden layer: which represents word vector eventually\n", + "W1 = tf.Variable(tf.random_normal([ONE_HOT_DIM, EMBEDDING_DIM]))\n", + "b1 = tf.Variable(tf.random_normal([1])) #bias\n", + "hidden_layer = tf.add(tf.matmul(x,W1), b1)\n", + "\n", + "# output layer\n", + "W2 = tf.Variable(tf.random_normal([EMBEDDING_DIM, ONE_HOT_DIM]))\n", + "b2 = tf.Variable(tf.random_normal([1]))\n", + "prediction = tf.nn.softmax(tf.add( tf.matmul(hidden_layer, W2), b2))\n", + "\n", + "# loss function: cross entropy\n", + "loss = tf.reduce_mean(-tf.reduce_sum(y_label * tf.log(prediction), axis=[1]))\n", + "\n", + "# training operation\n", + "train_op = tf.train.GradientDescentOptimizer(0.05).minimize(loss)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Train" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "iteration 0 loss is : 3.2637517\n", + "iteration 3000 loss is : 1.8743205\n", + "iteration 6000 loss is : 1.8236102\n", + "iteration 9000 loss is : 1.7924767\n", + "iteration 12000 loss is : 1.7737043\n", + "iteration 15000 loss is : 1.7602454\n", + "iteration 18000 loss is : 1.7496274\n" + ] + } + ], + "source": [ + "sess = tf.Session()\n", + "init = tf.global_variables_initializer()\n", + "sess.run(init) \n", + "\n", + "iteration = 20000\n", + "for i in range(iteration):\n", + " # input is X_train which is one hot encoded word\n", + " # label is Y_train which is one hot encoded neighbor word\n", + " sess.run(train_op, feed_dict={x: X_train, y_label: Y_train})\n", + " if i % 3000 == 0:\n", + " print('iteration '+str(i)+' loss is : ', sess.run(loss, feed_dict={x: X_train, y_label: Y_train}))" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[ 1.2848336e-01 1.5337467e-03]\n", + " [-2.0280635e+00 2.3127823e+00]\n", + " [ 1.7113209e-02 6.8822104e-01]\n", + " [-2.9886007e+00 -2.8444753e+00]\n", + " [-3.6029336e-01 3.9953849e+00]\n", + " [-3.3837869e+00 2.9268253e+00]\n", + " [-2.3786058e+00 -1.4266893e+00]\n", + " [ 7.8221262e-02 5.0086755e-01]\n", + " [-3.9878520e-01 -5.2032435e-01]\n", + " [-3.4548821e+00 -2.6956251e+00]\n", + " [-3.6400979e+00 -2.7140243e+00]\n", + " [-1.7572699e+00 -1.7303076e+00]]\n" + ] + } + ], + "source": [ + "# Now the hidden layer (W1 + b1) is actually the word look up table\n", + "vectors = sess.run(W1 + b1)\n", + "print(vectors)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# word vector in table" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
wordx1x2
0young0.1284830.001534
1prince-2.0280642.312782
2boy0.0171130.688221
3wise-2.988601-2.844475
4strong-0.3602933.995385
5man-3.3837872.926825
6queen-2.378606-1.426689
7king0.0782210.500868
8woman-0.398785-0.520324
9pretty-3.454882-2.695625
10girl-3.640098-2.714024
11princess-1.757270-1.730308
\n", + "
" + ], + "text/plain": [ + " word x1 x2\n", + "0 young 0.128483 0.001534\n", + "1 prince -2.028064 2.312782\n", + "2 boy 0.017113 0.688221\n", + "3 wise -2.988601 -2.844475\n", + "4 strong -0.360293 3.995385\n", + "5 man -3.383787 2.926825\n", + "6 queen -2.378606 -1.426689\n", + "7 king 0.078221 0.500868\n", + "8 woman -0.398785 -0.520324\n", + "9 pretty -3.454882 -2.695625\n", + "10 girl -3.640098 -2.714024\n", + "11 princess -1.757270 -1.730308" + ] + }, + "execution_count": 46, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "w2v_df = pd.DataFrame(vectors, columns = ['x1', 'x2'])\n", + "w2v_df['word'] = words\n", + "w2v_df = w2v_df[['word', 'x1', 'x2']]\n", + "w2v_df" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# word vector in 2d chart" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAlUAAAJDCAYAAAAiieE0AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3X2U3WV97/3PlQQSDRBcBS1IYOi50QCZSTIZlZgnnkFFWIlE4YAnIUU0llpdBSnFggZsTyVV6/EIB0uBioQIiK1iMQSxEE2FmRAghoQHz4gW7zYsMCVGKJP87j/E3CIPAXJldibzeq3FWrNn//b1+/5my+LttffsKU3TBACArTOk1QMAAOwIRBUAQAWiCgCgAlEFAFCBqAIAqEBUAQBUMKzGIqWU3iRPJtmYpK9pmq4a6wIADBRVoupZhzVN81jF9QAABgwv/wEAVFArqpoki0spPaWUMyqtCQAwYNR6+W9y0zSPllJen+SWUsrqpmlu/+0Dno2tM5Jk5MiRE8eMGVPp1AAA205PT89jTdPsuaXjSu2//VdK+WSS9U3TLHixY7q6upru7u6q5wUA2BZKKT0v55fwtvrlv1LKyFLKrr/5OsnRSVZu7boAAANJjZf/3pDkxlLKb9a7pmmamyusCwAwYGx1VDVN8+Mk4yrMAgAwYPlIBQCACkQVAEAFogoAoAJRBQBQgagCAKhAVAEAVCCqAAAqEFUAABWIKgCACkQVAEAFogoAoAJRBQBQgagCAKhAVAEAVCCqAAAqEFUAABWIKgCACkQVAEAFogoAoAJRBQBQgagCAKhAVAEAVCCqAAAqEFUAABWIKgCACkQVAEAFogoAoAJRBQBQgagCAKhAVAEAVCCqAAAqEFUAABWIKgCACkQVAEAFogoAoAJRBQBQgagCAKhAVAEAVCCqAAAqEFUAABWIKgCACkQVAEAFogoAoAJRBQBQgagCAKhAVAEAVCCqAAAqEFUAABWIKgCACkQVAIPW5z//+WzYsKHVY7CDEFUADFovFVUbN27s52kY6EQVAIPCL3/5y7zrXe/KuHHjMnbs2HzqU5/Ko48+msMOOyyHHXZYkmSXXXbJ+eefn7e97W1ZtmxZbr311kyYMCHt7e2ZO3dunn766SRJW1tbLrjggnR2dqa9vT2rV69OkqxduzZHHXVUOjs788EPfjD77bdfHnvssZZdM/1LVAEwKNx8883Ze++9c88992TlypX56Ec/mr333ju33XZbbrvttiS/Dq+xY8fmhz/8Ybq6ujJnzpwsWrQo9913X/r6+nLJJZdsXm+PPfbI8uXLM2/evCxYsCBJ8qlPfSqHH354li9fnhkzZuSRRx5pybXSGtWiqpQytJRydynlW7XWBIBa2tvbs2TJkpxzzjm54447MmrUqOcdM3To0LznPe9JkqxZsyb7779/3vSmNyVJZs+endtvv33zsTNnzkySTJw4Mb29vUmSpUuX5qSTTkqSHHvssXnd6163LS+J7cywimv9SZL7k+xWcU0AqOJNb3pTenp68u1vfzvnnntujj766OcdM2LEiAwdOjRJ0jTNS643fPjwJL8Osb6+vpf1GHZsVXaqSin7JHlXkr+rsR4A1Pboo4/mta99bU499dScddZZWb58eXbdddc8+eSTL3j8mDFj0tvbm4ceeihJ8pWvfCXTp09/yXNMmTIlX/va15IkixcvzhNPPFH3Itiu1dqp+nySjyfZ9cUOKKWckeSMJNl3330rnRYAXp777rsvZ599doYMGZKddtopl1xySZYtW5Z3vOMd2WuvvTa/r+o3RowYkSuuuCKzZs1KX19f3vKWt+RDH/rQS57jggsuyMknn5xFixZl+vTp2WuvvbLrri/6n0Z2MGVrtypLKccleWfTNB8upRya5KymaY57qcd0dXU13d3dW3VeANjePP300xk6dGiGDRuWZcuWZd68eVmxYkWrx2IrlVJ6mqbp2tJxNXaqJic5vpTyziQjkuxWSrm6aZpTK6wNAAPGI488kve+973ZtGlTdt5553z5y19u9Uj0o63eqXrOYnaqAIAdzMvdqfI5VQAAFdT8SIU0TfO9JN+ruSYAwEBgpwoAoAJRBQBQgagCAKhAVAEAVCCqAAAqEFUAABWIKgCACkQVAEAFogoAoAJRBQBQgagCAKhAVAEAVCCqAAAqEFUAABWIKgCACkQVAEAFogoAoAJRBQBQgagCAKhAVAEAVCCqAAAqEFUAABWIKgCACkQVAEAFogoAoAJRBQBQgagCAKhAVAEAVCCqAAAqEFUAABWIKgCACkQVAEAFogoAoAJRBQBQgagCAKhAVAEAVCCqAAAqEFUAABWIKgCACkQVAEAFogoAoAJRBQBQgagCAKhAVAEAVCCqAAAqEFUAABWIKgCACkQVAEAFogoAoAJRBQBQgagCAKhAVAEAVLDVUVVKGVFKubOUck8p5UellE/VGAwAYCAZVmGNp5Mc3jTN+lLKTkmWllL+uWmaf62wNgDAgLDVUdU0TZNk/bM3d3r2n2Zr1wUAGEiqvKeqlDK0lLIiyX8kuaVpmh/WWJeBrbe3N2PGjMnpp5+esWPH5pRTTsmSJUsyefLkHHDAAbnzzjtz55135u1vf3smTJiQt7/97VmzZk2S5Morr8zMmTNz7LHH5oADDsjHP/7xFl8NALy0KlHVNM3GpmnGJ9knyVtLKWN/95hSyhmllO5SSvfatWtrnJYB4KGHHsqf/Mmf5N57783q1atzzTXXZOnSpVmwYEH+8i//MmPGjMntt9+eu+++O/Pnz8+f//mfb37sihUrsmjRotx3331ZtGhRfvrTn7bwSgDgpdV4T9VmTdP8opTyvSTHJln5O/ddluSyJOnq6vLy4CCx//77p729PUly8MEH54gjjkgpJe3t7ent7c26desye/bsPPjggyml5Jlnntn82COOOCKjRo1Kkhx00EH5yU9+ktGjR7fkOgBgS2r89t+epZTdn/36NUmOTLJ6a9dlxzB8+PDNXw8ZMmTz7SFDhqSvry9/8Rd/kcMOOywrV67MN7/5zTz11FMv+NihQ4emr6+v/wYHgFeoxk7VXkmuKqUMza8j7WtN03yrwroMAuvWrcsb3/jGJL9+HxUADFRbvVPVNM29TdNMaJqmo2masU3TzK8xGIPDxz/+8Zx77rmZPHlyNm7c2OpxAOBVK7/+RIT+1dXV1XR3d/f7eQEAXqlSSk/TNF1bOs6fqQEAqEBUAQBUIKoAACoQVQAAFYgqAIAKRBUAQAWiCgCgAlEFAFCBqAIAqEBUAQBUIKoAACoQVQAAFYgqAIAKRBUAQAWiCgCgAlEFAFCBqAIAqEBUAQBUIKoAACoQVQAAFYgqAIAKRBUAQAWiCgCgAlEFAFCBqAIAqEBUAQBUIKoAACoQVQAAFYgqYLtx/vnnZ8mSJa0eA+BVGdbqAQCSZOPGjZk/f36rxwB41exUAdtcb29vxowZk9mzZ6ejoyMnnnhiNmzYkLa2tsyfPz9TpkzJddddlzlz5uT6669PkrS1teWCCy5IZ2dn2tvbs3r16iTJ+vXrc9ppp6W9vT0dHR254YYbkiSLFy/OpEmT0tnZmVmzZmX9+vUtu15gcBJVQL9Ys2ZNzjjjjNx7773Zbbfd8qUvfSlJMmLEiCxdujQnnXTS8x6zxx57ZPny5Zk3b14WLFiQJLnwwgszatSo3Hfffbn33ntz+OGH57HHHstFF12UJUuWZPny5enq6spnP/vZfr0+AC//Af1i9OjRmTx5cpLk1FNPzRe+8IUkyfve974XfczMmTOTJBMnTszXv/71JMmSJUty7bXXbj7mda97Xb71rW9l1apVm9f/r//6r0yaNGmbXAfAixFVQL8opbzg7ZEjR77oY4YPH54kGTp0aPr6+pIkTdM8b62maXLUUUdl4cKFNUcGeEW8/Af0i0ceeSTLli1LkixcuDBTpkx5VescffTR+eIXv7j59hNPPJFDDjkk3//+9/PQQw8lSTZs2JAHHnhg64cGeAVEFdAvDjzwwFx11VXp6OjI448/nnnz5r2qdT7xiU/kiSeeyNixYzNu3Ljcdttt2XPPPXPllVfm5JNPTkdHRw455JDNb2wH6C+laZp+P2lXV1fT3d3d7+cFWqO3tzfHHXdcVq5c2epRAF6xUkpP0zRdWzrOThUAQAWiCtjm2tra7FIBOzxRBQBQgagCAKhAVAEAVCCqAAAqEFUAABWIKgCACkQVAEAFogoAoAJRBQBQgagCAKhAVAEAVCCqAAAqEFUAABVsdVSVUkaXUm4rpdxfSvlRKeVPagwGADCQDKuwRl+SP22aZnkpZdckPaWUW5qmWVVhbQCAAWGrd6qapvl50zTLn/36yST3J3nj1q4LADCQVH1PVSmlLcmEJD98gfvOKKV0l1K6165dW/O0AAAtVy2qSim7JLkhyUebpvnP372/aZrLmqbpapqma88996x1WgCA7UKVqCql7JRfB9VXm6b5eo01AQAGkhq//VeSXJ7k/qZpPrv1IwEADDw1dqomJ3l/ksNLKSue/eedFdYFABgwtvojFZqmWZqkVJgFAGDA8onqAAAViCoAgApEFQBABaIKAKACUQUAUIGoAgCoQFQBAFQgqgAAKhBVAAAViCoAgApEFQBABaIKAKACUQUAUIGoAgCoQFQBAFQgqgAAKhBVAAAViCoAgApEFQBABaIKAKACUQUAUIGoAgCoQFQBAFQgqgAAKhBVAAAViCoAgApEFQBABaIKAKACUQUAUIGoAgCoQFQBAFQgqgAAKhBVAAAViCoAgApEFQBABaIKAKACUQUAUIGoAgCoQFQBAFQgqgAAKhBVAAAViCoAgApEFQBABaIKAKACUQUAUIGoAgCoQFQBAFQgqgAAKhBVAAAViCoAgApEFQBABaIKAKCCKlFVSvn7Usp/lFJW1lgPAGCgqbVTdWWSYyutBQAw4FSJqqZpbk/yeI21AAAGIu+pAgCooN+iqpRyRimlu5TSvXbt2v46LQBAv+i3qGqa5rKmabqapunac889++u0AAD9wst/AAAV1PpIhYVJliV5cynlZ6WUP6yxLgDAQDGsxiJN05xcYx0AgIHKy38AABWIKgCACkQVAEAFogoAoAJRBQBQgagCAKhAVAFAC/T29mbs2LGtHoOKRBUAQAWiCgBapK+vL7Nnz05HR0dOPPHEbNiwIbfeemsmTJiQ9vb2zJ07N08//XRuvfXWzJgxY/PjbrnllsycObOFk/NCRBUAtMiaNWtyxhln5N57781uu+2Wz372s5kzZ04WLVqU++67L319fbnkkkty+OGH5/7778/atWuTJFdccUVOO+20Fk/P7xJVANAio0ePzuTJk5Mkp556am699dbsv//+edOb3pQkmT17dm6//faUUvL+978/V199dX7xi19k2bJlecc73tHK0XkBVf72HwDwypVSXvaxp512Wt797ndnxIgRmTVrVoYN85/w7Y2dKgBokUceeSTLli1LkixcuDBHHnlkent789BDDyVJvvKVr2T69OlJkr333jt77713LrroosyZM6dVI/MSRBUAtMiBBx6Yq666Kh0dHXn88cfzsY99LFdccUVmzZqV9vb2DBkyJB/60Ic2H3/KKadk9OjROeigg1o4NS/G3iEAtEBbW1tWrVr1vO8fccQRufvuu1/wMUuXLs0HPvCBbT0ar5KoAoABYOLEiRk5cmT+5m/+ptWj8CJEFQAMAD09Pa0egS3wnioAgApEFQBsR17obwJ2d3fnIx/5SIsm4uXy8h8AbOe6urrS1dXV6jHYAjtVALCd+vGPf5wJEybk4osvznHHHZck+eQnP5m5c+fm0EMPzR/8wR/kC1/4wubjL7zwwowZMyZHHXVUTj755CxYsKBVow9KdqoAYDu0Zs2anHTSSbniiivyi1/8Iv/yL/+y+b7Vq1fntttuy5NPPpk3v/nNmTdvXu65557ccMMNufvuu9PX15fOzs5MnDixhVcw+NipAoDtzNq1a3PCCSfk6quvzvjx4593/7ve9a4MHz48e+yxR17/+tfn3//937N06dKccMIJec1rXpNdd9017373u1sw+eAmqgBgOzNq1KiMHj063//+91/w/uHDh2/+eujQoenr60vTNP01Hi9CVAHAdmbnnXfON77xjfzDP/xDrrnmmpf1mClTpuSb3/xmnnrqqaxfvz433XTTNp6S3yWqAGA7NHLkyHzrW9/K5z73uaxbt26Lx7/lLW/J8ccfn3HjxmXmzJnp6urKqFGj+mFSfqO0Yruwq6ur6e7u7vfzAsCObP369dlll12yYcOGTJs2LZdddlk6OztbPdaAV0rpaZpmi59p4bf/AGAHccYZZ2TVqlV56qmnMnv2bEHVz0QVAOwgXu77r9g2vKcKAKACUQUAUIGoAgCoQFQBAFQgqgAAKhBVAAAViCoAgApEFQBABaIKAKACUQUAUIGoAgCoQFQBAFQgqgAAKhBVAAAViCoAgApEFQBABaIKAKACUQUAUIGoAgCoQFQBAFQgqgAAKhBVAAAViCoA2M79xV/8Rf72b/928+3zzjsvf/u3f5uzzz47Y8eOTXt7exYtWpQk+d73vpfjjjtu87FnnnlmrrzyyiRJW1tbLrjggnR2dqa9vT2rV69OkqxduzZHHXVUOjs788EPfjD77bdfHnvssf67wB2EqAKA7dwf/uEf5qqrrkqSbNq0Kddee2322WefrFixIvfcc0+WLFmSs88+Oz//+c+3uNYee+yR5cuXZ968eVmwYEGS5FOf+lQOP/zwLF++PDNmzMgjjzyyTa9nR1Ulqkopx5ZS1pRSHiql/FmNNQGAX2tra8vv/d7v5e67787ixYszYcKELF26NCeffHKGDh2aN7zhDZk+fXruuuuuLa41c+bMJMnEiRPT29ubJFm6dGlOOumkJMmxxx6b173uddvsWnZkWx1VpZShSf53knckOSjJyaWUg7Z2XQDg/3f66afnyiuvzBVXXJG5c+emaZoXPG7YsGHZtGnT5ttPPfXUc+4fPnx4kmTo0KHp6+tLkhddi1emxk7VW5M81DTNj5um+a8k1yY5ocK6AMCzZsyYkZtvvjl33XVXjjnmmEybNi2LFi3Kxo0bs3bt2tx+++1561vfmv322y+rVq3K008/nXXr1uXWW2/d4tpTpkzJ1772tSTJ4sWL88QTT2zry9khDauwxhuT/PS3bv8sydt+96BSyhlJzkiSfffdt8JpAWDw2HnnnXPYYYdl9913z9ChQzNjxowsW7Ys48aNSykln/nMZ/L7v//7SZL3vve96ejoyAEHHJAJEyZsce0LLrggJ598chYtWpTp06dnr732yq677rqtL2mHU7Z2y6+UMivJMU3TnP7s7fcneWvTNH/8Yo/p6upquru7t+q8ADCYbNq0KZ2dnbnuuutywAEHVF376aefztChQzNs2LAsW7Ys8+bNy4oVK6qeYyArpfQ0TdO1peNq7FT9LMno37q9T5JHK6wLACRZtWpVjjvuuMyYMaN6UCXJI488kve+973ZtGlTdt5553z5y1+ufo7BoMZO1bAkDyQ5Ism/JbkryX9vmuZHL/YYO1UAwEDRbztVTdP0lVLOTPKdJEOT/P1LBRUAwI6oxst/aZrm20m+XWMtAICByCeqAwBUIKoAACoQVQAAFYgqAIAKRBUAQAWiCgCgAlEFAFCBqAIAqEBUAQBUIKoAACoQVQAAFYgqAIAKRBUAQAWiCgCgAlEFAFCBqAIAqEBUAQBUIKoAACoQVQAAFYgqAIAKRBUAQAWiCoAd0mc+85l84QtfSJJ87GMfy+GHH54kufXWW3Pqqadm4cKFaW9vz9ixY3POOedsftwuu+ySc845JxMnTsyRRx6ZO++8M4ceemj+4A/+IP/0T/+UJOnt7c3UqVPT2dmZzs7O/OAHP0iSfO9738uhhx6aE088MWPGjMkpp5ySpmn6+cppFVEFwA5p2rRpueOOO5Ik3d3dWb9+fZ555pksXbo0BxxwQM4555x897vfzYoVK3LXXXflG9/4RpLkl7/8ZQ499ND09PRk1113zSc+8YnccsstufHGG3P++ecnSV7/+tfnlltuyfLly7No0aJ85CMf2Xzeu+++O5///OezatWq/PjHP873v//9/r94WkJUAbBDmjhxYnp6evLkk09m+PDhmTRpUrq7u3PHHXdk9913z6GHHpo999wzw4YNyymnnJLbb789SbLzzjvn2GOPTZK0t7dn+vTp2WmnndLe3p7e3t4kyTPPPJMPfOADaW9vz6xZs7Jq1arN533rW9+affbZJ0OGDMn48eM3P4Yd37BWDwAA28JOO+2Utra2XHHFFXn729+ejo6O3HbbbXn44Yez7777pqen50UfV0pJkgwZMiTDhw/f/HVfX1+S5HOf+1ze8IY35J577smmTZsyYsSIzY//zfFJMnTo0M2PYcdnpwqAHda0adOyYMGCTJs2LVOnTs2ll16a8ePH55BDDsm//Mu/5LHHHsvGjRuzcOHCTJ8+/WWvu27duuy1114ZMmRIvvKVr2Tjxo3b8CoYKEQVADusqVOn5uc//3kmTZqUN7zhDRkxYkSmTp2avfbaK3/1V3+Vww47LOPGjUtnZ2dOOOGEl73uhz/84Vx11VU55JBD8sADD2TkyJHb8CoYKEorfiuhq6ur6e7u7vfzAgC8UqWUnqZpurZ0nJ0qAIAKRBUAQAWiCgCgAlEFAFCBqAIAqEBUAQBUIKoAACoQVQAAFYgqAIAKRBUAQAWiCgCgAlEFAFCBqAIAqEBUAQBUIKoAACoQVQAAFYgqAIAKRBUAQAWiCgCgAlEFAFCBqAIAqEBUAQBUIKoAACoQVQAAFWxVVJVSZpVSflRK2VRK6ao1FADAQLO1O1Urk8xMcnuFWQAABqxhW/PgpmnuT5JSSp1pAAAGKO+pAgCoYIs7VaWUJUl+/wXuOq9pmn98uScqpZyR5Iwk2XfffV/2gAAAA8EWo6ppmiNrnKhpmsuSXJYkXV1dTY01AQC2F17+AwCoYGs/UmFGKeVnSSYluamU8p06YwEADCxb+9t/Nya5sdIsAAADlpf/AAAqEFUAABWIKgCACkQVAEAFogoAoAJRBQBQgagCAKhAVAEAVCCqAAAqEFUAABWIKgCACkQVAEAFogoAoAJRBQBQgagCAKhAVMEO7tOf/nTe/OY358gjj8zJJ5+cBQsW5NBDD013d3eS5LHHHktbW1uSZOPGjTn77LPzlre8JR0dHfk//+f/bF7n4osv3vz9Cy64IEnS29ubAw88MB/4wAdy8MEH5+ijj86vfvWrfr9GgO2BqIIdWE9PT6699trcfffd+frXv5677rrrJY+//PLLM2rUqNx1112566678uUvfzn/9//+3yxevDgPPvhg7rzzzqxYsSI9PT25/fbbkyQPPvhg/uiP/ig/+tGPsvvuu+eGG27oj0sD2O4Ma/UAwLZzxx13ZMaMGXnta1+bJDn++ONf8vjFixfn3nvvzfXXX58kWbduXR588MEsXrw4ixcvzoQJE5Ik69evz4MPPph99903+++/f8aPH58kmThxYnp7e7fdBQFsx0QV7OBKKc/73rBhw7Jp06YkyVNPPbX5+03T5H/9r/+VY4455jnHf+c738m5556bD37wg8/5fm9vb4YPH7759tChQ738BwxaXv6DHdi0adNy44035le/+lWefPLJfPOb30yStLW1paenJ0k270olyTHHHJNLLrkkzzzzTJLkgQceyC9/+cscc8wx+fu///usX78+SfJv//Zv+Y//+I9+vhqA7ZudKtiBdXZ25n3ve1/Gjx+f/fbbL1OnTk2SnHXWWXnve9+br3zlKzn88MM3H3/66aent7c3nZ2daZome+65Z77xjW/k6KOPzv33359JkyYlSXbZZZdcffXVGTp0aEuuC2B7VJqm6feTdnV1Nb/5zSOg/3zyk5/MLrvskrPOOqvVowAMGKWUnqZpurZ0nJf/AAAq8PIfDCKf/OQnWz0CwA7LThUAQAWiCgCgAlEFAFCBqAIAqEBUAQBUIKoAACoQVQAAFYgqAIAKRBUwKJ1//vlZsmRJq8cAdiA+UR0YdDZu3Jj58+e3egxgB2OnCtih9Pb2ZsyYMZk9e3Y6Ojpy4oknZsOGDWlra8v8+fMzZcqUXHfddZkzZ06uv/76JElbW1suuOCCdHZ2pr29PatXr06SrF+/Pqeddlra29vT0dGRG264IUmyePHiTJo0KZ2dnZk1a1bWr1+fJPmzP/uzHHTQQeno6Nj8R6uvu+66jB07NuPGjcu0adNa8BMB+oudKmCHs2bNmlx++eWZPHly5s6dmy996UtJkhEjRmTp0qVJkptvvvk5j9ljjz2yfPnyfOlLX8qCBQvyd3/3d7nwwgszatSo3HfffUmSJ554Io899lguuuiiLFmyJCNHjsxf//Vf57Of/WzOPPPM3HjjjVm9enVKKfnFL36RJJk/f36+853v5I1vfOPm7wE7JjtVwA5n9OjRmTx5cpLk1FNP3RxS73vf+170MTNnzkySTJw4Mb29vUmSJUuW5I/+6I82H/O6170u//qv/5pVq1Zl8uTJGT9+fK666qr85Cc/yW677ZYRI0bk9NNPz9e//vW89rWvTZJMnjw5c+bMyZe//OVs3LhxW1wusJ2wUwXscEopL3h75MiRL/qY4cOHJ0mGDh2avr6+JEnTNM9bq2maHHXUUVm4cOHz1rjzzjtz66235tprr80Xv/jFfPe7382ll16aH/7wh7npppsyfvz4rFixIr/3e7+3VdcHbJ/sVAE7nEceeSTLli1LkixcuDBTpkx5VescffTR+eIXv7j59hNPPJFDDjkk3//+9/PQQw8lSTZs2JAHHngg69evz7p16/LOd74zn//857NixYokycMPP5y3ve1tmT9/fvbYY4/89Kc/3cqrA7ZXogrY4Rx44IG56qqr0tHRkccffzzz5s17Vet84hOfyBNPPLH5jea33XZb9txzz1x55ZU5+eST09HRkUMOOSSrV6/Ok08+meOOOy4dHR2ZPn16Pve5zyVJzj777LS3t2fs2LGZNm1axo0bV/NSge1IaZqm30+W3u9hAAAMqElEQVTa1dXVdHd39/t5gR1fb29vjjvuuKxcubLVowA7iFJKT9M0XVs6zk4VAEAFogrYobS1tdmlAlpCVAEAVCCqAAAqEFUAABWIKgCACkQVAEAFogoAoAJRBQBQgagCAKhAVAEAVCCqAAAq2KqoKqVcXEpZXUq5t5RyYyll91qDAQAMJFu7U3VLkrFN03QkeSDJuVs/EgDAwLNVUdU0zeKmafqevfmvSfbZ+pEAAAaemu+pmpvkn1/szlLKGaWU7lJK99q1ayueFgCg9YZt6YBSypIkv/8Cd53XNM0/PnvMeUn6knz1xdZpmuayJJclSVdXV/OqpgUA2E5tMaqapjnype4vpcxOclySI5qmEUsAwKC0xah6KaWUY5Ock2R60zQb6owEADDwbO17qr6YZNckt5RSVpRSLq0wEwDAgLNVO1VN0/w/tQYBABjIfKI6AEAFogoAoAJRBQBQgagCAKhAVAEAVCCqAAAqEFUAABWIKgCACkQVAEAFogoAoAJRBQBQgagCAKhAVAEAVCCqAAAqEFUAABWIKgCACkQVAEAFogoAoAJRBQBQgagCAKhAVAEAVCCqAAAqEFUAABWIKgCACkQVAEAFogoAoAJRBQBQgagCAKhAVAEAVCCqaKlvfOMbWbVq1ebbV155ZR599NEWTgQAr46oYpvbuHHji94nqgDYUYgqnuP888/PkiVLXvC+OXPm5Prrr3/O93p7ezNmzJjMnj07HR0dOfHEE7Nhw4a0tbVl/vz5mTJlSq677ro8/PDDOfbYYzNx4sRMnTo1q1evzg9+8IP80z/9U84+++yMHz8+f/3Xf53u7u6ccsopGT9+fG666abMmDFj87luueWWzJw5c5tePwC8WsNaPQDbl/nz57/g919qt2nNmjW5/PLLM3ny5MydOzdf+tKXkiQjRozI0qVLkyRHHHFELr300hxwwAH54Q9/mA9/+MP57ne/m+OPPz7HHXdcTjzxxCTJP//zP2fBggXp6upK0zT50z/906xduzZ77rlnrrjiipx22mmVrxgA6hBVg9iFF16Yr371qxk9enT22GOPTJw4MStXrtwcOW1tbZk7d24WL16cM88880XXGT16dCZPnpwkOfXUU/OFL3whSfK+970vSbJ+/fr84Ac/yKxZszY/5umnn97ifKWUvP/978/VV1+d0047LcuWLcs//MM/bM0lA8A2I6oGqe7u7txwww25++6709fXl87OzkycOPF5x/32btPNN9/8gmuVUl7w9siRI5MkmzZtyu67754VK1a84jlPO+20vPvd786IESMya9asDBvmf7IAbJ+8p2qQWrp0aU444YS85jWvya677pp3v/vdL3jcb3abXsojjzySZcuWJUkWLlyYKVOmPOf+3XbbLfvvv3+uu+66JEnTNLnnnnuSJLvuumuefPLJzcf+7u299947e++9dy666KLMmTPnFV0jAPQnUTVINU3zso77zW7TSznwwANz1VVXpaOjI48//njmzZv3vGO++tWv5vLLL8+4ceNy8MEH5x//8R+TJCeddFIuvvjiTJgwIQ8//HDmzJmTD33oQxk/fnx+9atfJUlOOeWUjB49OgcddNAruEIA6F9eSxmkpkyZkg9+8IM599xz09fXl5tuuikf+MAHXtVaQ4YMyaWXXvqc7/X29j7n9v777/+CLx9Onjz5OR+p8N/+23/Le97znuccs3Tp0lc9GwD0F1E1SL3lLW/J8ccfn3HjxmW//fZLV1dXRo0a1eqxnmfixIkZOXJk/uZv/qbVowDASyov92Wgmrq6upru7u5+Py/PtX79+uyyyy7ZsGFDpk2blssuuyydnZ2tHqvl3vnOd+aaa67J7rvv3upRANgOlFJ6mqbp2tJxdqoGsTPOOCOrVq3KU089ldmzZwuqZ337299u9QgADEDeqD6IXXPNNVmxYkVWr16dc889t9Xj9JvPfOYzmz9L62Mf+1gOP/zwJMmtt96aU089NW1tbXnsscfyy1/+Mu9617sybty4jB07NosWLUqS9PT0ZPr06Zk4cWKOOeaY/PznP2/ZtQCw/RBVDDrTpk3LHXfckeTXn9e1fv36PPPMM1m6dGmmTp26+bibb745e++9d+65556sXLkyxx57bJ555pn88R//ca6//vr09PRk7ty5Oe+881p1KQBsR0QVg87EiRPT09OTJ598MsOHD8+kSZPS3d2dO+644zlR1d7eniVLluScc87JHXfckVGjRmXNmjVZuXJljjrqqIwfPz4XXXRRfvazn7XwagDYXnhPFYPOTjvtlLa2tlxxxRV5+9vfno6Ojtx22215+OGHc+CBB24+7k1velN6enry7W9/O+eee26OPvrozJgxIwcffPDmDzsFgN+wU8WgNG3atCxYsCDTpk3L1KlTc+mll2b8+PHP+ZM7jz76aF772tfm1FNPzVlnnZXly5fnzW9+c9auXbs5qp555pn86Ec/atVlALAdsVPFoDR16tR8+tOfzqRJkzJy5MiMGDHiOS/9Jcl9992Xs88+O0OGDMlOO+2USy65JDvvvHOuv/76fOQjH8m6devS19eXj370ozn44INbdCUAbC98ThUAwEt4uZ9T5eU/AIAKRBUAQAWiCgCgAlEFAFDBVkVVKeXCUsq9pZQVpZTFpZS9aw0GADCQbO1O1cVN03Q0TTM+ybeSnF9hJgCAAWeroqppmv/8rZsjk/T/5zMAAGwHtvrDP0spn07yP5KsS3LYVk8EADAAbXGnqpSypJSy8gX+OSFJmqY5r2ma0Um+muTMl1jnjFJKdymle+3atfWuAABgO1DtE9VLKfslualpmrFbOtYnqgMAA0W/fKJ6KeWA37p5fJLVW7MeAMBAtbXvqfqfpZQ3J9mU5CdJPrT1IwEADDxbFVVN07yn1iAAAAOZT1QHAKhAVAEAVCCqAAAqEFUAABWIKgCACkQVAEAFogoAoAJRBQBQgagCAKhAVAEAVCCqAAAqEFUAABWIKgCACkQVAEAFogoAoAJRBQBQgagCAKhAVAEAVCCqAAAqEFUAABWIKgCACkQVAEAFogoAoAJRBQBQgagCAKhAVAEAVCCqAAAqEFUAABWIKgCACkQVAEAFogoAoAJRBQBQgagCAKhAVAEAVCCqAAAqEFUAABWIKgCACkQVAEAFogoAoAJRBQBQgagCAKhAVAEAVCCqAAAqEFUAABWIKgCACkQVAEAFogoAoILSNE3/n7SUtUl+0u8nbr09kjzW6iEGOc9Ba/n5t57noPU8B633Sp+D/Zqm2XNLB7UkqgarUkp30zRdrZ5jMPMctJaff+t5DlrPc9B62+o58PIfAEAFogoAoAJR1b8ua/UAeA5azM+/9TwHrec5aL1t8hx4TxUAQAV2qgAAKhBVLVJKOauU0pRS9mj1LINJKeXCUsq9pZQVpZTFpZS9Wz3TYFNKubiUsvrZ5+HGUsrurZ5psCmlzCql/KiUsqmU4rfQ+kkp5dhSyppSykOllD9r9TyDTSnl70sp/1FKWbmtziGqWqCUMjrJUUkeafUsg9DFTdN0NE0zPsm3kpzf6oEGoVuSjG2apiPJA0nObfE8g9HKJDOT3N7qQQaLUsrQJP87yTuSHJTk5FLKQa2datC5Msmx2/IEoqo1Ppfk40m8oa2fNU3zn791c2Q8B/2uaZrFTdP0PXvzX5Ps08p5BqOmae5vmmZNq+cYZN6a5KGmaX7cNM1/Jbk2yQktnmlQaZrm9iSPb8tzDNuWi/N8pZTjk/xb0zT3lFJaPc6gVEr5dJL/kWRdksNaPM5gNzfJolYPAf3gjUl++lu3f5bkbS2ahW1EVG0DpZQlSX7/Be46L8mfJzm6fycaXF7q5980zT82TXNekvNKKecmOTPJBf064CCwpefg2WPOS9KX5Kv9Odtg8XKeA/rVC/2/aDvlOxhRtQ00TXPkC32/lNKeZP8kv9ml2ifJ8lLKW5um+X/7ccQd2ov9/F/ANUluiqiqbkvPQSlldpLjkhzR+FyXbeIV/HtA//hZktG/dXufJI+2aBa2EVHVj5qmuS/J639zu5TSm6SraRp/WLOflFIOaJrmwWdvHp9kdSvnGYxKKccmOSfJ9KZpNrR6HugndyU5oJSyf5J/S3JSkv/e2pGozRvVGWz+ZyllZSnl3vz6Zdg/afVAg9AXk+ya5JZnP9ri0lYPNNiUUmaUUn6WZFKSm0op32n1TDu6Z38548wk30lyf5KvNU3zo9ZONbiUUhYmWZbkzaWUn5VS/rD6Oey8AwBsPTtVAAAViCoAgApEFQBABaIKAKACUQUAUIGoAgCoQFQBAFQgqgAAKvj/AF/v+l5TFDz/AAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "fig, ax = plt.subplots()\n", + "\n", + "for word, x1, x2 in zip(w2v_df['word'], w2v_df['x1'], w2v_df['x2']):\n", + " ax.annotate(word, (x1,x2 ))\n", + " \n", + "PADDING = 1.0\n", + "x_axis_min = np.amin(vectors, axis=0)[0] - PADDING\n", + "y_axis_min = np.amin(vectors, axis=0)[1] - PADDING\n", + "x_axis_max = np.amax(vectors, axis=0)[0] + PADDING\n", + "y_axis_max = np.amax(vectors, axis=0)[1] + PADDING\n", + " \n", + "plt.xlim(x_axis_min,x_axis_max)\n", + "plt.ylim(y_axis_min,y_axis_max)\n", + "plt.rcParams[\"figure.figsize\"] = (10,10)\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.5" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From b8f279823dd15b075f9f571553790c3d4246a980 Mon Sep 17 00:00:00 2001 From: minsuk-heo Date: Sun, 14 Oct 2018 00:38:55 -0700 Subject: [PATCH 2/8] voting added --- data_science/ensemble/voting.ipynb | 279 +++++++++++++++++++++++++++++ 1 file changed, 279 insertions(+) create mode 100644 data_science/ensemble/voting.ipynb diff --git a/data_science/ensemble/voting.ipynb b/data_science/ensemble/voting.ipynb new file mode 100644 index 0000000..e42e8d3 --- /dev/null +++ b/data_science/ensemble/voting.ipynb @@ -0,0 +1,279 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Voting\n", + "Based on the idea that classifiers can complement each other, \n", + "Aggregating individual classifier's prediction to make better prediction." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn import datasets\n", + "from sklearn import tree\n", + "from sklearn.neighbors import KNeighborsClassifier\n", + "from sklearn.svm import SVC\n", + "from sklearn.ensemble import VotingClassifier\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn.metrics import accuracy_score" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# load mnist dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "mnist = datasets.load_digits()\n", + "features, labels = mnist.data, mnist.target\n", + "X_train,X_test,y_train,y_test=train_test_split(features,labels,test_size=0.2)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# single classifiers accuracy on mnist\n", + "build decision tree, knn, svm and check accuracy on MNIST data." + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "dtree = tree.DecisionTreeClassifier(\n", + " criterion=\"gini\", max_depth=8, max_features=32,random_state=35)\n", + "\n", + "dtree = dtree.fit(X_train, y_train)\n", + "dtree_predicted = dtree.predict(X_test)\n", + "\n", + "knn = KNeighborsClassifier(n_neighbors=299).fit(X_train, y_train)\n", + "knn_predicted = knn.predict(X_test)\n", + "\n", + "svm = SVC(C=0.1, gamma=0.003,\n", + " probability=True,random_state=35).fit(X_train, y_train)\n", + "svm_predicted = svm.predict(X_test)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[accuarcy]\n", + "d-tree: 0.7972222222222223\n", + "knn : 0.8416666666666667\n", + "svm : 0.85\n" + ] + } + ], + "source": [ + "print(\"[accuarcy]\")\n", + "print(\"d-tree: \",accuracy_score(y_test, dtree_predicted))\n", + "print(\"knn : \",accuracy_score(y_test, knn_predicted))\n", + "print(\"svm : \",accuracy_score(y_test, svm_predicted))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "we can easily do soft voting or hard voting using sklearn's voting classifier \n", + "when you want to implement soft voting by scratch, you can use predict_proba just like below, \n", + "Below is the example of SVM's prediction (digit 0 to 9) on two MNIST data." + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[9.95557918e-01 3.42018637e-04 4.57700824e-04 4.19160266e-04\n", + " 4.21146304e-04 7.99436984e-04 4.11439277e-04 6.08753549e-04\n", + " 4.33211441e-04 5.49214707e-04]\n", + " [2.86586264e-03 4.17512273e-03 4.28013091e-03 4.14650212e-03\n", + " 9.27814553e-01 2.24791840e-02 3.06764221e-03 9.50855980e-03\n", + " 1.51437526e-02 6.51868962e-03]]\n" + ] + } + ], + "source": [ + "svm_proba = svm.predict_proba(X_test)\n", + "print(svm_proba[0:2])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# hard voting\n", + "hard voting is just majority vote which collects each classifier's prediction and take the most voted prediction." + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/anaconda3/envs/wikiml/lib/python3.6/site-packages/sklearn/preprocessing/label.py:151: DeprecationWarning: The truth value of an empty array is ambiguous. Returning False, but in future this will result in an error. Use `array.size > 0` to check that an array is not empty.\n", + " if diff:\n" + ] + }, + { + "data": { + "text/plain": [ + "0.9083333333333333" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "voting_clf = VotingClassifier(estimators=[\n", + " ('decision_tree', dtree), ('knn', knn), ('svm', svm)], \n", + " weights=[1,1,1], voting='hard').fit(X_train, y_train)\n", + "hard_voting_predicted = voting_clf.predict(X_test)\n", + "accuracy_score(y_test, hard_voting_predicted)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# soft voting\n", + "soft voting takes each classifier's predict_proba and then sum up all probabilities to take the prediction has highest probabilities." + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/anaconda3/envs/wikiml/lib/python3.6/site-packages/sklearn/preprocessing/label.py:151: DeprecationWarning: The truth value of an empty array is ambiguous. Returning False, but in future this will result in an error. Use `array.size > 0` to check that an array is not empty.\n", + " if diff:\n" + ] + }, + { + "data": { + "text/plain": [ + "0.9138888888888889" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "voting_clf = VotingClassifier(estimators=[\n", + " ('decision_tree', dtree), ('knn', knn), ('svm', svm)], \n", + " weights=[1,1,1], voting='soft').fit(X_train, y_train)\n", + "soft_voting_predicted = voting_clf.predict(X_test)\n", + "accuracy_score(y_test, soft_voting_predicted)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Visualization\n", + "we can visualize accuracy to check voting result is stabled or better than single model accuracy. \n", + "it is hard to say which voting is better, but we can confirm classifiers complement each other, \n", + "and voting result is better in this example." + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvqOYd8AAAEepJREFUeJzt3XvQHXV9x/H3h2BEES8lqVUghiqoqVaoGbwgikpbwAo4oEK1LQ6V6QVtvc3QwTIWrVXROrViK7SKYpWLiqYYDZWKUK2YIBdJMDQTUFLaMSpSURGRb//YjZwcT/Kc58l58iQ/3q+ZzLOX39n97e5vP2fP75zdpKqQJLVll7mugCRp8gx3SWqQ4S5JDTLcJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoN2nasVL1iwoBYvXjxXq5ekndLVV1/9napaOFW5OQv3xYsXs2rVqrlavSTtlJJ8c5xydstIUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12SGmS4S1KD5uwOVUmajsWnfmauqzAxt7ztBbO+DsNd2om0EnDbI9zu7+yWkaQGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12SGmS4S1KDfHCYdiqtPDgLfHiWZpdX7pLUIK/cd0KtXL165SrNHq/cJalBhrskNchwl6QGGe6S1CDDXZIaNFa4Jzk8ydok65KcOmL+oiRfSHJNkuuTHDn5qkqSxjVluCeZB5wFHAEsAU5IsmSo2BuBC6vqQOB44H2TrqgkaXzjXLkfBKyrqvVVdTdwPnD0UJkCHtoPPwy4bXJVlCRN1zg3Me0F3DowvgF42lCZNwGXJnkVsDtw2ERqJ0makXHCPSOm1dD4CcC5VfWuJM8AzkvypKq6d7MFJScDJwMsWrRoJvUF2rlDE7xLU9LsGKdbZgOwz8D43vxit8tJwIUAVfWfwG7AguEFVdXZVbW0qpYuXLhwZjWWJE1pnHBfCeyXZN8k8+m+MF02VOZbwPMBkjyRLtw3TrKikqTxTRnuVXUPcAqwAriR7lcxq5OckeSovtjrgFcmuQ74GHBiVQ133UiStpOxngpZVcuB5UPTTh8YXgMcPNmqSZJmyjtUJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12SGmS4S1KDDHdJapDhLkkNMtwlqUGGuyQ1yHCXpAYZ7pLUIMNdkhpkuEtSgwx3SWqQ4S5JDTLcJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktSgscI9yeFJ1iZZl+TULZR5SZI1SVYn+ehkqylJmo5dpyqQZB5wFvCbwAZgZZJlVbVmoMx+wF8AB1fV7Ul+ebYqLEma2jhX7gcB66pqfVXdDZwPHD1U5pXAWVV1O0BVfXuy1ZQkTcc44b4XcOvA+IZ+2qD9gf2TfCnJV5IcPqkKSpKmb8puGSAjptWI5ewHHArsDVyZ5ElV9f3NFpScDJwMsGjRomlXVpI0nnGu3DcA+wyM7w3cNqLMp6vqp1V1M7CWLuw3U1VnV9XSqlq6cOHCmdZZkjSFccJ9JbBfkn2TzAeOB5YNlfkU8FyAJAvoumnWT7KikqTxTRnuVXUPcAqwArgRuLCqVic5I8lRfbEVwHeTrAG+ALyhqr47W5WWJG3dOH3uVNVyYPnQtNMHhgt4bf9PkjTHvENVkhpkuEtSgwx3SWqQ4S5JDTLcJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12SGmS4S1KDDHdJapDhLkkNMtwlqUGGuyQ1yHCXpAYZ7pLUIMNdkhpkuEtSgwx3SWqQ4S5JDTLcJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQWOFe5LDk6xNsi7JqVspd1ySSrJ0clWUJE3XlOGeZB5wFnAEsAQ4IcmSEeX2AF4NXDXpSkqSpmecK/eDgHVVtb6q7gbOB44eUe7NwDuAuyZYP0nSDIwT7nsBtw6Mb+in/VySA4F9quqSrS0oyclJViVZtXHjxmlXVpI0nnHCPSOm1c9nJrsA7wZeN9WCqursqlpaVUsXLlw4fi0lSdMyTrhvAPYZGN8buG1gfA/gScDlSW4Bng4s80tVSZo744T7SmC/JPsmmQ8cDyzbNLOq7qiqBVW1uKoWA18BjqqqVbNSY0nSlKYM96q6BzgFWAHcCFxYVauTnJHkqNmuoCRp+nYdp1BVLQeWD007fQtlD932akmStoV3qEpSgwx3SWqQ4S5JDTLcJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12SGmS4S1KDDHdJapDhLkkNMtwlqUGGuyQ1yHCXpAYZ7pLUIMNdkhpkuEtSgwx3SWqQ4S5JDTLcJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQWOFe5LDk6xNsi7JqSPmvzbJmiTXJ7ksyWMmX1VJ0rimDPck84CzgCOAJcAJSZYMFbsGWFpVvw58HHjHpCsqSRrfOFfuBwHrqmp9Vd0NnA8cPVigqr5QVT/qR78C7D3ZakqSpmOccN8LuHVgfEM/bUtOAj47akaSk5OsSrJq48aN49dSkjQt44R7RkyrkQWTlwNLgTNHza+qs6tqaVUtXbhw4fi1lCRNy65jlNkA7DMwvjdw23ChJIcBpwHPqaqfTKZ6kqSZGOfKfSWwX5J9k8wHjgeWDRZIciDwfuCoqvr25KspSZqOKcO9qu4BTgFWADcCF1bV6iRnJDmqL3Ym8BDgoiTXJlm2hcVJkraDcbplqKrlwPKhaacPDB824XpJkraBd6hKUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12SGmS4S1KDDHdJapDhLkkNMtwlqUGGuyQ1yHCXpAYZ7pLUIMNdkhpkuEtSgwx3SWqQ4S5JDTLcJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12SGmS4S1KDDHdJapDhLkkNMtwlqUFjhXuSw5OsTbIuyakj5j8wyQX9/KuSLJ50RSVJ45sy3JPMA84CjgCWACckWTJU7CTg9qp6HPBu4O2TrqgkaXzjXLkfBKyrqvVVdTdwPnD0UJmjgQ/1wx8Hnp8kk6umJGk6xgn3vYBbB8Y39NNGlqmqe4A7gD0nUUFJ0vTtOkaZUVfgNYMyJDkZOLkfvTPJ2jHWP5cWAN+ZzRVkx+3Acttn2f15++/P2w7bvP2PGafQOOG+AdhnYHxv4LYtlNmQZFfgYcD3hhdUVWcDZ49TsR1BklVVtXSu6zEX3Pb757bD/Xv7W9r2cbplVgL7Jdk3yXzgeGDZUJllwB/0w8cB/15Vv3DlLknaPqa8cq+qe5KcAqwA5gEfqKrVSc4AVlXVMuCfgfOSrKO7Yj9+NistSdq6cbplqKrlwPKhaacPDN8FvHiyVdsh7DRdSLPAbb//uj9vfzPbHntPJKk9Pn5Akhq004R7kjclef0MX/vlKeYvT/LwmdVss+UcM+Lu3R1SksVJbpjremjrZuM4JbklyYJtXMbDk/zJwPijk3x822s3GUkOSbI6ybVJnpjkdye03AOSHDkwftSoR7LsCHaacN8WVfXMKeYfWVXfn8CqjqF7RMMv6H8iKm1Xs9juHg78PNyr6raqOm6W1jUTLwPeWVUHAI8EJhLuwAHAz8O9qpZV1dsmtOzJqqod9h9wGrAW+DzwMeD1/fTHAp8DrgauBJ7QT38kcDFwXf/vmf30O/u/jwKuAK4FbgAO6affAizoh1/bz7sB+PN+2mLgRuAcYDVwKfCgobo+k+6XQjf3y38scDnwVuCLwOuAhcAn6H5euhI4uH/t7sAH+mnXAEdvh327GLihH/7Vfr1vAD7Z79v/At4xUP5O4K/7/foV4JFz3T5muN27A5/pt+MGup/wXjgw/1DgXwe2+e19O/s83aM4LgfWA0dtp/puse0Br+zbzHV9u3pwP/1c4G+BLwDvortb/NL+GL8f+Oam9j6wnj8eOt4nAn+/lXPifODHfVs/c6g9nbiVdnQScFO/H88B3jvD4/bSfvrz++36en8OPRD4Q+47F/+lb6939HV9zdByLwCOHBg/FzgW2A34YL/ca4DnAvOBbwEb+2W9tN/W9w689j3Al/s2clw/fRfgff3xu4TuxynHzXrbmeuTbSsH86n9jn0w8FBgHfeF+2XAfv3w0+h+V7/pQG1qfPOAh206Sfu/rwNOG5i/Rz98C92daZvWuTvwkP5gHNg33HuAA/ryFwIvH1HncwcPWt+A3zcw/lHgWf3wIuDGfvitm5ZHd0V0E7D7LO/fxf1J8vi+8R7QN9T1dDeh7UYXAvv05Qt4YT/8DuCNc91GZrjdxwLnDIw/rD9hd+/H/2HgWBRwRD98MV1APgB4CnDtdqrvFtsesOdAubcArxpoh5cA8/rx9wCn98Mv6LdrONwX0j1DatP4Z4FnTXFO3DDcnvrhke0IeDTdufZL/X68kvHDfdRx243usSf799M+zH3n/7ncF66HApdsYbkvAj7UD8/vl/cguqz4YD/9CX0b2Y2BMB/Y1sFwv4guzJds2p909/4s76f/CnA72yHcd+RumUOAi6vqR1X1f/Q3TiV5CN1V8kVJrqW7EnlU/5rn0Z2cVNXPquqOoWWuBF6R5E3Ak6vqB0Pzn9Wv84dVdSfd1cch/bybq+rafvhqusY8jgsGhg8D3tvXexnw0CR7AL8FnNpPv5yuES0ac/nbYiHwabqw2LRtl1XVHdX9vHUN993qfDddYMD0tn9H83XgsCRvT3JI30Y+B7yw78J4Ad0+gW6bPzfwui9W1U/74cXbsc5bantPSnJlkq/TdUP82sBrLqqqn/XDzwY+AlBVn6ELl81U1UZgfZKnJ9mT7k3/S2z9nNiaUe3oILp9+L1+P1405vbD6OP2eLp9c1Nf5kP9tk7HZ4HnJXkg3ZNvr6iqH9Nt93kAVfUNujeo/cdY3qeq6t6qWkPXk0C/rIv66f9L94lq1u3o/cCjfqe5C/D96vrSprewqiuSPJvuBD4vyZlV9eGBIlt7kuVPBoZ/RvfuPo4fDgzvAjyjbzz3rbR7guaxVbW9n7VzB92VysF0V2Twi9u5qY38tPrLkKHpO5WquinJU+n6Tf8myaV0b8B/SvdRfuXAm/7gNt9Lv2+q6t7t/B3KltreucAxVXVdkhPprlA3GWx3MPpcGnYB8BLgG3SBXtvwdNdR7WjGT4rdwnEbvlN+Jsu9K8nlwG/TdbN8rJ81ie3O0N/take+cr8CeFGSB/VXty8E6K/ib07yYuiCMclT+tdcRtd3SJJ5SR46uMAkjwG+XVXn0N1V+xsj1nlMkgcn2Z3uI9uV06jzD4A9tjL/UuCUgfpseoNaAbxq04mU5MBprHNb3E33JfDvT+rXBDu6JI8GflRVHwHeSdcGLu//vpLNP2nt6PYA/ifJA+iu3Lfkik3zkxwBPGIL5T5J1x5O4L79sKVzYqq2PspXgeckeUT/5njsuC/cwnH7BrA4yeP6Yr9H9/3WsKnqej7wCrpPJCv6aYP7bH+6T9Jrx1jWKP8BHJtklySPZPM34Vmzw4Z7VX2NroFdS/dl0WDIvgw4Kcl1dFecm54v/2fAc/uPqVez+cdU6HbqtUmuoWtYfzdinefSNcKrgH+qqmumUe3zgTckuSbJY0fMfzWwNMn1SdYAf9RPfzNdH+T1/c/e3jyNdW6Tqvoh8DvAa+j6MVv3ZOCrfRfYacBb+u6LS+g+ll+ytRfvYP6Srp3+G13QbclfAc9O8jW6LsBvjSpUVbfTd6FU1Vf7aSPPiar6LvClJDckOXOcylbVf9N9v3QV3RfUa+g+PY5j1HG7iy6UL+rP+XuBfxzx2uuBe5Jcl+Q1I+ZfSted8/nq/s8K6L4Andcv9wLgxKr6CV2XypL+J5YvHbPun6B7uOINdN3IVzH+ds+Yd6hK2m6SPKSq7uyv3C+me1bVxXNdr9k2sN170r1RHtz3v8+anbLfVNJO601JDqP70cClwKfmuD7byyX9jZLzgTfPdrCDV+6S1KQdts9dkjRzhrskNchwl6QGGe6S1CDDXZIaZLhLUoP+H47Jp0tra/pcAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "%matplotlib inline\n", + "\n", + "x = np.arange(5)\n", + "plt.bar(x, height= [accuracy_score(y_test, dtree_predicted),\n", + " accuracy_score(y_test, knn_predicted),\n", + " accuracy_score(y_test, svm_predicted),\n", + " accuracy_score(y_test, hard_voting_predicted),\n", + " accuracy_score(y_test, soft_voting_predicted)])\n", + "plt.xticks(x, ['decision tree','knn','svm','hard voting','soft voting']);" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 5e255d87670567349cd5b6e615862d59a615b609 Mon Sep 17 00:00:00 2001 From: minsuk-heo Date: Sat, 20 Oct 2018 17:46:22 -0700 Subject: [PATCH 3/8] ramdomforest --- data_science/ensemble/randomforest.ipynb | 203 +++++++++++++++++++++++ 1 file changed, 203 insertions(+) create mode 100644 data_science/ensemble/randomforest.ipynb diff --git a/data_science/ensemble/randomforest.ipynb b/data_science/ensemble/randomforest.ipynb new file mode 100644 index 0000000..662d955 --- /dev/null +++ b/data_science/ensemble/randomforest.ipynb @@ -0,0 +1,203 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn import datasets\n", + "from sklearn import tree\n", + "from sklearn.ensemble import RandomForestClassifier\n", + "from sklearn.model_selection import cross_val_score\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import pandas as pd\n", + "import numpy as np" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Load MNIST dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "mnist = datasets.load_digits()\n", + "features, labels = mnist.data, mnist.target" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Cross Validation" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "def cross_validation(classifier,features, labels):\n", + " cv_scores = []\n", + "\n", + " for i in range(10):\n", + " scores = cross_val_score(classifier, features, labels, cv=10, scoring='accuracy')\n", + " cv_scores.append(scores.mean())\n", + " \n", + " return cv_scores" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "dt_cv_scores = cross_validation(tree.DecisionTreeClassifier(), features, labels)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "rf_cv_scores = cross_validation(RandomForestClassifier(), features, labels)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Random Forest VS Decision Tree visualization" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "cv_list = [ \n", + " ['random_forest',rf_cv_scores],\n", + " ['decision_tree',dt_cv_scores],\n", + " ]\n", + "df = pd.DataFrame.from_items(cv_list)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD8CAYAAACb4nSYAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAIABJREFUeJzt3Xd8VFX+//HXSSchkEACKCEkNOmEEIp0RaXo2lnERQULi2VFf3Z3VWTX9l3WFVcW1wJ2sWBBBQuIIoJIIKG3EEISSholjZD2+f1xJ8kklISQMJPcz/PxyCMz956Z+cwQ3vfec889Y0QEpZRS9uDh6gKUUkqdOxr6SillIxr6SillIxr6SillIxr6SillIxr6SillIxr6SillIxr6SillIxr6SillI16uLqCqkJAQiYiIcHUZSinVoKxbty5TREKra+d2oR8REUFsbKyry1BKqQbFGLO3Ju20e0cppWxEQ18ppWxEQ18ppWxEQ18ppWxEQ18ppWxEQ18ppWxEQ18ppWzE7cbpq7MjIqQePsb65MPkF5ZwXXQYPl66bVdKWTT0G7j8wmI2ph5lffJh4pKPEJd8hMzc4+XrP1iTzMsT+xIZEuDCKpVS7kJDvwEREZKy8lm/9zBxKVbIbz+YQ0mp9eX2HUICGNEllL7hQfQNDyLlUD6PLNzEFS//wt+v7sm10WEufgdKKVfT0HdjOQVF1l783sPEpRwhLvkwh/OLAGjq60VUuyDuHtmRvuHBRLULIjjAp9Lje5zfnN5hQdz3UTz/7+MN/LIrk5lX9SDQz9sVb0cp5QYaTeiXlgq7M3IJDfSleRNvjDGuLumMlJYKiZm5rN97hLiUw6zfe4Sd6TmItRNP51ZNuax7G8defDCdWjXF06P693h+UBM+vGMQc5Yn8NLSnaxPPszLN/SlT7ugen5HSil3ZKQsVdxETEyM1GbCtUN5hUT//QcAvD0NLQN8CQ30JaSpj+O39VN2OzTQl9CmvjRr4uWSDcTR/KLyLpq4lCPEJx8mu6AYgGZ+XvQNDyY6PJi+4UH0aRdE8yZnv3cem3SI6QviScsu4MHRFzB1WAc8arDhUEq5P2PMOhGJqbZdYwn9/MJiftiaRmZuIRk5x8nMtX4qbheW93078/H0IKSpDyGOjUBIU19CAn2s22XLHBuKZn6120CUlAq70nOsvfjkw6xPPszujDwAPAx0aR1IdPtg+rYLIrp9MJEtA+otjI/mF/HY5xtZvOkgQzuF8OIf+9CqmV+9vJZS6tyxXehXp7RUOHKsqNKGICPnOBm5x8nMKXT8tpZn5Z1iA+Hl4dgw+FQ6Yqh8FOGDv48XW/YfJS75COuTD7Mh5Qh5hSUAtAjwKQ/3vu2C6N0uiKa+57aXTUT4aG0KM77aQoCPF7PG9+Girq3OaQ1KqbqloX8WSkuFw/mF5RuEShuK8tvWEcWhvOOcZPsAgKeHodt5geXdNNHhwYS38Heb8w0J6Tnc80Ec2w/mcNvQSB4ecwG+Xp6uLkspVQsa+udISdkGwqlLKaegmAtaB9I7LIgmPu4dogVFJTy/ZDtvrUqix/nNeHliXzqGNnV1WUqpM6Shr87I0q1pPPTpBgqKSnn6qh6M7xfmNkckSqnq1TT09fp8BcAl3VuzZPpwotoF8fCnG7l3QTzZBUWuLkspVcc09FW5Ns39eO/2gTw0+gIWbzrAuNm/sD75sKvLUkrVIQ19VYmnh+HuizrxybQLARj/6mrmLE846WgmdWYSM3KZsWgLK3dluroUZWPap69OKbugiMc/28TXGw8wuGNL/j0hitY6pv+MJaTn8sqPu1i0YT+lAl4ehlnj+3B137auLk01Itqnr85aMz9v/jOxL/93fW/iko8w5qUVLNuW5uqyGoyE9BymL4jj0n//zHdb0rh9WAeWPziS/hEtuO+jeN5cucfVJSob0j19VSO7M3L5ywdxbD2QzeTBETw6tit+3u49HNVVdqXl8PKPCXy9cT9NvD256cL23DGsAyFNfQFrmOz9H8WzZPNBpo3oyCNjLtCRUuqs6ZBNVeeOF5fwwpIdzPt1D13bBPLKjX3p1CrQ1WW5jR0Hc3j5x10s3nQAf29Pbh4cwR3DOtCiyuynYF3f8eSXm3l/TTLj+4Xx3LW98PLUA29Vexr6qt4s357Og59sIK+wmBl/6MGE/u1svae6/WA2Ly/bxeJNB2nq68Utg9tz+9AOJ0x1XZWIMHvZLl5auotRXVvxyo3Rbn8xn3JfGvqqXqVnF/D/Pt7AyoRMLu91Hs9e26tOZgJtSLbut8L+2y1W2E8ZEsFtQyMJ8j992Ff13m97eeLLzfQLD+aNW2LO+PFKgYa+OgdKS4XXfklk1nc7aN3Mj5cnRtGvfQtXl1XvNu87ysvLdvH91jQCfb2YMjSS24ZE0ty/9hu9xZsOcN+CeCJC/Hn71gGc17xJHVas7EBDX50z8SlHuPfDOPYdOcb0UZ25+6JONfqCl4Zm876jvLR0F0u3pRHo58WtQyK59SzD3tmq3ZlMfWcdzZt48/atA+jUqnHPgSQiHM4vOuk5D3XmNPTVOZVTUMQTX2zmi/j9DIxswUs3RDWavdWNqUeYvXQXy7an08zPi9uGdmDykIh66c7avO8ok+evpaS0lHmT+9M3PLjOX8MdHDh6jMc/28TyHRlc07ctD4+5oNH8vbhKnYa+MWYMMBvwBN4QkeerrG8PzANCgUPAJBFJNcZEAXOBZkAJ8IyIfHS619LQb9g+W5/K377YjI+XBy9c15vRPdq4uqRa25ByhNnLdvHj9nSaN/Hm9qGR3DIkgmb1/B3De7PyuOnN38nIOc7cSdGMvKDxfNeBiLBgbQrPfrON4lJhbM82fL3pAB4Gpg7vyLQRHfD3aTTf4npO1VnoG2M8gZ3ApUAqsBaYKCJbndp8AnwtIm8bYy4GpojITcaYLoCIyC5jzPnAOqCbiBw51etp6Dd8ezLzuPfDODbtO8pNg9rz18u7Nagx/XHJh5m9bBc/7cggyN+bO4Z14OYL25/TL5TPyDnO5Pm/s+NgTqO5ejflUD6PfraRXxOyuLBDS164rjfhLf1JOZTPC99u5+uNB2jdzJeHR3flmr5tbfdVnkePFZGWXUCX1rUbBl2XoX8hMENERjvuPwYgIs85tdkCjHbs3RvgqIg0O8lzbQCuF5Fdp3o9Df3GobC4lFnf7+C1FYkE+XvTpVUgkSEBdAgNcPxuSngLf3y83Gds+rq9Vtiv2JlBsL83dwzvwM0XRpzzbzYrk1NQxNR31rE6MYu/Xd6N24d1cEkdZ6u0VHhndRIvfLsDTw/D4+O6MXHAicN81+09xMyvt7Eh5Qi9w5rzxBXd6R/R+AcGZOUeZ96ve3hn1V7aBjdhyfRhtRoCXZehfz0wRkRud9y/CRgoIvc4tfkAWCMis40x1wILgRARyXJqMwB4G+ghIqWnej0N/cbl14RMvozfx57MPPZk5pGZW1i+zsNAuxb+dAgJIDKkKZGhAXQMCSAyNIA2zfzO2dj/2KRDzF62i192ZdIiwIepwztw06D2BLgo7J0VFJXw/z6OZ/Gmg/x5RAceHdO1QV0TkZiRyyMLN7I26TAjuoTy7LW9aBt06r770lLhyw37eGHJDg5mFzCuVxseG9uNdi38z2HV58aBo8d4bUUiH/6ezPHiUsb1PI87R3akZ9vmtXq+moZ+Tf6qT/YXVnVL8SDwijFmMrAC2AcUOxVzHvAucMvJAt8YMxWYChAeHl6DklRDMaRTCEM6hZTfP3qsyLEByCUxI4/EzDz2ZOTxW+IhjhWVlLdr4u1JpGMDULYhiAxpSofQgDrrU/99zyFmL9vJrwlZhDT14fFxXZk0qL1b9Sn7eXvyn4nRtAjYzP9+TiQrt5DnG8DVuyWlwpsrE/nX9zvx9fJg1vg+XBfdttoNloeH4Zq+YYzpcR6vrUjk1Z93s3RrOlOGRnDPRZ3OaRdbfUnKzOPVn3ezcH0qpQJXR7XlzpEdz9lorTrp3qnSvimwXUTCHPebAT8Bz4nIJ9UVpHv69lRaKqTlFLAnI4/djg3BnsxcEjPzSDmUX+l7iEOa+tAhpGn5RqGDo9sovEVAjbqLfkvMYvbSXaxOtML+z8M78qdB4W4V9lU1pKt3d6bl8NCnG9mQcoRLu7fmmat70qqWs7MePFrAP7/bwcL1qbQM8OGByy5gQv92DXJI8I6DOcxZbs3J5OXpwYSYdkwd3qHOjmLqsnvHC+tE7iisPfi1wI0issWpTQhwSERKjTHPACUi8qQxxgdYAnwlIi/VpHANfVVVYXEpyYfySczIZU9mHokZVldRYmYembnHy9uVdRdFhgRYGwWnDUKbZn78lniIl5buZM2eQ4QG+vLn4R3408D2bhueJ1N29W50eDBvutnVu0Ulpfzv5928vCyBpn5ePH1lD67ofV6ddEdtTD3C37/eytqkw3RtE8jfLu/O0M4h1T/QDcSnHGHO8gR+2JpGgI8nkwa157ZhkbQKrNtpyut6yOY44CWsIZvzROQZY8xMIFZEFjn6/Z/D6vZZAdwtIseNMZOA+cAWp6ebLCLxp3otDX11Jo4eKyLJcb4gMcM6Mig7f5BfWNFd5OPlQWFxKa0CfZk2oiM3DgxvUCOKnC3ZdIDpbnb17pb9R3nok41sPZDNFb3P4+kre9DSMatoXRERlmw+yHNLtpFy6Bijurbi8cu70THU/S5iExFWJ2bx3+W7WZmQSfMm3kwZEsHkwRH1tqHWi7OUrYkIadnHSXScO9iTmUdES3/Gx7RrsGHvrOzq3WZ+Xrxz2wCXzXZ6vLiEV35MYO5Puwny9+EfV/dkTM/6vTajoKiEt1Yl8cqPCRQUlTBpUHvuu6SzWxz1iAg/bk9nzvIE1icfITTQlzuGRXLjwPb1PgpMQ1+pRq7s6t3i0lLmu+Dq3fiUIzz86QZ2puVyXXQYT1zR7ZwGb0bOcV78YScfrU0m0M+b+y7pzKRB7fF2wUnuklJhyeYDzFm+m20Hsmkb1IRpIzsyvl/YOdvJ0NBXygb2ZuVx87zfSc8+d1fvFhSV8O8fdvL6L4m0bubHs9f04qKurrtqePvBbP7x9TZWJmTSITSAv47rxsVdW52Toa2FxaV8Eb+PV3/aTWJmHh1CA7hrZCeuijr/nG98NPSVsgnnq3f/Ob431/QNq7fXWpt0iIc/3ciezDwmDgjnsXFd631aipoo61Z55pttJGbmMbRTCH+7ohtd25xwjWidKCgq4aO1Kfzv593sP1pAj/ObcfdFnRjdo43LRhZp6CtlIzkFRfz53XWs2l0/V+/mHS/mn9/t4O3VSbQNasIL1/WudP2FuygqKeW93/by0tJd5BQUMaF/OA9c1qX8qyrPVk5BEe/9lsybKxPJzC0kpn0wd1/ciZFdQl1+0ZyGvlI2c7zY+u7dur5699eETB5ZuJHUw8eYPDiCh0Zf4BZXK5/OkfxCXlq6i/d+24uftyf3XNyJKUMi8PWqXf/64bxC5v+6h7dWJZFdUMywziHcc1EnBkS2cHnYl9HQV8qGSkqFGYu28O5ve7kuOoznr+tV677l7IIinlu8nQ9/TyYyJIAXruvNgMiGNRfO7oxcnv1mG8u2p9OuRRMeG9uNsT3b1Dio07ILeH1FIh/8nkx+YQmje7TmrpGd6NMuqJ4rP3Ma+krZlIjw8rIE/r10Jxd3bcWcWly9u3xHOo9/tom07ALuGNaB+y/t0qCHuq7clcnfv97KjrQcBkS04IkrutMr7NRz3KQcyufVn3fzSWwqJSJc2ed87hzZsdYzYJ4LGvpK2dz7a/byxBeb6XsGV+8eyS9k5tdb+Wz9Pjq3aso/x/chyg33amujuKSUj2JTePH7nWTlFXJtdFseHt2VNs0rrozdlZbD3J928+WG/Xgaw3X9wrhzREfCW7r/hG8a+kopvt18gHs/jKd9S3/eue30V+9+t+Ugf/tiM4fyCrlrZEfuubhTrfvA3Vl2QRFzlicwf2USnh6GP4/owLDOIby+Yg/fbT2In5cnNw4M545hHSptENydhr5SCoDVu7OY+k4sgae4ejcr9zhPLdrC1xsP0P28Zvzf9b1rPb1vQ5Kclc/z325j8aaDAAT6eTF5cARThkQ2yO/t1dBXSpXbsv8ot8yrfPWuiPDVxgPMWLSFnIIipo/qzJ9HdHTJFa2uFJt0iJ1puVzR5zy3uOagtjT0lVKVJGflc9O8NaRnH+fZa3uyeNNBftiaRp+w5vxzfB+3Pkmpqqehr5Q6QdnVu1v2Z+Pr5cEDl3Xh1iGRbv+lLKp6dfnNWUqpRiI00JcFUwfx9qokxvU6jw5uOC2xql8a+krZTKCfN/dc3NnVZSgX0WM6pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSykRqFvjFmjDFmhzEmwRjz6EnWtzfGLDPGbDTG/GSMCXNad4sxZpfj55a6LF4ppdSZqTb0jTGewBxgLNAdmGiM6V6l2SzgHRHpDcwEnnM8tgXwFDAQGAA8ZYwJrrvylVJKnYma7OkPABJEJFFECoEFwFVV2nQHljluL3daPxr4QUQOichh4AdgzNmXrZRSqjZqEvptgRSn+6mOZc42ANc5bl8DBBpjWtbwsUoppc6RmoS+Ocmyqt+m/iAwwhgTB4wA9gHFNXwsxpipxphYY0xsRkZGDUpSSilVGzUJ/VSgndP9MGC/cwMR2S8i14pIX+CvjmVHa/JYR9vXRCRGRGJCQ0PP8C0opZSqqZqE/lqgszEm0hjjA9wALHJuYIwJMcaUPddjwDzH7e+Ay4wxwY4TuJc5limllHKBakNfRIqBe7DCehvwsYhsMcbMNMZc6Wg2EthhjNkJtAaecTz2EPB3rA3HWmCmY5lSSikXMCIndLG7VExMjMTGxrq6DKWUalCMMetEJKa6dnpFrlJK2YiGvlJK2YiGvlJK2YiGvlJK2YiGvlJK2YiGvlJK2YiGvlJK2YiGvlJK2YiGvlJK2YiGvlJK2YiGvlJK2YiGvlJK2YiGvlJK2YiXqwtQSrleUVERqampFBQUuLoUVQ0/Pz/CwsLw9vau1eM19JVSpKamEhgYSEREBMac7FtOlTsQEbKyskhNTSUyMrJWz6HdO0opCgoKaNmypQa+mzPG0LJly7M6ItPQV0oBaOA3EGf776Shr5RSNqKhr5RqdCIiIsjMzKyX5z5+/DiXXHIJUVFRfPTRR/XyGvHx8SxevLhenltP5Cql3IqIICJ4eLjnPmlcXBxFRUXEx8fX+DElJSV4enrWuH18fDyxsbGMGzeuNiWeloa+UqqSp7/awtb92XX6nN3Pb8ZTf+hxyvVJSUmMHTuWiy66iNWrVxMVFcWmTZs4duwY119/PU8//TRg7cHfcsstfPXVVxQVFfHJJ5/QtWtXsrKymDhxIhkZGQwYMAARKX/uF198kXnz5gFw++23c99995GUlMSYMWMYOnQov/32G3369GHKlCk89dRTpKen8/777zNgwIAT6kxPT2fSpElkZGQQFRXFwoULSUpK4sEHH6S4uJj+/fszd+5cfH19iYiI4NZbb+X777/nnnvuoX///tx9991kZGTg7+/P66+/TteuXfnkk094+umn8fT0pHnz5ixdupQnn3ySY8eOsXLlSh577DEmTJhQZ/8W7rkpVUrZzo4dO7j55puJi4vjX//6F7GxsWzcuJGff/6ZjRs3lrcLCQlh/fr13HnnncyaNQuAp59+mqFDhxIXF8eVV15JcnIyAOvWrWP+/PmsWbOG3377jddff524uDgAEhISmD59Ohs3bmT79u188MEHrFy5klmzZvHss8+etMZWrVrxxhtvMGzYMOLj42nbti2TJ0/mo48+YtOmTRQXFzN37tzy9n5+fqxcuZIbbriBqVOn8p///Id169Yxa9Ys7rrrLgBmzpzJd999x4YNG1i0aBE+Pj7MnDmTCRMmEB8fX6eBD7qnr5Sq4nR75PWpffv2DBo0CICPP/6Y1157jeLiYg4cOMDWrVvp3bs3ANdeey0A/fr147PPPgNgxYoV5bcvv/xygoODAVi5ciXXXHMNAQEB5Y/95ZdfuPLKK4mMjKRXr14A9OjRg1GjRmGMoVevXiQlJdWo5h07dhAZGUmXLl0AuOWWW5gzZw733XcfQHlg5+bmsmrVKsaPH1/+2OPHjwMwZMgQJk+ezB//+Mfy91afNPSVUm6hLJj37NnDrFmzWLt2LcHBwUyePLnSuHRfX18APD09KS4uLl9+sqGMzt08VZU9D4CHh0f5fQ8Pj0rPezqne36oeE+lpaUEBQWd9DzAq6++ypo1a/jmm2+Iioo6o3MFtaHdO0opt5KdnU1AQADNmzcnLS2NJUuWVPuY4cOH8/777wOwZMkSDh8+XL78iy++ID8/n7y8PD7//HOGDRtWZ7V27dqVpKQkEhISAHj33XcZMWLECe2aNWtGZGQkn3zyCWBtLDZs2ADA7t27GThwIDNnziQkJISUlBQCAwPJycmpszqdaegrpdxKnz596Nu3Lz169ODWW29lyJAh1T7mqaeeYsWKFURHR/P9998THh4OQHR0NJMnT2bAgAEMHDiQ22+/nb59+9ZZrX5+fsyfP5/x48fTq1cvPDw8mDZt2knbvv/++7z55pv06dOHHj168OWXXwLw0EMP0atXL3r27Mnw4cPp06cPF110EVu3bq2XYaGmusOTcy0mJkZiY2NdXYZStrJt2za6devm6jJUDZ3s38sYs05EYqp7rO7pK6WUjeiJXKWUOon58+cze/bsSsuGDBnCnDlzXFRR3ahR6BtjxgCzAU/gDRF5vsr6cOBtIMjR5lERWWyM8QbeAKIdr/WOiDxXh/UrpVS9mDJlClOmTHF1GXWu2u4dY4wnMAcYC3QHJhpjuldp9jfgYxHpC9wA/NexfDzgKyK9gH7An40xEXVTulJKqTNVkz79AUCCiCSKSCGwALiqShsBmjluNwf2Oy0PMMZ4AU2AQqBur+9WSilVYzUJ/bZAitP9VMcyZzOAScaYVGAx8BfH8k+BPOAAkAzMEpFDZ1OwUkqp2qtJ6J9sxv6q4zwnAm+JSBgwDnjXGOOBdZRQApwPRAIPGGM6nPACxkw1xsQaY2IzMjLO6A0opZSquZqEfirQzul+GBXdN2VuAz4GEJHVgB8QAtwIfCsiRSKSDvwKnDCOVEReE5EYEYkJDQ0983ehlGpUZsyYUT6Z2pkYPHjwadePGzeOI0eO1LasE7z11lvs3181Dt1bTUJ/LdDZGBNpjPHBOlG7qEqbZGAUgDGmG1boZziWX2wsAcAgYHtdFa+UUs5WrVp12vWLFy8mKCiozl7vdKFfUlJSZ69Tl6odsikixcaYe4DvsIZjzhORLcaYmUCsiCwCHgBeN8bcj9X1M1lExBgzB5gPbMbqJpovIhtP/kpKKbew5FE4uKlun7NNLxj7/GmbPPPMM7zzzju0a9eO0NBQ+vXrx+7du086B31aWhrTpk0jMTERgLlz5zJ48GCaNm1Kbm4uBw4cYMKECWRnZ5dPdzxs2DAiIiKIjY0lJCTklPPsjx07lqFDh7Jq1Sratm3Ll19+SZMmTU6o99NPPyU2NpY//elPNGnShNWrV9OtW7cazaGfkZHBtGnTyqeAfumll2o03URdqNE4fRFZjHWC1nnZk063twInVCwiuVjDNpVS6pTWrVvHggULiIuLo7i4mOjoaPr168fUqVN59dVX6dy5M2vWrOGuu+7ixx9/5N5772XEiBF8/vnnlJSUkJubW+n5PvjgA0aPHs1f//pXSkpKyM/PP+H1yubZFxEGDhzIiBEjCA4OZteuXXz44Ye8/vrr/PGPf2ThwoVMmjTphJqvv/56XnnlFWbNmkVMTEWvddkc+gCjRo06af3Tp0/n/vvvZ+jQoSQnJzN69Gi2bdtWD5/sifSKXKVUZdXskdeHX375hWuuuQZ/f38ArrzySgoKCk45B/2PP/7IO++8A1D+jVPO+vfvz6233kpRURFXX301UVFRldZXN89+Wft+/frVeG79MjWZQ3/p0qVs3bq1fHl2djY5OTkEBgae0WvVhoa+UsotVJ0P/3Rz0Fdn+PDhrFixgm+++YabbrqJhx56iJtvvrl8fU3n2ff09OTYsWNn9No1mUO/tLSU1atXn7TbqL7phGtKKZcbPnw4n3/+OceOHSMnJ4evvvoKf3//U85BP2rUqPKvJSwpKSE7u/I1n3v37qVVq1bccccd3Hbbbaxfv/6E16uLefZPN+/96ebQv+yyy3jllVfK29b3F6c409BXSrlcdHQ0EyZMICoqiuuuu648gE81B/3s2bNZvnw5vXr1ol+/fmzZsqXS8/30009ERUXRt29fFi5cyPTp0094vbqYZ3/y5MlMmzaNqKiokx4RnKr+l19+mdjYWHr37k337t159dVXz/i1a0vn01dK6Xz6DYzOp6+UUqpG9ESuUkpV4+677+bXX3+ttGz69OkNcuplDX2lFGCdaKw6gkZZ3OmLU862S167d5RS+Pn5kZWVddaBouqXiJCVlYWfn1+tn0P39JVShIWFkZqais5y6/78/PwICwur9eM19JVSeHt7ExkZ6eoy1Dmg3TtKKWUjGvpKKWUjGvpKKWUjGvpKKWUjGvpKKWUjGvpKKWUjGvpKKWUjGvpKKWUjGvpKKWUjGvpKKWUjGvpKKWUjGvpKKWUjGvpKKWUjGvpKKWUjGvpKKWUjGvpKKWUjGvpKKWUjGvpKKWUjNQp9Y8wYY8wOY0yCMebRk6wPN8YsN8bEGWM2GmPGOa3rbYxZbYzZYozZZIyp/Tf6KqWUOivVfkeuMcYTmANcCqQCa40xi0Rkq1OzvwEfi8hcY0x3YDEQYYzxAt4DbhKRDcaYlkBRnb8LpZRSNVKTPf0BQIKIJIpIIbAAuKpKGwGaOW43B/Y7bl8GbBSRDQAikiUiJWdftlJKqdqoSei3BVKc7qc6ljmbAUwyxqRi7eX/xbG8CyDGmO+MMeuNMQ+fZb1KKaXOQk1C35xkmVS5PxF4S0TCgHHAu8YYD6zuo6G53EEhAAAOXklEQVTAnxy/rzHGjDrhBYyZaoyJNcbEZmRknNEbUEopVXM1Cf1UoJ3T/TAqum/K3AZ8DCAiqwE/IMTx2J9FJFNE8rGOAqKrvoCIvCYiMSISExoaeubvQimlVI3UJPTXAp2NMZHGGB/gBmBRlTbJwCgAY0w3rNDPAL4Dehtj/B0ndUcAW1FKKeUS1Y7eEZFiY8w9WAHuCcwTkS3GmJlArIgsAh4AXjfG3I/V9TNZRAQ4bIx5EWvDIcBiEfmmvt6MUkqp0zNWNruPmJgYiY2NdXUZSinVoBhj1olITHXt9IpcpZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ39ulZSDEm/QvZ+V1eilFIn8KpJI2PMGGA24Am8ISLPV1kfDrwNBDnaPCoii6us3wrMEJFZdVS7+zm4Gb68Gw7EW/ebt4Ow/tBugPXTuhd4+bi2RqWUrVUb+sYYT2AOcCmQCqw1xiwSka1Ozf4GfCwic40x3YHFQITT+n8DS+qsandTXAgrX4QVs8CvOfzhZSjMg9TfIeV32PKZ1c7LD87va20AwhwbgqatXFu7UspWarKnPwBIEJFEAGPMAuAqrD33MgI0c9xuDpT3bRhjrgYSgby6KNjt7I+DL++BtM3QazyMeQECWjpW3mX9OrrPsQFYCylrYPV/oXS2tS44omIDENYfWvcEzxodgLmXkmI4shcyd0HmTusnK8H67RcEl/0dLhgHxri6UqVsrSbp0hZIcbqfCgys0mYG8L0x5i9AAHAJgDEmAHgE6yjhwVO9gDFmKjAVIDw8vIalu1hRAfz8Avw6GwJC4YYPoeu4k7dt3haaXwM9rql47IEN1gYg9XfY8zNs+tha5+0Pbfs5uoUGWr/LNyJuoCAbsnY5hbvj9qHdUFJY0S4gFEK6QNcrrKOdBTdCp0th7AvQsqPr6lfK5moS+ifbNZMq9ycCb4nIv4wxFwLvGmN6Ak8D/xaRXHOaPTwReQ14DSAmJqbqc7uf1Fj44i7I3AFRk2D0P6BJcM0f7+0H4QOtHwAROJIMqWutgExZY21MpMRa36KjtQFo1986KmjVDTw86/59lSkthex9lffWywI+50BFO+MJLTpY4d7lMut3SBdo2Qn8W1S0KymC31+Hn56D/w6CwX+BYQ+AT0D9vQel1EkZkdNnrCPEZ4jIaMf9xwBE5DmnNluAMSKS4rifCAwCFgLtHM2CgFLgSRF55VSvFxMTI7GxsbV+Q/Wq6Bj8+A/47b8QeD78YTZ0vqR+Xqsw3+o6SllTsTHIz7TW+QRCWD9Ht9BA6/aZbHTKFB2DrN1Oe+xO3TJF+RXtfJtDSGdHqHeuCPfgiDM7MZ2TBj88CRsXQLMwGP0MdL9Ku3yUqgPGmHUiElNtuxqEvhewExgF7APWAjeKyBanNkuAj0TkLWNMN2AZ0FacntwYMwPIrW70jtuG/t5VVt/9od3QbwpcOhP8mlX/uLoiAocSnY4Gfof0LSCl1vqQCypGCYUNsELZw8N6XF5G5b31sttHUqg4aDMQ1K4i0J3DPSC0boN572pY/BCkbYIOI2HsPyG0S909v1I2VGeh73iyccBLWMMx54nIM8aYmUCsiCxyjNh5HWiKlSIPi8j3VZ5jBg0x9I/nwrKZ8PtrEBQOV/4HOoxwdVWW4zmwb721AUj93dogHDtsrfNrbu2JH06CgqMVj/H2t7pfKoV7Z6sLycf/3NVeUgyx86wjp6I8GHQXjHgYfAPPXQ1KNSJ1GvrnkluFfuLPsOgv1qiUAX+GUU+Cb1NXV3VqIlbXTMoaa0NwNAWCIyvvuTdrax0BuIvcDFg2A+Leg8Dz4LJ/QM/rtMtHqTOkoX82CrLhhydg3VvWHvBVr0D7wa6tqbFLWQuLH7BGNUUMg7H/B627u7oq1ViJwNFUa6j1wc2Qlw7+IRAQYnVnBoRa19AEhIBvswaxE1LT0G+AA8Lr2a6l8NV0yNlvjTIZ+fi57fawq3b94Y7lsP5tqzvt1aEwcBqMfMTqqlKqtoqOQfo2SNtSEfJpm6HgSEUb3+Zw/OjJH+/p69gQOG8QQituB4RAQKuK257e5+Z91ZLu6Zc5dhi++yvEv2+dFL36vxBW7UZT1Yf8Q7DsaVj3tvUf6bK/Q+8JDWJvS7mQCOQcdAT7poqAz0qoGP7sHWAdQbbuCW16WlOjtO5unUsqLoT8LGvgQ9WfXOf7mdaRgfN1Kc6aBJ96g9C0VeV1dXgUod07Z2L7Yvj6fusfdOh9MOIR8PI9tzWoE+1bD4sfhH3rIPxCGPdPaNPL1VUpd1BcCBnbK++5p222QrtM83BHsPes+B0cWTfntETgeLa1AchNP3GDUHa7bJ3zUYWzqkcR5/WBUU/UqiTt3qmJvCz49hHY9In1B3HjR3B+lKurUmXaRsNtSyH+PVg6A/43HPrfARc9Dk2CXF2dOldyM6zhvWXhfnCzdWFkabG13svPumDxgnHWTkHrntC6R/3+jRhjdTv6Na/ZFeblRxFOG4S8DMdGwXE7L90aEl7P7Lunv+ULay/y2BEY/hAMvV9nwHRnxw7Dj89A7JvQpAVc+jT0udG9RiKps1NSZF1Hkra58h58blpFm8Dzq+y997KuCm+I81XVMe3eOZXcdCvst34J50VZffete9Tf66m6dWCj9e+Xssaal2jcP62ZS1XDIWIFefo2xwlWRx98xvaKfnJPHwi9wAr1spBv3dO95qFyMxr6VYnApk9hycPWtMcjH4XB9+oeQkMkAhsWWFM65GVAzBS4+InK8/0o95CbARnbIH07pG+1gj19W+U+7oBWTnvvju6ZkM5uPwrG3WifvrPsA9aJ2p1LrL3Dq+ZYexGqYTIGoiZas5r+9Dys+Z/VXTfqSYi+uX4no1Mnl3/IsedeFuzbrbB3PrHq1xxadbdmm23VDUK7Wr/1OyXOqca9py9iDcH89nEoOW7tDQ66U0OhsUnbYs3ls/dXq6tn3L+sSehU3Tt2uCLQnX/npVe08W3mCPSuENqt4ndgGx12W490T/9IinWR1e5l0H6INWeOzuPeOLXuAZO/gc0LrWst3hgF0TfBqKesoXDqzBVkV3TFOP92nlrbp6l1xNz5ssoB36ythrsba3yhLwLr5sP3T1ozUI6bBTG36SiPxs4Y6HU9dBltfbnNb3Otk/UXPwExt+rR3akcz4WMHY49dqeAz95X0cariRXuHUZWdMm06mZNj63/rxqcxtW9c2gPfHUv7Flh/YH+4WUIbl+X5amGImOH1eWz52fr5OC4f1V8ac3ZEHH8lFg7FaWO31LqWCZVljutL3WsLy12+imB0qITl5UUVWlT7NSupGJdSVHl+6VV7xdbM5pWff7iY5CZAEeTK96bp681xXVot4pgD+0KQe013BsA+43eydxlXbzj4WXN1Bh9sx5i2p0IbP3C6vLJ3mdNKV0WytUGdGmVMHe0O+FL49yIh7c14sXDyzqy8fCyljnf9/S2bnv6WFNvt+pWEfLBEXpE1IDZr0+/ZSfrAquoG6F5mKurUe7AGGukSOfLYNUrVheG8XD8eFq/PTxOsszTaZnHSZad5rHlbc1JljndLg9jRyB7VrlfdX15YJ+qje6Jq5ppPKFvjPUlHEpV5RNgzdaplEJ3D5RSykY09JVSykY09JVSykY09JVSykY09JVSykY09JVSykY09JVSykY09JVSykbcbhoGY0wGsPcsniIEyKyjcho6/Swq08+jMv08KjSGz6K9iIRW18jtQv9sGWNiazL/hB3oZ1GZfh6V6edRwU6fhXbvKKWUjWjoK6WUjTTG0H/N1QW4Ef0sKtPPozL9PCrY5rNodH36SimlTq0x7ukrpZQ6hUYT+saYMcaYHcaYBGPMo66ux5WMMe2MMcuNMduMMVuMMdNdXZOrGWM8jTFxxpivXV2Lqxljgowxnxpjtjv+Ri50dU2uZIy53/H/ZLMx5kNjjJ+ra6pPjSL0jTGewBxgLNAdmGiM6e7aqlyqGHhARLoBg4C7bf55AEwHtrm6CDcxG/hWRLoCfbDx52KMaQvcC8SISE/AE7jBtVXVr0YR+sAAIEFEEkWkEFgAXOXimlxGRA6IyHrH7Rys/9RtXVuV6xhjwoDLgTdcXYurGWOaAcOBNwFEpFBEjri2KpfzApoYY7wAf2C/i+upV40l9NsCKU73U7FxyDkzxkQAfYE1rq3EpV4CHgZKXV2IG+gAZADzHd1dbxhjAlxdlKuIyD5gFpAMHACOisj3rq2qfjWW0DcnWWb7YUnGmKbAQuA+Ecl2dT2uYIy5AkgXkXWursVNeAHRwFwR6QvkAbY9B2aMCcbqFYgEzgcCjDGTXFtV/WosoZ8KtHO6H0YjP0SrjjHGGyvw3xeRz1xdjwsNAa40xiRhdftdbIx5z7UluVQqkCoiZUd+n2JtBOzqEmCPiGSISBHwGTDYxTXVq8YS+muBzsaYSGOMD9aJmEUurslljDEGq892m4i86Op6XElEHhORMBGJwPq7+FFEGvWe3OmIyEEgxRhzgWPRKGCrC0tytWRgkDHG3/H/ZhSN/MS2l6sLqAsiUmyMuQf4Duvs+zwR2eLislxpCHATsMkYE+9Y9riILHZhTcp9/AV437GDlAhMcXE9LiMia4wxnwLrsUa9xdHIr87VK3KVUspGGkv3jlJKqRrQ0FdKKRvR0FdKKRvR0FdKKRvR0FdKKRvR0FdKKRvR0FdKKRvR0FdKKRv5/5Yq8G830UTCAAAAAElFTkSuQmCC\n", + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "df.plot()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Decision Tree Accuracy" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.8343173330831328" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "np.mean(dt_cv_scores)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Random Forest Accuracy" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.9223850187122359" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "np.mean(rf_cv_scores)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 59f6cb4ab071f969477e12cfedeedad87e291c25 Mon Sep 17 00:00:00 2001 From: minsuk-heo Date: Fri, 26 Oct 2018 20:51:38 -0700 Subject: [PATCH 4/8] svm added --- data_science/svm/svm.ipynb | 401 +++++++++++++++++++++++++++++++++++++ 1 file changed, 401 insertions(+) create mode 100644 data_science/svm/svm.ipynb diff --git a/data_science/svm/svm.ipynb b/data_science/svm/svm.ipynb new file mode 100644 index 0000000..db0733a --- /dev/null +++ b/data_science/svm/svm.ipynb @@ -0,0 +1,401 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 79, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "from sklearn.datasets import load_iris\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn.model_selection import GridSearchCV\n", + "from sklearn.metrics import classification_report\n", + "from sklearn.metrics import accuracy_score\n", + "from sklearn.svm import SVC" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Load dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 80, + "metadata": {}, + "outputs": [], + "source": [ + "# load iris data\n", + "dataset = load_iris()\n", + "\n", + "# use 80% as train data, 20% as test data\n", + "X_train,X_test,y_train,y_test=train_test_split(dataset.data,dataset.target,test_size=0.2)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Find best hyperparamters\n", + "RBF kernel SVM has two parameters.\n", + "1. C (cost): The C parameter trades off correct classification of training examples against maximization of the decision function’s margin. For larger values of C, a smaller margin will be accepted if the decision function is better at classifying all training points correctly. \n", + "\n", + "2. gamma: the gamma parameter defines how far the influence of a single training example reaches, with low values meaning ‘far’ and high values meaning ‘close’. The gamma parameters can be seen as the inverse of the radius of influence of samples selected by the model as support vectors.\n", + "\n", + "reference:\n", + "http://scikit-learn.org/stable/auto_examples/svm/plot_rbf_parameters.html" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Grid Search\n", + "find best hyperparameter using grid search." + ] + }, + { + "cell_type": "code", + "execution_count": 81, + "metadata": {}, + "outputs": [], + "source": [ + "def svc_param_selection(X, y, nfolds):\n", + " svm_parameters = [\n", + " {'kernel': ['rbf'],\n", + " 'gamma': [0.00001,0.0001, 0.001, 0.01, 0.1, 1],\n", + " 'C': [0.01, 0.1, 1, 10, 100, 1000]\n", + " }\n", + " ]\n", + " \n", + " clf = GridSearchCV(SVC(), svm_parameters, cv=10)\n", + " clf.fit(X_train, y_train)\n", + " print(clf.best_params_)\n", + " \n", + " return clf" + ] + }, + { + "cell_type": "code", + "execution_count": 82, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'C': 10, 'gamma': 0.1, 'kernel': 'rbf'}\n" + ] + } + ], + "source": [ + "clf = svc_param_selection(X_train, y_train, 10)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Test" + ] + }, + { + "cell_type": "code", + "execution_count": 83, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " precision recall f1-score support\n", + "\n", + " 0 1.00 1.00 1.00 7\n", + " 1 1.00 1.00 1.00 13\n", + " 2 1.00 1.00 1.00 10\n", + "\n", + "avg / total 1.00 1.00 1.00 30\n", + "\n", + "\n", + "accuracy : 1.0\n" + ] + } + ], + "source": [ + "y_true, y_pred = y_test, clf.predict(X_test)\n", + "\n", + "print(classification_report(y_true, y_pred))\n", + "print()\n", + "print(\"accuracy : \"+ str(accuracy_score(y_true, y_pred)) )" + ] + }, + { + "cell_type": "code", + "execution_count": 84, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ground_truthprediction
011
111
222
300
411
522
622
722
811
900
1011
1100
1200
1311
1411
1522
1611
1722
1811
1922
2000
2111
2200
2300
2411
2511
2622
2711
2822
2922
\n", + "
" + ], + "text/plain": [ + " ground_truth prediction\n", + "0 1 1\n", + "1 1 1\n", + "2 2 2\n", + "3 0 0\n", + "4 1 1\n", + "5 2 2\n", + "6 2 2\n", + "7 2 2\n", + "8 1 1\n", + "9 0 0\n", + "10 1 1\n", + "11 0 0\n", + "12 0 0\n", + "13 1 1\n", + "14 1 1\n", + "15 2 2\n", + "16 1 1\n", + "17 2 2\n", + "18 1 1\n", + "19 2 2\n", + "20 0 0\n", + "21 1 1\n", + "22 0 0\n", + "23 0 0\n", + "24 1 1\n", + "25 1 1\n", + "26 2 2\n", + "27 1 1\n", + "28 2 2\n", + "29 2 2" + ] + }, + "execution_count": 84, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Visualize true value with prediction value in pandas dataframe.\n", + "comparison = pd.DataFrame({'prediction':y_pred, 'ground_truth':y_true}) \n", + "comparison" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From f37e4b95d151e5fec46319d539a0c8b6cb0dee06 Mon Sep 17 00:00:00 2001 From: minsuk-heo Date: Thu, 15 Nov 2018 09:34:12 -0800 Subject: [PATCH 5/8] deleted --- data_science/ensemble/randomforest.ipynb | 203 ------------ data_science/ensemble/voting.ipynb | 279 ---------------- data_science/svm/svm.ipynb | 401 ----------------------- 3 files changed, 883 deletions(-) delete mode 100644 data_science/ensemble/randomforest.ipynb delete mode 100644 data_science/ensemble/voting.ipynb delete mode 100644 data_science/svm/svm.ipynb diff --git a/data_science/ensemble/randomforest.ipynb b/data_science/ensemble/randomforest.ipynb deleted file mode 100644 index 662d955..0000000 --- a/data_science/ensemble/randomforest.ipynb +++ /dev/null @@ -1,203 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 18, - "metadata": {}, - "outputs": [], - "source": [ - "from sklearn import datasets\n", - "from sklearn import tree\n", - "from sklearn.ensemble import RandomForestClassifier\n", - "from sklearn.model_selection import cross_val_score\n", - "\n", - "import matplotlib.pyplot as plt\n", - "import pandas as pd\n", - "import numpy as np" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Load MNIST dataset" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [], - "source": [ - "mnist = datasets.load_digits()\n", - "features, labels = mnist.data, mnist.target" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Cross Validation" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "def cross_validation(classifier,features, labels):\n", - " cv_scores = []\n", - "\n", - " for i in range(10):\n", - " scores = cross_val_score(classifier, features, labels, cv=10, scoring='accuracy')\n", - " cv_scores.append(scores.mean())\n", - " \n", - " return cv_scores" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "dt_cv_scores = cross_validation(tree.DecisionTreeClassifier(), features, labels)" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": {}, - "outputs": [], - "source": [ - "rf_cv_scores = cross_validation(RandomForestClassifier(), features, labels)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Random Forest VS Decision Tree visualization" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [], - "source": [ - "cv_list = [ \n", - " ['random_forest',rf_cv_scores],\n", - " ['decision_tree',dt_cv_scores],\n", - " ]\n", - "df = pd.DataFrame.from_items(cv_list)" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 21, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD8CAYAAACb4nSYAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAIABJREFUeJzt3Xd8VFX+//HXSSchkEACKCEkNOmEEIp0RaXo2lnERQULi2VFf3Z3VWTX9l3WFVcW1wJ2sWBBBQuIIoJIIKG3EEISSholjZD2+f1xJ8kklISQMJPcz/PxyCMz956Z+cwQ3vfec889Y0QEpZRS9uDh6gKUUkqdOxr6SillIxr6SillIxr6SillIxr6SillIxr6SillIxr6SillIxr6SillIxr6SillI16uLqCqkJAQiYiIcHUZSinVoKxbty5TREKra+d2oR8REUFsbKyry1BKqQbFGLO3Ju20e0cppWxEQ18ppWxEQ18ppWxEQ18ppWxEQ18ppWxEQ18ppWxEQ18ppWzE7cbpq7MjIqQePsb65MPkF5ZwXXQYPl66bVdKWTT0G7j8wmI2ph5lffJh4pKPEJd8hMzc4+XrP1iTzMsT+xIZEuDCKpVS7kJDvwEREZKy8lm/9zBxKVbIbz+YQ0mp9eX2HUICGNEllL7hQfQNDyLlUD6PLNzEFS//wt+v7sm10WEufgdKKVfT0HdjOQVF1l783sPEpRwhLvkwh/OLAGjq60VUuyDuHtmRvuHBRLULIjjAp9Lje5zfnN5hQdz3UTz/7+MN/LIrk5lX9SDQz9sVb0cp5QYaTeiXlgq7M3IJDfSleRNvjDGuLumMlJYKiZm5rN97hLiUw6zfe4Sd6TmItRNP51ZNuax7G8defDCdWjXF06P693h+UBM+vGMQc5Yn8NLSnaxPPszLN/SlT7ugen5HSil3ZKQsVdxETEyM1GbCtUN5hUT//QcAvD0NLQN8CQ30JaSpj+O39VN2OzTQl9CmvjRr4uWSDcTR/KLyLpq4lCPEJx8mu6AYgGZ+XvQNDyY6PJi+4UH0aRdE8yZnv3cem3SI6QviScsu4MHRFzB1WAc8arDhUEq5P2PMOhGJqbZdYwn9/MJiftiaRmZuIRk5x8nMtX4qbheW93078/H0IKSpDyGOjUBIU19CAn2s22XLHBuKZn6120CUlAq70nOsvfjkw6xPPszujDwAPAx0aR1IdPtg+rYLIrp9MJEtA+otjI/mF/HY5xtZvOkgQzuF8OIf+9CqmV+9vJZS6tyxXehXp7RUOHKsqNKGICPnOBm5x8nMKXT8tpZn5Z1iA+Hl4dgw+FQ6Yqh8FOGDv48XW/YfJS75COuTD7Mh5Qh5hSUAtAjwKQ/3vu2C6N0uiKa+57aXTUT4aG0KM77aQoCPF7PG9+Girq3OaQ1KqbqloX8WSkuFw/mF5RuEShuK8tvWEcWhvOOcZPsAgKeHodt5geXdNNHhwYS38Heb8w0J6Tnc80Ec2w/mcNvQSB4ecwG+Xp6uLkspVQsa+udISdkGwqlLKaegmAtaB9I7LIgmPu4dogVFJTy/ZDtvrUqix/nNeHliXzqGNnV1WUqpM6Shr87I0q1pPPTpBgqKSnn6qh6M7xfmNkckSqnq1TT09fp8BcAl3VuzZPpwotoF8fCnG7l3QTzZBUWuLkspVcc09FW5Ns39eO/2gTw0+gIWbzrAuNm/sD75sKvLUkrVIQ19VYmnh+HuizrxybQLARj/6mrmLE846WgmdWYSM3KZsWgLK3dluroUZWPap69OKbugiMc/28TXGw8wuGNL/j0hitY6pv+MJaTn8sqPu1i0YT+lAl4ehlnj+3B137auLk01Itqnr85aMz9v/jOxL/93fW/iko8w5qUVLNuW5uqyGoyE9BymL4jj0n//zHdb0rh9WAeWPziS/hEtuO+jeN5cucfVJSob0j19VSO7M3L5ywdxbD2QzeTBETw6tit+3u49HNVVdqXl8PKPCXy9cT9NvD256cL23DGsAyFNfQFrmOz9H8WzZPNBpo3oyCNjLtCRUuqs6ZBNVeeOF5fwwpIdzPt1D13bBPLKjX3p1CrQ1WW5jR0Hc3j5x10s3nQAf29Pbh4cwR3DOtCiyuynYF3f8eSXm3l/TTLj+4Xx3LW98PLUA29Vexr6qt4s357Og59sIK+wmBl/6MGE/u1svae6/WA2Ly/bxeJNB2nq68Utg9tz+9AOJ0x1XZWIMHvZLl5auotRXVvxyo3Rbn8xn3JfGvqqXqVnF/D/Pt7AyoRMLu91Hs9e26tOZgJtSLbut8L+2y1W2E8ZEsFtQyMJ8j992Ff13m97eeLLzfQLD+aNW2LO+PFKgYa+OgdKS4XXfklk1nc7aN3Mj5cnRtGvfQtXl1XvNu87ysvLdvH91jQCfb2YMjSS24ZE0ty/9hu9xZsOcN+CeCJC/Hn71gGc17xJHVas7EBDX50z8SlHuPfDOPYdOcb0UZ25+6JONfqCl4Zm876jvLR0F0u3pRHo58WtQyK59SzD3tmq3ZlMfWcdzZt48/atA+jUqnHPgSQiHM4vOuk5D3XmNPTVOZVTUMQTX2zmi/j9DIxswUs3RDWavdWNqUeYvXQXy7an08zPi9uGdmDykIh66c7avO8ok+evpaS0lHmT+9M3PLjOX8MdHDh6jMc/28TyHRlc07ctD4+5oNH8vbhKnYa+MWYMMBvwBN4QkeerrG8PzANCgUPAJBFJNcZEAXOBZkAJ8IyIfHS619LQb9g+W5/K377YjI+XBy9c15vRPdq4uqRa25ByhNnLdvHj9nSaN/Hm9qGR3DIkgmb1/B3De7PyuOnN38nIOc7cSdGMvKDxfNeBiLBgbQrPfrON4lJhbM82fL3pAB4Gpg7vyLQRHfD3aTTf4npO1VnoG2M8gZ3ApUAqsBaYKCJbndp8AnwtIm8bYy4GpojITcaYLoCIyC5jzPnAOqCbiBw51etp6Dd8ezLzuPfDODbtO8pNg9rz18u7Nagx/XHJh5m9bBc/7cggyN+bO4Z14OYL25/TL5TPyDnO5Pm/s+NgTqO5ejflUD6PfraRXxOyuLBDS164rjfhLf1JOZTPC99u5+uNB2jdzJeHR3flmr5tbfdVnkePFZGWXUCX1rUbBl2XoX8hMENERjvuPwYgIs85tdkCjHbs3RvgqIg0O8lzbQCuF5Fdp3o9Df3GobC4lFnf7+C1FYkE+XvTpVUgkSEBdAgNcPxuSngLf3y83Gds+rq9Vtiv2JlBsL83dwzvwM0XRpzzbzYrk1NQxNR31rE6MYu/Xd6N24d1cEkdZ6u0VHhndRIvfLsDTw/D4+O6MXHAicN81+09xMyvt7Eh5Qi9w5rzxBXd6R/R+AcGZOUeZ96ve3hn1V7aBjdhyfRhtRoCXZehfz0wRkRud9y/CRgoIvc4tfkAWCMis40x1wILgRARyXJqMwB4G+ghIqWnej0N/cbl14RMvozfx57MPPZk5pGZW1i+zsNAuxb+dAgJIDKkKZGhAXQMCSAyNIA2zfzO2dj/2KRDzF62i192ZdIiwIepwztw06D2BLgo7J0VFJXw/z6OZ/Gmg/x5RAceHdO1QV0TkZiRyyMLN7I26TAjuoTy7LW9aBt06r770lLhyw37eGHJDg5mFzCuVxseG9uNdi38z2HV58aBo8d4bUUiH/6ezPHiUsb1PI87R3akZ9vmtXq+moZ+Tf6qT/YXVnVL8SDwijFmMrAC2AcUOxVzHvAucMvJAt8YMxWYChAeHl6DklRDMaRTCEM6hZTfP3qsyLEByCUxI4/EzDz2ZOTxW+IhjhWVlLdr4u1JpGMDULYhiAxpSofQgDrrU/99zyFmL9vJrwlZhDT14fFxXZk0qL1b9Sn7eXvyn4nRtAjYzP9+TiQrt5DnG8DVuyWlwpsrE/nX9zvx9fJg1vg+XBfdttoNloeH4Zq+YYzpcR6vrUjk1Z93s3RrOlOGRnDPRZ3OaRdbfUnKzOPVn3ezcH0qpQJXR7XlzpEdz9lorTrp3qnSvimwXUTCHPebAT8Bz4nIJ9UVpHv69lRaKqTlFLAnI4/djg3BnsxcEjPzSDmUX+l7iEOa+tAhpGn5RqGDo9sovEVAjbqLfkvMYvbSXaxOtML+z8M78qdB4W4V9lU1pKt3d6bl8NCnG9mQcoRLu7fmmat70qqWs7MePFrAP7/bwcL1qbQM8OGByy5gQv92DXJI8I6DOcxZbs3J5OXpwYSYdkwd3qHOjmLqsnvHC+tE7iisPfi1wI0issWpTQhwSERKjTHPACUi8qQxxgdYAnwlIi/VpHANfVVVYXEpyYfySczIZU9mHokZVldRYmYembnHy9uVdRdFhgRYGwWnDUKbZn78lniIl5buZM2eQ4QG+vLn4R3408D2bhueJ1N29W50eDBvutnVu0Ulpfzv5928vCyBpn5ePH1lD67ofV6ddEdtTD3C37/eytqkw3RtE8jfLu/O0M4h1T/QDcSnHGHO8gR+2JpGgI8nkwa157ZhkbQKrNtpyut6yOY44CWsIZvzROQZY8xMIFZEFjn6/Z/D6vZZAdwtIseNMZOA+cAWp6ebLCLxp3otDX11Jo4eKyLJcb4gMcM6Mig7f5BfWNFd5OPlQWFxKa0CfZk2oiM3DgxvUCOKnC3ZdIDpbnb17pb9R3nok41sPZDNFb3P4+kre9DSMatoXRERlmw+yHNLtpFy6Bijurbi8cu70THU/S5iExFWJ2bx3+W7WZmQSfMm3kwZEsHkwRH1tqHWi7OUrYkIadnHSXScO9iTmUdES3/Gx7RrsGHvrOzq3WZ+Xrxz2wCXzXZ6vLiEV35MYO5Puwny9+EfV/dkTM/6vTajoKiEt1Yl8cqPCRQUlTBpUHvuu6SzWxz1iAg/bk9nzvIE1icfITTQlzuGRXLjwPb1PgpMQ1+pRq7s6t3i0lLmu+Dq3fiUIzz86QZ2puVyXXQYT1zR7ZwGb0bOcV78YScfrU0m0M+b+y7pzKRB7fF2wUnuklJhyeYDzFm+m20Hsmkb1IRpIzsyvl/YOdvJ0NBXygb2ZuVx87zfSc8+d1fvFhSV8O8fdvL6L4m0bubHs9f04qKurrtqePvBbP7x9TZWJmTSITSAv47rxsVdW52Toa2FxaV8Eb+PV3/aTWJmHh1CA7hrZCeuijr/nG98NPSVsgnnq3f/Ob431/QNq7fXWpt0iIc/3ciezDwmDgjnsXFd631aipoo61Z55pttJGbmMbRTCH+7ohtd25xwjWidKCgq4aO1Kfzv593sP1pAj/ObcfdFnRjdo43LRhZp6CtlIzkFRfz53XWs2l0/V+/mHS/mn9/t4O3VSbQNasIL1/WudP2FuygqKeW93/by0tJd5BQUMaF/OA9c1qX8qyrPVk5BEe/9lsybKxPJzC0kpn0wd1/ciZFdQl1+0ZyGvlI2c7zY+u7dur5699eETB5ZuJHUw8eYPDiCh0Zf4BZXK5/OkfxCXlq6i/d+24uftyf3XNyJKUMi8PWqXf/64bxC5v+6h7dWJZFdUMywziHcc1EnBkS2cHnYl9HQV8qGSkqFGYu28O5ve7kuOoznr+tV677l7IIinlu8nQ9/TyYyJIAXruvNgMiGNRfO7oxcnv1mG8u2p9OuRRMeG9uNsT3b1Dio07ILeH1FIh/8nkx+YQmje7TmrpGd6NMuqJ4rP3Ma+krZlIjw8rIE/r10Jxd3bcWcWly9u3xHOo9/tom07ALuGNaB+y/t0qCHuq7clcnfv97KjrQcBkS04IkrutMr7NRz3KQcyufVn3fzSWwqJSJc2ed87hzZsdYzYJ4LGvpK2dz7a/byxBeb6XsGV+8eyS9k5tdb+Wz9Pjq3aso/x/chyg33amujuKSUj2JTePH7nWTlFXJtdFseHt2VNs0rrozdlZbD3J928+WG/Xgaw3X9wrhzREfCW7r/hG8a+kopvt18gHs/jKd9S3/eue30V+9+t+Ugf/tiM4fyCrlrZEfuubhTrfvA3Vl2QRFzlicwf2USnh6GP4/owLDOIby+Yg/fbT2In5cnNw4M545hHSptENydhr5SCoDVu7OY+k4sgae4ejcr9zhPLdrC1xsP0P28Zvzf9b1rPb1vQ5Kclc/z325j8aaDAAT6eTF5cARThkQ2yO/t1dBXSpXbsv8ot8yrfPWuiPDVxgPMWLSFnIIipo/qzJ9HdHTJFa2uFJt0iJ1puVzR5zy3uOagtjT0lVKVJGflc9O8NaRnH+fZa3uyeNNBftiaRp+w5vxzfB+3Pkmpqqehr5Q6QdnVu1v2Z+Pr5cEDl3Xh1iGRbv+lLKp6dfnNWUqpRiI00JcFUwfx9qokxvU6jw5uOC2xql8a+krZTKCfN/dc3NnVZSgX0WM6pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSykRqFvjFmjDFmhzEmwRjz6EnWtzfGLDPGbDTG/GSMCXNad4sxZpfj55a6LF4ppdSZqTb0jTGewBxgLNAdmGiM6V6l2SzgHRHpDcwEnnM8tgXwFDAQGAA8ZYwJrrvylVJKnYma7OkPABJEJFFECoEFwFVV2nQHljluL3daPxr4QUQOichh4AdgzNmXrZRSqjZqEvptgRSn+6mOZc42ANc5bl8DBBpjWtbwsUoppc6RmoS+Ocmyqt+m/iAwwhgTB4wA9gHFNXwsxpipxphYY0xsRkZGDUpSSilVGzUJ/VSgndP9MGC/cwMR2S8i14pIX+CvjmVHa/JYR9vXRCRGRGJCQ0PP8C0opZSqqZqE/lqgszEm0hjjA9wALHJuYIwJMcaUPddjwDzH7e+Ay4wxwY4TuJc5limllHKBakNfRIqBe7DCehvwsYhsMcbMNMZc6Wg2EthhjNkJtAaecTz2EPB3rA3HWmCmY5lSSikXMCIndLG7VExMjMTGxrq6DKWUalCMMetEJKa6dnpFrlJK2YiGvlJK2YiGvlJK2YiGvlJK2YiGvlJK2YiGvlJK2YiGvlJK2YiGvlJK2YiGvlJK2YiGvlJK2YiGvlJK2YiGvlJK2YiGvlJK2YiXqwtQSrleUVERqampFBQUuLoUVQ0/Pz/CwsLw9vau1eM19JVSpKamEhgYSEREBMac7FtOlTsQEbKyskhNTSUyMrJWz6HdO0opCgoKaNmypQa+mzPG0LJly7M6ItPQV0oBaOA3EGf776Shr5RSNqKhr5RqdCIiIsjMzKyX5z5+/DiXXHIJUVFRfPTRR/XyGvHx8SxevLhenltP5Cql3IqIICJ4eLjnPmlcXBxFRUXEx8fX+DElJSV4enrWuH18fDyxsbGMGzeuNiWeloa+UqqSp7/awtb92XX6nN3Pb8ZTf+hxyvVJSUmMHTuWiy66iNWrVxMVFcWmTZs4duwY119/PU8//TRg7cHfcsstfPXVVxQVFfHJJ5/QtWtXsrKymDhxIhkZGQwYMAARKX/uF198kXnz5gFw++23c99995GUlMSYMWMYOnQov/32G3369GHKlCk89dRTpKen8/777zNgwIAT6kxPT2fSpElkZGQQFRXFwoULSUpK4sEHH6S4uJj+/fszd+5cfH19iYiI4NZbb+X777/nnnvuoX///tx9991kZGTg7+/P66+/TteuXfnkk094+umn8fT0pHnz5ixdupQnn3ySY8eOsXLlSh577DEmTJhQZ/8W7rkpVUrZzo4dO7j55puJi4vjX//6F7GxsWzcuJGff/6ZjRs3lrcLCQlh/fr13HnnncyaNQuAp59+mqFDhxIXF8eVV15JcnIyAOvWrWP+/PmsWbOG3377jddff524uDgAEhISmD59Ohs3bmT79u188MEHrFy5klmzZvHss8+etMZWrVrxxhtvMGzYMOLj42nbti2TJ0/mo48+YtOmTRQXFzN37tzy9n5+fqxcuZIbbriBqVOn8p///Id169Yxa9Ys7rrrLgBmzpzJd999x4YNG1i0aBE+Pj7MnDmTCRMmEB8fX6eBD7qnr5Sq4nR75PWpffv2DBo0CICPP/6Y1157jeLiYg4cOMDWrVvp3bs3ANdeey0A/fr147PPPgNgxYoV5bcvv/xygoODAVi5ciXXXHMNAQEB5Y/95ZdfuPLKK4mMjKRXr14A9OjRg1GjRmGMoVevXiQlJdWo5h07dhAZGUmXLl0AuOWWW5gzZw733XcfQHlg5+bmsmrVKsaPH1/+2OPHjwMwZMgQJk+ezB//+Mfy91afNPSVUm6hLJj37NnDrFmzWLt2LcHBwUyePLnSuHRfX18APD09KS4uLl9+sqGMzt08VZU9D4CHh0f5fQ8Pj0rPezqne36oeE+lpaUEBQWd9DzAq6++ypo1a/jmm2+Iioo6o3MFtaHdO0opt5KdnU1AQADNmzcnLS2NJUuWVPuY4cOH8/777wOwZMkSDh8+XL78iy++ID8/n7y8PD7//HOGDRtWZ7V27dqVpKQkEhISAHj33XcZMWLECe2aNWtGZGQkn3zyCWBtLDZs2ADA7t27GThwIDNnziQkJISUlBQCAwPJycmpszqdaegrpdxKnz596Nu3Lz169ODWW29lyJAh1T7mqaeeYsWKFURHR/P9998THh4OQHR0NJMnT2bAgAEMHDiQ22+/nb59+9ZZrX5+fsyfP5/x48fTq1cvPDw8mDZt2knbvv/++7z55pv06dOHHj168OWXXwLw0EMP0atXL3r27Mnw4cPp06cPF110EVu3bq2XYaGmusOTcy0mJkZiY2NdXYZStrJt2za6devm6jJUDZ3s38sYs05EYqp7rO7pK6WUjeiJXKWUOon58+cze/bsSsuGDBnCnDlzXFRR3ahR6BtjxgCzAU/gDRF5vsr6cOBtIMjR5lERWWyM8QbeAKIdr/WOiDxXh/UrpVS9mDJlClOmTHF1GXWu2u4dY4wnMAcYC3QHJhpjuldp9jfgYxHpC9wA/NexfDzgKyK9gH7An40xEXVTulJKqTNVkz79AUCCiCSKSCGwALiqShsBmjluNwf2Oy0PMMZ4AU2AQqBur+9WSilVYzUJ/bZAitP9VMcyZzOAScaYVGAx8BfH8k+BPOAAkAzMEpFDZ1OwUkqp2qtJ6J9sxv6q4zwnAm+JSBgwDnjXGOOBdZRQApwPRAIPGGM6nPACxkw1xsQaY2IzMjLO6A0opZSquZqEfirQzul+GBXdN2VuAz4GEJHVgB8QAtwIfCsiRSKSDvwKnDCOVEReE5EYEYkJDQ0983ehlGpUZsyYUT6Z2pkYPHjwadePGzeOI0eO1LasE7z11lvs3181Dt1bTUJ/LdDZGBNpjPHBOlG7qEqbZGAUgDGmG1boZziWX2wsAcAgYHtdFa+UUs5WrVp12vWLFy8mKCiozl7vdKFfUlJSZ69Tl6odsikixcaYe4DvsIZjzhORLcaYmUCsiCwCHgBeN8bcj9X1M1lExBgzB5gPbMbqJpovIhtP/kpKKbew5FE4uKlun7NNLxj7/GmbPPPMM7zzzju0a9eO0NBQ+vXrx+7du086B31aWhrTpk0jMTERgLlz5zJ48GCaNm1Kbm4uBw4cYMKECWRnZ5dPdzxs2DAiIiKIjY0lJCTklPPsjx07lqFDh7Jq1Sratm3Ll19+SZMmTU6o99NPPyU2NpY//elPNGnShNWrV9OtW7cazaGfkZHBtGnTyqeAfumll2o03URdqNE4fRFZjHWC1nnZk063twInVCwiuVjDNpVS6pTWrVvHggULiIuLo7i4mOjoaPr168fUqVN59dVX6dy5M2vWrOGuu+7ixx9/5N5772XEiBF8/vnnlJSUkJubW+n5PvjgA0aPHs1f//pXSkpKyM/PP+H1yubZFxEGDhzIiBEjCA4OZteuXXz44Ye8/vrr/PGPf2ThwoVMmjTphJqvv/56XnnlFWbNmkVMTEWvddkc+gCjRo06af3Tp0/n/vvvZ+jQoSQnJzN69Gi2bdtWD5/sifSKXKVUZdXskdeHX375hWuuuQZ/f38ArrzySgoKCk45B/2PP/7IO++8A1D+jVPO+vfvz6233kpRURFXX301UVFRldZXN89+Wft+/frVeG79MjWZQ3/p0qVs3bq1fHl2djY5OTkEBgae0WvVhoa+UsotVJ0P/3Rz0Fdn+PDhrFixgm+++YabbrqJhx56iJtvvrl8fU3n2ff09OTYsWNn9No1mUO/tLSU1atXn7TbqL7phGtKKZcbPnw4n3/+OceOHSMnJ4evvvoKf3//U85BP2rUqPKvJSwpKSE7u/I1n3v37qVVq1bccccd3Hbbbaxfv/6E16uLefZPN+/96ebQv+yyy3jllVfK29b3F6c409BXSrlcdHQ0EyZMICoqiuuuu648gE81B/3s2bNZvnw5vXr1ol+/fmzZsqXS8/30009ERUXRt29fFi5cyPTp0094vbqYZ3/y5MlMmzaNqKiokx4RnKr+l19+mdjYWHr37k337t159dVXz/i1a0vn01dK6Xz6DYzOp6+UUqpG9ESuUkpV4+677+bXX3+ttGz69OkNcuplDX2lFGCdaKw6gkZZ3OmLU862S167d5RS+Pn5kZWVddaBouqXiJCVlYWfn1+tn0P39JVShIWFkZqais5y6/78/PwICwur9eM19JVSeHt7ExkZ6eoy1Dmg3TtKKWUjGvpKKWUjGvpKKWUjGvpKKWUjGvpKKWUjGvpKKWUjGvpKKWUjGvpKKWUjGvpKKWUjGvpKKWUjGvpKKWUjGvpKKWUjGvpKKWUjGvpKKWUjGvpKKWUjGvpKKWUjGvpKKWUjGvpKKWUjNQp9Y8wYY8wOY0yCMebRk6wPN8YsN8bEGWM2GmPGOa3rbYxZbYzZYozZZIyp/Tf6KqWUOivVfkeuMcYTmANcCqQCa40xi0Rkq1OzvwEfi8hcY0x3YDEQYYzxAt4DbhKRDcaYlkBRnb8LpZRSNVKTPf0BQIKIJIpIIbAAuKpKGwGaOW43B/Y7bl8GbBSRDQAikiUiJWdftlJKqdqoSei3BVKc7qc6ljmbAUwyxqRi7eX/xbG8CyDGmO+MMeuNMQ+fZb1KKaXOQk1C35xkmVS5PxF4S0TCgHHAu8YYD6zuo6G53EEhAAAOXklEQVTAnxy/rzHGjDrhBYyZaoyJNcbEZmRknNEbUEopVXM1Cf1UoJ3T/TAqum/K3AZ8DCAiqwE/IMTx2J9FJFNE8rGOAqKrvoCIvCYiMSISExoaeubvQimlVI3UJPTXAp2NMZHGGB/gBmBRlTbJwCgAY0w3rNDPAL4Dehtj/B0ndUcAW1FKKeUS1Y7eEZFiY8w9WAHuCcwTkS3GmJlArIgsAh4AXjfG3I/V9TNZRAQ4bIx5EWvDIcBiEfmmvt6MUkqp0zNWNruPmJgYiY2NdXUZSinVoBhj1olITHXt9IpcpZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ39ulZSDEm/QvZ+V1eilFIn8KpJI2PMGGA24Am8ISLPV1kfDrwNBDnaPCoii6us3wrMEJFZdVS7+zm4Gb68Gw7EW/ebt4Ow/tBugPXTuhd4+bi2RqWUrVUb+sYYT2AOcCmQCqw1xiwSka1Ozf4GfCwic40x3YHFQITT+n8DS+qsandTXAgrX4QVs8CvOfzhZSjMg9TfIeV32PKZ1c7LD87va20AwhwbgqatXFu7UspWarKnPwBIEJFEAGPMAuAqrD33MgI0c9xuDpT3bRhjrgYSgby6KNjt7I+DL++BtM3QazyMeQECWjpW3mX9OrrPsQFYCylrYPV/oXS2tS44omIDENYfWvcEzxodgLmXkmI4shcyd0HmTusnK8H67RcEl/0dLhgHxri6UqVsrSbp0hZIcbqfCgys0mYG8L0x5i9AAHAJgDEmAHgE6yjhwVO9gDFmKjAVIDw8vIalu1hRAfz8Avw6GwJC4YYPoeu4k7dt3haaXwM9rql47IEN1gYg9XfY8zNs+tha5+0Pbfs5uoUGWr/LNyJuoCAbsnY5hbvj9qHdUFJY0S4gFEK6QNcrrKOdBTdCp0th7AvQsqPr6lfK5moS+ifbNZMq9ycCb4nIv4wxFwLvGmN6Ak8D/xaRXHOaPTwReQ14DSAmJqbqc7uf1Fj44i7I3AFRk2D0P6BJcM0f7+0H4QOtHwAROJIMqWutgExZY21MpMRa36KjtQFo1986KmjVDTw86/59lSkthex9lffWywI+50BFO+MJLTpY4d7lMut3SBdo2Qn8W1S0KymC31+Hn56D/w6CwX+BYQ+AT0D9vQel1EkZkdNnrCPEZ4jIaMf9xwBE5DmnNluAMSKS4rifCAwCFgLtHM2CgFLgSRF55VSvFxMTI7GxsbV+Q/Wq6Bj8+A/47b8QeD78YTZ0vqR+Xqsw3+o6SllTsTHIz7TW+QRCWD9Ht9BA6/aZbHTKFB2DrN1Oe+xO3TJF+RXtfJtDSGdHqHeuCPfgiDM7MZ2TBj88CRsXQLMwGP0MdL9Ku3yUqgPGmHUiElNtuxqEvhewExgF7APWAjeKyBanNkuAj0TkLWNMN2AZ0FacntwYMwPIrW70jtuG/t5VVt/9od3QbwpcOhP8mlX/uLoiAocSnY4Gfof0LSCl1vqQCypGCYUNsELZw8N6XF5G5b31sttHUqg4aDMQ1K4i0J3DPSC0boN572pY/BCkbYIOI2HsPyG0S909v1I2VGeh73iyccBLWMMx54nIM8aYmUCsiCxyjNh5HWiKlSIPi8j3VZ5jBg0x9I/nwrKZ8PtrEBQOV/4HOoxwdVWW4zmwb721AUj93dogHDtsrfNrbu2JH06CgqMVj/H2t7pfKoV7Z6sLycf/3NVeUgyx86wjp6I8GHQXjHgYfAPPXQ1KNSJ1GvrnkluFfuLPsOgv1qiUAX+GUU+Cb1NXV3VqIlbXTMoaa0NwNAWCIyvvuTdrax0BuIvcDFg2A+Leg8Dz4LJ/QM/rtMtHqTOkoX82CrLhhydg3VvWHvBVr0D7wa6tqbFLWQuLH7BGNUUMg7H/B627u7oq1ViJwNFUa6j1wc2Qlw7+IRAQYnVnBoRa19AEhIBvswaxE1LT0G+AA8Lr2a6l8NV0yNlvjTIZ+fi57fawq3b94Y7lsP5tqzvt1aEwcBqMfMTqqlKqtoqOQfo2SNtSEfJpm6HgSEUb3+Zw/OjJH+/p69gQOG8QQituB4RAQKuK257e5+Z91ZLu6Zc5dhi++yvEv2+dFL36vxBW7UZT1Yf8Q7DsaVj3tvUf6bK/Q+8JDWJvS7mQCOQcdAT7poqAz0qoGP7sHWAdQbbuCW16WlOjtO5unUsqLoT8LGvgQ9WfXOf7mdaRgfN1Kc6aBJ96g9C0VeV1dXgUod07Z2L7Yvj6fusfdOh9MOIR8PI9tzWoE+1bD4sfhH3rIPxCGPdPaNPL1VUpd1BcCBnbK++5p222QrtM83BHsPes+B0cWTfntETgeLa1AchNP3GDUHa7bJ3zUYWzqkcR5/WBUU/UqiTt3qmJvCz49hHY9In1B3HjR3B+lKurUmXaRsNtSyH+PVg6A/43HPrfARc9Dk2CXF2dOldyM6zhvWXhfnCzdWFkabG13svPumDxgnHWTkHrntC6R/3+jRhjdTv6Na/ZFeblRxFOG4S8DMdGwXE7L90aEl7P7Lunv+ULay/y2BEY/hAMvV9nwHRnxw7Dj89A7JvQpAVc+jT0udG9RiKps1NSZF1Hkra58h58blpFm8Dzq+y997KuCm+I81XVMe3eOZXcdCvst34J50VZffete9Tf66m6dWCj9e+Xssaal2jcP62ZS1XDIWIFefo2xwlWRx98xvaKfnJPHwi9wAr1spBv3dO95qFyMxr6VYnApk9hycPWtMcjH4XB9+oeQkMkAhsWWFM65GVAzBS4+InK8/0o95CbARnbIH07pG+1gj19W+U+7oBWTnvvju6ZkM5uPwrG3WifvrPsA9aJ2p1LrL3Dq+ZYexGqYTIGoiZas5r+9Dys+Z/VXTfqSYi+uX4no1Mnl3/IsedeFuzbrbB3PrHq1xxadbdmm23VDUK7Wr/1OyXOqca9py9iDcH89nEoOW7tDQ66U0OhsUnbYs3ls/dXq6tn3L+sSehU3Tt2uCLQnX/npVe08W3mCPSuENqt4ndgGx12W490T/9IinWR1e5l0H6INWeOzuPeOLXuAZO/gc0LrWst3hgF0TfBqKesoXDqzBVkV3TFOP92nlrbp6l1xNz5ssoB36ythrsba3yhLwLr5sP3T1ozUI6bBTG36SiPxs4Y6HU9dBltfbnNb3Otk/UXPwExt+rR3akcz4WMHY49dqeAz95X0cariRXuHUZWdMm06mZNj63/rxqcxtW9c2gPfHUv7Flh/YH+4WUIbl+X5amGImOH1eWz52fr5OC4f1V8ac3ZEHH8lFg7FaWO31LqWCZVljutL3WsLy12+imB0qITl5UUVWlT7NSupGJdSVHl+6VV7xdbM5pWff7iY5CZAEeTK96bp681xXVot4pgD+0KQe013BsA+43eydxlXbzj4WXN1Bh9sx5i2p0IbP3C6vLJ3mdNKV0WytUGdGmVMHe0O+FL49yIh7c14sXDyzqy8fCyljnf9/S2bnv6WFNvt+pWEfLBEXpE1IDZr0+/ZSfrAquoG6F5mKurUe7AGGukSOfLYNUrVheG8XD8eFq/PTxOsszTaZnHSZad5rHlbc1JljndLg9jRyB7VrlfdX15YJ+qje6Jq5ppPKFvjPUlHEpV5RNgzdaplEJ3D5RSykY09JVSykY09JVSykY09JVSykY09JVSykY09JVSykY09JVSykY09JVSykbcbhoGY0wGsPcsniIEyKyjcho6/Swq08+jMv08KjSGz6K9iIRW18jtQv9sGWNiazL/hB3oZ1GZfh6V6edRwU6fhXbvKKWUjWjoK6WUjTTG0H/N1QW4Ef0sKtPPozL9PCrY5rNodH36SimlTq0x7ukrpZQ6hUYT+saYMcaYHcaYBGPMo66ux5WMMe2MMcuNMduMMVuMMdNdXZOrGWM8jTFxxpivXV2Lqxljgowxnxpjtjv+Ri50dU2uZIy53/H/ZLMx5kNjjJ+ra6pPjSL0jTGewBxgLNAdmGiM6e7aqlyqGHhARLoBg4C7bf55AEwHtrm6CDcxG/hWRLoCfbDx52KMaQvcC8SISE/AE7jBtVXVr0YR+sAAIEFEEkWkEFgAXOXimlxGRA6IyHrH7Rys/9RtXVuV6xhjwoDLgTdcXYurGWOaAcOBNwFEpFBEjri2KpfzApoYY7wAf2C/i+upV40l9NsCKU73U7FxyDkzxkQAfYE1rq3EpV4CHgZKXV2IG+gAZADzHd1dbxhjAlxdlKuIyD5gFpAMHACOisj3rq2qfjWW0DcnWWb7YUnGmKbAQuA+Ecl2dT2uYIy5AkgXkXWursVNeAHRwFwR6QvkAbY9B2aMCcbqFYgEzgcCjDGTXFtV/WosoZ8KtHO6H0YjP0SrjjHGGyvw3xeRz1xdjwsNAa40xiRhdftdbIx5z7UluVQqkCoiZUd+n2JtBOzqEmCPiGSISBHwGTDYxTXVq8YS+muBzsaYSGOMD9aJmEUurslljDEGq892m4i86Op6XElEHhORMBGJwPq7+FFEGvWe3OmIyEEgxRhzgWPRKGCrC0tytWRgkDHG3/H/ZhSN/MS2l6sLqAsiUmyMuQf4Duvs+zwR2eLislxpCHATsMkYE+9Y9riILHZhTcp9/AV437GDlAhMcXE9LiMia4wxnwLrsUa9xdHIr87VK3KVUspGGkv3jlJKqRrQ0FdKKRvR0FdKKRvR0FdKKRvR0FdKKRvR0FdKKRvR0FdKKRvR0FdKKRv5/5Yq8G830UTCAAAAAElFTkSuQmCC\n", - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "df.plot()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Decision Tree Accuracy" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "0.8343173330831328" - ] - }, - "execution_count": 22, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "np.mean(dt_cv_scores)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Random Forest Accuracy" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "0.9223850187122359" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "np.mean(rf_cv_scores)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.4" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/data_science/ensemble/voting.ipynb b/data_science/ensemble/voting.ipynb deleted file mode 100644 index e42e8d3..0000000 --- a/data_science/ensemble/voting.ipynb +++ /dev/null @@ -1,279 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Voting\n", - "Based on the idea that classifiers can complement each other, \n", - "Aggregating individual classifier's prediction to make better prediction." - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [], - "source": [ - "from sklearn import datasets\n", - "from sklearn import tree\n", - "from sklearn.neighbors import KNeighborsClassifier\n", - "from sklearn.svm import SVC\n", - "from sklearn.ensemble import VotingClassifier\n", - "from sklearn.model_selection import train_test_split\n", - "from sklearn.metrics import accuracy_score" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# load mnist dataset" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": {}, - "outputs": [], - "source": [ - "mnist = datasets.load_digits()\n", - "features, labels = mnist.data, mnist.target\n", - "X_train,X_test,y_train,y_test=train_test_split(features,labels,test_size=0.2)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# single classifiers accuracy on mnist\n", - "build decision tree, knn, svm and check accuracy on MNIST data." - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": {}, - "outputs": [], - "source": [ - "dtree = tree.DecisionTreeClassifier(\n", - " criterion=\"gini\", max_depth=8, max_features=32,random_state=35)\n", - "\n", - "dtree = dtree.fit(X_train, y_train)\n", - "dtree_predicted = dtree.predict(X_test)\n", - "\n", - "knn = KNeighborsClassifier(n_neighbors=299).fit(X_train, y_train)\n", - "knn_predicted = knn.predict(X_test)\n", - "\n", - "svm = SVC(C=0.1, gamma=0.003,\n", - " probability=True,random_state=35).fit(X_train, y_train)\n", - "svm_predicted = svm.predict(X_test)" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[accuarcy]\n", - "d-tree: 0.7972222222222223\n", - "knn : 0.8416666666666667\n", - "svm : 0.85\n" - ] - } - ], - "source": [ - "print(\"[accuarcy]\")\n", - "print(\"d-tree: \",accuracy_score(y_test, dtree_predicted))\n", - "print(\"knn : \",accuracy_score(y_test, knn_predicted))\n", - "print(\"svm : \",accuracy_score(y_test, svm_predicted))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "we can easily do soft voting or hard voting using sklearn's voting classifier \n", - "when you want to implement soft voting by scratch, you can use predict_proba just like below, \n", - "Below is the example of SVM's prediction (digit 0 to 9) on two MNIST data." - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[9.95557918e-01 3.42018637e-04 4.57700824e-04 4.19160266e-04\n", - " 4.21146304e-04 7.99436984e-04 4.11439277e-04 6.08753549e-04\n", - " 4.33211441e-04 5.49214707e-04]\n", - " [2.86586264e-03 4.17512273e-03 4.28013091e-03 4.14650212e-03\n", - " 9.27814553e-01 2.24791840e-02 3.06764221e-03 9.50855980e-03\n", - " 1.51437526e-02 6.51868962e-03]]\n" - ] - } - ], - "source": [ - "svm_proba = svm.predict_proba(X_test)\n", - "print(svm_proba[0:2])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# hard voting\n", - "hard voting is just majority vote which collects each classifier's prediction and take the most voted prediction." - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/anaconda3/envs/wikiml/lib/python3.6/site-packages/sklearn/preprocessing/label.py:151: DeprecationWarning: The truth value of an empty array is ambiguous. Returning False, but in future this will result in an error. Use `array.size > 0` to check that an array is not empty.\n", - " if diff:\n" - ] - }, - { - "data": { - "text/plain": [ - "0.9083333333333333" - ] - }, - "execution_count": 24, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "voting_clf = VotingClassifier(estimators=[\n", - " ('decision_tree', dtree), ('knn', knn), ('svm', svm)], \n", - " weights=[1,1,1], voting='hard').fit(X_train, y_train)\n", - "hard_voting_predicted = voting_clf.predict(X_test)\n", - "accuracy_score(y_test, hard_voting_predicted)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# soft voting\n", - "soft voting takes each classifier's predict_proba and then sum up all probabilities to take the prediction has highest probabilities." - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/anaconda3/envs/wikiml/lib/python3.6/site-packages/sklearn/preprocessing/label.py:151: DeprecationWarning: The truth value of an empty array is ambiguous. Returning False, but in future this will result in an error. Use `array.size > 0` to check that an array is not empty.\n", - " if diff:\n" - ] - }, - { - "data": { - "text/plain": [ - "0.9138888888888889" - ] - }, - "execution_count": 25, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "voting_clf = VotingClassifier(estimators=[\n", - " ('decision_tree', dtree), ('knn', knn), ('svm', svm)], \n", - " weights=[1,1,1], voting='soft').fit(X_train, y_train)\n", - "soft_voting_predicted = voting_clf.predict(X_test)\n", - "accuracy_score(y_test, soft_voting_predicted)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Visualization\n", - "we can visualize accuracy to check voting result is stabled or better than single model accuracy. \n", - "it is hard to say which voting is better, but we can confirm classifiers complement each other, \n", - "and voting result is better in this example." - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvqOYd8AAAEepJREFUeJzt3XvQHXV9x/H3h2BEES8lqVUghiqoqVaoGbwgikpbwAo4oEK1LQ6V6QVtvc3QwTIWrVXROrViK7SKYpWLiqYYDZWKUK2YIBdJMDQTUFLaMSpSURGRb//YjZwcT/Kc58l58iQ/3q+ZzLOX39n97e5vP2fP75zdpKqQJLVll7mugCRp8gx3SWqQ4S5JDTLcJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoN2nasVL1iwoBYvXjxXq5ekndLVV1/9napaOFW5OQv3xYsXs2rVqrlavSTtlJJ8c5xydstIUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12SGmS4S1KD5uwOVUmajsWnfmauqzAxt7ztBbO+DsNd2om0EnDbI9zu7+yWkaQGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12SGmS4S1KDfHCYdiqtPDgLfHiWZpdX7pLUIK/cd0KtXL165SrNHq/cJalBhrskNchwl6QGGe6S1CDDXZIaNFa4Jzk8ydok65KcOmL+oiRfSHJNkuuTHDn5qkqSxjVluCeZB5wFHAEsAU5IsmSo2BuBC6vqQOB44H2TrqgkaXzjXLkfBKyrqvVVdTdwPnD0UJkCHtoPPwy4bXJVlCRN1zg3Me0F3DowvgF42lCZNwGXJnkVsDtw2ERqJ0makXHCPSOm1dD4CcC5VfWuJM8AzkvypKq6d7MFJScDJwMsWrRoJvUF2rlDE7xLU9LsGKdbZgOwz8D43vxit8tJwIUAVfWfwG7AguEFVdXZVbW0qpYuXLhwZjWWJE1pnHBfCeyXZN8k8+m+MF02VOZbwPMBkjyRLtw3TrKikqTxTRnuVXUPcAqwAriR7lcxq5OckeSovtjrgFcmuQ74GHBiVQ133UiStpOxngpZVcuB5UPTTh8YXgMcPNmqSZJmyjtUJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12SGmS4S1KDDHdJapDhLkkNMtwlqUGGuyQ1yHCXpAYZ7pLUIMNdkhpkuEtSgwx3SWqQ4S5JDTLcJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktSgscI9yeFJ1iZZl+TULZR5SZI1SVYn+ehkqylJmo5dpyqQZB5wFvCbwAZgZZJlVbVmoMx+wF8AB1fV7Ul+ebYqLEma2jhX7gcB66pqfVXdDZwPHD1U5pXAWVV1O0BVfXuy1ZQkTcc44b4XcOvA+IZ+2qD9gf2TfCnJV5IcPqkKSpKmb8puGSAjptWI5ewHHArsDVyZ5ElV9f3NFpScDJwMsGjRomlXVpI0nnGu3DcA+wyM7w3cNqLMp6vqp1V1M7CWLuw3U1VnV9XSqlq6cOHCmdZZkjSFccJ9JbBfkn2TzAeOB5YNlfkU8FyAJAvoumnWT7KikqTxTRnuVXUPcAqwArgRuLCqVic5I8lRfbEVwHeTrAG+ALyhqr47W5WWJG3dOH3uVNVyYPnQtNMHhgt4bf9PkjTHvENVkhpkuEtSgwx3SWqQ4S5JDTLcJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12SGmS4S1KDDHdJapDhLkkNMtwlqUGGuyQ1yHCXpAYZ7pLUIMNdkhpkuEtSgwx3SWqQ4S5JDTLcJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQWOFe5LDk6xNsi7JqVspd1ySSrJ0clWUJE3XlOGeZB5wFnAEsAQ4IcmSEeX2AF4NXDXpSkqSpmecK/eDgHVVtb6q7gbOB44eUe7NwDuAuyZYP0nSDIwT7nsBtw6Mb+in/VySA4F9quqSrS0oyclJViVZtXHjxmlXVpI0nnHCPSOm1c9nJrsA7wZeN9WCqursqlpaVUsXLlw4fi0lSdMyTrhvAPYZGN8buG1gfA/gScDlSW4Bng4s80tVSZo744T7SmC/JPsmmQ8cDyzbNLOq7qiqBVW1uKoWA18BjqqqVbNSY0nSlKYM96q6BzgFWAHcCFxYVauTnJHkqNmuoCRp+nYdp1BVLQeWD007fQtlD932akmStoV3qEpSgwx3SWqQ4S5JDTLcJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12SGmS4S1KDDHdJapDhLkkNMtwlqUGGuyQ1yHCXpAYZ7pLUIMNdkhpkuEtSgwx3SWqQ4S5JDTLcJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQWOFe5LDk6xNsi7JqSPmvzbJmiTXJ7ksyWMmX1VJ0rimDPck84CzgCOAJcAJSZYMFbsGWFpVvw58HHjHpCsqSRrfOFfuBwHrqmp9Vd0NnA8cPVigqr5QVT/qR78C7D3ZakqSpmOccN8LuHVgfEM/bUtOAj47akaSk5OsSrJq48aN49dSkjQt44R7RkyrkQWTlwNLgTNHza+qs6tqaVUtXbhw4fi1lCRNy65jlNkA7DMwvjdw23ChJIcBpwHPqaqfTKZ6kqSZGOfKfSWwX5J9k8wHjgeWDRZIciDwfuCoqvr25KspSZqOKcO9qu4BTgFWADcCF1bV6iRnJDmqL3Ym8BDgoiTXJlm2hcVJkraDcbplqKrlwPKhaacPDB824XpJkraBd6hKUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12SGmS4S1KDDHdJapDhLkkNMtwlqUGGuyQ1yHCXpAYZ7pLUIMNdkhpkuEtSgwx3SWqQ4S5JDTLcJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12SGmS4S1KDDHdJapDhLkkNMtwlqUFjhXuSw5OsTbIuyakj5j8wyQX9/KuSLJ50RSVJ45sy3JPMA84CjgCWACckWTJU7CTg9qp6HPBu4O2TrqgkaXzjXLkfBKyrqvVVdTdwPnD0UJmjgQ/1wx8Hnp8kk6umJGk6xgn3vYBbB8Y39NNGlqmqe4A7gD0nUUFJ0vTtOkaZUVfgNYMyJDkZOLkfvTPJ2jHWP5cWAN+ZzRVkx+3Acttn2f15++/P2w7bvP2PGafQOOG+AdhnYHxv4LYtlNmQZFfgYcD3hhdUVWcDZ49TsR1BklVVtXSu6zEX3Pb757bD/Xv7W9r2cbplVgL7Jdk3yXzgeGDZUJllwB/0w8cB/15Vv3DlLknaPqa8cq+qe5KcAqwA5gEfqKrVSc4AVlXVMuCfgfOSrKO7Yj9+NistSdq6cbplqKrlwPKhaacPDN8FvHiyVdsh7DRdSLPAbb//uj9vfzPbHntPJKk9Pn5Akhq004R7kjclef0MX/vlKeYvT/LwmdVss+UcM+Lu3R1SksVJbpjremjrZuM4JbklyYJtXMbDk/zJwPijk3x822s3GUkOSbI6ybVJnpjkdye03AOSHDkwftSoR7LsCHaacN8WVfXMKeYfWVXfn8CqjqF7RMMv6H8iKm1Xs9juHg78PNyr6raqOm6W1jUTLwPeWVUHAI8EJhLuwAHAz8O9qpZV1dsmtOzJqqod9h9wGrAW+DzwMeD1/fTHAp8DrgauBJ7QT38kcDFwXf/vmf30O/u/jwKuAK4FbgAO6affAizoh1/bz7sB+PN+2mLgRuAcYDVwKfCgobo+k+6XQjf3y38scDnwVuCLwOuAhcAn6H5euhI4uH/t7sAH+mnXAEdvh327GLihH/7Vfr1vAD7Z79v/At4xUP5O4K/7/foV4JFz3T5muN27A5/pt+MGup/wXjgw/1DgXwe2+e19O/s83aM4LgfWA0dtp/puse0Br+zbzHV9u3pwP/1c4G+BLwDvortb/NL+GL8f+Oam9j6wnj8eOt4nAn+/lXPifODHfVs/c6g9nbiVdnQScFO/H88B3jvD4/bSfvrz++36en8OPRD4Q+47F/+lb6939HV9zdByLwCOHBg/FzgW2A34YL/ca4DnAvOBbwEb+2W9tN/W9w689j3Al/s2clw/fRfgff3xu4TuxynHzXrbmeuTbSsH86n9jn0w8FBgHfeF+2XAfv3w0+h+V7/pQG1qfPOAh206Sfu/rwNOG5i/Rz98C92daZvWuTvwkP5gHNg33HuAA/ryFwIvH1HncwcPWt+A3zcw/lHgWf3wIuDGfvitm5ZHd0V0E7D7LO/fxf1J8vi+8R7QN9T1dDeh7UYXAvv05Qt4YT/8DuCNc91GZrjdxwLnDIw/rD9hd+/H/2HgWBRwRD98MV1APgB4CnDtdqrvFtsesOdAubcArxpoh5cA8/rx9wCn98Mv6LdrONwX0j1DatP4Z4FnTXFO3DDcnvrhke0IeDTdufZL/X68kvHDfdRx243usSf799M+zH3n/7ncF66HApdsYbkvAj7UD8/vl/cguqz4YD/9CX0b2Y2BMB/Y1sFwv4guzJds2p909/4s76f/CnA72yHcd+RumUOAi6vqR1X1f/Q3TiV5CN1V8kVJrqW7EnlU/5rn0Z2cVNXPquqOoWWuBF6R5E3Ak6vqB0Pzn9Wv84dVdSfd1cch/bybq+rafvhqusY8jgsGhg8D3tvXexnw0CR7AL8FnNpPv5yuES0ac/nbYiHwabqw2LRtl1XVHdX9vHUN993qfDddYMD0tn9H83XgsCRvT3JI30Y+B7yw78J4Ad0+gW6bPzfwui9W1U/74cXbsc5bantPSnJlkq/TdUP82sBrLqqqn/XDzwY+AlBVn6ELl81U1UZgfZKnJ9mT7k3/S2z9nNiaUe3oILp9+L1+P1405vbD6OP2eLp9c1Nf5kP9tk7HZ4HnJXkg3ZNvr6iqH9Nt93kAVfUNujeo/cdY3qeq6t6qWkPXk0C/rIv66f9L94lq1u3o/cCjfqe5C/D96vrSprewqiuSPJvuBD4vyZlV9eGBIlt7kuVPBoZ/RvfuPo4fDgzvAjyjbzz3rbR7guaxVbW9n7VzB92VysF0V2Twi9u5qY38tPrLkKHpO5WquinJU+n6Tf8myaV0b8B/SvdRfuXAm/7gNt9Lv2+q6t7t/B3KltreucAxVXVdkhPprlA3GWx3MPpcGnYB8BLgG3SBXtvwdNdR7WjGT4rdwnEbvlN+Jsu9K8nlwG/TdbN8rJ81ie3O0N/take+cr8CeFGSB/VXty8E6K/ib07yYuiCMclT+tdcRtd3SJJ5SR46uMAkjwG+XVXn0N1V+xsj1nlMkgcn2Z3uI9uV06jzD4A9tjL/UuCUgfpseoNaAbxq04mU5MBprHNb3E33JfDvT+rXBDu6JI8GflRVHwHeSdcGLu//vpLNP2nt6PYA/ifJA+iu3Lfkik3zkxwBPGIL5T5J1x5O4L79sKVzYqq2PspXgeckeUT/5njsuC/cwnH7BrA4yeP6Yr9H9/3WsKnqej7wCrpPJCv6aYP7bH+6T9Jrx1jWKP8BHJtklySPZPM34Vmzw4Z7VX2NroFdS/dl0WDIvgw4Kcl1dFecm54v/2fAc/uPqVez+cdU6HbqtUmuoWtYfzdinefSNcKrgH+qqmumUe3zgTckuSbJY0fMfzWwNMn1SdYAf9RPfzNdH+T1/c/e3jyNdW6Tqvoh8DvAa+j6MVv3ZOCrfRfYacBb+u6LS+g+ll+ytRfvYP6Srp3+G13QbclfAc9O8jW6LsBvjSpUVbfTd6FU1Vf7aSPPiar6LvClJDckOXOcylbVf9N9v3QV3RfUa+g+PY5j1HG7iy6UL+rP+XuBfxzx2uuBe5Jcl+Q1I+ZfSted8/nq/s8K6L4Andcv9wLgxKr6CV2XypL+J5YvHbPun6B7uOINdN3IVzH+ds+Yd6hK2m6SPKSq7uyv3C+me1bVxXNdr9k2sN170r1RHtz3v8+anbLfVNJO601JDqP70cClwKfmuD7byyX9jZLzgTfPdrCDV+6S1KQdts9dkjRzhrskNchwl6QGGe6S1CDDXZIaZLhLUoP+H47Jp0tra/pcAAAAAElFTkSuQmCC\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "%matplotlib inline\n", - "\n", - "x = np.arange(5)\n", - "plt.bar(x, height= [accuracy_score(y_test, dtree_predicted),\n", - " accuracy_score(y_test, knn_predicted),\n", - " accuracy_score(y_test, svm_predicted),\n", - " accuracy_score(y_test, hard_voting_predicted),\n", - " accuracy_score(y_test, soft_voting_predicted)])\n", - "plt.xticks(x, ['decision tree','knn','svm','hard voting','soft voting']);" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.4" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/data_science/svm/svm.ipynb b/data_science/svm/svm.ipynb deleted file mode 100644 index db0733a..0000000 --- a/data_science/svm/svm.ipynb +++ /dev/null @@ -1,401 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 79, - "metadata": {}, - "outputs": [], - "source": [ - "import pandas as pd\n", - "from sklearn.datasets import load_iris\n", - "from sklearn.model_selection import train_test_split\n", - "from sklearn.model_selection import GridSearchCV\n", - "from sklearn.metrics import classification_report\n", - "from sklearn.metrics import accuracy_score\n", - "from sklearn.svm import SVC" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Load dataset" - ] - }, - { - "cell_type": "code", - "execution_count": 80, - "metadata": {}, - "outputs": [], - "source": [ - "# load iris data\n", - "dataset = load_iris()\n", - "\n", - "# use 80% as train data, 20% as test data\n", - "X_train,X_test,y_train,y_test=train_test_split(dataset.data,dataset.target,test_size=0.2)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Find best hyperparamters\n", - "RBF kernel SVM has two parameters.\n", - "1. C (cost): The C parameter trades off correct classification of training examples against maximization of the decision function’s margin. For larger values of C, a smaller margin will be accepted if the decision function is better at classifying all training points correctly. \n", - "\n", - "2. gamma: the gamma parameter defines how far the influence of a single training example reaches, with low values meaning ‘far’ and high values meaning ‘close’. The gamma parameters can be seen as the inverse of the radius of influence of samples selected by the model as support vectors.\n", - "\n", - "reference:\n", - "http://scikit-learn.org/stable/auto_examples/svm/plot_rbf_parameters.html" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Grid Search\n", - "find best hyperparameter using grid search." - ] - }, - { - "cell_type": "code", - "execution_count": 81, - "metadata": {}, - "outputs": [], - "source": [ - "def svc_param_selection(X, y, nfolds):\n", - " svm_parameters = [\n", - " {'kernel': ['rbf'],\n", - " 'gamma': [0.00001,0.0001, 0.001, 0.01, 0.1, 1],\n", - " 'C': [0.01, 0.1, 1, 10, 100, 1000]\n", - " }\n", - " ]\n", - " \n", - " clf = GridSearchCV(SVC(), svm_parameters, cv=10)\n", - " clf.fit(X_train, y_train)\n", - " print(clf.best_params_)\n", - " \n", - " return clf" - ] - }, - { - "cell_type": "code", - "execution_count": 82, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'C': 10, 'gamma': 0.1, 'kernel': 'rbf'}\n" - ] - } - ], - "source": [ - "clf = svc_param_selection(X_train, y_train, 10)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Test" - ] - }, - { - "cell_type": "code", - "execution_count": 83, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - " precision recall f1-score support\n", - "\n", - " 0 1.00 1.00 1.00 7\n", - " 1 1.00 1.00 1.00 13\n", - " 2 1.00 1.00 1.00 10\n", - "\n", - "avg / total 1.00 1.00 1.00 30\n", - "\n", - "\n", - "accuracy : 1.0\n" - ] - } - ], - "source": [ - "y_true, y_pred = y_test, clf.predict(X_test)\n", - "\n", - "print(classification_report(y_true, y_pred))\n", - "print()\n", - "print(\"accuracy : \"+ str(accuracy_score(y_true, y_pred)) )" - ] - }, - { - "cell_type": "code", - "execution_count": 84, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
ground_truthprediction
011
111
222
300
411
522
622
722
811
900
1011
1100
1200
1311
1411
1522
1611
1722
1811
1922
2000
2111
2200
2300
2411
2511
2622
2711
2822
2922
\n", - "
" - ], - "text/plain": [ - " ground_truth prediction\n", - "0 1 1\n", - "1 1 1\n", - "2 2 2\n", - "3 0 0\n", - "4 1 1\n", - "5 2 2\n", - "6 2 2\n", - "7 2 2\n", - "8 1 1\n", - "9 0 0\n", - "10 1 1\n", - "11 0 0\n", - "12 0 0\n", - "13 1 1\n", - "14 1 1\n", - "15 2 2\n", - "16 1 1\n", - "17 2 2\n", - "18 1 1\n", - "19 2 2\n", - "20 0 0\n", - "21 1 1\n", - "22 0 0\n", - "23 0 0\n", - "24 1 1\n", - "25 1 1\n", - "26 2 2\n", - "27 1 1\n", - "28 2 2\n", - "29 2 2" - ] - }, - "execution_count": 84, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# Visualize true value with prediction value in pandas dataframe.\n", - "comparison = pd.DataFrame({'prediction':y_pred, 'ground_truth':y_true}) \n", - "comparison" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.4" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} From 6047e9d92a461d6ea3ebbda08b72934f41c730b3 Mon Sep 17 00:00:00 2001 From: minsuk-heo Date: Sun, 3 Nov 2019 15:06:07 -0800 Subject: [PATCH 6/8] code uploaded --- data_science/ensemble/randomforest.ipynb | 203 ++++++++++++ data_science/ensemble/voting.ipynb | 279 ++++++++++++++++ data_science/svm/svm.ipynb | 401 +++++++++++++++++++++++ 3 files changed, 883 insertions(+) create mode 100755 data_science/ensemble/randomforest.ipynb create mode 100755 data_science/ensemble/voting.ipynb create mode 100755 data_science/svm/svm.ipynb diff --git a/data_science/ensemble/randomforest.ipynb b/data_science/ensemble/randomforest.ipynb new file mode 100755 index 0000000..662d955 --- /dev/null +++ b/data_science/ensemble/randomforest.ipynb @@ -0,0 +1,203 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn import datasets\n", + "from sklearn import tree\n", + "from sklearn.ensemble import RandomForestClassifier\n", + "from sklearn.model_selection import cross_val_score\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import pandas as pd\n", + "import numpy as np" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Load MNIST dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "mnist = datasets.load_digits()\n", + "features, labels = mnist.data, mnist.target" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Cross Validation" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "def cross_validation(classifier,features, labels):\n", + " cv_scores = []\n", + "\n", + " for i in range(10):\n", + " scores = cross_val_score(classifier, features, labels, cv=10, scoring='accuracy')\n", + " cv_scores.append(scores.mean())\n", + " \n", + " return cv_scores" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "dt_cv_scores = cross_validation(tree.DecisionTreeClassifier(), features, labels)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "rf_cv_scores = cross_validation(RandomForestClassifier(), features, labels)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Random Forest VS Decision Tree visualization" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "cv_list = [ \n", + " ['random_forest',rf_cv_scores],\n", + " ['decision_tree',dt_cv_scores],\n", + " ]\n", + "df = pd.DataFrame.from_items(cv_list)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD8CAYAAACb4nSYAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAIABJREFUeJzt3Xd8VFX+//HXSSchkEACKCEkNOmEEIp0RaXo2lnERQULi2VFf3Z3VWTX9l3WFVcW1wJ2sWBBBQuIIoJIIKG3EEISSholjZD2+f1xJ8kklISQMJPcz/PxyCMz956Z+cwQ3vfec889Y0QEpZRS9uDh6gKUUkqdOxr6SillIxr6SillIxr6SillIxr6SillIxr6SillIxr6SillIxr6SillIxr6SillI16uLqCqkJAQiYiIcHUZSinVoKxbty5TREKra+d2oR8REUFsbKyry1BKqQbFGLO3Ju20e0cppWxEQ18ppWxEQ18ppWxEQ18ppWxEQ18ppWxEQ18ppWxEQ18ppWzE7cbpq7MjIqQePsb65MPkF5ZwXXQYPl66bVdKWTT0G7j8wmI2ph5lffJh4pKPEJd8hMzc4+XrP1iTzMsT+xIZEuDCKpVS7kJDvwEREZKy8lm/9zBxKVbIbz+YQ0mp9eX2HUICGNEllL7hQfQNDyLlUD6PLNzEFS//wt+v7sm10WEufgdKKVfT0HdjOQVF1l783sPEpRwhLvkwh/OLAGjq60VUuyDuHtmRvuHBRLULIjjAp9Lje5zfnN5hQdz3UTz/7+MN/LIrk5lX9SDQz9sVb0cp5QYaTeiXlgq7M3IJDfSleRNvjDGuLumMlJYKiZm5rN97hLiUw6zfe4Sd6TmItRNP51ZNuax7G8defDCdWjXF06P693h+UBM+vGMQc5Yn8NLSnaxPPszLN/SlT7ugen5HSil3ZKQsVdxETEyM1GbCtUN5hUT//QcAvD0NLQN8CQ30JaSpj+O39VN2OzTQl9CmvjRr4uWSDcTR/KLyLpq4lCPEJx8mu6AYgGZ+XvQNDyY6PJi+4UH0aRdE8yZnv3cem3SI6QviScsu4MHRFzB1WAc8arDhUEq5P2PMOhGJqbZdYwn9/MJiftiaRmZuIRk5x8nMtX4qbheW93078/H0IKSpDyGOjUBIU19CAn2s22XLHBuKZn6120CUlAq70nOsvfjkw6xPPszujDwAPAx0aR1IdPtg+rYLIrp9MJEtA+otjI/mF/HY5xtZvOkgQzuF8OIf+9CqmV+9vJZS6tyxXehXp7RUOHKsqNKGICPnOBm5x8nMKXT8tpZn5Z1iA+Hl4dgw+FQ6Yqh8FOGDv48XW/YfJS75COuTD7Mh5Qh5hSUAtAjwKQ/3vu2C6N0uiKa+57aXTUT4aG0KM77aQoCPF7PG9+Girq3OaQ1KqbqloX8WSkuFw/mF5RuEShuK8tvWEcWhvOOcZPsAgKeHodt5geXdNNHhwYS38Heb8w0J6Tnc80Ec2w/mcNvQSB4ecwG+Xp6uLkspVQsa+udISdkGwqlLKaegmAtaB9I7LIgmPu4dogVFJTy/ZDtvrUqix/nNeHliXzqGNnV1WUqpM6Shr87I0q1pPPTpBgqKSnn6qh6M7xfmNkckSqnq1TT09fp8BcAl3VuzZPpwotoF8fCnG7l3QTzZBUWuLkspVcc09FW5Ns39eO/2gTw0+gIWbzrAuNm/sD75sKvLUkrVIQ19VYmnh+HuizrxybQLARj/6mrmLE846WgmdWYSM3KZsWgLK3dluroUZWPap69OKbugiMc/28TXGw8wuGNL/j0hitY6pv+MJaTn8sqPu1i0YT+lAl4ehlnj+3B137auLk01Itqnr85aMz9v/jOxL/93fW/iko8w5qUVLNuW5uqyGoyE9BymL4jj0n//zHdb0rh9WAeWPziS/hEtuO+jeN5cucfVJSob0j19VSO7M3L5ywdxbD2QzeTBETw6tit+3u49HNVVdqXl8PKPCXy9cT9NvD256cL23DGsAyFNfQFrmOz9H8WzZPNBpo3oyCNjLtCRUuqs6ZBNVeeOF5fwwpIdzPt1D13bBPLKjX3p1CrQ1WW5jR0Hc3j5x10s3nQAf29Pbh4cwR3DOtCiyuynYF3f8eSXm3l/TTLj+4Xx3LW98PLUA29Vexr6qt4s357Og59sIK+wmBl/6MGE/u1svae6/WA2Ly/bxeJNB2nq68Utg9tz+9AOJ0x1XZWIMHvZLl5auotRXVvxyo3Rbn8xn3JfGvqqXqVnF/D/Pt7AyoRMLu91Hs9e26tOZgJtSLbut8L+2y1W2E8ZEsFtQyMJ8j992Ff13m97eeLLzfQLD+aNW2LO+PFKgYa+OgdKS4XXfklk1nc7aN3Mj5cnRtGvfQtXl1XvNu87ysvLdvH91jQCfb2YMjSS24ZE0ty/9hu9xZsOcN+CeCJC/Hn71gGc17xJHVas7EBDX50z8SlHuPfDOPYdOcb0UZ25+6JONfqCl4Zm876jvLR0F0u3pRHo58WtQyK59SzD3tmq3ZlMfWcdzZt48/atA+jUqnHPgSQiHM4vOuk5D3XmNPTVOZVTUMQTX2zmi/j9DIxswUs3RDWavdWNqUeYvXQXy7an08zPi9uGdmDykIh66c7avO8ok+evpaS0lHmT+9M3PLjOX8MdHDh6jMc/28TyHRlc07ctD4+5oNH8vbhKnYa+MWYMMBvwBN4QkeerrG8PzANCgUPAJBFJNcZEAXOBZkAJ8IyIfHS619LQb9g+W5/K377YjI+XBy9c15vRPdq4uqRa25ByhNnLdvHj9nSaN/Hm9qGR3DIkgmb1/B3De7PyuOnN38nIOc7cSdGMvKDxfNeBiLBgbQrPfrON4lJhbM82fL3pAB4Gpg7vyLQRHfD3aTTf4npO1VnoG2M8gZ3ApUAqsBaYKCJbndp8AnwtIm8bYy4GpojITcaYLoCIyC5jzPnAOqCbiBw51etp6Dd8ezLzuPfDODbtO8pNg9rz18u7Nagx/XHJh5m9bBc/7cggyN+bO4Z14OYL25/TL5TPyDnO5Pm/s+NgTqO5ejflUD6PfraRXxOyuLBDS164rjfhLf1JOZTPC99u5+uNB2jdzJeHR3flmr5tbfdVnkePFZGWXUCX1rUbBl2XoX8hMENERjvuPwYgIs85tdkCjHbs3RvgqIg0O8lzbQCuF5Fdp3o9Df3GobC4lFnf7+C1FYkE+XvTpVUgkSEBdAgNcPxuSngLf3y83Gds+rq9Vtiv2JlBsL83dwzvwM0XRpzzbzYrk1NQxNR31rE6MYu/Xd6N24d1cEkdZ6u0VHhndRIvfLsDTw/D4+O6MXHAicN81+09xMyvt7Eh5Qi9w5rzxBXd6R/R+AcGZOUeZ96ve3hn1V7aBjdhyfRhtRoCXZehfz0wRkRud9y/CRgoIvc4tfkAWCMis40x1wILgRARyXJqMwB4G+ghIqWnej0N/cbl14RMvozfx57MPPZk5pGZW1i+zsNAuxb+dAgJIDKkKZGhAXQMCSAyNIA2zfzO2dj/2KRDzF62i192ZdIiwIepwztw06D2BLgo7J0VFJXw/z6OZ/Gmg/x5RAceHdO1QV0TkZiRyyMLN7I26TAjuoTy7LW9aBt06r770lLhyw37eGHJDg5mFzCuVxseG9uNdi38z2HV58aBo8d4bUUiH/6ezPHiUsb1PI87R3akZ9vmtXq+moZ+Tf6qT/YXVnVL8SDwijFmMrAC2AcUOxVzHvAucMvJAt8YMxWYChAeHl6DklRDMaRTCEM6hZTfP3qsyLEByCUxI4/EzDz2ZOTxW+IhjhWVlLdr4u1JpGMDULYhiAxpSofQgDrrU/99zyFmL9vJrwlZhDT14fFxXZk0qL1b9Sn7eXvyn4nRtAjYzP9+TiQrt5DnG8DVuyWlwpsrE/nX9zvx9fJg1vg+XBfdttoNloeH4Zq+YYzpcR6vrUjk1Z93s3RrOlOGRnDPRZ3OaRdbfUnKzOPVn3ezcH0qpQJXR7XlzpEdz9lorTrp3qnSvimwXUTCHPebAT8Bz4nIJ9UVpHv69lRaKqTlFLAnI4/djg3BnsxcEjPzSDmUX+l7iEOa+tAhpGn5RqGDo9sovEVAjbqLfkvMYvbSXaxOtML+z8M78qdB4W4V9lU1pKt3d6bl8NCnG9mQcoRLu7fmmat70qqWs7MePFrAP7/bwcL1qbQM8OGByy5gQv92DXJI8I6DOcxZbs3J5OXpwYSYdkwd3qHOjmLqsnvHC+tE7iisPfi1wI0issWpTQhwSERKjTHPACUi8qQxxgdYAnwlIi/VpHANfVVVYXEpyYfySczIZU9mHokZVldRYmYembnHy9uVdRdFhgRYGwWnDUKbZn78lniIl5buZM2eQ4QG+vLn4R3408D2bhueJ1N29W50eDBvutnVu0Ulpfzv5928vCyBpn5ePH1lD67ofV6ddEdtTD3C37/eytqkw3RtE8jfLu/O0M4h1T/QDcSnHGHO8gR+2JpGgI8nkwa157ZhkbQKrNtpyut6yOY44CWsIZvzROQZY8xMIFZEFjn6/Z/D6vZZAdwtIseNMZOA+cAWp6ebLCLxp3otDX11Jo4eKyLJcb4gMcM6Mig7f5BfWNFd5OPlQWFxKa0CfZk2oiM3DgxvUCOKnC3ZdIDpbnb17pb9R3nok41sPZDNFb3P4+kre9DSMatoXRERlmw+yHNLtpFy6Bijurbi8cu70THU/S5iExFWJ2bx3+W7WZmQSfMm3kwZEsHkwRH1tqHWi7OUrYkIadnHSXScO9iTmUdES3/Gx7RrsGHvrOzq3WZ+Xrxz2wCXzXZ6vLiEV35MYO5Puwny9+EfV/dkTM/6vTajoKiEt1Yl8cqPCRQUlTBpUHvuu6SzWxz1iAg/bk9nzvIE1icfITTQlzuGRXLjwPb1PgpMQ1+pRq7s6t3i0lLmu+Dq3fiUIzz86QZ2puVyXXQYT1zR7ZwGb0bOcV78YScfrU0m0M+b+y7pzKRB7fF2wUnuklJhyeYDzFm+m20Hsmkb1IRpIzsyvl/YOdvJ0NBXygb2ZuVx87zfSc8+d1fvFhSV8O8fdvL6L4m0bubHs9f04qKurrtqePvBbP7x9TZWJmTSITSAv47rxsVdW52Toa2FxaV8Eb+PV3/aTWJmHh1CA7hrZCeuijr/nG98NPSVsgnnq3f/Ob431/QNq7fXWpt0iIc/3ciezDwmDgjnsXFd631aipoo61Z55pttJGbmMbRTCH+7ohtd25xwjWidKCgq4aO1Kfzv593sP1pAj/ObcfdFnRjdo43LRhZp6CtlIzkFRfz53XWs2l0/V+/mHS/mn9/t4O3VSbQNasIL1/WudP2FuygqKeW93/by0tJd5BQUMaF/OA9c1qX8qyrPVk5BEe/9lsybKxPJzC0kpn0wd1/ciZFdQl1+0ZyGvlI2c7zY+u7dur5699eETB5ZuJHUw8eYPDiCh0Zf4BZXK5/OkfxCXlq6i/d+24uftyf3XNyJKUMi8PWqXf/64bxC5v+6h7dWJZFdUMywziHcc1EnBkS2cHnYl9HQV8qGSkqFGYu28O5ve7kuOoznr+tV677l7IIinlu8nQ9/TyYyJIAXruvNgMiGNRfO7oxcnv1mG8u2p9OuRRMeG9uNsT3b1Dio07ILeH1FIh/8nkx+YQmje7TmrpGd6NMuqJ4rP3Ma+krZlIjw8rIE/r10Jxd3bcWcWly9u3xHOo9/tom07ALuGNaB+y/t0qCHuq7clcnfv97KjrQcBkS04IkrutMr7NRz3KQcyufVn3fzSWwqJSJc2ed87hzZsdYzYJ4LGvpK2dz7a/byxBeb6XsGV+8eyS9k5tdb+Wz9Pjq3aso/x/chyg33amujuKSUj2JTePH7nWTlFXJtdFseHt2VNs0rrozdlZbD3J928+WG/Xgaw3X9wrhzREfCW7r/hG8a+kopvt18gHs/jKd9S3/eue30V+9+t+Ugf/tiM4fyCrlrZEfuubhTrfvA3Vl2QRFzlicwf2USnh6GP4/owLDOIby+Yg/fbT2In5cnNw4M545hHSptENydhr5SCoDVu7OY+k4sgae4ejcr9zhPLdrC1xsP0P28Zvzf9b1rPb1vQ5Kclc/z325j8aaDAAT6eTF5cARThkQ2yO/t1dBXSpXbsv8ot8yrfPWuiPDVxgPMWLSFnIIipo/qzJ9HdHTJFa2uFJt0iJ1puVzR5zy3uOagtjT0lVKVJGflc9O8NaRnH+fZa3uyeNNBftiaRp+w5vxzfB+3Pkmpqqehr5Q6QdnVu1v2Z+Pr5cEDl3Xh1iGRbv+lLKp6dfnNWUqpRiI00JcFUwfx9qokxvU6jw5uOC2xql8a+krZTKCfN/dc3NnVZSgX0WM6pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSykRqFvjFmjDFmhzEmwRjz6EnWtzfGLDPGbDTG/GSMCXNad4sxZpfj55a6LF4ppdSZqTb0jTGewBxgLNAdmGiM6V6l2SzgHRHpDcwEnnM8tgXwFDAQGAA8ZYwJrrvylVJKnYma7OkPABJEJFFECoEFwFVV2nQHljluL3daPxr4QUQOichh4AdgzNmXrZRSqjZqEvptgRSn+6mOZc42ANc5bl8DBBpjWtbwsUoppc6RmoS+Ocmyqt+m/iAwwhgTB4wA9gHFNXwsxpipxphYY0xsRkZGDUpSSilVGzUJ/VSgndP9MGC/cwMR2S8i14pIX+CvjmVHa/JYR9vXRCRGRGJCQ0PP8C0opZSqqZqE/lqgszEm0hjjA9wALHJuYIwJMcaUPddjwDzH7e+Ay4wxwY4TuJc5limllHKBakNfRIqBe7DCehvwsYhsMcbMNMZc6Wg2EthhjNkJtAaecTz2EPB3rA3HWmCmY5lSSikXMCIndLG7VExMjMTGxrq6DKWUalCMMetEJKa6dnpFrlJK2YiGvlJK2YiGvlJK2YiGvlJK2YiGvlJK2YiGvlJK2YiGvlJK2YiGvlJK2YiGvlJK2YiGvlJK2YiGvlJK2YiGvlJK2YiGvlJK2YiXqwtQSrleUVERqampFBQUuLoUVQ0/Pz/CwsLw9vau1eM19JVSpKamEhgYSEREBMac7FtOlTsQEbKyskhNTSUyMrJWz6HdO0opCgoKaNmypQa+mzPG0LJly7M6ItPQV0oBaOA3EGf776Shr5RSNqKhr5RqdCIiIsjMzKyX5z5+/DiXXHIJUVFRfPTRR/XyGvHx8SxevLhenltP5Cql3IqIICJ4eLjnPmlcXBxFRUXEx8fX+DElJSV4enrWuH18fDyxsbGMGzeuNiWeloa+UqqSp7/awtb92XX6nN3Pb8ZTf+hxyvVJSUmMHTuWiy66iNWrVxMVFcWmTZs4duwY119/PU8//TRg7cHfcsstfPXVVxQVFfHJJ5/QtWtXsrKymDhxIhkZGQwYMAARKX/uF198kXnz5gFw++23c99995GUlMSYMWMYOnQov/32G3369GHKlCk89dRTpKen8/777zNgwIAT6kxPT2fSpElkZGQQFRXFwoULSUpK4sEHH6S4uJj+/fszd+5cfH19iYiI4NZbb+X777/nnnvuoX///tx9991kZGTg7+/P66+/TteuXfnkk094+umn8fT0pHnz5ixdupQnn3ySY8eOsXLlSh577DEmTJhQZ/8W7rkpVUrZzo4dO7j55puJi4vjX//6F7GxsWzcuJGff/6ZjRs3lrcLCQlh/fr13HnnncyaNQuAp59+mqFDhxIXF8eVV15JcnIyAOvWrWP+/PmsWbOG3377jddff524uDgAEhISmD59Ohs3bmT79u188MEHrFy5klmzZvHss8+etMZWrVrxxhtvMGzYMOLj42nbti2TJ0/mo48+YtOmTRQXFzN37tzy9n5+fqxcuZIbbriBqVOn8p///Id169Yxa9Ys7rrrLgBmzpzJd999x4YNG1i0aBE+Pj7MnDmTCRMmEB8fX6eBD7qnr5Sq4nR75PWpffv2DBo0CICPP/6Y1157jeLiYg4cOMDWrVvp3bs3ANdeey0A/fr147PPPgNgxYoV5bcvv/xygoODAVi5ciXXXHMNAQEB5Y/95ZdfuPLKK4mMjKRXr14A9OjRg1GjRmGMoVevXiQlJdWo5h07dhAZGUmXLl0AuOWWW5gzZw733XcfQHlg5+bmsmrVKsaPH1/+2OPHjwMwZMgQJk+ezB//+Mfy91afNPSVUm6hLJj37NnDrFmzWLt2LcHBwUyePLnSuHRfX18APD09KS4uLl9+sqGMzt08VZU9D4CHh0f5fQ8Pj0rPezqne36oeE+lpaUEBQWd9DzAq6++ypo1a/jmm2+Iioo6o3MFtaHdO0opt5KdnU1AQADNmzcnLS2NJUuWVPuY4cOH8/777wOwZMkSDh8+XL78iy++ID8/n7y8PD7//HOGDRtWZ7V27dqVpKQkEhISAHj33XcZMWLECe2aNWtGZGQkn3zyCWBtLDZs2ADA7t27GThwIDNnziQkJISUlBQCAwPJycmpszqdaegrpdxKnz596Nu3Lz169ODWW29lyJAh1T7mqaeeYsWKFURHR/P9998THh4OQHR0NJMnT2bAgAEMHDiQ22+/nb59+9ZZrX5+fsyfP5/x48fTq1cvPDw8mDZt2knbvv/++7z55pv06dOHHj168OWXXwLw0EMP0atXL3r27Mnw4cPp06cPF110EVu3bq2XYaGmusOTcy0mJkZiY2NdXYZStrJt2za6devm6jJUDZ3s38sYs05EYqp7rO7pK6WUjeiJXKWUOon58+cze/bsSsuGDBnCnDlzXFRR3ahR6BtjxgCzAU/gDRF5vsr6cOBtIMjR5lERWWyM8QbeAKIdr/WOiDxXh/UrpVS9mDJlClOmTHF1GXWu2u4dY4wnMAcYC3QHJhpjuldp9jfgYxHpC9wA/NexfDzgKyK9gH7An40xEXVTulJKqTNVkz79AUCCiCSKSCGwALiqShsBmjluNwf2Oy0PMMZ4AU2AQqBur+9WSilVYzUJ/bZAitP9VMcyZzOAScaYVGAx8BfH8k+BPOAAkAzMEpFDZ1OwUkqp2qtJ6J9sxv6q4zwnAm+JSBgwDnjXGOOBdZRQApwPRAIPGGM6nPACxkw1xsQaY2IzMjLO6A0opZSquZqEfirQzul+GBXdN2VuAz4GEJHVgB8QAtwIfCsiRSKSDvwKnDCOVEReE5EYEYkJDQ0983ehlGpUZsyYUT6Z2pkYPHjwadePGzeOI0eO1LasE7z11lvs3181Dt1bTUJ/LdDZGBNpjPHBOlG7qEqbZGAUgDGmG1boZziWX2wsAcAgYHtdFa+UUs5WrVp12vWLFy8mKCiozl7vdKFfUlJSZ69Tl6odsikixcaYe4DvsIZjzhORLcaYmUCsiCwCHgBeN8bcj9X1M1lExBgzB5gPbMbqJpovIhtP/kpKKbew5FE4uKlun7NNLxj7/GmbPPPMM7zzzju0a9eO0NBQ+vXrx+7du086B31aWhrTpk0jMTERgLlz5zJ48GCaNm1Kbm4uBw4cYMKECWRnZ5dPdzxs2DAiIiKIjY0lJCTklPPsjx07lqFDh7Jq1Sratm3Ll19+SZMmTU6o99NPPyU2NpY//elPNGnShNWrV9OtW7cazaGfkZHBtGnTyqeAfumll2o03URdqNE4fRFZjHWC1nnZk063twInVCwiuVjDNpVS6pTWrVvHggULiIuLo7i4mOjoaPr168fUqVN59dVX6dy5M2vWrOGuu+7ixx9/5N5772XEiBF8/vnnlJSUkJubW+n5PvjgA0aPHs1f//pXSkpKyM/PP+H1yubZFxEGDhzIiBEjCA4OZteuXXz44Ye8/vrr/PGPf2ThwoVMmjTphJqvv/56XnnlFWbNmkVMTEWvddkc+gCjRo06af3Tp0/n/vvvZ+jQoSQnJzN69Gi2bdtWD5/sifSKXKVUZdXskdeHX375hWuuuQZ/f38ArrzySgoKCk45B/2PP/7IO++8A1D+jVPO+vfvz6233kpRURFXX301UVFRldZXN89+Wft+/frVeG79MjWZQ3/p0qVs3bq1fHl2djY5OTkEBgae0WvVhoa+UsotVJ0P/3Rz0Fdn+PDhrFixgm+++YabbrqJhx56iJtvvrl8fU3n2ff09OTYsWNn9No1mUO/tLSU1atXn7TbqL7phGtKKZcbPnw4n3/+OceOHSMnJ4evvvoKf3//U85BP2rUqPKvJSwpKSE7u/I1n3v37qVVq1bccccd3Hbbbaxfv/6E16uLefZPN+/96ebQv+yyy3jllVfK29b3F6c409BXSrlcdHQ0EyZMICoqiuuuu648gE81B/3s2bNZvnw5vXr1ol+/fmzZsqXS8/30009ERUXRt29fFi5cyPTp0094vbqYZ3/y5MlMmzaNqKiokx4RnKr+l19+mdjYWHr37k337t159dVXz/i1a0vn01dK6Xz6DYzOp6+UUqpG9ESuUkpV4+677+bXX3+ttGz69OkNcuplDX2lFGCdaKw6gkZZ3OmLU862S167d5RS+Pn5kZWVddaBouqXiJCVlYWfn1+tn0P39JVShIWFkZqais5y6/78/PwICwur9eM19JVSeHt7ExkZ6eoy1Dmg3TtKKWUjGvpKKWUjGvpKKWUjGvpKKWUjGvpKKWUjGvpKKWUjGvpKKWUjGvpKKWUjGvpKKWUjGvpKKWUjGvpKKWUjGvpKKWUjGvpKKWUjGvpKKWUjGvpKKWUjGvpKKWUjGvpKKWUjGvpKKWUjNQp9Y8wYY8wOY0yCMebRk6wPN8YsN8bEGWM2GmPGOa3rbYxZbYzZYozZZIyp/Tf6KqWUOivVfkeuMcYTmANcCqQCa40xi0Rkq1OzvwEfi8hcY0x3YDEQYYzxAt4DbhKRDcaYlkBRnb8LpZRSNVKTPf0BQIKIJIpIIbAAuKpKGwGaOW43B/Y7bl8GbBSRDQAikiUiJWdftlJKqdqoSei3BVKc7qc6ljmbAUwyxqRi7eX/xbG8CyDGmO+MMeuNMQ+fZb1KKaXOQk1C35xkmVS5PxF4S0TCgHHAu8YYD6zuo6G53EEhAAAOXklEQVTAnxy/rzHGjDrhBYyZaoyJNcbEZmRknNEbUEopVXM1Cf1UoJ3T/TAqum/K3AZ8DCAiqwE/IMTx2J9FJFNE8rGOAqKrvoCIvCYiMSISExoaeubvQimlVI3UJPTXAp2NMZHGGB/gBmBRlTbJwCgAY0w3rNDPAL4Dehtj/B0ndUcAW1FKKeUS1Y7eEZFiY8w9WAHuCcwTkS3GmJlArIgsAh4AXjfG3I/V9TNZRAQ4bIx5EWvDIcBiEfmmvt6MUkqp0zNWNruPmJgYiY2NdXUZSinVoBhj1olITHXt9IpcpZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ39ulZSDEm/QvZ+V1eilFIn8KpJI2PMGGA24Am8ISLPV1kfDrwNBDnaPCoii6us3wrMEJFZdVS7+zm4Gb68Gw7EW/ebt4Ow/tBugPXTuhd4+bi2RqWUrVUb+sYYT2AOcCmQCqw1xiwSka1Ozf4GfCwic40x3YHFQITT+n8DS+qsandTXAgrX4QVs8CvOfzhZSjMg9TfIeV32PKZ1c7LD87va20AwhwbgqatXFu7UspWarKnPwBIEJFEAGPMAuAqrD33MgI0c9xuDpT3bRhjrgYSgby6KNjt7I+DL++BtM3QazyMeQECWjpW3mX9OrrPsQFYCylrYPV/oXS2tS44omIDENYfWvcEzxodgLmXkmI4shcyd0HmTusnK8H67RcEl/0dLhgHxri6UqVsrSbp0hZIcbqfCgys0mYG8L0x5i9AAHAJgDEmAHgE6yjhwVO9gDFmKjAVIDw8vIalu1hRAfz8Avw6GwJC4YYPoeu4k7dt3haaXwM9rql47IEN1gYg9XfY8zNs+tha5+0Pbfs5uoUGWr/LNyJuoCAbsnY5hbvj9qHdUFJY0S4gFEK6QNcrrKOdBTdCp0th7AvQsqPr6lfK5moS+ifbNZMq9ycCb4nIv4wxFwLvGmN6Ak8D/xaRXHOaPTwReQ14DSAmJqbqc7uf1Fj44i7I3AFRk2D0P6BJcM0f7+0H4QOtHwAROJIMqWutgExZY21MpMRa36KjtQFo1986KmjVDTw86/59lSkthex9lffWywI+50BFO+MJLTpY4d7lMut3SBdo2Qn8W1S0KymC31+Hn56D/w6CwX+BYQ+AT0D9vQel1EkZkdNnrCPEZ4jIaMf9xwBE5DmnNluAMSKS4rifCAwCFgLtHM2CgFLgSRF55VSvFxMTI7GxsbV+Q/Wq6Bj8+A/47b8QeD78YTZ0vqR+Xqsw3+o6SllTsTHIz7TW+QRCWD9Ht9BA6/aZbHTKFB2DrN1Oe+xO3TJF+RXtfJtDSGdHqHeuCPfgiDM7MZ2TBj88CRsXQLMwGP0MdL9Ku3yUqgPGmHUiElNtuxqEvhewExgF7APWAjeKyBanNkuAj0TkLWNMN2AZ0FacntwYMwPIrW70jtuG/t5VVt/9od3QbwpcOhP8mlX/uLoiAocSnY4Gfof0LSCl1vqQCypGCYUNsELZw8N6XF5G5b31sttHUqg4aDMQ1K4i0J3DPSC0boN572pY/BCkbYIOI2HsPyG0S909v1I2VGeh73iyccBLWMMx54nIM8aYmUCsiCxyjNh5HWiKlSIPi8j3VZ5jBg0x9I/nwrKZ8PtrEBQOV/4HOoxwdVWW4zmwb721AUj93dogHDtsrfNrbu2JH06CgqMVj/H2t7pfKoV7Z6sLycf/3NVeUgyx86wjp6I8GHQXjHgYfAPPXQ1KNSJ1GvrnkluFfuLPsOgv1qiUAX+GUU+Cb1NXV3VqIlbXTMoaa0NwNAWCIyvvuTdrax0BuIvcDFg2A+Leg8Dz4LJ/QM/rtMtHqTOkoX82CrLhhydg3VvWHvBVr0D7wa6tqbFLWQuLH7BGNUUMg7H/B627u7oq1ViJwNFUa6j1wc2Qlw7+IRAQYnVnBoRa19AEhIBvswaxE1LT0G+AA8Lr2a6l8NV0yNlvjTIZ+fi57fawq3b94Y7lsP5tqzvt1aEwcBqMfMTqqlKqtoqOQfo2SNtSEfJpm6HgSEUb3+Zw/OjJH+/p69gQOG8QQituB4RAQKuK257e5+Z91ZLu6Zc5dhi++yvEv2+dFL36vxBW7UZT1Yf8Q7DsaVj3tvUf6bK/Q+8JDWJvS7mQCOQcdAT7poqAz0qoGP7sHWAdQbbuCW16WlOjtO5unUsqLoT8LGvgQ9WfXOf7mdaRgfN1Kc6aBJ96g9C0VeV1dXgUod07Z2L7Yvj6fusfdOh9MOIR8PI9tzWoE+1bD4sfhH3rIPxCGPdPaNPL1VUpd1BcCBnbK++5p222QrtM83BHsPes+B0cWTfntETgeLa1AchNP3GDUHa7bJ3zUYWzqkcR5/WBUU/UqiTt3qmJvCz49hHY9In1B3HjR3B+lKurUmXaRsNtSyH+PVg6A/43HPrfARc9Dk2CXF2dOldyM6zhvWXhfnCzdWFkabG13svPumDxgnHWTkHrntC6R/3+jRhjdTv6Na/ZFeblRxFOG4S8DMdGwXE7L90aEl7P7Lunv+ULay/y2BEY/hAMvV9nwHRnxw7Dj89A7JvQpAVc+jT0udG9RiKps1NSZF1Hkra58h58blpFm8Dzq+y997KuCm+I81XVMe3eOZXcdCvst34J50VZffete9Tf66m6dWCj9e+Xssaal2jcP62ZS1XDIWIFefo2xwlWRx98xvaKfnJPHwi9wAr1spBv3dO95qFyMxr6VYnApk9hycPWtMcjH4XB9+oeQkMkAhsWWFM65GVAzBS4+InK8/0o95CbARnbIH07pG+1gj19W+U+7oBWTnvvju6ZkM5uPwrG3WifvrPsA9aJ2p1LrL3Dq+ZYexGqYTIGoiZas5r+9Dys+Z/VXTfqSYi+uX4no1Mnl3/IsedeFuzbrbB3PrHq1xxadbdmm23VDUK7Wr/1OyXOqca9py9iDcH89nEoOW7tDQ66U0OhsUnbYs3ls/dXq6tn3L+sSehU3Tt2uCLQnX/npVe08W3mCPSuENqt4ndgGx12W490T/9IinWR1e5l0H6INWeOzuPeOLXuAZO/gc0LrWst3hgF0TfBqKesoXDqzBVkV3TFOP92nlrbp6l1xNz5ssoB36ythrsba3yhLwLr5sP3T1ozUI6bBTG36SiPxs4Y6HU9dBltfbnNb3Otk/UXPwExt+rR3akcz4WMHY49dqeAz95X0cariRXuHUZWdMm06mZNj63/rxqcxtW9c2gPfHUv7Flh/YH+4WUIbl+X5amGImOH1eWz52fr5OC4f1V8ac3ZEHH8lFg7FaWO31LqWCZVljutL3WsLy12+imB0qITl5UUVWlT7NSupGJdSVHl+6VV7xdbM5pWff7iY5CZAEeTK96bp681xXVot4pgD+0KQe013BsA+43eydxlXbzj4WXN1Bh9sx5i2p0IbP3C6vLJ3mdNKV0WytUGdGmVMHe0O+FL49yIh7c14sXDyzqy8fCyljnf9/S2bnv6WFNvt+pWEfLBEXpE1IDZr0+/ZSfrAquoG6F5mKurUe7AGGukSOfLYNUrVheG8XD8eFq/PTxOsszTaZnHSZad5rHlbc1JljndLg9jRyB7VrlfdX15YJ+qje6Jq5ppPKFvjPUlHEpV5RNgzdaplEJ3D5RSykY09JVSykY09JVSykY09JVSykY09JVSykY09JVSykY09JVSykY09JVSykbcbhoGY0wGsPcsniIEyKyjcho6/Swq08+jMv08KjSGz6K9iIRW18jtQv9sGWNiazL/hB3oZ1GZfh6V6edRwU6fhXbvKKWUjWjoK6WUjTTG0H/N1QW4Ef0sKtPPozL9PCrY5rNodH36SimlTq0x7ukrpZQ6hUYT+saYMcaYHcaYBGPMo66ux5WMMe2MMcuNMduMMVuMMdNdXZOrGWM8jTFxxpivXV2Lqxljgowxnxpjtjv+Ri50dU2uZIy53/H/ZLMx5kNjjJ+ra6pPjSL0jTGewBxgLNAdmGiM6e7aqlyqGHhARLoBg4C7bf55AEwHtrm6CDcxG/hWRLoCfbDx52KMaQvcC8SISE/AE7jBtVXVr0YR+sAAIEFEEkWkEFgAXOXimlxGRA6IyHrH7Rys/9RtXVuV6xhjwoDLgTdcXYurGWOaAcOBNwFEpFBEjri2KpfzApoYY7wAf2C/i+upV40l9NsCKU73U7FxyDkzxkQAfYE1rq3EpV4CHgZKXV2IG+gAZADzHd1dbxhjAlxdlKuIyD5gFpAMHACOisj3rq2qfjWW0DcnWWb7YUnGmKbAQuA+Ecl2dT2uYIy5AkgXkXWursVNeAHRwFwR6QvkAbY9B2aMCcbqFYgEzgcCjDGTXFtV/WosoZ8KtHO6H0YjP0SrjjHGGyvw3xeRz1xdjwsNAa40xiRhdftdbIx5z7UluVQqkCoiZUd+n2JtBOzqEmCPiGSISBHwGTDYxTXVq8YS+muBzsaYSGOMD9aJmEUurslljDEGq892m4i86Op6XElEHhORMBGJwPq7+FFEGvWe3OmIyEEgxRhzgWPRKGCrC0tytWRgkDHG3/H/ZhSN/MS2l6sLqAsiUmyMuQf4Duvs+zwR2eLislxpCHATsMkYE+9Y9riILHZhTcp9/AV437GDlAhMcXE9LiMia4wxnwLrsUa9xdHIr87VK3KVUspGGkv3jlJKqRrQ0FdKKRvR0FdKKRvR0FdKKRvR0FdKKRvR0FdKKRvR0FdKKRvR0FdKKRv5/5Yq8G830UTCAAAAAElFTkSuQmCC\n", + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "df.plot()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Decision Tree Accuracy" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.8343173330831328" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "np.mean(dt_cv_scores)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Random Forest Accuracy" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.9223850187122359" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "np.mean(rf_cv_scores)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/data_science/ensemble/voting.ipynb b/data_science/ensemble/voting.ipynb new file mode 100755 index 0000000..e42e8d3 --- /dev/null +++ b/data_science/ensemble/voting.ipynb @@ -0,0 +1,279 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Voting\n", + "Based on the idea that classifiers can complement each other, \n", + "Aggregating individual classifier's prediction to make better prediction." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn import datasets\n", + "from sklearn import tree\n", + "from sklearn.neighbors import KNeighborsClassifier\n", + "from sklearn.svm import SVC\n", + "from sklearn.ensemble import VotingClassifier\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn.metrics import accuracy_score" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# load mnist dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "mnist = datasets.load_digits()\n", + "features, labels = mnist.data, mnist.target\n", + "X_train,X_test,y_train,y_test=train_test_split(features,labels,test_size=0.2)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# single classifiers accuracy on mnist\n", + "build decision tree, knn, svm and check accuracy on MNIST data." + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "dtree = tree.DecisionTreeClassifier(\n", + " criterion=\"gini\", max_depth=8, max_features=32,random_state=35)\n", + "\n", + "dtree = dtree.fit(X_train, y_train)\n", + "dtree_predicted = dtree.predict(X_test)\n", + "\n", + "knn = KNeighborsClassifier(n_neighbors=299).fit(X_train, y_train)\n", + "knn_predicted = knn.predict(X_test)\n", + "\n", + "svm = SVC(C=0.1, gamma=0.003,\n", + " probability=True,random_state=35).fit(X_train, y_train)\n", + "svm_predicted = svm.predict(X_test)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[accuarcy]\n", + "d-tree: 0.7972222222222223\n", + "knn : 0.8416666666666667\n", + "svm : 0.85\n" + ] + } + ], + "source": [ + "print(\"[accuarcy]\")\n", + "print(\"d-tree: \",accuracy_score(y_test, dtree_predicted))\n", + "print(\"knn : \",accuracy_score(y_test, knn_predicted))\n", + "print(\"svm : \",accuracy_score(y_test, svm_predicted))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "we can easily do soft voting or hard voting using sklearn's voting classifier \n", + "when you want to implement soft voting by scratch, you can use predict_proba just like below, \n", + "Below is the example of SVM's prediction (digit 0 to 9) on two MNIST data." + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[9.95557918e-01 3.42018637e-04 4.57700824e-04 4.19160266e-04\n", + " 4.21146304e-04 7.99436984e-04 4.11439277e-04 6.08753549e-04\n", + " 4.33211441e-04 5.49214707e-04]\n", + " [2.86586264e-03 4.17512273e-03 4.28013091e-03 4.14650212e-03\n", + " 9.27814553e-01 2.24791840e-02 3.06764221e-03 9.50855980e-03\n", + " 1.51437526e-02 6.51868962e-03]]\n" + ] + } + ], + "source": [ + "svm_proba = svm.predict_proba(X_test)\n", + "print(svm_proba[0:2])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# hard voting\n", + "hard voting is just majority vote which collects each classifier's prediction and take the most voted prediction." + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/anaconda3/envs/wikiml/lib/python3.6/site-packages/sklearn/preprocessing/label.py:151: DeprecationWarning: The truth value of an empty array is ambiguous. Returning False, but in future this will result in an error. Use `array.size > 0` to check that an array is not empty.\n", + " if diff:\n" + ] + }, + { + "data": { + "text/plain": [ + "0.9083333333333333" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "voting_clf = VotingClassifier(estimators=[\n", + " ('decision_tree', dtree), ('knn', knn), ('svm', svm)], \n", + " weights=[1,1,1], voting='hard').fit(X_train, y_train)\n", + "hard_voting_predicted = voting_clf.predict(X_test)\n", + "accuracy_score(y_test, hard_voting_predicted)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# soft voting\n", + "soft voting takes each classifier's predict_proba and then sum up all probabilities to take the prediction has highest probabilities." + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/anaconda3/envs/wikiml/lib/python3.6/site-packages/sklearn/preprocessing/label.py:151: DeprecationWarning: The truth value of an empty array is ambiguous. Returning False, but in future this will result in an error. Use `array.size > 0` to check that an array is not empty.\n", + " if diff:\n" + ] + }, + { + "data": { + "text/plain": [ + "0.9138888888888889" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "voting_clf = VotingClassifier(estimators=[\n", + " ('decision_tree', dtree), ('knn', knn), ('svm', svm)], \n", + " weights=[1,1,1], voting='soft').fit(X_train, y_train)\n", + "soft_voting_predicted = voting_clf.predict(X_test)\n", + "accuracy_score(y_test, soft_voting_predicted)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Visualization\n", + "we can visualize accuracy to check voting result is stabled or better than single model accuracy. \n", + "it is hard to say which voting is better, but we can confirm classifiers complement each other, \n", + "and voting result is better in this example." + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvqOYd8AAAEepJREFUeJzt3XvQHXV9x/H3h2BEES8lqVUghiqoqVaoGbwgikpbwAo4oEK1LQ6V6QVtvc3QwTIWrVXROrViK7SKYpWLiqYYDZWKUK2YIBdJMDQTUFLaMSpSURGRb//YjZwcT/Kc58l58iQ/3q+ZzLOX39n97e5vP2fP75zdpKqQJLVll7mugCRp8gx3SWqQ4S5JDTLcJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoN2nasVL1iwoBYvXjxXq5ekndLVV1/9napaOFW5OQv3xYsXs2rVqrlavSTtlJJ8c5xydstIUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12SGmS4S1KD5uwOVUmajsWnfmauqzAxt7ztBbO+DsNd2om0EnDbI9zu7+yWkaQGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12SGmS4S1KDfHCYdiqtPDgLfHiWZpdX7pLUIK/cd0KtXL165SrNHq/cJalBhrskNchwl6QGGe6S1CDDXZIaNFa4Jzk8ydok65KcOmL+oiRfSHJNkuuTHDn5qkqSxjVluCeZB5wFHAEsAU5IsmSo2BuBC6vqQOB44H2TrqgkaXzjXLkfBKyrqvVVdTdwPnD0UJkCHtoPPwy4bXJVlCRN1zg3Me0F3DowvgF42lCZNwGXJnkVsDtw2ERqJ0makXHCPSOm1dD4CcC5VfWuJM8AzkvypKq6d7MFJScDJwMsWrRoJvUF2rlDE7xLU9LsGKdbZgOwz8D43vxit8tJwIUAVfWfwG7AguEFVdXZVbW0qpYuXLhwZjWWJE1pnHBfCeyXZN8k8+m+MF02VOZbwPMBkjyRLtw3TrKikqTxTRnuVXUPcAqwAriR7lcxq5OckeSovtjrgFcmuQ74GHBiVQ133UiStpOxngpZVcuB5UPTTh8YXgMcPNmqSZJmyjtUJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12SGmS4S1KDDHdJapDhLkkNMtwlqUGGuyQ1yHCXpAYZ7pLUIMNdkhpkuEtSgwx3SWqQ4S5JDTLcJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktSgscI9yeFJ1iZZl+TULZR5SZI1SVYn+ehkqylJmo5dpyqQZB5wFvCbwAZgZZJlVbVmoMx+wF8AB1fV7Ul+ebYqLEma2jhX7gcB66pqfVXdDZwPHD1U5pXAWVV1O0BVfXuy1ZQkTcc44b4XcOvA+IZ+2qD9gf2TfCnJV5IcPqkKSpKmb8puGSAjptWI5ewHHArsDVyZ5ElV9f3NFpScDJwMsGjRomlXVpI0nnGu3DcA+wyM7w3cNqLMp6vqp1V1M7CWLuw3U1VnV9XSqlq6cOHCmdZZkjSFccJ9JbBfkn2TzAeOB5YNlfkU8FyAJAvoumnWT7KikqTxTRnuVXUPcAqwArgRuLCqVic5I8lRfbEVwHeTrAG+ALyhqr47W5WWJG3dOH3uVNVyYPnQtNMHhgt4bf9PkjTHvENVkhpkuEtSgwx3SWqQ4S5JDTLcJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12SGmS4S1KDDHdJapDhLkkNMtwlqUGGuyQ1yHCXpAYZ7pLUIMNdkhpkuEtSgwx3SWqQ4S5JDTLcJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQWOFe5LDk6xNsi7JqVspd1ySSrJ0clWUJE3XlOGeZB5wFnAEsAQ4IcmSEeX2AF4NXDXpSkqSpmecK/eDgHVVtb6q7gbOB44eUe7NwDuAuyZYP0nSDIwT7nsBtw6Mb+in/VySA4F9quqSrS0oyclJViVZtXHjxmlXVpI0nnHCPSOm1c9nJrsA7wZeN9WCqursqlpaVUsXLlw4fi0lSdMyTrhvAPYZGN8buG1gfA/gScDlSW4Bng4s80tVSZo744T7SmC/JPsmmQ8cDyzbNLOq7qiqBVW1uKoWA18BjqqqVbNSY0nSlKYM96q6BzgFWAHcCFxYVauTnJHkqNmuoCRp+nYdp1BVLQeWD007fQtlD932akmStoV3qEpSgwx3SWqQ4S5JDTLcJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12SGmS4S1KDDHdJapDhLkkNMtwlqUGGuyQ1yHCXpAYZ7pLUIMNdkhpkuEtSgwx3SWqQ4S5JDTLcJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQWOFe5LDk6xNsi7JqSPmvzbJmiTXJ7ksyWMmX1VJ0rimDPck84CzgCOAJcAJSZYMFbsGWFpVvw58HHjHpCsqSRrfOFfuBwHrqmp9Vd0NnA8cPVigqr5QVT/qR78C7D3ZakqSpmOccN8LuHVgfEM/bUtOAj47akaSk5OsSrJq48aN49dSkjQt44R7RkyrkQWTlwNLgTNHza+qs6tqaVUtXbhw4fi1lCRNy65jlNkA7DMwvjdw23ChJIcBpwHPqaqfTKZ6kqSZGOfKfSWwX5J9k8wHjgeWDRZIciDwfuCoqvr25KspSZqOKcO9qu4BTgFWADcCF1bV6iRnJDmqL3Ym8BDgoiTXJlm2hcVJkraDcbplqKrlwPKhaacPDB824XpJkraBd6hKUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12SGmS4S1KDDHdJapDhLkkNMtwlqUGGuyQ1yHCXpAYZ7pLUIMNdkhpkuEtSgwx3SWqQ4S5JDTLcJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12SGmS4S1KDDHdJapDhLkkNMtwlqUFjhXuSw5OsTbIuyakj5j8wyQX9/KuSLJ50RSVJ45sy3JPMA84CjgCWACckWTJU7CTg9qp6HPBu4O2TrqgkaXzjXLkfBKyrqvVVdTdwPnD0UJmjgQ/1wx8Hnp8kk6umJGk6xgn3vYBbB8Y39NNGlqmqe4A7gD0nUUFJ0vTtOkaZUVfgNYMyJDkZOLkfvTPJ2jHWP5cWAN+ZzRVkx+3Acttn2f15++/P2w7bvP2PGafQOOG+AdhnYHxv4LYtlNmQZFfgYcD3hhdUVWcDZ49TsR1BklVVtXSu6zEX3Pb757bD/Xv7W9r2cbplVgL7Jdk3yXzgeGDZUJllwB/0w8cB/15Vv3DlLknaPqa8cq+qe5KcAqwA5gEfqKrVSc4AVlXVMuCfgfOSrKO7Yj9+NistSdq6cbplqKrlwPKhaacPDN8FvHiyVdsh7DRdSLPAbb//uj9vfzPbHntPJKk9Pn5Akhq004R7kjclef0MX/vlKeYvT/LwmdVss+UcM+Lu3R1SksVJbpjremjrZuM4JbklyYJtXMbDk/zJwPijk3x822s3GUkOSbI6ybVJnpjkdye03AOSHDkwftSoR7LsCHaacN8WVfXMKeYfWVXfn8CqjqF7RMMv6H8iKm1Xs9juHg78PNyr6raqOm6W1jUTLwPeWVUHAI8EJhLuwAHAz8O9qpZV1dsmtOzJqqod9h9wGrAW+DzwMeD1/fTHAp8DrgauBJ7QT38kcDFwXf/vmf30O/u/jwKuAK4FbgAO6affAizoh1/bz7sB+PN+2mLgRuAcYDVwKfCgobo+k+6XQjf3y38scDnwVuCLwOuAhcAn6H5euhI4uH/t7sAH+mnXAEdvh327GLihH/7Vfr1vAD7Z79v/At4xUP5O4K/7/foV4JFz3T5muN27A5/pt+MGup/wXjgw/1DgXwe2+e19O/s83aM4LgfWA0dtp/puse0Br+zbzHV9u3pwP/1c4G+BLwDvortb/NL+GL8f+Oam9j6wnj8eOt4nAn+/lXPifODHfVs/c6g9nbiVdnQScFO/H88B3jvD4/bSfvrz++36en8OPRD4Q+47F/+lb6939HV9zdByLwCOHBg/FzgW2A34YL/ca4DnAvOBbwEb+2W9tN/W9w689j3Al/s2clw/fRfgff3xu4TuxynHzXrbmeuTbSsH86n9jn0w8FBgHfeF+2XAfv3w0+h+V7/pQG1qfPOAh206Sfu/rwNOG5i/Rz98C92daZvWuTvwkP5gHNg33HuAA/ryFwIvH1HncwcPWt+A3zcw/lHgWf3wIuDGfvitm5ZHd0V0E7D7LO/fxf1J8vi+8R7QN9T1dDeh7UYXAvv05Qt4YT/8DuCNc91GZrjdxwLnDIw/rD9hd+/H/2HgWBRwRD98MV1APgB4CnDtdqrvFtsesOdAubcArxpoh5cA8/rx9wCn98Mv6LdrONwX0j1DatP4Z4FnTXFO3DDcnvrhke0IeDTdufZL/X68kvHDfdRx243usSf799M+zH3n/7ncF66HApdsYbkvAj7UD8/vl/cguqz4YD/9CX0b2Y2BMB/Y1sFwv4guzJds2p909/4s76f/CnA72yHcd+RumUOAi6vqR1X1f/Q3TiV5CN1V8kVJrqW7EnlU/5rn0Z2cVNXPquqOoWWuBF6R5E3Ak6vqB0Pzn9Wv84dVdSfd1cch/bybq+rafvhqusY8jgsGhg8D3tvXexnw0CR7AL8FnNpPv5yuES0ac/nbYiHwabqw2LRtl1XVHdX9vHUN993qfDddYMD0tn9H83XgsCRvT3JI30Y+B7yw78J4Ad0+gW6bPzfwui9W1U/74cXbsc5bantPSnJlkq/TdUP82sBrLqqqn/XDzwY+AlBVn6ELl81U1UZgfZKnJ9mT7k3/S2z9nNiaUe3oILp9+L1+P1405vbD6OP2eLp9c1Nf5kP9tk7HZ4HnJXkg3ZNvr6iqH9Nt93kAVfUNujeo/cdY3qeq6t6qWkPXk0C/rIv66f9L94lq1u3o/cCjfqe5C/D96vrSprewqiuSPJvuBD4vyZlV9eGBIlt7kuVPBoZ/RvfuPo4fDgzvAjyjbzz3rbR7guaxVbW9n7VzB92VysF0V2Twi9u5qY38tPrLkKHpO5WquinJU+n6Tf8myaV0b8B/SvdRfuXAm/7gNt9Lv2+q6t7t/B3KltreucAxVXVdkhPprlA3GWx3MPpcGnYB8BLgG3SBXtvwdNdR7WjGT4rdwnEbvlN+Jsu9K8nlwG/TdbN8rJ81ie3O0N/take+cr8CeFGSB/VXty8E6K/ib07yYuiCMclT+tdcRtd3SJJ5SR46uMAkjwG+XVXn0N1V+xsj1nlMkgcn2Z3uI9uV06jzD4A9tjL/UuCUgfpseoNaAbxq04mU5MBprHNb3E33JfDvT+rXBDu6JI8GflRVHwHeSdcGLu//vpLNP2nt6PYA/ifJA+iu3Lfkik3zkxwBPGIL5T5J1x5O4L79sKVzYqq2PspXgeckeUT/5njsuC/cwnH7BrA4yeP6Yr9H9/3WsKnqej7wCrpPJCv6aYP7bH+6T9Jrx1jWKP8BHJtklySPZPM34Vmzw4Z7VX2NroFdS/dl0WDIvgw4Kcl1dFecm54v/2fAc/uPqVez+cdU6HbqtUmuoWtYfzdinefSNcKrgH+qqmumUe3zgTckuSbJY0fMfzWwNMn1SdYAf9RPfzNdH+T1/c/e3jyNdW6Tqvoh8DvAa+j6MVv3ZOCrfRfYacBb+u6LS+g+ll+ytRfvYP6Srp3+G13QbclfAc9O8jW6LsBvjSpUVbfTd6FU1Vf7aSPPiar6LvClJDckOXOcylbVf9N9v3QV3RfUa+g+PY5j1HG7iy6UL+rP+XuBfxzx2uuBe5Jcl+Q1I+ZfSted8/nq/s8K6L4Andcv9wLgxKr6CV2XypL+J5YvHbPun6B7uOINdN3IVzH+ds+Yd6hK2m6SPKSq7uyv3C+me1bVxXNdr9k2sN170r1RHtz3v8+anbLfVNJO601JDqP70cClwKfmuD7byyX9jZLzgTfPdrCDV+6S1KQdts9dkjRzhrskNchwl6QGGe6S1CDDXZIaZLhLUoP+H47Jp0tra/pcAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "%matplotlib inline\n", + "\n", + "x = np.arange(5)\n", + "plt.bar(x, height= [accuracy_score(y_test, dtree_predicted),\n", + " accuracy_score(y_test, knn_predicted),\n", + " accuracy_score(y_test, svm_predicted),\n", + " accuracy_score(y_test, hard_voting_predicted),\n", + " accuracy_score(y_test, soft_voting_predicted)])\n", + "plt.xticks(x, ['decision tree','knn','svm','hard voting','soft voting']);" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/data_science/svm/svm.ipynb b/data_science/svm/svm.ipynb new file mode 100755 index 0000000..db0733a --- /dev/null +++ b/data_science/svm/svm.ipynb @@ -0,0 +1,401 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 79, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "from sklearn.datasets import load_iris\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn.model_selection import GridSearchCV\n", + "from sklearn.metrics import classification_report\n", + "from sklearn.metrics import accuracy_score\n", + "from sklearn.svm import SVC" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Load dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 80, + "metadata": {}, + "outputs": [], + "source": [ + "# load iris data\n", + "dataset = load_iris()\n", + "\n", + "# use 80% as train data, 20% as test data\n", + "X_train,X_test,y_train,y_test=train_test_split(dataset.data,dataset.target,test_size=0.2)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Find best hyperparamters\n", + "RBF kernel SVM has two parameters.\n", + "1. C (cost): The C parameter trades off correct classification of training examples against maximization of the decision function’s margin. For larger values of C, a smaller margin will be accepted if the decision function is better at classifying all training points correctly. \n", + "\n", + "2. gamma: the gamma parameter defines how far the influence of a single training example reaches, with low values meaning ‘far’ and high values meaning ‘close’. The gamma parameters can be seen as the inverse of the radius of influence of samples selected by the model as support vectors.\n", + "\n", + "reference:\n", + "http://scikit-learn.org/stable/auto_examples/svm/plot_rbf_parameters.html" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Grid Search\n", + "find best hyperparameter using grid search." + ] + }, + { + "cell_type": "code", + "execution_count": 81, + "metadata": {}, + "outputs": [], + "source": [ + "def svc_param_selection(X, y, nfolds):\n", + " svm_parameters = [\n", + " {'kernel': ['rbf'],\n", + " 'gamma': [0.00001,0.0001, 0.001, 0.01, 0.1, 1],\n", + " 'C': [0.01, 0.1, 1, 10, 100, 1000]\n", + " }\n", + " ]\n", + " \n", + " clf = GridSearchCV(SVC(), svm_parameters, cv=10)\n", + " clf.fit(X_train, y_train)\n", + " print(clf.best_params_)\n", + " \n", + " return clf" + ] + }, + { + "cell_type": "code", + "execution_count": 82, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'C': 10, 'gamma': 0.1, 'kernel': 'rbf'}\n" + ] + } + ], + "source": [ + "clf = svc_param_selection(X_train, y_train, 10)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Test" + ] + }, + { + "cell_type": "code", + "execution_count": 83, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " precision recall f1-score support\n", + "\n", + " 0 1.00 1.00 1.00 7\n", + " 1 1.00 1.00 1.00 13\n", + " 2 1.00 1.00 1.00 10\n", + "\n", + "avg / total 1.00 1.00 1.00 30\n", + "\n", + "\n", + "accuracy : 1.0\n" + ] + } + ], + "source": [ + "y_true, y_pred = y_test, clf.predict(X_test)\n", + "\n", + "print(classification_report(y_true, y_pred))\n", + "print()\n", + "print(\"accuracy : \"+ str(accuracy_score(y_true, y_pred)) )" + ] + }, + { + "cell_type": "code", + "execution_count": 84, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ground_truthprediction
011
111
222
300
411
522
622
722
811
900
1011
1100
1200
1311
1411
1522
1611
1722
1811
1922
2000
2111
2200
2300
2411
2511
2622
2711
2822
2922
\n", + "
" + ], + "text/plain": [ + " ground_truth prediction\n", + "0 1 1\n", + "1 1 1\n", + "2 2 2\n", + "3 0 0\n", + "4 1 1\n", + "5 2 2\n", + "6 2 2\n", + "7 2 2\n", + "8 1 1\n", + "9 0 0\n", + "10 1 1\n", + "11 0 0\n", + "12 0 0\n", + "13 1 1\n", + "14 1 1\n", + "15 2 2\n", + "16 1 1\n", + "17 2 2\n", + "18 1 1\n", + "19 2 2\n", + "20 0 0\n", + "21 1 1\n", + "22 0 0\n", + "23 0 0\n", + "24 1 1\n", + "25 1 1\n", + "26 2 2\n", + "27 1 1\n", + "28 2 2\n", + "29 2 2" + ] + }, + "execution_count": 84, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Visualize true value with prediction value in pandas dataframe.\n", + "comparison = pd.DataFrame({'prediction':y_pred, 'ground_truth':y_true}) \n", + "comparison" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 13be66dccaa243db2cea75353def473341941b7a Mon Sep 17 00:00:00 2001 From: minsuk-heo Date: Wed, 3 Jun 2020 22:57:12 -0700 Subject: [PATCH 7/8] updated --- data_science/nlp/word2vec_gensim.ipynb | 244 +++++++++++++++++++++++++ 1 file changed, 244 insertions(+) create mode 100644 data_science/nlp/word2vec_gensim.ipynb diff --git a/data_science/nlp/word2vec_gensim.ipynb b/data_science/nlp/word2vec_gensim.ipynb new file mode 100644 index 0000000..9e75f1d --- /dev/null +++ b/data_science/nlp/word2vec_gensim.ipynb @@ -0,0 +1,244 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# pretrained Word2Vec download" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--2020-01-20 22:14:56-- https://s3.amazonaws.com/dl4j-distribution/GoogleNews-vectors-negative300.bin.gz\n", + "Resolving s3.amazonaws.com (s3.amazonaws.com)... 52.216.166.53\n", + "Connecting to s3.amazonaws.com (s3.amazonaws.com)|52.216.166.53|:443... connected.\n", + "HTTP request sent, awaiting response... 416 Requested Range Not Satisfiable\n", + "\n", + " The file is already fully retrieved; nothing to do.\n", + "\n" + ] + } + ], + "source": [ + "!wget -P . -c \"https://s3.amazonaws.com/dl4j-distribution/GoogleNews-vectors-negative300.bin.gz\"" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import gensim" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# load pretrained word2vec\n", + "model = gensim.models.KeyedVectors.load_word2vec_format('GoogleNews-vectors-negative300.bin.gz', binary=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[('coffees', 0.721267819404602),\n", + " ('gourmet_coffee', 0.7057087421417236),\n", + " ('Coffee', 0.6900454759597778),\n", + " ('o_joe', 0.6891065835952759),\n", + " ('Starbucks_coffee', 0.6874972581863403)]" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# similar words\n", + "model.most_similar(positive=['friend'], topn=5)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[('queen', 0.7118192911148071)]" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# king + woman - man = queen\n", + "model.most_similar(positive=['king', 'woman'], negative=['man'], topn=1)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "300" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Word2Vec vector dimension\n", + "len(model['friend'])" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([-1.61132812e-01, -1.36718750e-01, -3.73046875e-01, 6.17187500e-01,\n", + " 1.08398438e-01, 2.72216797e-02, 1.00097656e-01, -1.51367188e-01,\n", + " -1.66015625e-02, 3.80859375e-01, 6.54296875e-02, -1.31835938e-01,\n", + " 2.53906250e-01, 9.08203125e-02, 2.86865234e-02, 2.53906250e-01,\n", + " -2.05078125e-01, 1.64062500e-01, 2.20703125e-01, -1.74804688e-01,\n", + " -2.01171875e-01, 1.30859375e-01, -3.22265625e-02, -2.41210938e-01,\n", + " -3.19824219e-02, 2.48046875e-01, -2.37304688e-01, 2.89062500e-01,\n", + " 1.64794922e-02, 1.29394531e-02, 1.72119141e-02, -3.53515625e-01,\n", + " -1.66992188e-01, -5.90820312e-02, -2.81250000e-01, 9.94873047e-03,\n", + " -1.94091797e-02, -3.22265625e-01, 1.73339844e-02, -5.83496094e-02,\n", + " -2.59765625e-01, 1.42669678e-03, 5.81054688e-02, 1.13769531e-01,\n", + " -8.64257812e-02, 3.54003906e-02, -4.29687500e-01, 2.86865234e-03,\n", + " 6.98852539e-03, 1.80664062e-01, -1.79687500e-01, 2.95410156e-02,\n", + " -1.56250000e-01, -2.08007812e-01, -9.08203125e-02, 4.15039062e-03,\n", + " 1.07421875e-01, 3.12500000e-01, -1.04980469e-01, -3.24218750e-01,\n", + " -1.24023438e-01, -7.05718994e-04, -1.05957031e-01, 2.12890625e-01,\n", + " 1.12304688e-01, -1.58203125e-01, -1.67968750e-01, -9.71679688e-02,\n", + " 1.53320312e-01, -1.11328125e-01, 3.22265625e-01, 2.28515625e-01,\n", + " 3.20312500e-01, -1.72119141e-02, -4.57031250e-01, 3.23486328e-03,\n", + " -1.76757812e-01, -5.00488281e-02, 3.05175781e-02, -2.75390625e-01,\n", + " -1.65039062e-01, -3.56445312e-02, 7.95898438e-02, 1.35742188e-01,\n", + " -8.64257812e-02, -7.32421875e-02, 1.36718750e-01, 2.33398438e-01,\n", + " 7.95898438e-02, 1.32446289e-02, -4.71191406e-02, 1.01074219e-01,\n", + " 2.37304688e-01, -1.81640625e-01, -2.14843750e-01, -1.65039062e-01,\n", + " -1.66015625e-02, -1.51367188e-01, 3.06640625e-01, -2.40234375e-01,\n", + " -2.29492188e-01, -1.29882812e-01, 8.97216797e-03, 1.97265625e-01,\n", + " 7.47070312e-02, -1.64031982e-03, 1.54296875e-01, -6.80541992e-03,\n", + " -1.12304688e-01, -7.61718750e-02, -8.74023438e-02, -1.31835938e-01,\n", + " -2.94921875e-01, -2.46093750e-01, 6.15234375e-02, -1.23046875e-01,\n", + " -8.34960938e-02, -8.39843750e-02, -1.61132812e-02, -4.30297852e-03,\n", + " -4.05273438e-02, -2.84423828e-02, 1.36718750e-01, 2.13623047e-02,\n", + " -2.81250000e-01, 2.40234375e-01, -3.75976562e-02, -9.66796875e-02,\n", + " 1.28906250e-01, 1.43554688e-01, -1.37695312e-01, -1.38549805e-02,\n", + " -4.12597656e-02, -4.51660156e-02, -3.75976562e-02, 1.89453125e-01,\n", + " 5.32226562e-02, 1.17675781e-01, -8.25195312e-02, -1.56250000e-01,\n", + " 1.47460938e-01, -2.63671875e-01, -2.79296875e-01, -4.31640625e-01,\n", + " -5.90820312e-02, 2.74658203e-03, 2.87109375e-01, -2.71606445e-03,\n", + " -2.46093750e-01, 2.74658203e-02, -9.08203125e-02, 6.54296875e-02,\n", + " -1.94335938e-01, -2.16064453e-02, 2.77343750e-01, 5.98144531e-02,\n", + " 2.33154297e-02, -1.37695312e-01, -5.39062500e-01, -1.64794922e-02,\n", + " -1.25976562e-01, -1.36718750e-01, 3.02734375e-02, 2.50000000e-01,\n", + " 5.53131104e-04, 1.36718750e-01, 2.96875000e-01, -5.10253906e-02,\n", + " 9.08203125e-02, -2.39257812e-01, 1.35742188e-01, 1.11328125e-01,\n", + " 1.96289062e-01, -1.54296875e-01, -3.37890625e-01, -3.36914062e-02,\n", + " -9.47265625e-02, -1.69921875e-01, -1.04003906e-01, 1.46484375e-01,\n", + " 4.54101562e-02, -4.12109375e-01, -2.47070312e-01, -6.10351562e-03,\n", + " 4.55078125e-01, -2.35595703e-02, 4.93164062e-02, 1.42578125e-01,\n", + " 2.66113281e-02, 4.11987305e-03, -7.27539062e-02, 2.53906250e-02,\n", + " -3.39355469e-02, 7.91015625e-02, 2.87109375e-01, 3.88671875e-01,\n", + " -1.58691406e-02, -8.44726562e-02, -1.15722656e-01, -1.22558594e-01,\n", + " -1.02050781e-01, 1.32812500e-01, 2.21679688e-01, -2.03125000e-01,\n", + " 7.91015625e-02, 1.69677734e-02, 2.16796875e-01, 2.33398438e-01,\n", + " -2.08984375e-01, -1.36718750e-01, -2.45117188e-01, 3.93066406e-02,\n", + " -1.80664062e-01, 1.37695312e-01, 1.50390625e-01, -3.90625000e-02,\n", + " -1.32812500e-01, 2.75878906e-02, -1.78710938e-01, 1.55273438e-01,\n", + " 1.36718750e-01, -1.14257812e-01, -2.79296875e-01, -7.86132812e-02,\n", + " 3.08593750e-01, -5.32226562e-02, -1.65039062e-01, 5.83496094e-02,\n", + " 2.19726562e-01, -1.25000000e-01, 6.10351562e-02, -3.39355469e-02,\n", + " -3.16406250e-01, 2.14843750e-01, -4.12597656e-02, -1.94335938e-01,\n", + " 7.76367188e-02, -5.21850586e-03, 6.93359375e-02, 2.18750000e-01,\n", + " 1.71875000e-01, -1.97265625e-01, 1.07910156e-01, 8.25195312e-02,\n", + " 3.39355469e-02, -1.15722656e-01, -2.02941895e-03, 4.83398438e-02,\n", + " 1.50390625e-01, -2.73437500e-01, -9.61914062e-02, 3.39843750e-01,\n", + " 2.98828125e-01, 1.32812500e-01, -3.68652344e-02, -3.08593750e-01,\n", + " 2.94189453e-02, -1.31835938e-01, -7.12890625e-02, -2.57873535e-03,\n", + " -1.17187500e-01, 6.34765625e-03, -1.66992188e-01, 2.01171875e-01,\n", + " -1.33789062e-01, -1.77734375e-01, -1.09863281e-01, 5.06591797e-03,\n", + " -1.07910156e-01, -1.30859375e-01, -5.17578125e-02, 2.57812500e-01,\n", + " 5.41992188e-02, -6.34765625e-03, 3.00598145e-03, 7.95898438e-02,\n", + " -2.37304688e-01, -8.05664062e-02, 6.07910156e-02, 9.27734375e-02,\n", + " 1.65039062e-01, -1.22558594e-01, 1.88476562e-01, 2.50000000e-01,\n", + " -1.42578125e-01, -7.91015625e-02, -1.78710938e-01, 1.52343750e-01,\n", + " -7.76367188e-02, 2.42187500e-01, 2.56347656e-02, -1.26953125e-01,\n", + " -1.25000000e-01, -3.19824219e-02, -1.27929688e-01, 1.49414062e-01,\n", + " -1.34277344e-02, 6.59179688e-02, 2.17773438e-01, 2.02148438e-01],\n", + " dtype=float32)" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# print word2vec\n", + "model['friend']" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.5" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From e4eb0c86007a5523bccb89cb41c168711bd6e69f Mon Sep 17 00:00:00 2001 From: minsuk-heo Date: Wed, 3 Jun 2020 23:16:49 -0700 Subject: [PATCH 8/8] updated --- data_science/nlp/word2vec_gensim.ipynb | 175 +++++++++++-------------- 1 file changed, 80 insertions(+), 95 deletions(-) diff --git a/data_science/nlp/word2vec_gensim.ipynb b/data_science/nlp/word2vec_gensim.ipynb index 9e75f1d..5024407 100644 --- a/data_science/nlp/word2vec_gensim.ipynb +++ b/data_science/nlp/word2vec_gensim.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "\"Open" + "\"Open" ] }, { @@ -16,16 +16,16 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "--2020-01-20 22:14:56-- https://s3.amazonaws.com/dl4j-distribution/GoogleNews-vectors-negative300.bin.gz\n", - "Resolving s3.amazonaws.com (s3.amazonaws.com)... 52.216.166.53\n", - "Connecting to s3.amazonaws.com (s3.amazonaws.com)|52.216.166.53|:443... connected.\n", + "--2020-06-03 23:13:08-- https://s3.amazonaws.com/dl4j-distribution/GoogleNews-vectors-negative300.bin.gz\n", + "Resolving s3.amazonaws.com (s3.amazonaws.com)... 52.217.16.238\n", + "Connecting to s3.amazonaws.com (s3.amazonaws.com)|52.217.16.238|:443... connected.\n", "HTTP request sent, awaiting response... 416 Requested Range Not Satisfiable\n", "\n", " The file is already fully retrieved; nothing to do.\n", @@ -39,7 +39,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -48,7 +48,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -58,20 +58,20 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "[('coffees', 0.721267819404602),\n", - " ('gourmet_coffee', 0.7057087421417236),\n", - " ('Coffee', 0.6900454759597778),\n", - " ('o_joe', 0.6891065835952759),\n", - " ('Starbucks_coffee', 0.6874972581863403)]" + "[('pal', 0.7476358413696289),\n", + " ('friends', 0.7098034620285034),\n", + " ('buddy', 0.6972494125366211),\n", + " ('dear_friend', 0.6960037350654602),\n", + " ('acquaintance', 0.6843010187149048)]" ] }, - "execution_count": 3, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -83,7 +83,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -92,7 +92,7 @@ "[('queen', 0.7118192911148071)]" ] }, - "execution_count": 7, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -104,7 +104,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -113,7 +113,7 @@ "300" ] }, - "execution_count": 5, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -125,91 +125,76 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "array([-1.61132812e-01, -1.36718750e-01, -3.73046875e-01, 6.17187500e-01,\n", - " 1.08398438e-01, 2.72216797e-02, 1.00097656e-01, -1.51367188e-01,\n", - " -1.66015625e-02, 3.80859375e-01, 6.54296875e-02, -1.31835938e-01,\n", - " 2.53906250e-01, 9.08203125e-02, 2.86865234e-02, 2.53906250e-01,\n", - " -2.05078125e-01, 1.64062500e-01, 2.20703125e-01, -1.74804688e-01,\n", - " -2.01171875e-01, 1.30859375e-01, -3.22265625e-02, -2.41210938e-01,\n", - " -3.19824219e-02, 2.48046875e-01, -2.37304688e-01, 2.89062500e-01,\n", - " 1.64794922e-02, 1.29394531e-02, 1.72119141e-02, -3.53515625e-01,\n", - " -1.66992188e-01, -5.90820312e-02, -2.81250000e-01, 9.94873047e-03,\n", - " -1.94091797e-02, -3.22265625e-01, 1.73339844e-02, -5.83496094e-02,\n", - " -2.59765625e-01, 1.42669678e-03, 5.81054688e-02, 1.13769531e-01,\n", - " -8.64257812e-02, 3.54003906e-02, -4.29687500e-01, 2.86865234e-03,\n", - " 6.98852539e-03, 1.80664062e-01, -1.79687500e-01, 2.95410156e-02,\n", - " -1.56250000e-01, -2.08007812e-01, -9.08203125e-02, 4.15039062e-03,\n", - " 1.07421875e-01, 3.12500000e-01, -1.04980469e-01, -3.24218750e-01,\n", - " -1.24023438e-01, -7.05718994e-04, -1.05957031e-01, 2.12890625e-01,\n", - " 1.12304688e-01, -1.58203125e-01, -1.67968750e-01, -9.71679688e-02,\n", - " 1.53320312e-01, -1.11328125e-01, 3.22265625e-01, 2.28515625e-01,\n", - " 3.20312500e-01, -1.72119141e-02, -4.57031250e-01, 3.23486328e-03,\n", - " -1.76757812e-01, -5.00488281e-02, 3.05175781e-02, -2.75390625e-01,\n", - " -1.65039062e-01, -3.56445312e-02, 7.95898438e-02, 1.35742188e-01,\n", - " -8.64257812e-02, -7.32421875e-02, 1.36718750e-01, 2.33398438e-01,\n", - " 7.95898438e-02, 1.32446289e-02, -4.71191406e-02, 1.01074219e-01,\n", - " 2.37304688e-01, -1.81640625e-01, -2.14843750e-01, -1.65039062e-01,\n", - " -1.66015625e-02, -1.51367188e-01, 3.06640625e-01, -2.40234375e-01,\n", - " -2.29492188e-01, -1.29882812e-01, 8.97216797e-03, 1.97265625e-01,\n", - " 7.47070312e-02, -1.64031982e-03, 1.54296875e-01, -6.80541992e-03,\n", - " -1.12304688e-01, -7.61718750e-02, -8.74023438e-02, -1.31835938e-01,\n", - " -2.94921875e-01, -2.46093750e-01, 6.15234375e-02, -1.23046875e-01,\n", - " -8.34960938e-02, -8.39843750e-02, -1.61132812e-02, -4.30297852e-03,\n", - " -4.05273438e-02, -2.84423828e-02, 1.36718750e-01, 2.13623047e-02,\n", - " -2.81250000e-01, 2.40234375e-01, -3.75976562e-02, -9.66796875e-02,\n", - " 1.28906250e-01, 1.43554688e-01, -1.37695312e-01, -1.38549805e-02,\n", - " -4.12597656e-02, -4.51660156e-02, -3.75976562e-02, 1.89453125e-01,\n", - " 5.32226562e-02, 1.17675781e-01, -8.25195312e-02, -1.56250000e-01,\n", - " 1.47460938e-01, -2.63671875e-01, -2.79296875e-01, -4.31640625e-01,\n", - " -5.90820312e-02, 2.74658203e-03, 2.87109375e-01, -2.71606445e-03,\n", - " -2.46093750e-01, 2.74658203e-02, -9.08203125e-02, 6.54296875e-02,\n", - " -1.94335938e-01, -2.16064453e-02, 2.77343750e-01, 5.98144531e-02,\n", - " 2.33154297e-02, -1.37695312e-01, -5.39062500e-01, -1.64794922e-02,\n", - " -1.25976562e-01, -1.36718750e-01, 3.02734375e-02, 2.50000000e-01,\n", - " 5.53131104e-04, 1.36718750e-01, 2.96875000e-01, -5.10253906e-02,\n", - " 9.08203125e-02, -2.39257812e-01, 1.35742188e-01, 1.11328125e-01,\n", - " 1.96289062e-01, -1.54296875e-01, -3.37890625e-01, -3.36914062e-02,\n", - " -9.47265625e-02, -1.69921875e-01, -1.04003906e-01, 1.46484375e-01,\n", - " 4.54101562e-02, -4.12109375e-01, -2.47070312e-01, -6.10351562e-03,\n", - " 4.55078125e-01, -2.35595703e-02, 4.93164062e-02, 1.42578125e-01,\n", - " 2.66113281e-02, 4.11987305e-03, -7.27539062e-02, 2.53906250e-02,\n", - " -3.39355469e-02, 7.91015625e-02, 2.87109375e-01, 3.88671875e-01,\n", - " -1.58691406e-02, -8.44726562e-02, -1.15722656e-01, -1.22558594e-01,\n", - " -1.02050781e-01, 1.32812500e-01, 2.21679688e-01, -2.03125000e-01,\n", - " 7.91015625e-02, 1.69677734e-02, 2.16796875e-01, 2.33398438e-01,\n", - " -2.08984375e-01, -1.36718750e-01, -2.45117188e-01, 3.93066406e-02,\n", - " -1.80664062e-01, 1.37695312e-01, 1.50390625e-01, -3.90625000e-02,\n", - " -1.32812500e-01, 2.75878906e-02, -1.78710938e-01, 1.55273438e-01,\n", - " 1.36718750e-01, -1.14257812e-01, -2.79296875e-01, -7.86132812e-02,\n", - " 3.08593750e-01, -5.32226562e-02, -1.65039062e-01, 5.83496094e-02,\n", - " 2.19726562e-01, -1.25000000e-01, 6.10351562e-02, -3.39355469e-02,\n", - " -3.16406250e-01, 2.14843750e-01, -4.12597656e-02, -1.94335938e-01,\n", - " 7.76367188e-02, -5.21850586e-03, 6.93359375e-02, 2.18750000e-01,\n", - " 1.71875000e-01, -1.97265625e-01, 1.07910156e-01, 8.25195312e-02,\n", - " 3.39355469e-02, -1.15722656e-01, -2.02941895e-03, 4.83398438e-02,\n", - " 1.50390625e-01, -2.73437500e-01, -9.61914062e-02, 3.39843750e-01,\n", - " 2.98828125e-01, 1.32812500e-01, -3.68652344e-02, -3.08593750e-01,\n", - " 2.94189453e-02, -1.31835938e-01, -7.12890625e-02, -2.57873535e-03,\n", - " -1.17187500e-01, 6.34765625e-03, -1.66992188e-01, 2.01171875e-01,\n", - " -1.33789062e-01, -1.77734375e-01, -1.09863281e-01, 5.06591797e-03,\n", - " -1.07910156e-01, -1.30859375e-01, -5.17578125e-02, 2.57812500e-01,\n", - " 5.41992188e-02, -6.34765625e-03, 3.00598145e-03, 7.95898438e-02,\n", - " -2.37304688e-01, -8.05664062e-02, 6.07910156e-02, 9.27734375e-02,\n", - " 1.65039062e-01, -1.22558594e-01, 1.88476562e-01, 2.50000000e-01,\n", - " -1.42578125e-01, -7.91015625e-02, -1.78710938e-01, 1.52343750e-01,\n", - " -7.76367188e-02, 2.42187500e-01, 2.56347656e-02, -1.26953125e-01,\n", - " -1.25000000e-01, -3.19824219e-02, -1.27929688e-01, 1.49414062e-01,\n", - " -1.34277344e-02, 6.59179688e-02, 2.17773438e-01, 2.02148438e-01],\n", + "array([ 0.07080078, -0.21386719, 0.15332031, 0.09423828, -0.03442383,\n", + " 0.43359375, -0.16503906, -0.05786133, 0.17578125, -0.08203125,\n", + " 0.24511719, -0.19335938, -0.0255127 , -0.09619141, -0.125 ,\n", + " 0.02575684, 0.16796875, -0.03759766, 0.09472656, -0.04760742,\n", + " 0.20605469, 0.31835938, 0.15917969, -0.17089844, 0.09033203,\n", + " -0.1640625 , -0.15234375, 0.3125 , 0.06298828, -0.24902344,\n", + " 0.15625 , -0.04516602, -0.12890625, -0.00686646, -0.02160645,\n", + " 0.14453125, 0.2734375 , 0.12695312, 0.10742188, 0.11376953,\n", + " 0.14355469, -0.00173187, 0.22851562, -0.03515625, 0.17089844,\n", + " 0.04516602, -0.07958984, -0.08886719, -0.01342773, -0.09667969,\n", + " -0.12597656, 0.10595703, 0.15332031, -0.03808594, 0.02246094,\n", + " 0.01428223, -0.03295898, 0.20703125, -0.03417969, 0.02233887,\n", + " 0.00244141, 0.13476562, -0.01403809, 0.13378906, 0.0201416 ,\n", + " 0.14746094, 0.00759888, -0.18652344, 0.16113281, 0.109375 ,\n", + " 0.14355469, 0.01623535, 0.01867676, 0.09179688, -0.33789062,\n", + " 0.19335938, -0.29101562, -0.00860596, 0.10644531, 0.359375 ,\n", + " 0.25585938, -0.03320312, 0.15625 , -0.24316406, -0.06738281,\n", + " 0.09033203, -0.125 , 0.21777344, -0.02380371, -0.06445312,\n", + " -0.14355469, 0.05664062, -0.12597656, 0.02172852, 0.03833008,\n", + " -0.17578125, -0.08349609, 0.21386719, -0.01855469, -0.23535156,\n", + " -0.14746094, -0.16113281, -0.03125 , -0.10107422, 0.07080078,\n", + " 0.01135254, -0.04370117, 0.07666016, 0.16503906, 0.04541016,\n", + " -0.13867188, 0.13085938, 0.13378906, -0.14453125, 0.12792969,\n", + " -0.06787109, -0.04296875, -0.03369141, 0.10302734, 0.22949219,\n", + " 0.14160156, -0.01153564, -0.00086212, -0.10449219, -0.03710938,\n", + " 0.01928711, 0.16699219, -0.06079102, 0.09814453, 0.0703125 ,\n", + " -0.39648438, -0.23242188, -0.04077148, 0.09570312, -0.0546875 ,\n", + " -0.09814453, 0.09082031, 0.03588867, 0.09228516, 0.3125 ,\n", + " 0.10595703, 0.18847656, -0.11230469, 0.00842285, 0.08935547,\n", + " 0.04663086, -0.25 , -0.03369141, 0.03808594, -0.03710938,\n", + " 0.42773438, 0.10839844, -0.01391602, -0.01965332, -0.04296875,\n", + " -0.11035156, 0.0390625 , 0.04541016, -0.20019531, -0.14355469,\n", + " -0.14257812, 0.03662109, 0.25 , 0.3671875 , -0.12304688,\n", + " -0.0859375 , 0.24902344, -0.21582031, 0.02648926, 0.17871094,\n", + " 0.29296875, 0.21582031, 0.1015625 , 0.00167084, -0.07177734,\n", + " 0.03686523, 0.22851562, -0.125 , 0.17285156, 0.22265625,\n", + " 0.21191406, 0.03686523, 0.09570312, -0.00344849, 0.13183594,\n", + " -0.23925781, 0.00576782, 0.27148438, 0.10400391, 0.0098877 ,\n", + " -0.24511719, 0.21777344, -0.03027344, 0.23046875, 0.11816406,\n", + " 0.1640625 , -0.00109863, 0.00349426, -0.02197266, -0.09179688,\n", + " -0.10351562, 0.06933594, -0.13476562, -0.06201172, 0.14355469,\n", + " -0.10888672, -0.11328125, 0.2109375 , -0.10839844, -0.18261719,\n", + " -0.06689453, -0.265625 , -0.13378906, -0.04296875, -0.17773438,\n", + " 0.00689697, -0.00982666, -0.00640869, -0.12792969, 0.08203125,\n", + " -0.01367188, 0.02734375, 0.12597656, -0.00772095, -0.04614258,\n", + " -0.12255859, 0.16210938, 0.28320312, 0.04296875, -0.05175781,\n", + " -0.16210938, 0.14648438, -0.18359375, -0.24511719, 0.22167969,\n", + " 0.0546875 , -0.10302734, -0.07763672, -0.33984375, -0.05908203,\n", + " -0.0022583 , -0.11962891, -0.3046875 , 0.02233887, 0.02941895,\n", + " 0.37695312, -0.01721191, -0.05932617, 0.30273438, -0.13574219,\n", + " 0.14746094, 0.17089844, 0.16015625, 0.21484375, 0.01013184,\n", + " 0.06738281, -0.12109375, -0.12304688, -0.20117188, 0.02880859,\n", + " -0.00662231, -0.20410156, 0.02001953, -0.15136719, 0.16699219,\n", + " 0.14160156, -0.02331543, 0.14550781, -0.13476562, 0.04785156,\n", + " 0.14160156, 0.03808594, -0.12109375, 0.02770996, -0.0123291 ,\n", + " -0.20410156, -0.06445312, 0.06079102, -0.07519531, -0.28125 ,\n", + " 0.18261719, -0.25390625, -0.0456543 , 0.14160156, -0.0546875 ,\n", + " -0.01477051, -0.38085938, 0.14355469, 0.12255859, 0.14941406,\n", + " -0.03320312, 0.19433594, -0.34375 , -0.24902344, -0.00331116,\n", + " -0.05639648, -0.00079727, -0.21679688, -0.01977539, 0.10644531],\n", " dtype=float32)" ] }, - "execution_count": 6, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" }