딥러닝

[딥러닝] CNN 모델

퓨어맨 2022. 7. 26. 09:30
from tensorflow.keras.layers import Conv2D, MaxPool2D, Flatten, Dropout

cnn_model = Sequential()

# 1.특성추출부(Conv - 특징이 되는 정보를 부각시켜 추출함)
cnn_model.add(Conv2D(input_shape=(224,224,3),
              # 필터(돋보기)의 개수 -> 추출하는 특징의 개수를 설정
              filters=128,
              # 필터의 크기 설정(행, 열)
              kernel_size=(3,3),
              # same : 원본 데이터의 크기에 맞게 알아서 패딩을 적용(valid : 패딩 적용 X)
              padding='same',
              activation='relu'
              ))
# 2.특성추출부(Pooling - 불필요한 정보 삭제)
# pool_size : 디폴트 값이 2(필터 크기가 2 x 2)
cnn_model.add(MaxPool2D())

cnn_model.add(Conv2D(filters=256,
              kernel_size=(3,3),
              padding='same',
              activation='relu'
              ))
cnn_model.add(MaxPool2D())

# Dropout : 신경망의 전체 뉴런중 일부(20%)를 학습이 불가하도록 만들어주는 명령
#  -> 신경망의 복잡도를 낮춰서 좀 더 가볍게 동작시키고 과대적합을 해소하는데 도움을 줌
cnn_model.add(Dropout(0.2))

cnn_model.add(Conv2D(filters=128,
              kernel_size=(3,3),
              padding='same',
              activation='relu'
              ))
cnn_model.add(MaxPool2D())

cnn_model.add(Conv2D(filters=64,
              kernel_size=(3,3),
              padding='same',
              activation='relu'
              ))
cnn_model.add(MaxPool2D())

# 분류기(MLP)
cnn_model.add(Flatten())
cnn_model.add(Dense(128, activation='relu'))
cnn_model.add(Dense(64, activation='relu'))
cnn_model.add(Dense(32, activation='relu'))

cnn_model.add(Dense(3, activation='softmax'))

cnn_model.summary()

 

cnn_model.compile(loss='sparse_categorical_crossentropy',
                   optimizer='Adam',
                   metrics=['acc']
                   )
                   
h = cnn_model.fit(X_train, y_train,
                  validation_split=0.2,
                  batch_size=128,
                  epochs=50
                  )
                  
plt.figure(figsize=(15,5))

# train 데이터
plt.plot(h.history['acc'],
         label='acc',
         )

# val 데이터
plt.plot(h.history['val_acc'],
         label='val_acc',
         )

plt.legend()
plt.show()