[딥러닝] 기억하는 신경망 : RNN, 그리고 개선 모델 (LSTM, GRU)
이번에는 RNN의 개념과 이를 기반으로 하는 LSTM, GRU에 대해서 살펴보려고 한다. 해당 포스팅에서는 아래의 내용들을 다룬다.
1. RNN
- RNN의 구조
- RNN의 활용 예시
- RNN의 등장 배경과 한계점
2. LSTM
- LSTM의 등장 배경과 구조
3. GRU
- GRU의 등장 배경과 구조
4. 시계열 데이터 학습 실습 : RNN 기반 모델 비교
RNN(순환 신경망, Recurrent Neural Network)
과거의 정보를 사용해서 현재 및 미래의 입력에 대한 신경망의 예측 성능을 개선하는 인공 신경망 구조이다. 즉 순차적인 문맥을 파악하는 신경망이다.
순차적 구성 요소가 복잡한 의미와 구문 규칙에 따라 상호 연관되는 단어, 문장 또는 시계열 데이터 등의 순차적 데이터의 처리에 강점을 가지고 있다.
RNN의 구조
RNN은 순차적 데이터에서 이전 순서의 데이터의 상태를 참고한다고 하였다.
예를 들어서 생각해보면 만약 내일 주식 가격을 예측해야 한다고 생각해보자. 아무 정보 없이 예측해야 한다면 예측이 어려울 것이다.
하지만 이전 7일동안 10000원대였다면 내일의 주식도 10000원일 확률이 높을 것이라는 것을 알 수 있다.
이를 기반으로 RNN의 구조를 살펴보자.
x는 Input Layer의 입력 벡터, y는 Output Layer의 출력 벡터이다. 중간의 h는 Hidden Layer이자 Memory Cell이다.
Hidden Layer은 현재의 State를 기억해 두었다가 다음 순서의 데이터를 처리할 때 Hidden Layer에 이전 State의 값을 활용한다.
따라서 Memory Cell이라고도 불린다.
이처럼 현재 Hidden Layer의 상태를 기억하고 다시 다음 Hidden Layer의 연산에서 활용한다는 점에서 순환 신경망이라는 이름이 붙었다.
RNN은 이전 단계의 출력을 현재 단계의 입력으로 포함하여, 데이터 간의 시간적 관계와 종속성을 학습하고 유지할 수 있다!
시점에 따른 RNN 구조
위의 RNN 구조를 시간에 따라서 펼쳐보면 위의 그림과 같다.
입력 𝑥1를 처리할 때를 살펴보면 은닉층의 상태(State) ℎ1를 출력층으로 보낼 뿐만 아니라 다음 입력 𝑥2의 Hidden State ℎ2의 입력으로도 활용되는 것을 알 수 있다.
따라서 입력 𝑥2를 처리할 때 이전 입력에 대한 결과 ℎ1를 활용하게 되고, 𝑥3를 처리할 때에는 ℎ1과 ℎ2를 활용하게 된다는 것을 알 수 있다.
파라미터(가중치)의 공유
\(W_{xh}\) (입력에서 은닉 상태로의 가중치), \(W_{hh}\) (은닉 상태에서 은닉 상태로의 가중치), \(W_{hy}\) (은닉 상태에서 출력으로의 가중치)가 공유된다.
이것은 RNN의 핵심 아이디어 중 하나로서 시간 스텝마다 동일한 파라미터를 사용하여 다음과 같은 장점이 있다.
1. 파라미터 수의 감소와 정규화
가중치를 공유함으로써 모델이 학습시켜야 하는 파라미터의 수가 줄어든다.또한 파라미터의 수가 적기 때문에 과적합을 방지하고 일반화 능력을 향상시킬 수 있다.
2. 순차 구조 포착 가능
각 시점의 입력에 동일한 가중치를 적용함으로써 일관된 방식으로 데이터를 처리할 수 있게 된다. 따라서 순차적인 데이터 구조에 집중하여 학습할 수 있으며 시간적 패턴과 의존성을 효과적으로 포착할 수 있다.
3. 가변 길이 데이터 처리
입력 길이에 상관 없이 동일한 가중치를 사용하기 때문이다. 예를 들어 길이가 10인 시퀀스와, 20인 시퀀스에서 \(W_{xh}\), \(W_{hh}\), \(W_{hy}\)가 공유된다.
출력과 Hidden State의 값
\(t\)시점의 출력 : \(y_t = W_{hy}h_t\)
\(t\)시점의 Hidden State : \(h_t = tanh(W_{hh}h_{t-1} + W_{xh}x_t)\)
RNN의 활용
왼쪽 사진과 같이 4개의 문장을 RNN을 통해서 학습시켰다고 생각해보자.
이후 사용자가 "I", "play"까지 입력한 3번째로 올 단어를 예측한다고 가정해보자.
만약 전통적인 Feed Foward 모델을 사용했다면 3번째로 올 단어를 예측하기 어려울 것이다. 왜냐하면 각 입력이 독립적으로 처리되며 이전 입력을 고려하지 않기 때문이다. 즉 문맥을 이해하지 못하기 때문이다.
RNN은 "I play"까지의 입력을 통해 \(h_1\) Hidden State \(h_2\)를 계산했다. \(h_2\) 기반으로 다음에 올 단어의 확률 분포를 계산하여 가장 가능성이 높은 단어를 선택하게 된다.
출력 \(y_2\)
- 높은 확률 단어: "soccer", "basketball"
- 낮은 확률 단어: "eat", "love", "chicken", "you"
입력 "I play" 후, RNN은 다음 단어로 "soccer"와 "basketball"을 높은 확률로 예측할 것이다.
이는 RNN이 학습된 문장들에서 "I play" 다음에 "soccer"와 "basketball"이 나온 적이 있기 때문이다.
결과적으로 이전 시퀀스의 Hidden State를 활용하여 다음 단어의 확률 분포를 계산하고, 이를 통해 가장 가능성 높은 단어를 예측한다.
RNN 아키텍쳐 유형
일대다 : 하나의 입력을 여러 출력으로 채널링한다. 단일 키워드로 문장을 생성하여 이미지 캡션과 같이 사용할 수 있다.
다대일 : 입력 문서가 긍정적인지 부정적인지를 판별하는 감성 분류, 메일이 스팸 메일인지 판별하는 분류 등에 사용할 수 있다.
다대다 : 여러 입력이 출력에 매핑된다. 사용자가 문장을 입력하면 대답 문장을 출력하는 챗봇, 입력 문장으로부터 번역된 문장을 출력하는 번역기 등에서 활용될 수 있다.
기존의 한계점과 RNN의 등장 배경
1. 고정된 입력 크기
- 고정된 크기의 입력 데이터만 처리 가능, 가변 길이의 시퀀스 데이터 처리에 어려움
- RNN: 시퀀스의 각 요소를 순차적으로 처리하며, 가변 길이의 데이터에서도 유연하게 작동할 수 있는 구조를 가짐
2. 시간 종속성 무시
- 각 입력 데이터가 독립적이라 시간적 종속성이나 순서를 반영하지 못함
- RNN: 이전 단계의 출력을 현재 단계의 입력으로 포함하여, 데이터 간의 시간적 관계와 종속성을 학습하고 유지할 수 있음
3. 연속 데이터 처리의 어려움
- 시계열 데이터나 자연어와 같은 연속 데이터에서 패턴을 효과적으로 학습하지 못함
- RNN: 순환 구조를 통해 시퀀스 데이터의 연속적인 패턴을 학습하고, 과거 정보를 기억하여 다음 예측에 활용
RNN의 문제점과 한계
1. 긴 문장을 처리할 때 Gradient 제어가 어렵다.
- 자연어는 그 길이를 예측하기 힘들고 RNN은 Weight가 누적되어 State에 반영한다는 특성이 있다.
- \((W_{hh})^l\), \(l\)(Input의 길이)가 길어질수록 문제가 발생할 수 있다
- \(W_{hh}\) > 1 : 소수를 계속 곱하면 0에 수렴한다.
- Gradient Vanishing :초기 gradient가 작을 경우 sequence의 마지막에는 gradient가 소실될 수 있음
- \(W_{hh}\) < 1 : 수를 계속 곱하면 수가 매우 커진다.
- Gradient Exploding : 초기 gradient가 클 경우 sequence의 마지막에는 gradient가 너무 커질 수 있음
2. 장기 의존성 문제(Long-term Dependency problem)
- Gradient Vanishing으로 인해서 문장이 너무 길어지면 문장의 첫부분의 의미가 남아있지 않게 된다.
LSTM (Long Short Term Memory)
1997년 Hinton이 제안한 RNN의 개선 모델.
RNN의 Gradient 소실 및 폭발에 의한 장기 의존성 문제(Long-term dependency problem)를 완화하기 위해 고안된 신경망. 중요한 정보를 기억하고 불필요한 정보를 잊는 기능(Forget)을 통해 긴 시퀀스에서 효과적으로 학습이 가능하도록 하였다.
- 장기 의존성 문제의 "완화", 해결이 아니다.
LSTM은 셀의 연결체로 모델을 구성하였으며 Hidden State와 함께 장기 기억을 위한 Cell State의 개념을 사용하였다.
LSTM의 구조
셀 상태 (Cell State) : 장기 기억
- 추가적인 가중치 연산 없이 정보가 흐를 수 있게 하여 중요한 정보를 장기간 유지할 수 있다.
- 시퀀스의 길이에 관계없이 정보를 계속 전달할 수 있는 경로로서 그래디언트 소실 문제를 완화한다.
- Cell State는 이전 Cell State의 일부를 잊고(Forget Gate)와 새로운 정보를 추가하여(Input Gate)를 통해 업데이트된다.
은닉 상태 (Hidden State) : 단기 기억
- 현재 시점에서 LSTM의 출력으로, 다음 시점의 입력으로도 사용된다.
- 단기적인 정보를 유지하고, Cell State와 결합하여 시퀀스의 각 단계에서 중요한 정보를 출력한다.
- Hidden State는 Output Gate를 통해서 현재 Cell State의 정보를 기반으로 결정된다.
- Forget Gate 이전 Cell memory \(C_{t-1}\)의 영향도, 즉 얼마만큼 삭제할지 결정
- Input Gate : 입력 \(x_t\)를 Cell memory \(C_{t-1}\)에 추가할지 말지를 결정하고 추가할 거라면 얼마만큼 추가할지 결정
- Output Gate : 이전 Cell의 Hidden state \(h_{t-1}\), 입력 \(x_t\), Cell memory \(C_{t-1}\)을 취합해서 이 Cell의 Hidden state \(h_t\)를 결정
Forget Gate
이전 Cell memory \(C_{t-1}\)의 영향도, 즉 얼마만큼 삭제할지 결정
입력: 이전 Cell의 Hidden state \(h_{t-1}\)와 입력 \(x_t\)
과정 : Sigmoid 함수를 이용해서 0 (삭제) 또는 1 (유지)를 출력
→ \(W_{h}h_{t-1} + b_h + W_{x}x_t + b_x\)
→ W와 b는 입력과 이전 Hidden State에 대한 가중치와 Bias
출력 : \(f_t\) = 0~1 사이의 값
→ \(if\) \(f_t\) = 0, 이전 cell memory \(C_{t-1}\) 을 완전 삭제
→ \(if\) \(f_t\) = 1,이전 cell memory \(C_{t-1}\)을 완전 보존
Input Gate
입력 \(x_t\)를 Cell memory \(C_{t-1}\)에 추가할지 말지를 결정하고 추가할 거라면 얼마만큼 추가할지 결정
입력: 이전 Cell의 Hidden state \(h_{t-1}\)와 입력 \(x_t\)
→ \(W_{h}h_{t-1} + b_h + W_{x}x_t + b_x\)
과정
1. Sigmoid 함수 : 입력을 추가할지 말지를 결정 (0 or 1)
2. tanh 함수 : 출력을 (-1, 1 )사이의 값으로 변환 (얼마만큼 추가)
출력: Sigmoid 함수 결과 * tanh 함수 결과
Output Gate
이전 Cell의 Hidden state \(h_{t-1}\), 입력 \(x_t\) , cell memory \(C_{t-1}\)을 취합 → 이 Cell의 Hidden state \(h_{t}\)를 결정
입력: 이전 cell의 hidden state \(h_{t-1}\),와 입력 \(x_t\), cell memory \(C_{t}\)
→ \(W_{h}h_{t-1} + b_h + W_{x}x_t + b_x\)
과정
1. Sigmoid 함수: \(h_{t}\)를 출력할지 말지 결정 (0 or 1)
2. tanh 함수: \(C_{t}\) 를 (-1, 1)로 대응
출력: (-1, 1)로 대응된 \(C_{t}\)
GRU (Gated Recurrent Unit)
조경현 박사가 제안한 LSTM의 간소화 모델. LSTM의 매개변수가 많기 때문에 계산이 오래 걸린다는 단점을 극복하기 위해 제안
1. Cell state가 없음 : Hidden State가 Cell State의 역할을 함께 수행한다.
- Hidden State : 장기, 단기 기억 모두 담당
2. Update gate : Forget gate + Input gate를 통합
3. Reset gate 추가
Reset Gate
이전 State의 입력 \(h_{t-1}\)을 삭제할지 말지 결정
입력: \(x_{t}\)와 \(h_{t-1}\)의 선형조합
→ \(W_{x}x_{t} + b_x + W_{h}h_{t-1} + b_h\)
과정: \(x_{t}\)와 \(h_{t-1}\)의 선형 조합에 대한 Sigmoid 함수 적용
출력: 0 (Reset) or 1 (Not reset)
→ 0이라면 \(h_{t-1}\)을 무시하고 연산을 진행함 (reset)
Update Gate
LSTM의 Forget gate와 Input Gate의 역할을 수행
hidden state를 업데이트할지 말지를 결정
입력 : \(x_{t}\)와 \(h_{t-1}\)의 선형조합
→ \(W_{x}x_{t} + b_x + W_{h}h_{t-1} + b_h\)
과정 : 입력 -> → sigmoid \(x_{t}\) → \(\tan h\)
1. LSTM의 Forget gate : Sigmoid 함수의 출력이 1이면, \(h_{t-1}\)에는 0이 곱해지지만, \(\tan h\)의 결과에는 1이 곱해짐
2. LSTM의 Input gate : Sigmoid 함수의 출력이 0이면, \(h_{t-1}\)에는 1이 곱해지고, \(\tan h\)의 결과에는 0이 곱해짐
출력
Forget일 경우 : \(h_{t-1}\)
Input일 경우 : (-1, 1) 사이의 \(x_{t}\)
RNN 기반 모델들의 사용
1. 자연어 처리 (NLP)
- 텍스트 생성: 특정 스타일이나 주제에 맞춰 텍스트 생성
- 번역: 기계 번역 시스템에서 소스 언어를 타겟 언어로 변환
- 감정 분석: 텍스트 데이터에서 감정 추출
2. 음성 인식
- 스피치 투 텍스트: 음성 신호를 텍스트로 변환
- 명령어 인식: 음성 명령을 이해하고 처리
3. 시계열 예측
- 주가 예측: 과거 주가 데이터를 기반으로 미래 주가 예측
- 날씨 예측: 기상 데이터 분석 및 예측
4. 이미지 캡셔닝
- 이미지 설명: 이미지 내용을 설명하는 텍스트 생성
시계열 데이터 학습 : RNN 기반 모델들의 비교
이제 앞에서 살펴본 내용들을 바탕으로 임의의 시계열 데이터를 세 개의 RNN 기반 모델들을 이용해서 학습시켜보고 학습 속도, Loss 등을 비교해보면서 성능을 비교해보려고 한다.
from google.colab import drive
drive.mount('/content/drive')
Mounted at /content/drive
import os
os.chdir('drive/MyDrive/DL2024_201810776/week11/')
%load_ext autoreload
%autoreload 2
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import SimpleRNN, LSTM, GRU, Dense
import matplotlib.pyplot as plt
from utils import plot_learning_curves, decode_imdb, sentence_to_vector
임의의 신호 시계열 데이터 생성¶
임의로 시간마다 변화하는 시계열 데이터를 생성한다. Input의 차원은 100 * 2고, 1개의 Output을 가진다.
def signal_two():
X_data = []
y_data = []
for i in range(2500):
lst = np.random.rand(100)
idx = np.random.choice(100, 2, replace = False)
zeros = np.zeros(100)
zeros[idx] = 1
X_data.append(np.array(list(zip(zeros, lst))))
y_data.append(np.prod(lst[idx]))
X_data = np.array(X_data)
y_data = np.array(y_data)
return X_data, y_data
X_data, y_data = signal_two()
X_data.shape, y_data.shape
((2500, 100, 2), (2500,))
idx = 33
plt.plot(X_data[idx].flatten())
plt.scatter(50, y_data[idx], color = 'red')
plt.show()
idx = 200
plt.plot(X_data[idx].flatten())
plt.scatter(50, y_data[idx], color = 'red')
plt.show()
RNN & LSTM & GRU 모델 생성¶
rnn_model = Sequential([
# Hidden Layer Node 30개, input size [100, 2]
SimpleRNN(30, return_sequences=True, input_shape=[100, 2]),
SimpleRNN(30), # 2개 층 사용
Dense(1) # 마지막 출력층은 단일 출력
])
rnn_model.compile(optimizer='adam', loss='mse')
rnn_model.summary()
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= simple_rnn (SimpleRNN) (None, 100, 30) 990 simple_rnn_1 (SimpleRNN) (None, 30) 1830 dense (Dense) (None, 1) 31 ================================================================= Total params: 2851 (11.14 KB) Trainable params: 2851 (11.14 KB) Non-trainable params: 0 (0.00 Byte) _________________________________________________________________
lstm_model = Sequential([
# Hidden Layer Node 30개, input size [100, 2]
LSTM(30, return_sequences=True, input_shape=[100, 2]),
LSTM(30), # 2개 층 사용
Dense(1) # 마지막 출력층은 단일 출력
])
lstm_model.compile(optimizer='adam', loss='mse')
lstm_model.summary()
Model: "sequential_1" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= lstm (LSTM) (None, 100, 30) 3960 lstm_1 (LSTM) (None, 30) 7320 dense_1 (Dense) (None, 1) 31 ================================================================= Total params: 11311 (44.18 KB) Trainable params: 11311 (44.18 KB) Non-trainable params: 0 (0.00 Byte) _________________________________________________________________
gru_model = Sequential([
# Hidden Layer Node 30개, input size [100, 2]
GRU(30, return_sequences=True, input_shape=[100, 2]),
GRU(30), # 2개 층 사용
Dense(1) # 마지막 출력층은 단일 출력
])
gru_model.compile(optimizer='adam', loss='mse')
gru_model.summary()
Model: "sequential_2" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= gru (GRU) (None, 100, 30) 3060 gru_1 (GRU) (None, 30) 5580 dense_2 (Dense) (None, 1) 31 ================================================================= Total params: 8671 (33.87 KB) Trainable params: 8671 (33.87 KB) Non-trainable params: 0 (0.00 Byte) _________________________________________________________________
모델 학습¶
rnn_history = rnn_model.fit(X_data, y_data, epochs=100, validation_split=0.2)
Epoch 1/100 63/63 [==============================] - 6s 61ms/step - loss: 0.0794 - val_loss: 0.0508 Epoch 2/100 63/63 [==============================] - 2s 37ms/step - loss: 0.0538 - val_loss: 0.0518 Epoch 3/100 63/63 [==============================] - 2s 36ms/step - loss: 0.0509 - val_loss: 0.0500 Epoch 4/100 63/63 [==============================] - 2s 35ms/step - loss: 0.0508 - val_loss: 0.0500 Epoch 5/100 63/63 [==============================] - 3s 42ms/step - loss: 0.0515 - val_loss: 0.0495 Epoch 6/100 63/63 [==============================] - 4s 66ms/step - loss: 0.0518 - val_loss: 0.0651 Epoch 7/100 63/63 [==============================] - 2s 36ms/step - loss: 0.0516 - val_loss: 0.0496 Epoch 8/100 63/63 [==============================] - 2s 35ms/step - loss: 0.0494 - val_loss: 0.0492 Epoch 9/100 63/63 [==============================] - 2s 36ms/step - loss: 0.0508 - val_loss: 0.0498 Epoch 10/100 63/63 [==============================] - 2s 35ms/step - loss: 0.0506 - val_loss: 0.0511 Epoch 11/100 63/63 [==============================] - 3s 54ms/step - loss: 0.0498 - val_loss: 0.0507 Epoch 12/100 63/63 [==============================] - 4s 61ms/step - loss: 0.0498 - val_loss: 0.0492 Epoch 13/100 63/63 [==============================] - 2s 36ms/step - loss: 0.0496 - val_loss: 0.0577 Epoch 14/100 63/63 [==============================] - 2s 38ms/step - loss: 0.0509 - val_loss: 0.0490 Epoch 15/100 63/63 [==============================] - 2s 35ms/step - loss: 0.0505 - val_loss: 0.0489 Epoch 16/100 63/63 [==============================] - 4s 56ms/step - loss: 0.0496 - val_loss: 0.0486 Epoch 17/100 63/63 [==============================] - 4s 65ms/step - loss: 0.0497 - val_loss: 0.0487 Epoch 18/100 63/63 [==============================] - 3s 43ms/step - loss: 0.0500 - val_loss: 0.0491 Epoch 19/100 63/63 [==============================] - 2s 35ms/step - loss: 0.0489 - val_loss: 0.0508 Epoch 20/100 63/63 [==============================] - 2s 36ms/step - loss: 0.0507 - val_loss: 0.0503 Epoch 21/100 63/63 [==============================] - 2s 35ms/step - loss: 0.0508 - val_loss: 0.0492 Epoch 22/100 63/63 [==============================] - 4s 65ms/step - loss: 0.0508 - val_loss: 0.0490 Epoch 23/100 63/63 [==============================] - 3s 45ms/step - loss: 0.0499 - val_loss: 0.0490 Epoch 24/100 63/63 [==============================] - 2s 36ms/step - loss: 0.0493 - val_loss: 0.0493 Epoch 25/100 63/63 [==============================] - 2s 36ms/step - loss: 0.0496 - val_loss: 0.0483 Epoch 26/100 63/63 [==============================] - 2s 36ms/step - loss: 0.0496 - val_loss: 0.0505 Epoch 27/100 63/63 [==============================] - 3s 44ms/step - loss: 0.0496 - val_loss: 0.0535 Epoch 28/100 63/63 [==============================] - 4s 66ms/step - loss: 0.0492 - val_loss: 0.0493 Epoch 29/100 63/63 [==============================] - 2s 36ms/step - loss: 0.0490 - val_loss: 0.0502 Epoch 30/100 63/63 [==============================] - 2s 35ms/step - loss: 0.0493 - val_loss: 0.0494 Epoch 31/100 63/63 [==============================] - 2s 36ms/step - loss: 0.0502 - val_loss: 0.0483 Epoch 32/100 63/63 [==============================] - 2s 36ms/step - loss: 0.0482 - val_loss: 0.0485 Epoch 33/100 63/63 [==============================] - 4s 64ms/step - loss: 0.0488 - val_loss: 0.0507 Epoch 34/100 63/63 [==============================] - 5s 75ms/step - loss: 0.0487 - val_loss: 0.0493 Epoch 35/100 63/63 [==============================] - 3s 45ms/step - loss: 0.0488 - val_loss: 0.0503 Epoch 36/100 63/63 [==============================] - 2s 37ms/step - loss: 0.0499 - val_loss: 0.0486 Epoch 37/100 63/63 [==============================] - 2s 36ms/step - loss: 0.0476 - val_loss: 0.0502 Epoch 38/100 63/63 [==============================] - 2s 36ms/step - loss: 0.0478 - val_loss: 0.0558 Epoch 39/100 63/63 [==============================] - 4s 59ms/step - loss: 0.0477 - val_loss: 0.0493 Epoch 40/100 63/63 [==============================] - 3s 50ms/step - loss: 0.0475 - val_loss: 0.0491 Epoch 41/100 63/63 [==============================] - 2s 36ms/step - loss: 0.0491 - val_loss: 0.0496 Epoch 42/100 63/63 [==============================] - 2s 36ms/step - loss: 0.0479 - val_loss: 0.0512 Epoch 43/100 63/63 [==============================] - 2s 35ms/step - loss: 0.0493 - val_loss: 0.0492 Epoch 44/100 63/63 [==============================] - 2s 39ms/step - loss: 0.0473 - val_loss: 0.0490 Epoch 45/100 63/63 [==============================] - 4s 68ms/step - loss: 0.0476 - val_loss: 0.0495 Epoch 46/100 63/63 [==============================] - 3s 41ms/step - loss: 0.0477 - val_loss: 0.0528 Epoch 47/100 63/63 [==============================] - 2s 35ms/step - loss: 0.0476 - val_loss: 0.0497 Epoch 48/100 63/63 [==============================] - 2s 36ms/step - loss: 0.0472 - val_loss: 0.0544 Epoch 49/100 63/63 [==============================] - 2s 37ms/step - loss: 0.0477 - val_loss: 0.0534 Epoch 50/100 63/63 [==============================] - 3s 49ms/step - loss: 0.0472 - val_loss: 0.0494 Epoch 51/100 63/63 [==============================] - 4s 61ms/step - loss: 0.0469 - val_loss: 0.0496 Epoch 52/100 63/63 [==============================] - 2s 36ms/step - loss: 0.0462 - val_loss: 0.0513 Epoch 53/100 63/63 [==============================] - 2s 37ms/step - loss: 0.0469 - val_loss: 0.0500 Epoch 54/100 63/63 [==============================] - 2s 36ms/step - loss: 0.0464 - val_loss: 0.0495 Epoch 55/100 63/63 [==============================] - 2s 36ms/step - loss: 0.0456 - val_loss: 0.0536 Epoch 56/100 63/63 [==============================] - 4s 65ms/step - loss: 0.0462 - val_loss: 0.0493 Epoch 57/100 63/63 [==============================] - 4s 64ms/step - loss: 0.0467 - val_loss: 0.0523 Epoch 58/100 63/63 [==============================] - 3s 44ms/step - loss: 0.0453 - val_loss: 0.0505 Epoch 59/100 63/63 [==============================] - 3s 53ms/step - loss: 0.0458 - val_loss: 0.0486 Epoch 60/100 63/63 [==============================] - 4s 56ms/step - loss: 0.0448 - val_loss: 0.0517 Epoch 61/100 63/63 [==============================] - 5s 82ms/step - loss: 0.0444 - val_loss: 0.0479 Epoch 62/100 63/63 [==============================] - 3s 41ms/step - loss: 0.0444 - val_loss: 0.0506 Epoch 63/100 63/63 [==============================] - 3s 41ms/step - loss: 0.0443 - val_loss: 0.0495 Epoch 64/100 63/63 [==============================] - 3s 49ms/step - loss: 0.0462 - val_loss: 0.0516 Epoch 65/100 63/63 [==============================] - 4s 65ms/step - loss: 0.0441 - val_loss: 0.0494 Epoch 66/100 63/63 [==============================] - 4s 67ms/step - loss: 0.0435 - val_loss: 0.0503 Epoch 67/100 63/63 [==============================] - 2s 36ms/step - loss: 0.0438 - val_loss: 0.0492 Epoch 68/100 63/63 [==============================] - 2s 38ms/step - loss: 0.0422 - val_loss: 0.0495 Epoch 69/100 63/63 [==============================] - 4s 70ms/step - loss: 0.0429 - val_loss: 0.0496 Epoch 70/100 63/63 [==============================] - 7s 117ms/step - loss: 0.0448 - val_loss: 0.0496 Epoch 71/100 63/63 [==============================] - 3s 43ms/step - loss: 0.0436 - val_loss: 0.0509 Epoch 72/100 63/63 [==============================] - 2s 35ms/step - loss: 0.0399 - val_loss: 0.0462 Epoch 73/100 63/63 [==============================] - 2s 35ms/step - loss: 0.0399 - val_loss: 0.0443 Epoch 74/100 63/63 [==============================] - 2s 35ms/step - loss: 0.0441 - val_loss: 0.0518 Epoch 75/100 63/63 [==============================] - 3s 49ms/step - loss: 0.0408 - val_loss: 0.0478 Epoch 76/100 63/63 [==============================] - 4s 62ms/step - loss: 0.0397 - val_loss: 0.0547 Epoch 77/100 63/63 [==============================] - 2s 35ms/step - loss: 0.0418 - val_loss: 0.0489 Epoch 78/100 63/63 [==============================] - 2s 37ms/step - loss: 0.0396 - val_loss: 0.0464 Epoch 79/100 63/63 [==============================] - 2s 37ms/step - loss: 0.0408 - val_loss: 0.0569 Epoch 80/100 63/63 [==============================] - 2s 36ms/step - loss: 0.0498 - val_loss: 0.0470 Epoch 81/100 63/63 [==============================] - 4s 61ms/step - loss: 0.0484 - val_loss: 0.0539 Epoch 82/100 63/63 [==============================] - 3s 52ms/step - loss: 0.0491 - val_loss: 0.0482 Epoch 83/100 63/63 [==============================] - 2s 38ms/step - loss: 0.0478 - val_loss: 0.0475 Epoch 84/100 63/63 [==============================] - 2s 36ms/step - loss: 0.0450 - val_loss: 0.0506 Epoch 85/100 63/63 [==============================] - 2s 36ms/step - loss: 0.0426 - val_loss: 0.0486 Epoch 86/100 63/63 [==============================] - 3s 43ms/step - loss: 0.0414 - val_loss: 0.0462 Epoch 87/100 63/63 [==============================] - 4s 66ms/step - loss: 0.0437 - val_loss: 0.0486 Epoch 88/100 63/63 [==============================] - 2s 37ms/step - loss: 0.0397 - val_loss: 0.0466 Epoch 89/100 63/63 [==============================] - 3s 44ms/step - loss: 0.0418 - val_loss: 0.0470 Epoch 90/100 63/63 [==============================] - 3s 42ms/step - loss: 0.0392 - val_loss: 0.0448 Epoch 91/100 63/63 [==============================] - 2s 36ms/step - loss: 0.0390 - val_loss: 0.0442 Epoch 92/100 63/63 [==============================] - 4s 66ms/step - loss: 0.0417 - val_loss: 0.0475 Epoch 93/100 63/63 [==============================] - 3s 44ms/step - loss: 0.0415 - val_loss: 0.0466 Epoch 94/100 63/63 [==============================] - 2s 35ms/step - loss: 0.0379 - val_loss: 0.0445 Epoch 95/100 63/63 [==============================] - 2s 36ms/step - loss: 0.0379 - val_loss: 0.0426 Epoch 96/100 63/63 [==============================] - 2s 37ms/step - loss: 0.0351 - val_loss: 0.0421 Epoch 97/100 63/63 [==============================] - 3s 45ms/step - loss: 0.0374 - val_loss: 0.0432 Epoch 98/100 63/63 [==============================] - 4s 67ms/step - loss: 0.0339 - val_loss: 0.0398 Epoch 99/100 63/63 [==============================] - 2s 36ms/step - loss: 0.0342 - val_loss: 0.0414 Epoch 100/100 63/63 [==============================] - 2s 35ms/step - loss: 0.0305 - val_loss: 0.0365
lstm_history = lstm_model.fit(X_data, y_data, epochs=100, validation_split=0.2)
Epoch 1/100 63/63 [==============================] - 9s 86ms/step - loss: 0.0520 - val_loss: 0.0499 Epoch 2/100 63/63 [==============================] - 6s 92ms/step - loss: 0.0492 - val_loss: 0.0484 Epoch 3/100 63/63 [==============================] - 5s 83ms/step - loss: 0.0492 - val_loss: 0.0490 Epoch 4/100 63/63 [==============================] - 5s 72ms/step - loss: 0.0497 - val_loss: 0.0488 Epoch 5/100 63/63 [==============================] - 6s 99ms/step - loss: 0.0496 - val_loss: 0.0482 Epoch 6/100 63/63 [==============================] - 5s 75ms/step - loss: 0.0496 - val_loss: 0.0487 Epoch 7/100 63/63 [==============================] - 5s 72ms/step - loss: 0.0493 - val_loss: 0.0487 Epoch 8/100 63/63 [==============================] - 6s 104ms/step - loss: 0.0500 - val_loss: 0.0482 Epoch 9/100 63/63 [==============================] - 5s 72ms/step - loss: 0.0495 - val_loss: 0.0481 Epoch 10/100 63/63 [==============================] - 5s 74ms/step - loss: 0.0495 - val_loss: 0.0487 Epoch 11/100 63/63 [==============================] - 7s 107ms/step - loss: 0.0491 - val_loss: 0.0480 Epoch 12/100 63/63 [==============================] - 5s 74ms/step - loss: 0.0496 - val_loss: 0.0484 Epoch 13/100 63/63 [==============================] - 5s 74ms/step - loss: 0.0490 - val_loss: 0.0482 Epoch 14/100 63/63 [==============================] - 7s 104ms/step - loss: 0.0491 - val_loss: 0.0487 Epoch 15/100 63/63 [==============================] - 5s 75ms/step - loss: 0.0493 - val_loss: 0.0478 Epoch 16/100 63/63 [==============================] - 5s 82ms/step - loss: 0.0489 - val_loss: 0.0485 Epoch 17/100 63/63 [==============================] - 8s 120ms/step - loss: 0.0490 - val_loss: 0.0479 Epoch 18/100 63/63 [==============================] - 5s 81ms/step - loss: 0.0487 - val_loss: 0.0478 Epoch 19/100 63/63 [==============================] - 6s 103ms/step - loss: 0.0487 - val_loss: 0.0481 Epoch 20/100 63/63 [==============================] - 5s 73ms/step - loss: 0.0487 - val_loss: 0.0479 Epoch 21/100 63/63 [==============================] - 5s 73ms/step - loss: 0.0487 - val_loss: 0.0477 Epoch 22/100 63/63 [==============================] - 7s 105ms/step - loss: 0.0489 - val_loss: 0.0478 Epoch 23/100 63/63 [==============================] - 5s 72ms/step - loss: 0.0481 - val_loss: 0.0475 Epoch 24/100 63/63 [==============================] - 5s 72ms/step - loss: 0.0483 - val_loss: 0.0473 Epoch 25/100 63/63 [==============================] - 7s 106ms/step - loss: 0.0485 - val_loss: 0.0480 Epoch 26/100 63/63 [==============================] - 5s 73ms/step - loss: 0.0478 - val_loss: 0.0498 Epoch 27/100 63/63 [==============================] - 5s 78ms/step - loss: 0.0529 - val_loss: 0.0483 Epoch 28/100 63/63 [==============================] - 6s 97ms/step - loss: 0.0488 - val_loss: 0.0494 Epoch 29/100 63/63 [==============================] - 5s 74ms/step - loss: 0.0483 - val_loss: 0.0477 Epoch 30/100 63/63 [==============================] - 5s 83ms/step - loss: 0.0477 - val_loss: 0.0477 Epoch 31/100 63/63 [==============================] - 6s 92ms/step - loss: 0.0480 - val_loss: 0.0475 Epoch 32/100 63/63 [==============================] - 5s 73ms/step - loss: 0.0475 - val_loss: 0.0460 Epoch 33/100 63/63 [==============================] - 6s 88ms/step - loss: 0.0476 - val_loss: 0.0476 Epoch 34/100 63/63 [==============================] - 6s 87ms/step - loss: 0.0471 - val_loss: 0.0468 Epoch 35/100 63/63 [==============================] - 5s 72ms/step - loss: 0.0426 - val_loss: 0.0670 Epoch 36/100 63/63 [==============================] - 6s 94ms/step - loss: 0.0527 - val_loss: 0.0480 Epoch 37/100 63/63 [==============================] - 6s 91ms/step - loss: 0.0498 - val_loss: 0.0476 Epoch 38/100 63/63 [==============================] - 6s 98ms/step - loss: 0.0492 - val_loss: 0.0487 Epoch 39/100 63/63 [==============================] - 6s 99ms/step - loss: 0.0492 - val_loss: 0.0472 Epoch 40/100 63/63 [==============================] - 5s 74ms/step - loss: 0.0490 - val_loss: 0.0413 Epoch 41/100 63/63 [==============================] - 5s 83ms/step - loss: 0.0417 - val_loss: 0.0359 Epoch 42/100 63/63 [==============================] - 6s 96ms/step - loss: 0.0270 - val_loss: 0.0190 Epoch 43/100 63/63 [==============================] - 5s 73ms/step - loss: 0.0255 - val_loss: 0.0164 Epoch 44/100 63/63 [==============================] - 5s 87ms/step - loss: 0.0120 - val_loss: 0.0089 Epoch 45/100 63/63 [==============================] - 6s 88ms/step - loss: 0.0090 - val_loss: 0.0085 Epoch 46/100 63/63 [==============================] - 4s 71ms/step - loss: 0.0062 - val_loss: 0.0078 Epoch 47/100 63/63 [==============================] - 6s 91ms/step - loss: 0.0054 - val_loss: 0.0076 Epoch 48/100 63/63 [==============================] - 5s 84ms/step - loss: 0.0060 - val_loss: 0.0046 Epoch 49/100 63/63 [==============================] - 5s 73ms/step - loss: 0.0041 - val_loss: 0.0047 Epoch 50/100 63/63 [==============================] - 6s 99ms/step - loss: 0.0036 - val_loss: 0.0063 Epoch 51/100 63/63 [==============================] - 5s 77ms/step - loss: 0.0052 - val_loss: 0.0033 Epoch 52/100 63/63 [==============================] - 5s 72ms/step - loss: 0.0032 - val_loss: 0.0038 Epoch 53/100 63/63 [==============================] - 6s 104ms/step - loss: 0.0031 - val_loss: 0.0051 Epoch 54/100 63/63 [==============================] - 5s 74ms/step - loss: 0.0033 - val_loss: 0.0027 Epoch 55/100 63/63 [==============================] - 5s 72ms/step - loss: 0.0028 - val_loss: 0.0024 Epoch 56/100 63/63 [==============================] - 7s 107ms/step - loss: 0.0023 - val_loss: 0.0030 Epoch 57/100 63/63 [==============================] - 5s 77ms/step - loss: 0.0024 - val_loss: 0.0040 Epoch 58/100 63/63 [==============================] - 7s 117ms/step - loss: 0.0023 - val_loss: 0.0026 Epoch 59/100 63/63 [==============================] - 5s 86ms/step - loss: 0.0023 - val_loss: 0.0021 Epoch 60/100 63/63 [==============================] - 5s 73ms/step - loss: 0.0020 - val_loss: 0.0020 Epoch 61/100 63/63 [==============================] - 6s 100ms/step - loss: 0.0021 - val_loss: 0.0025 Epoch 62/100 63/63 [==============================] - 5s 76ms/step - loss: 0.0020 - val_loss: 0.0022 Epoch 63/100 63/63 [==============================] - 5s 74ms/step - loss: 0.0032 - val_loss: 0.0019 Epoch 64/100 63/63 [==============================] - 7s 105ms/step - loss: 0.0022 - val_loss: 0.0028 Epoch 65/100 63/63 [==============================] - 5s 74ms/step - loss: 0.0016 - val_loss: 0.0018 Epoch 66/100 63/63 [==============================] - 5s 73ms/step - loss: 0.0017 - val_loss: 0.0015 Epoch 67/100 63/63 [==============================] - 6s 104ms/step - loss: 0.0015 - val_loss: 0.0014 Epoch 68/100 63/63 [==============================] - 5s 76ms/step - loss: 0.0014 - val_loss: 0.0019 Epoch 69/100 63/63 [==============================] - 5s 79ms/step - loss: 0.0024 - val_loss: 0.0027 Epoch 70/100 63/63 [==============================] - 6s 97ms/step - loss: 0.0017 - val_loss: 0.0014 Epoch 71/100 63/63 [==============================] - 5s 72ms/step - loss: 0.0014 - val_loss: 0.0021 Epoch 72/100 63/63 [==============================] - 5s 85ms/step - loss: 0.0016 - val_loss: 0.0014 Epoch 73/100 63/63 [==============================] - 6s 90ms/step - loss: 0.0013 - val_loss: 0.0017 Epoch 74/100 63/63 [==============================] - 5s 72ms/step - loss: 0.0011 - val_loss: 0.0012 Epoch 75/100 63/63 [==============================] - 5s 86ms/step - loss: 0.0012 - val_loss: 0.0012 Epoch 76/100 63/63 [==============================] - 6s 88ms/step - loss: 0.0010 - val_loss: 0.0016 Epoch 77/100 63/63 [==============================] - 5s 73ms/step - loss: 0.0011 - val_loss: 0.0011 Epoch 78/100 63/63 [==============================] - 7s 110ms/step - loss: 0.0011 - val_loss: 0.0021 Epoch 79/100 63/63 [==============================] - 6s 101ms/step - loss: 0.0012 - val_loss: 0.0011 Epoch 80/100 63/63 [==============================] - 5s 75ms/step - loss: 8.8357e-04 - val_loss: 9.2832e-04 Epoch 81/100 63/63 [==============================] - 6s 89ms/step - loss: 8.5371e-04 - val_loss: 9.8483e-04 Epoch 82/100 63/63 [==============================] - 6s 88ms/step - loss: 9.2047e-04 - val_loss: 8.6136e-04 Epoch 83/100 63/63 [==============================] - 5s 72ms/step - loss: 0.0010 - val_loss: 0.0011 Epoch 84/100 63/63 [==============================] - 6s 90ms/step - loss: 8.3855e-04 - val_loss: 7.8609e-04 Epoch 85/100 63/63 [==============================] - 5s 87ms/step - loss: 9.9517e-04 - val_loss: 9.2863e-04 Epoch 86/100 63/63 [==============================] - 5s 73ms/step - loss: 8.9824e-04 - val_loss: 0.0012 Epoch 87/100 63/63 [==============================] - 6s 101ms/step - loss: 8.8566e-04 - val_loss: 9.0958e-04 Epoch 88/100 63/63 [==============================] - 5s 75ms/step - loss: 8.1512e-04 - val_loss: 8.5832e-04 Epoch 89/100 63/63 [==============================] - 5s 74ms/step - loss: 9.4467e-04 - val_loss: 7.0608e-04 Epoch 90/100 63/63 [==============================] - 7s 104ms/step - loss: 9.1719e-04 - val_loss: 9.5856e-04 Epoch 91/100 63/63 [==============================] - 5s 72ms/step - loss: 8.4016e-04 - val_loss: 7.5225e-04 Epoch 92/100 63/63 [==============================] - 5s 74ms/step - loss: 6.3936e-04 - val_loss: 6.3716e-04 Epoch 93/100 63/63 [==============================] - 6s 104ms/step - loss: 7.3913e-04 - val_loss: 0.0010 Epoch 94/100 63/63 [==============================] - 5s 73ms/step - loss: 8.7897e-04 - val_loss: 0.0017 Epoch 95/100 63/63 [==============================] - 5s 75ms/step - loss: 8.2503e-04 - val_loss: 0.0017 Epoch 96/100 63/63 [==============================] - 6s 99ms/step - loss: 0.0011 - val_loss: 9.2233e-04 Epoch 97/100 63/63 [==============================] - 5s 73ms/step - loss: 6.4775e-04 - val_loss: 5.3284e-04 Epoch 98/100 63/63 [==============================] - 6s 101ms/step - loss: 4.8493e-04 - val_loss: 6.5035e-04 Epoch 99/100 63/63 [==============================] - 7s 113ms/step - loss: 5.9518e-04 - val_loss: 6.0999e-04 Epoch 100/100 63/63 [==============================] - 5s 75ms/step - loss: 5.9284e-04 - val_loss: 5.2339e-04
gru_history = gru_model.fit(X_data, y_data, epochs=100, validation_split=0.2)
Epoch 1/100 63/63 [==============================] - 11s 124ms/step - loss: 0.0523 - val_loss: 0.0504 Epoch 2/100 63/63 [==============================] - 5s 76ms/step - loss: 0.0499 - val_loss: 0.0523 Epoch 3/100 63/63 [==============================] - 6s 97ms/step - loss: 0.0498 - val_loss: 0.0487 Epoch 4/100 63/63 [==============================] - 6s 87ms/step - loss: 0.0498 - val_loss: 0.0483 Epoch 5/100 63/63 [==============================] - 5s 76ms/step - loss: 0.0496 - val_loss: 0.0483 Epoch 6/100 63/63 [==============================] - 7s 110ms/step - loss: 0.0495 - val_loss: 0.0488 Epoch 7/100 63/63 [==============================] - 5s 78ms/step - loss: 0.0496 - val_loss: 0.0482 Epoch 8/100 63/63 [==============================] - 5s 78ms/step - loss: 0.0494 - val_loss: 0.0482 Epoch 9/100 63/63 [==============================] - 7s 107ms/step - loss: 0.0490 - val_loss: 0.0482 Epoch 10/100 63/63 [==============================] - 5s 78ms/step - loss: 0.0491 - val_loss: 0.0481 Epoch 11/100 63/63 [==============================] - 6s 91ms/step - loss: 0.0493 - val_loss: 0.0487 Epoch 12/100 63/63 [==============================] - 6s 93ms/step - loss: 0.0491 - val_loss: 0.0483 Epoch 13/100 63/63 [==============================] - 5s 78ms/step - loss: 0.0490 - val_loss: 0.0480 Epoch 14/100 63/63 [==============================] - 7s 106ms/step - loss: 0.0484 - val_loss: 0.0491 Epoch 15/100 63/63 [==============================] - 5s 78ms/step - loss: 0.0482 - val_loss: 0.0472 Epoch 16/100 63/63 [==============================] - 5s 82ms/step - loss: 0.0478 - val_loss: 0.0466 Epoch 17/100 63/63 [==============================] - 9s 138ms/step - loss: 0.0460 - val_loss: 0.0464 Epoch 18/100 63/63 [==============================] - 5s 79ms/step - loss: 0.0273 - val_loss: 0.0066 Epoch 19/100 63/63 [==============================] - 5s 77ms/step - loss: 0.0037 - val_loss: 0.0035 Epoch 20/100 63/63 [==============================] - 7s 110ms/step - loss: 0.0025 - val_loss: 0.0021 Epoch 21/100 63/63 [==============================] - 5s 77ms/step - loss: 0.0019 - val_loss: 0.0017 Epoch 22/100 63/63 [==============================] - 6s 93ms/step - loss: 0.0015 - val_loss: 0.0016 Epoch 23/100 63/63 [==============================] - 6s 93ms/step - loss: 0.0014 - val_loss: 0.0018 Epoch 24/100 63/63 [==============================] - 5s 77ms/step - loss: 0.0013 - val_loss: 0.0015 Epoch 25/100 63/63 [==============================] - 7s 106ms/step - loss: 0.0011 - val_loss: 0.0018 Epoch 26/100 63/63 [==============================] - 5s 77ms/step - loss: 0.0011 - val_loss: 0.0013 Epoch 27/100 63/63 [==============================] - 5s 77ms/step - loss: 9.8503e-04 - val_loss: 0.0015 Epoch 28/100 63/63 [==============================] - 7s 107ms/step - loss: 9.9584e-04 - val_loss: 0.0010 Epoch 29/100 63/63 [==============================] - 5s 78ms/step - loss: 9.0967e-04 - val_loss: 0.0011 Epoch 30/100 63/63 [==============================] - 6s 92ms/step - loss: 9.7673e-04 - val_loss: 0.0011 Epoch 31/100 63/63 [==============================] - 6s 95ms/step - loss: 9.4391e-04 - val_loss: 0.0012 Epoch 32/100 63/63 [==============================] - 5s 78ms/step - loss: 8.7120e-04 - val_loss: 8.7380e-04 Epoch 33/100 63/63 [==============================] - 7s 106ms/step - loss: 6.9260e-04 - val_loss: 0.0010 Epoch 34/100 63/63 [==============================] - 5s 81ms/step - loss: 7.2104e-04 - val_loss: 8.0142e-04 Epoch 35/100 63/63 [==============================] - 5s 78ms/step - loss: 7.0432e-04 - val_loss: 7.1026e-04 Epoch 36/100 63/63 [==============================] - 8s 133ms/step - loss: 6.5480e-04 - val_loss: 9.1472e-04 Epoch 37/100 63/63 [==============================] - 5s 83ms/step - loss: 6.5465e-04 - val_loss: 9.4294e-04 Epoch 38/100 63/63 [==============================] - 5s 78ms/step - loss: 7.6229e-04 - val_loss: 0.0014 Epoch 39/100 63/63 [==============================] - 7s 106ms/step - loss: 7.7934e-04 - val_loss: 6.3873e-04 Epoch 40/100 63/63 [==============================] - 5s 78ms/step - loss: 6.0807e-04 - val_loss: 9.5741e-04 Epoch 41/100 63/63 [==============================] - 5s 86ms/step - loss: 6.3156e-04 - val_loss: 0.0011 Epoch 42/100 63/63 [==============================] - 6s 100ms/step - loss: 6.0989e-04 - val_loss: 5.8265e-04 Epoch 43/100 63/63 [==============================] - 5s 77ms/step - loss: 6.4547e-04 - val_loss: 7.0458e-04 Epoch 44/100 63/63 [==============================] - 7s 109ms/step - loss: 5.3649e-04 - val_loss: 9.2800e-04 Epoch 45/100 63/63 [==============================] - 5s 84ms/step - loss: 5.2868e-04 - val_loss: 5.3463e-04 Epoch 46/100 63/63 [==============================] - 5s 80ms/step - loss: 4.6257e-04 - val_loss: 5.0950e-04 Epoch 47/100 63/63 [==============================] - 7s 111ms/step - loss: 4.6910e-04 - val_loss: 5.6282e-04 Epoch 48/100 63/63 [==============================] - 5s 78ms/step - loss: 4.5464e-04 - val_loss: 8.6205e-04 Epoch 49/100 63/63 [==============================] - 6s 90ms/step - loss: 6.5826e-04 - val_loss: 5.7842e-04 Epoch 50/100 63/63 [==============================] - 6s 97ms/step - loss: 4.1492e-04 - val_loss: 4.7834e-04 Epoch 51/100 63/63 [==============================] - 5s 79ms/step - loss: 4.2575e-04 - val_loss: 5.3386e-04 Epoch 52/100 63/63 [==============================] - 7s 104ms/step - loss: 4.5305e-04 - val_loss: 6.4115e-04 Epoch 53/100 63/63 [==============================] - 5s 82ms/step - loss: 4.3105e-04 - val_loss: 4.3264e-04 Epoch 54/100 63/63 [==============================] - 5s 80ms/step - loss: 4.0882e-04 - val_loss: 4.5356e-04 Epoch 55/100 63/63 [==============================] - 8s 133ms/step - loss: 4.2416e-04 - val_loss: 4.6622e-04 Epoch 56/100 63/63 [==============================] - 6s 88ms/step - loss: 3.4169e-04 - val_loss: 4.3814e-04 Epoch 57/100 63/63 [==============================] - 5s 77ms/step - loss: 3.5770e-04 - val_loss: 5.2064e-04 Epoch 58/100 63/63 [==============================] - 7s 111ms/step - loss: 3.5943e-04 - val_loss: 3.8141e-04 Epoch 59/100 63/63 [==============================] - 5s 80ms/step - loss: 3.2737e-04 - val_loss: 3.8554e-04 Epoch 60/100 63/63 [==============================] - 5s 83ms/step - loss: 4.1654e-04 - val_loss: 6.0449e-04 Epoch 61/100 63/63 [==============================] - 7s 104ms/step - loss: 3.6135e-04 - val_loss: 6.4644e-04 Epoch 62/100 63/63 [==============================] - 5s 78ms/step - loss: 3.6500e-04 - val_loss: 3.5036e-04 Epoch 63/100 63/63 [==============================] - 6s 97ms/step - loss: 4.2947e-04 - val_loss: 3.4373e-04 Epoch 64/100 63/63 [==============================] - 6s 90ms/step - loss: 3.1938e-04 - val_loss: 3.6149e-04 Epoch 65/100 63/63 [==============================] - 5s 78ms/step - loss: 3.2857e-04 - val_loss: 5.7637e-04 Epoch 66/100 63/63 [==============================] - 7s 110ms/step - loss: 3.1412e-04 - val_loss: 4.6957e-04 Epoch 67/100 63/63 [==============================] - 5s 77ms/step - loss: 2.8744e-04 - val_loss: 3.0571e-04 Epoch 68/100 63/63 [==============================] - 5s 77ms/step - loss: 3.6987e-04 - val_loss: 4.7171e-04 Epoch 69/100 63/63 [==============================] - 7s 105ms/step - loss: 2.8811e-04 - val_loss: 3.1594e-04 Epoch 70/100 63/63 [==============================] - 5s 79ms/step - loss: 3.3134e-04 - val_loss: 3.4860e-04 Epoch 71/100 63/63 [==============================] - 6s 94ms/step - loss: 2.8636e-04 - val_loss: 3.9410e-04 Epoch 72/100 63/63 [==============================] - 6s 92ms/step - loss: 2.9701e-04 - val_loss: 3.3203e-04 Epoch 73/100 63/63 [==============================] - 5s 77ms/step - loss: 2.4998e-04 - val_loss: 3.0161e-04 Epoch 74/100 63/63 [==============================] - 7s 109ms/step - loss: 2.7491e-04 - val_loss: 3.1272e-04 Epoch 75/100 63/63 [==============================] - 6s 101ms/step - loss: 2.7833e-04 - val_loss: 2.7460e-04 Epoch 76/100 63/63 [==============================] - 6s 100ms/step - loss: 2.5963e-04 - val_loss: 3.2516e-04 Epoch 77/100 63/63 [==============================] - 6s 89ms/step - loss: 2.5041e-04 - val_loss: 2.7449e-04 Epoch 78/100 63/63 [==============================] - 5s 77ms/step - loss: 2.2904e-04 - val_loss: 3.1260e-04 Epoch 79/100 63/63 [==============================] - 7s 111ms/step - loss: 2.4657e-04 - val_loss: 6.7247e-04 Epoch 80/100 63/63 [==============================] - 5s 78ms/step - loss: 2.6744e-04 - val_loss: 2.9775e-04 Epoch 81/100 63/63 [==============================] - 5s 79ms/step - loss: 3.0769e-04 - val_loss: 3.7119e-04 Epoch 82/100 63/63 [==============================] - 7s 107ms/step - loss: 2.9003e-04 - val_loss: 4.1077e-04 Epoch 83/100 63/63 [==============================] - 5s 78ms/step - loss: 2.4844e-04 - val_loss: 2.4148e-04 Epoch 84/100 63/63 [==============================] - 6s 90ms/step - loss: 1.9636e-04 - val_loss: 2.7925e-04 Epoch 85/100 63/63 [==============================] - 6s 95ms/step - loss: 2.0438e-04 - val_loss: 2.9553e-04 Epoch 86/100 63/63 [==============================] - 5s 77ms/step - loss: 1.9700e-04 - val_loss: 2.4795e-04 Epoch 87/100 63/63 [==============================] - 7s 105ms/step - loss: 2.0499e-04 - val_loss: 3.2470e-04 Epoch 88/100 63/63 [==============================] - 5s 82ms/step - loss: 1.9801e-04 - val_loss: 2.0676e-04 Epoch 89/100 63/63 [==============================] - 5s 78ms/step - loss: 2.0586e-04 - val_loss: 2.0084e-04 Epoch 90/100 63/63 [==============================] - 7s 108ms/step - loss: 3.1947e-04 - val_loss: 2.6646e-04 Epoch 91/100 63/63 [==============================] - 5s 77ms/step - loss: 2.4486e-04 - val_loss: 2.2623e-04 Epoch 92/100 63/63 [==============================] - 5s 83ms/step - loss: 1.9662e-04 - val_loss: 2.6710e-04 Epoch 93/100 63/63 [==============================] - 6s 102ms/step - loss: 2.3732e-04 - val_loss: 3.6315e-04 Epoch 94/100 63/63 [==============================] - 5s 87ms/step - loss: 1.9620e-04 - val_loss: 1.8894e-04 Epoch 95/100 63/63 [==============================] - 11s 177ms/step - loss: 1.7684e-04 - val_loss: 1.9143e-04 Epoch 96/100 63/63 [==============================] - 5s 78ms/step - loss: 2.1609e-04 - val_loss: 2.1834e-04 Epoch 97/100 63/63 [==============================] - 6s 92ms/step - loss: 1.5835e-04 - val_loss: 2.2991e-04 Epoch 98/100 63/63 [==============================] - 8s 120ms/step - loss: 1.5376e-04 - val_loss: 3.0917e-04 Epoch 99/100 63/63 [==============================] - 5s 77ms/step - loss: 1.6704e-04 - val_loss: 2.0586e-04 Epoch 100/100 63/63 [==============================] - 7s 108ms/step - loss: 1.7955e-04 - val_loss: 1.8844e-04
모델 비교 및 평가¶
Train과 Test의 Loss를 Plot시켜 RNN. LSTM, GRU의 성능을 비교한다
- Simple RNN : Loss가 Epochs에 따라 오히려 증가하는 것을 확인할 수 있는데 장기 의존성 문제로 인해 긴 시퀀스를 처리할 때 그레디언트가 소실되거나 폭주할 수 있다. 이 경우에는 역전파 과정에서 그레디언트가 제대로 전파되지 않아 학습이 어려워질 수 있다.
- LSTM, GRU : 반면 LSTM과 GRU에서는 Loss가 감소하였고 장기 의존성 문제를 해결하여 Test Set과 Train Set의 Loss의 차이가 거의 나지 않는 것을 확인할 수 있다.
- GRU가 LSTM에 비해서 학습 속도가 빨랐다. LSTM의 매개변수가 많기 때문에 계산이 오래 걸린다는 단점을 개선했다는 것을확인하였다.
plot_learning_curves(rnn_history, "Simple RNN")
plot_learning_curves(lstm_history, "LSTM")
plot_learning_curves(gru_history, "GRU")
해당 포스팅의 내용은 "상명대학교 민경하 교수님 "인공지능" 수업, 상명대학교 김승현 교수님 "딥러닝"수업을 기반으로 작성하였으며, 포스팅 자료는 해당 내용을 기반으로 재구성하여 만들어 사용하였습니다.
'Data Science > 머신러닝 & 딥러닝' 카테고리의 다른 글
[딥러닝] 기계 번역 : Seq2Seq와 Attention (+ 모델 학습시켜 다국어 번역해보기) (0) | 2024.06.11 |
---|---|
[딥러닝] NLP : 자연어 처리 기본 (+ 영화 리뷰글 긍정/부정 판단해보기) (0) | 2024.06.09 |
[딥러닝] CNN : ResNet 모델로 동물 이미지 분류하기(CIFAR 이미지셋) (0) | 2024.06.08 |
[딥러닝] CNN : 이미지 학습을 위한 신경망 (+ MNIST 손글씨 분류해보기) (1) | 2024.06.08 |
[딥러닝] 심층학습 시작 : 인공 신경망과 MLP (+ 신경망 모델 만들어보기) (1) | 2024.06.08 |