[딥러닝] CNN : ResNet 모델로 동물 이미지 분류하기(CIFAR 이미지셋)
2024. 6. 8. 16:54
반응형
이전 포스팅에서 우리는 CNN의 개념과, LeNet5부터 시작해서 다양한 CNN 모델들을 알아보고, LeNet5 모델을 학습시켜서 MNIST 손글씨를 분류해보는 작업까지 해 보았었다.
https://sjh9708.tistory.com/223
이번 포스팅에서는 ResNet을 모델을 만들어서 동물 이미지를 분류해보도록 하자.
1. Tensorflow로 ResNet34 모델을 직접 만들어서 CIFAR(동물 이미지 학습셋) 학습시키고 인식해보기
2. 학습이 완료된 ResNet101 모델을 불러와서 실제 동물 사진 분류시켜보기
ResNet34
ResNet이미지 인식 작업에서 뛰어난 성능과 효율성을 자랑하는 CNN 구조이다.
잔차 블록(Residual blocks)을 사용하여 Gradient Vanishing 문제를 완화하고 훨씬 Deep한 인공 신경망이 훈련하게 가능하게 하였다.
다음은 ResNet 중, 34개의 Layer로 구성된 ResNet34의 구조이며, 이를 Tensorflow를 이용하여 직접 만들어 볼 예정이다.
ResNet 모델로 CIFAR 이미지셋 및 실제 동물 사진 분류하기
CIFAR(Canadian Institute for Advanced Research)¶
- 기계 학습 및 컴퓨터 비전 분야에서 널리 사용되는 이미지 분류를 위한 데이터셋
- CIFAR-10은 10개의 클래스(예: 비행기, 자동차, 새 등)로 구성된 60,000개의 32x32 컬러 이미지를 포함
- CIFAR-100은 100개의 클래스로 구성된 동일한 크기의 이미지를 포함
In [2]:
from google.colab import drive
drive.mount('/content/drive')
import matplotlib.pyplot as plt
import numpy as np
import os
%load_ext autoreload
%autoreload 2
os.chdir('drive/MyDrive/DL2024_201810776/week10')
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
In [3]:
from google.colab import drive
drive.mount('/content/drive')
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
CIFAR 데이터 로드 및 전처리¶
In [13]:
import tensorflow as tf
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout, Layer, AveragePooling2D, Input, BatchNormalization, ReLU, Add, GlobalAveragePooling2D
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
In [14]:
from tensorflow.keras.datasets import cifar10
# 데이터 로드 및 전처리
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train = x_train.astype('float32') / 255 #0~255 정규화
x_test = x_test.astype('float32') / 255
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)
데이터의 크기¶
- Input Data의 Shape는 32X32X3
In [15]:
x_train.shape
Out[15]:
(50000, 32, 32, 3)
ResNet34 모델 만들기¶
- 직접 ResNet34의 Layer를 구성해서 모델을 생성해보자.
- 총 컨볼루션 연산 실행 횟수 : 36회
- 초기 컨볼루션 Layer : 1회
- 64 Filter Layer : 3 x 2 = 6회
- 128 Filter Layer : 4 x 2 + 1(shortcut) = 9회
- 256 Filter Layer : 6 x 2 + 1(shortcut) = 13회
- 512 Filter Layer : 3 x 2 + 1(shortcut) = 7회
In [21]:
import tensorflow as tf
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout, Layer, AveragePooling2D, Input, BatchNormalization, ReLU, Add, GlobalAveragePooling2D
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
def residual_block(x, filters, stride=1):
shortcut = x
# 첫 번째 컨볼루션
x = Conv2D(filters, 3, padding='same', strides=stride)(x)
x = BatchNormalization()(x)
x = ReLU()(x)
# 두 번째 컨볼루션
x = Conv2D(filters, 3, padding='same')(x)
x = BatchNormalization()(x)
# Shortcut 연결
if stride != 1 or shortcut.shape[-1] != filters:
shortcut = Conv2D(filters, 1, strides=stride)(shortcut)
shortcut = BatchNormalization()(shortcut)
x = Add()([x, shortcut])
x = ReLU()(x)
return x
def build_resnet34(input_shape, num_classes):
inputs = Input(shape=input_shape)
# 초기 컨볼루션 레이어
x = Conv2D(64, 7, strides=2, padding='same')(inputs)
x = BatchNormalization()(x)
x = ReLU()(x)
x = MaxPooling2D(3, strides=2, padding='same')(x)
# Residual 블록
x = residual_block(x, 64)
x = residual_block(x, 64)
x = residual_block(x, 64)
x = residual_block(x, 128, stride=2)
x = residual_block(x, 128)
x = residual_block(x, 128)
x = residual_block(x, 128)
x = residual_block(x, 256, stride=2)
x = residual_block(x, 256)
x = residual_block(x, 256)
x = residual_block(x, 256)
x = residual_block(x, 256)
x = residual_block(x, 256)
x = residual_block(x, 512, stride=2)
x = residual_block(x, 512)
x = residual_block(x, 512)
# 평균 풀링 및 밀집층
x = GlobalAveragePooling2D()(x)
x = Dense(num_classes, activation='softmax')(x)
# 모델 구성
model = Model(inputs=inputs, outputs=x)
return model
모델 학습¶
- 다중 카테고리 분류이므로 Categorical Cross Entropy를 Loss로 사용한다
- Optimizer로는 ADAM을 사용한다.
In [22]:
model = build_resnet34((32, 32, 3), 10)
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy']) #adam
In [23]:
model.summary()
Model: "model_2" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_3 (InputLayer) [(None, 32, 32, 3)] 0 [] conv2d_72 (Conv2D) (None, 16, 16, 64) 9472 ['input_3[0][0]'] batch_normalization_72 (Ba (None, 16, 16, 64) 256 ['conv2d_72[0][0]'] tchNormalization) re_lu_66 (ReLU) (None, 16, 16, 64) 0 ['batch_normalization_72[0][0] '] max_pooling2d_2 (MaxPoolin (None, 8, 8, 64) 0 ['re_lu_66[0][0]'] g2D) conv2d_73 (Conv2D) (None, 8, 8, 64) 36928 ['max_pooling2d_2[0][0]'] batch_normalization_73 (Ba (None, 8, 8, 64) 256 ['conv2d_73[0][0]'] tchNormalization) re_lu_67 (ReLU) (None, 8, 8, 64) 0 ['batch_normalization_73[0][0] '] conv2d_74 (Conv2D) (None, 8, 8, 64) 36928 ['re_lu_67[0][0]'] batch_normalization_74 (Ba (None, 8, 8, 64) 256 ['conv2d_74[0][0]'] tchNormalization) add_32 (Add) (None, 8, 8, 64) 0 ['batch_normalization_74[0][0] ', 'max_pooling2d_2[0][0]'] re_lu_68 (ReLU) (None, 8, 8, 64) 0 ['add_32[0][0]'] conv2d_75 (Conv2D) (None, 8, 8, 64) 36928 ['re_lu_68[0][0]'] batch_normalization_75 (Ba (None, 8, 8, 64) 256 ['conv2d_75[0][0]'] tchNormalization) re_lu_69 (ReLU) (None, 8, 8, 64) 0 ['batch_normalization_75[0][0] '] conv2d_76 (Conv2D) (None, 8, 8, 64) 36928 ['re_lu_69[0][0]'] batch_normalization_76 (Ba (None, 8, 8, 64) 256 ['conv2d_76[0][0]'] tchNormalization) add_33 (Add) (None, 8, 8, 64) 0 ['batch_normalization_76[0][0] ', 're_lu_68[0][0]'] re_lu_70 (ReLU) (None, 8, 8, 64) 0 ['add_33[0][0]'] conv2d_77 (Conv2D) (None, 8, 8, 64) 36928 ['re_lu_70[0][0]'] batch_normalization_77 (Ba (None, 8, 8, 64) 256 ['conv2d_77[0][0]'] tchNormalization) re_lu_71 (ReLU) (None, 8, 8, 64) 0 ['batch_normalization_77[0][0] '] conv2d_78 (Conv2D) (None, 8, 8, 64) 36928 ['re_lu_71[0][0]'] batch_normalization_78 (Ba (None, 8, 8, 64) 256 ['conv2d_78[0][0]'] tchNormalization) add_34 (Add) (None, 8, 8, 64) 0 ['batch_normalization_78[0][0] ', 're_lu_70[0][0]'] re_lu_72 (ReLU) (None, 8, 8, 64) 0 ['add_34[0][0]'] conv2d_79 (Conv2D) (None, 4, 4, 128) 73856 ['re_lu_72[0][0]'] batch_normalization_79 (Ba (None, 4, 4, 128) 512 ['conv2d_79[0][0]'] tchNormalization) re_lu_73 (ReLU) (None, 4, 4, 128) 0 ['batch_normalization_79[0][0] '] conv2d_80 (Conv2D) (None, 4, 4, 128) 147584 ['re_lu_73[0][0]'] conv2d_81 (Conv2D) (None, 4, 4, 128) 8320 ['re_lu_72[0][0]'] batch_normalization_80 (Ba (None, 4, 4, 128) 512 ['conv2d_80[0][0]'] tchNormalization) batch_normalization_81 (Ba (None, 4, 4, 128) 512 ['conv2d_81[0][0]'] tchNormalization) add_35 (Add) (None, 4, 4, 128) 0 ['batch_normalization_80[0][0] ', 'batch_normalization_81[0][0] '] re_lu_74 (ReLU) (None, 4, 4, 128) 0 ['add_35[0][0]'] conv2d_82 (Conv2D) (None, 4, 4, 128) 147584 ['re_lu_74[0][0]'] batch_normalization_82 (Ba (None, 4, 4, 128) 512 ['conv2d_82[0][0]'] tchNormalization) re_lu_75 (ReLU) (None, 4, 4, 128) 0 ['batch_normalization_82[0][0] '] conv2d_83 (Conv2D) (None, 4, 4, 128) 147584 ['re_lu_75[0][0]'] batch_normalization_83 (Ba (None, 4, 4, 128) 512 ['conv2d_83[0][0]'] tchNormalization) add_36 (Add) (None, 4, 4, 128) 0 ['batch_normalization_83[0][0] ', 're_lu_74[0][0]'] re_lu_76 (ReLU) (None, 4, 4, 128) 0 ['add_36[0][0]'] conv2d_84 (Conv2D) (None, 4, 4, 128) 147584 ['re_lu_76[0][0]'] batch_normalization_84 (Ba (None, 4, 4, 128) 512 ['conv2d_84[0][0]'] tchNormalization) re_lu_77 (ReLU) (None, 4, 4, 128) 0 ['batch_normalization_84[0][0] '] conv2d_85 (Conv2D) (None, 4, 4, 128) 147584 ['re_lu_77[0][0]'] batch_normalization_85 (Ba (None, 4, 4, 128) 512 ['conv2d_85[0][0]'] tchNormalization) add_37 (Add) (None, 4, 4, 128) 0 ['batch_normalization_85[0][0] ', 're_lu_76[0][0]'] re_lu_78 (ReLU) (None, 4, 4, 128) 0 ['add_37[0][0]'] conv2d_86 (Conv2D) (None, 4, 4, 128) 147584 ['re_lu_78[0][0]'] batch_normalization_86 (Ba (None, 4, 4, 128) 512 ['conv2d_86[0][0]'] tchNormalization) re_lu_79 (ReLU) (None, 4, 4, 128) 0 ['batch_normalization_86[0][0] '] conv2d_87 (Conv2D) (None, 4, 4, 128) 147584 ['re_lu_79[0][0]'] batch_normalization_87 (Ba (None, 4, 4, 128) 512 ['conv2d_87[0][0]'] tchNormalization) add_38 (Add) (None, 4, 4, 128) 0 ['batch_normalization_87[0][0] ', 're_lu_78[0][0]'] re_lu_80 (ReLU) (None, 4, 4, 128) 0 ['add_38[0][0]'] conv2d_88 (Conv2D) (None, 2, 2, 256) 295168 ['re_lu_80[0][0]'] batch_normalization_88 (Ba (None, 2, 2, 256) 1024 ['conv2d_88[0][0]'] tchNormalization) re_lu_81 (ReLU) (None, 2, 2, 256) 0 ['batch_normalization_88[0][0] '] conv2d_89 (Conv2D) (None, 2, 2, 256) 590080 ['re_lu_81[0][0]'] conv2d_90 (Conv2D) (None, 2, 2, 256) 33024 ['re_lu_80[0][0]'] batch_normalization_89 (Ba (None, 2, 2, 256) 1024 ['conv2d_89[0][0]'] tchNormalization) batch_normalization_90 (Ba (None, 2, 2, 256) 1024 ['conv2d_90[0][0]'] tchNormalization) add_39 (Add) (None, 2, 2, 256) 0 ['batch_normalization_89[0][0] ', 'batch_normalization_90[0][0] '] re_lu_82 (ReLU) (None, 2, 2, 256) 0 ['add_39[0][0]'] conv2d_91 (Conv2D) (None, 2, 2, 256) 590080 ['re_lu_82[0][0]'] batch_normalization_91 (Ba (None, 2, 2, 256) 1024 ['conv2d_91[0][0]'] tchNormalization) re_lu_83 (ReLU) (None, 2, 2, 256) 0 ['batch_normalization_91[0][0] '] conv2d_92 (Conv2D) (None, 2, 2, 256) 590080 ['re_lu_83[0][0]'] batch_normalization_92 (Ba (None, 2, 2, 256) 1024 ['conv2d_92[0][0]'] tchNormalization) add_40 (Add) (None, 2, 2, 256) 0 ['batch_normalization_92[0][0] ', 're_lu_82[0][0]'] re_lu_84 (ReLU) (None, 2, 2, 256) 0 ['add_40[0][0]'] conv2d_93 (Conv2D) (None, 2, 2, 256) 590080 ['re_lu_84[0][0]'] batch_normalization_93 (Ba (None, 2, 2, 256) 1024 ['conv2d_93[0][0]'] tchNormalization) re_lu_85 (ReLU) (None, 2, 2, 256) 0 ['batch_normalization_93[0][0] '] conv2d_94 (Conv2D) (None, 2, 2, 256) 590080 ['re_lu_85[0][0]'] batch_normalization_94 (Ba (None, 2, 2, 256) 1024 ['conv2d_94[0][0]'] tchNormalization) add_41 (Add) (None, 2, 2, 256) 0 ['batch_normalization_94[0][0] ', 're_lu_84[0][0]'] re_lu_86 (ReLU) (None, 2, 2, 256) 0 ['add_41[0][0]'] conv2d_95 (Conv2D) (None, 2, 2, 256) 590080 ['re_lu_86[0][0]'] batch_normalization_95 (Ba (None, 2, 2, 256) 1024 ['conv2d_95[0][0]'] tchNormalization) re_lu_87 (ReLU) (None, 2, 2, 256) 0 ['batch_normalization_95[0][0] '] conv2d_96 (Conv2D) (None, 2, 2, 256) 590080 ['re_lu_87[0][0]'] batch_normalization_96 (Ba (None, 2, 2, 256) 1024 ['conv2d_96[0][0]'] tchNormalization) add_42 (Add) (None, 2, 2, 256) 0 ['batch_normalization_96[0][0] ', 're_lu_86[0][0]'] re_lu_88 (ReLU) (None, 2, 2, 256) 0 ['add_42[0][0]'] conv2d_97 (Conv2D) (None, 2, 2, 256) 590080 ['re_lu_88[0][0]'] batch_normalization_97 (Ba (None, 2, 2, 256) 1024 ['conv2d_97[0][0]'] tchNormalization) re_lu_89 (ReLU) (None, 2, 2, 256) 0 ['batch_normalization_97[0][0] '] conv2d_98 (Conv2D) (None, 2, 2, 256) 590080 ['re_lu_89[0][0]'] batch_normalization_98 (Ba (None, 2, 2, 256) 1024 ['conv2d_98[0][0]'] tchNormalization) add_43 (Add) (None, 2, 2, 256) 0 ['batch_normalization_98[0][0] ', 're_lu_88[0][0]'] re_lu_90 (ReLU) (None, 2, 2, 256) 0 ['add_43[0][0]'] conv2d_99 (Conv2D) (None, 2, 2, 256) 590080 ['re_lu_90[0][0]'] batch_normalization_99 (Ba (None, 2, 2, 256) 1024 ['conv2d_99[0][0]'] tchNormalization) re_lu_91 (ReLU) (None, 2, 2, 256) 0 ['batch_normalization_99[0][0] '] conv2d_100 (Conv2D) (None, 2, 2, 256) 590080 ['re_lu_91[0][0]'] batch_normalization_100 (B (None, 2, 2, 256) 1024 ['conv2d_100[0][0]'] atchNormalization) add_44 (Add) (None, 2, 2, 256) 0 ['batch_normalization_100[0][0 ]', 're_lu_90[0][0]'] re_lu_92 (ReLU) (None, 2, 2, 256) 0 ['add_44[0][0]'] conv2d_101 (Conv2D) (None, 1, 1, 512) 1180160 ['re_lu_92[0][0]'] batch_normalization_101 (B (None, 1, 1, 512) 2048 ['conv2d_101[0][0]'] atchNormalization) re_lu_93 (ReLU) (None, 1, 1, 512) 0 ['batch_normalization_101[0][0 ]'] conv2d_102 (Conv2D) (None, 1, 1, 512) 2359808 ['re_lu_93[0][0]'] conv2d_103 (Conv2D) (None, 1, 1, 512) 131584 ['re_lu_92[0][0]'] batch_normalization_102 (B (None, 1, 1, 512) 2048 ['conv2d_102[0][0]'] atchNormalization) batch_normalization_103 (B (None, 1, 1, 512) 2048 ['conv2d_103[0][0]'] atchNormalization) add_45 (Add) (None, 1, 1, 512) 0 ['batch_normalization_102[0][0 ]', 'batch_normalization_103[0][0 ]'] re_lu_94 (ReLU) (None, 1, 1, 512) 0 ['add_45[0][0]'] conv2d_104 (Conv2D) (None, 1, 1, 512) 2359808 ['re_lu_94[0][0]'] batch_normalization_104 (B (None, 1, 1, 512) 2048 ['conv2d_104[0][0]'] atchNormalization) re_lu_95 (ReLU) (None, 1, 1, 512) 0 ['batch_normalization_104[0][0 ]'] conv2d_105 (Conv2D) (None, 1, 1, 512) 2359808 ['re_lu_95[0][0]'] batch_normalization_105 (B (None, 1, 1, 512) 2048 ['conv2d_105[0][0]'] atchNormalization) add_46 (Add) (None, 1, 1, 512) 0 ['batch_normalization_105[0][0 ]', 're_lu_94[0][0]'] re_lu_96 (ReLU) (None, 1, 1, 512) 0 ['add_46[0][0]'] conv2d_106 (Conv2D) (None, 1, 1, 512) 2359808 ['re_lu_96[0][0]'] batch_normalization_106 (B (None, 1, 1, 512) 2048 ['conv2d_106[0][0]'] atchNormalization) re_lu_97 (ReLU) (None, 1, 1, 512) 0 ['batch_normalization_106[0][0 ]'] conv2d_107 (Conv2D) (None, 1, 1, 512) 2359808 ['re_lu_97[0][0]'] batch_normalization_107 (B (None, 1, 1, 512) 2048 ['conv2d_107[0][0]'] atchNormalization) add_47 (Add) (None, 1, 1, 512) 0 ['batch_normalization_107[0][0 ]', 're_lu_96[0][0]'] re_lu_98 (ReLU) (None, 1, 1, 512) 0 ['add_47[0][0]'] global_average_pooling2d_2 (None, 512) 0 ['re_lu_98[0][0]'] (GlobalAveragePooling2D) dense_2 (Dense) (None, 10) 5130 ['global_average_pooling2d_2[0 ][0]'] ================================================================================================== Total params: 21315338 (81.31 MB) Trainable params: 21298314 (81.25 MB) Non-trainable params: 17024 (66.50 KB) __________________________________________________________________________________________________
In [24]:
model.fit(x_train, y_train, epochs=20, batch_size=128, validation_split=0.1)
Epoch 1/20 352/352 [==============================] - 149s 62ms/step - loss: 1.5190 - accuracy: 0.4599 - val_loss: 2.5403 - val_accuracy: 0.2354 Epoch 2/20 352/352 [==============================] - 21s 58ms/step - loss: 1.0879 - accuracy: 0.6126 - val_loss: 1.3812 - val_accuracy: 0.5130 Epoch 3/20 352/352 [==============================] - 20s 58ms/step - loss: 0.9227 - accuracy: 0.6766 - val_loss: 1.8749 - val_accuracy: 0.4850 Epoch 4/20 352/352 [==============================] - 20s 57ms/step - loss: 0.7835 - accuracy: 0.7258 - val_loss: 1.6271 - val_accuracy: 0.5264 Epoch 5/20 352/352 [==============================] - 21s 59ms/step - loss: 0.6738 - accuracy: 0.7661 - val_loss: 1.2822 - val_accuracy: 0.5680 Epoch 6/20 352/352 [==============================] - 20s 58ms/step - loss: 0.5981 - accuracy: 0.7909 - val_loss: 1.0986 - val_accuracy: 0.6388 Epoch 7/20 352/352 [==============================] - 21s 60ms/step - loss: 0.5363 - accuracy: 0.8114 - val_loss: 0.9936 - val_accuracy: 0.6724 Epoch 8/20 352/352 [==============================] - 20s 57ms/step - loss: 0.4801 - accuracy: 0.8328 - val_loss: 0.9565 - val_accuracy: 0.6984 Epoch 9/20 352/352 [==============================] - 21s 59ms/step - loss: 0.4532 - accuracy: 0.8421 - val_loss: 1.0334 - val_accuracy: 0.6740 Epoch 10/20 352/352 [==============================] - 20s 57ms/step - loss: 0.3547 - accuracy: 0.8745 - val_loss: 1.1955 - val_accuracy: 0.6704 Epoch 11/20 352/352 [==============================] - 20s 58ms/step - loss: 0.2835 - accuracy: 0.9008 - val_loss: 1.0760 - val_accuracy: 0.6910 Epoch 12/20 352/352 [==============================] - 20s 58ms/step - loss: 0.2432 - accuracy: 0.9158 - val_loss: 1.0495 - val_accuracy: 0.7142 Epoch 13/20 352/352 [==============================] - 20s 57ms/step - loss: 0.2080 - accuracy: 0.9262 - val_loss: 1.4214 - val_accuracy: 0.6650 Epoch 14/20 352/352 [==============================] - 20s 57ms/step - loss: 0.1808 - accuracy: 0.9360 - val_loss: 1.1311 - val_accuracy: 0.7242 Epoch 15/20 352/352 [==============================] - 21s 59ms/step - loss: 0.1542 - accuracy: 0.9452 - val_loss: 1.4610 - val_accuracy: 0.6748 Epoch 16/20 352/352 [==============================] - 20s 57ms/step - loss: 0.1607 - accuracy: 0.9445 - val_loss: 0.9664 - val_accuracy: 0.7462 Epoch 17/20 352/352 [==============================] - 20s 57ms/step - loss: 0.1196 - accuracy: 0.9583 - val_loss: 1.3573 - val_accuracy: 0.6872 Epoch 18/20 352/352 [==============================] - 20s 57ms/step - loss: 0.1089 - accuracy: 0.9612 - val_loss: 1.3623 - val_accuracy: 0.7188 Epoch 19/20 352/352 [==============================] - 20s 56ms/step - loss: 0.1697 - accuracy: 0.9432 - val_loss: 1.3120 - val_accuracy: 0.6938 Epoch 20/20 352/352 [==============================] - 20s 57ms/step - loss: 0.1287 - accuracy: 0.9566 - val_loss: 1.1119 - val_accuracy: 0.7462
Out[24]:
<keras.src.callbacks.History at 0x784a6574c7f0>
Test Set에 대한 예측 및 모델 검증¶
In [25]:
test_loss, test_accuracy = model.evaluate(x_test, y_test)
print(f'Test accuracy: {test_accuracy:.3f}')
313/313 [==============================] - 4s 11ms/step - loss: 1.1489 - accuracy: 0.7362 Test accuracy: 0.736
In [27]:
# 예측
y_pred = model.predict(x_test)
y_pred = np.array(y_pred)
313/313 [==============================] - 3s 9ms/step
In [28]:
cifar10 = {
0: 'airplane',
1: 'automobile',
2: 'bird',
3: 'cat',
4: 'deer',
5: 'dog',
6: 'frog',
7: 'horse',
8: 'ship',
9: 'truck',
}
for idx in range(len(y_pred[:10])):
plt.title("Truth : {}, Predicted : {}".format(cifar10[np.argmax(y_test[idx])], cifar10[np.argmax(y_pred[idx])]))
plt.imshow(x_test[idx], cmap='gray')
plt.show()
학습된 ResNet101 사용해보기¶
이번에는 학습이 완료된 ResNet101 모델을 사용해서 동물 이미지가 무엇인지 예측해보자.
In [30]:
import tensorflow as tf
from tensorflow.keras.applications.resnet import ResNet101, preprocess_input, decode_predictions
from tensorflow.keras.preprocessing import image
import numpy as np
# 이미지 파일 불러오기
img_path = 'images/cat_224x224.jpg' # 이미지 경로 설정
img = image.load_img(img_path, target_size=(224, 224)) # 모델에 맞는 이미지 크기로 조정
# 이미지를 배열로 변환하고 전처리
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
우리는 해당 고양이 사진을 예측할 것이다.¶
In [36]:
plt.imshow(img)
Out[36]:
<matplotlib.image.AxesImage at 0x784b00738f10>
ResNet101 모델 불러오기¶
In [37]:
# 모델 불러오기
model = ResNet101(weights='imagenet')
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet101_weights_tf_dim_ordering_tf_kernels.h5 179648224/179648224 [==============================] - 2s 0us/step
예측¶
- tiger_cat일 확률이 43%로 가장 높다. 다중 분류 문제는 가장 높은 확률을 선택한다. 고양이로 예측을 성공하였다.
In [38]:
# 이미지에 대해 예측
predictions = model.predict(x)
# 예측 결과 해석
decoded_predictions = decode_predictions(predictions, top=3)[0]
print('Predictions:')
for i, (imagenet_id, label, score) in enumerate(decoded_predictions):
print(f"{i + 1}: {label} ({score:.2f})")
1/1 [==============================] - 3s 3s/step Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/imagenet_class_index.json 35363/35363 [==============================] - 0s 0us/step Predictions: 1: tiger_cat (0.43) 2: Egyptian_cat (0.27) 3: tabby (0.13)
해당 포스팅의 내용은 "상명대학교 민경하 교수님 "인공지능" 수업, 상명대학교 김승현 교수님 "딥러닝"수업을 기반으로 작성하였으며, 포스팅 자료는 해당 내용을 기반으로 재구성하여 만들어 사용하였습니다.
반응형
'Data Science > 머신러닝 & 딥러닝' 카테고리의 다른 글
[딥러닝] NLP : 자연어 처리 기본 (+ 영화 리뷰글 긍정/부정 판단해보기) (0) | 2024.06.09 |
---|---|
[딥러닝] 기억하는 신경망 : RNN, 그리고 개선 모델 (LSTM, GRU) (0) | 2024.06.08 |
[딥러닝] CNN : 이미지 학습을 위한 신경망 (+ MNIST 손글씨 분류해보기) (1) | 2024.06.08 |
[딥러닝] 심층학습 시작 : 인공 신경망과 MLP (+ 신경망 모델 만들어보기) (1) | 2024.06.08 |
[머신러닝] 앙상블 모델 : Boosting / Stacking 적용해보기 (0) | 2023.08.30 |