PyTorch Tutorial - Linear Regression With PyTorch

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch

# step 1: prepare the dataset
x_data = torch.tensor([[1.0],[2.0],[3.0]])
y_data = torch.tensor([[2.0],[4.0],[6.0]])

# step2: design model
"""
Our model class should inherit from torch.nn.Module, which is the base class for all of the neural networks.
__init__() and forward() have to be implemented.
class nn.Linear contains two Tensor components: weight and bias
class nn.Linear has implemented the magic method __call__(), which enable the instance of the class to be called like a function.
therefore, model(data) can call the forward function
"""

class LinearModel(torch.nn.Module):
def __init__(self):
super(LinearModel, self).__init__()
# (1, 1) is the sie of each input sample and output sample
# parameters in this model: weights and bias
# model: y = xm + b
self.linear = torch.nn.Linear(1, 1)

def forward(self, x):
y_pred = self.linear(x)
return y_pred

model = LinearModel()

# step3: construct loss and optimizer
criterion = torch.nn.MSELoss(size_average = False)
optimizer = torch.optim.SGD(model.parameters(), lr = 0.01) #iterable of parameters to optimizer. model.parameters() can get all the parameters in the model

# step4: training cycle - forward, backward, update
for epoch in range(1000):
y_pred = model(x_data)
loss = criterion(y_pred, y_data)
print(epoch, loss.item())

optimizer.zero_grad() # the grad will be accumulated, so before backward, we need to set the grad to zero
loss.backward() # backward: autograd
optimizer.step() # update w, b

print('w=', model.linear.weight.item())
print('b=', model.linear.weight.item())

# step5: test model
x_test = torch.tensor([4.0])
y_test = model(x_test)
print('y_pred= ', y_test.data)

# 980 3.610836074585677e-10
# 981 3.6504843592410907e-10
# 982 3.5726088754017837e-10
# 983 3.4162894735345617e-10
# 984 3.445990159889334e-10
# 985 3.355609123900649e-10
# 986 3.370388412804459e-10
# 987 3.249738256272394e-10
# 988 3.1610625228495337e-10
# 989 3.1587887860951014e-10
# 990 3.1248248433257686e-10
# 991 3.1373303954751464e-10
# 992 3.1042191039887257e-10
# 993 3.0178171073202975e-10
# 994 3.0526337013725424e-10
# 995 2.9184832328610355e-10
# 996 2.9478997021215037e-10
# 997 2.864908310584724e-10
# 998 2.8178703814774053e-10
# 999 2.765574436125462e-10
# w= 2.0000112056732178
# b= 2.0000112056732178
# y_pred= tensor([8.0000])