Import library
matplot for graphs visualization
import matplotlib.pyplot as plt
Data
Create dataset and plot on matplot grap for visualization
x1_blue = [-3, -2, -1, 2]
x2_blue = [-1, -3, -2, -2]
x1_red = [1, 2, 3]
x2_red = [1, 3, 2]
plt.figure(figsize=(10,10))
plt.xlabel('x1')
plt.ylabel('x2')
plt.title('Simple Perceptron Example')
plt.scatter(x1_blue,x2_blue, color=["blue"])
plt.scatter(x1_red,x2_red, color=["red"])
<matplotlib.collections.PathCollection at 0x1124a24c0>
Parameters and Inputs
Weights and learning rate are parameters, X is the input.
w = [0, 1, 0.5] # initial weight
x0 = 1
x1 = x1_blue + x1_red
x2 = x2_blue + x2_red
Perceptron Basic Formula
$ y = \sum \limits _{i=1} ^n x_{i}w_{i} $
If y > 0 => class red, otherwise class blue.
$ w_{0} = 0, w_{1} = 1, w_{2} = \frac{1}{2}$
Plot the Line
Plot the line to separate the classes.
plt.figure(figsize=(10,10))
plt.xlabel('x1')
plt.ylabel('x2')
plt.title('Simple Perceptron Example')
plt.scatter(x1_blue,x2_blue, color=["blue"])
plt.scatter(x1_red,x2_red, color=["red"])
plt.plot([-4, 3], [-2*-3, -2*3])
[<matplotlib.lines.Line2D at 0x1125d0400>]
$ w'_{i} = w_{i} + n * d * x_{i} $
n is learning rate (parameter)
d is 1 if the missed point should above the line, 0 otherwise.
For this example, 0.25 is n
For the misclassify point above is class blue at (2, -2), so d would be -1
w0_new = w[0] + (0.25 * -1 * 1)
w1_new = w[1] + (0.25 * -1 * 2)
w2_new = w[2] + (0.25 * -1 * -2)
print(f"w0_new: {w0_new}, w1_new: {w1_new}, w2_new: {w2_new}")
w0_new: -0.25, w1_new: 0.5, w2_new: 1.0
After update
The new weights are
$ w_{0} = \frac{-1}{4}, w_{1} = \frac{1}{2}, w_{2} = 1$
plt.figure(figsize=(10,10))
plt.xlabel('x1')
plt.ylabel('x2')
plt.title('Simple Perceptron Example')
plt.scatter(x1_blue,x2_blue, color=["blue"])
plt.scatter(x1_red,x2_red, color=["red"])
plt.plot([-3, 3], [-3*-0.5+0.25, 3*-0.5+0.25])
[<matplotlib.lines.Line2D at 0x112633a30>]