-
Notifications
You must be signed in to change notification settings - Fork 0
/
4.KNN_Prediction_Grid.py
39 lines (35 loc) · 1.57 KB
/
4.KNN_Prediction_Grid.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import numpy as np
import random
import scipy.stats as ss
import matplotlib.pyplot as plt
def make_prediction_grid(predictors,outcomes,limits,h,k):
(x_min,x_max, y_min,y_max) = limits
xs=np.arange(x_min,x_max,h)
ys = np.arange(y_min,y_max, h)
xx, yy = np.meshgrid(xs,ys)
prediction_grid = np.zeros(xx.shape, dtype = int)
for i,x in enumerate(xs):
for j,y in enumerate(ys):
p=np.array([x,y])
prediction_grid[j,i] = knn_predict(p,predictors,outcomes,k)
return (xx,yy,prediction_grid)
def plot_prediction_grid (xx, yy, prediction_grid, filename):
""" Plot KNN predictions for every point on the grid."""
from matplotlib.colors import ListedColormap
background_colormap = ListedColormap (["hotpink","lightskyblue", "yellowgreen"])
observation_colormap = ListedColormap (["red","blue","green"])
plt.figure(figsize =(10,10))
plt.pcolormesh(xx, yy, prediction_grid, cmap = background_colormap, alpha = 0.5)
plt.scatter(predictors[:,0], predictors [:,1], c = outcomes, cmap = observation_colormap, s = 50)
plt.xlabel('Variable 1'); plt.ylabel('Variable 2')
plt.xticks(()); plt.yticks(())
plt.xlim (np.min(xx), np.max(xx))
plt.ylim (np.min(yy), np.max(yy))
plt.savefig(filename)
(predictors, outcomes) = generate_synth_data()
# >>>predictors.shape
# >>>outcomes.shape
k=5; filename="knn_synth_5.pdf"; limits=(-3,4,-3,4); h=0.1
(xx,yy,prediction_grid) = make_prediction_grid(predictors,outcomes,limits,h,k)
plot_prediction_grid(xx,yy,prediction_grid,filename)
plt.show()