Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
2
2021-193 User-friendly enhanced machine learning-based railway management system
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
2021-193
2021-193 User-friendly enhanced machine learning-based railway management system
Commits
3eb3c071
Commit
3eb3c071
authored
Jul 04, 2021
by
Weerasooriya W.K.M-IT18085822
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Upload New File
parent
9c700478
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
123 additions
and
0 deletions
+123
-0
train.py
train.py
+123
-0
No files found.
train.py
0 → 100644
View file @
3eb3c071
import
json
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
torch.utils.data
import
Dataset
,
DataLoader
from
nltk_utils
import
bag_of_words
,
tokenize
,
stem
from
model
import
NeuralNet
with
open
(
'intents.json'
,
'r'
)
as
f
:
intents
=
json
.
load
(
f
)
all_words
=
[]
tags
=
[]
xy
=
[]
for
intent
in
intents
[
'intents'
]:
tag
=
intent
[
'tag'
]
# add to tag list
tags
.
append
(
tag
)
for
pattern
in
intent
[
'patterns'
]:
# tokenize each word in the sentence
w
=
tokenize
(
pattern
)
# add to our words list
all_words
.
extend
(
w
)
# add to xy pair
xy
.
append
((
w
,
tag
))
ignore_words
=
[
'?'
,
'.'
,
'!'
]
all_words
=
[
stem
(
w
)
for
w
in
all_words
if
w
not
in
ignore_words
]
all_words
=
sorted
(
set
(
all_words
))
tags
=
sorted
(
set
(
tags
))
print
(
tags
)
X_train
=
[]
y_train
=
[]
for
(
pattern_sentence
,
tag
)
in
xy
:
# X: bag of words for each pattern_sentence
bag
=
bag_of_words
(
pattern_sentence
,
all_words
)
X_train
.
append
(
bag
)
# y: PyTorch CrossEntropyLoss needs only class labels, not one-hot
label
=
tags
.
index
(
tag
)
y_train
.
append
(
label
)
X_train
=
np
.
array
(
X_train
)
y_train
=
np
.
array
(
y_train
)
class
ChatDataset
(
Dataset
):
def
__init__
(
self
):
self
.
n_samples
=
len
(
X_train
)
self
.
x_data
=
X_train
self
.
y_data
=
y_train
# support indexing such that dataset[i] can be used to get i-th sample
def
__getitem__
(
self
,
index
):
return
self
.
x_data
[
index
],
self
.
y_data
[
index
]
# we can call len(dataset) to return the size
def
__len__
(
self
):
return
self
.
n_samples
num_epochs
=
1000
learning_rate
=
0.001
batch_size
=
8
input_size
=
len
(
X_train
[
0
])
hidden_size
=
8
output_size
=
len
(
tags
)
print
(
input_size
,
output_size
)
dataset
=
ChatDataset
()
train_loader
=
DataLoader
(
dataset
=
dataset
,
batch_size
=
batch_size
,
shuffle
=
True
,
num_workers
=
0
)
device
=
torch
.
device
(
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
)
model
=
NeuralNet
(
input_size
,
hidden_size
,
output_size
)
.
to
(
device
)
criterion
=
nn
.
CrossEntropyLoss
()
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
learning_rate
)
for
epoch
in
range
(
num_epochs
):
for
(
words
,
labels
)
in
train_loader
:
words
=
words
.
to
(
device
)
labels
=
labels
.
to
(
dtype
=
torch
.
long
)
.
to
(
device
)
# Forward pass
outputs
=
model
(
words
)
# if y would be one-hot, we must apply
# labels = torch.max(labels, 1)[1]
loss
=
criterion
(
outputs
,
labels
)
# Backward and optimize
optimizer
.
zero_grad
()
loss
.
backward
()
optimizer
.
step
()
if
(
epoch
+
1
)
%
100
==
0
:
print
(
f
'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}'
)
print
(
f
'final loss: {loss.item():.4f}'
)
data
=
{
"model_state"
:
model
.
state_dict
(),
"input_size"
:
input_size
,
"hidden_size"
:
hidden_size
,
"output_size"
:
output_size
,
"all_words"
:
all_words
,
"tags"
:
tags
}
FILE
=
"data.pth"
torch
.
save
(
data
,
FILE
)
print
(
f
'training complete. file saved to {FILE}'
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment