PyTorch Tutorial - Gradient Descent

Gradient Descent

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import  matplotlib.pyplot as plt

# prepare the training data
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]

# initialize guess of weight
w = 1.0

# define the linear model
def forward(x):
return w*x

# cost function
def cost(xs, ys):
cost = 0
for x, y in zip(xs, ys):
y_pred = forward(x)
cost += (y_pred - y) **2
return cost/len(xs)

# define the gradient
def gradient(xs, ys):
grad = 0
for x, y in zip(xs, ys):
grad += 2*x*(x*w-y)
return grad/len(xs)

# training process
epoch_list = []
cost_list = []
print('predict (before training):', 4, forward(4))
for epoch in range(100):
cost_val = cost(x_data, y_data)
grad_val = gradient(x_data, y_data)
w -= 0.01 * grad_val
print('epoch:', epoch, "w=", w, 'loss:', cost_val)
epoch_list.append(epoch)
cost_list.append(cost_val)

print('predict (after training):', 4, forward(4))

#visualization
plt.plot(epoch_list, cost_list)
plt.ylabel("cost")
plt.xlabel("epoch")
plt.show()

Stochastic Gradient Descent

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import matplotlib.pyplot as plt

# prepare the training dataset
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 3.0, 4.0]

# initialize w
w = 1.0

# define the linear model
def forward(x):
return w*x

# define the loss function
def loss(x, y):
y_pred = forward(x)
return (y_pred - y) **2

# define the stochastic gradient descent function
def gradient(x, y):
return 2*x*(x*w-y)

# training process
epoch_list = []
loss_list = []
print('predict (before training):', 4, forward(4))
for epoch in range(100):
for x, y in zip(x_data, y_data):
l = loss(x, y)
grad = gradient(x, y)
w -= 0.01 * grad
print("\tgradient:", x, y, grad)
print("process: ", epoch, "w=", w, "loss=", l)
epoch_list.append(epoch)
loss_list.append(l)
print("predict (after training)", 4, forward(4))

# visualization
plt.plot(epoch_list, loss_list)
plt.ylabel("loss")
plt.xlabel("epoch")
plt.show()