In [None]:
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

[![View on GitHub][github-badge]][github-cifar10-tf-nn-lrn] [![Open In Colab][colab-badge]][colab-cifar10-tf-nn-lrn] [![Open in Binder][binder-badge]][binder-cifar10-tf-nn-lrn]

[github-badge]: https://img.shields.io/badge/View-on%20GitHub-blue?logo=GitHub
[colab-badge]: https://colab.research.google.com/assets/colab-badge.svg
[binder-badge]: https://static.mybinder.org/badge_logo.svg

[github-cifar10-tf-nn-lrn]: https://github.com/mbrukman/reimplementing-ml-papers/blob/main/alexnet/AlexNet_for_CIFAR-10_in_Keras_with_tf_nn_LocalResponseNormalization.ipynb
[colab-cifar10-tf-nn-lrn]: https://colab.research.google.com/github/mbrukman/reimplementing-ml-papers/blob/main/alexnet/AlexNet_for_CIFAR-10_in_Keras_with_tf_nn_LocalResponseNormalization.ipynb
[binder-cifar10-tf-nn-lrn]: https://mybinder.org/v2/gh/mbrukman/reimplementing-ml-papers/main?filepath=alexnet/AlexNet_for_CIFAR-10_in_Keras_with_tf_nn_LocalResponseNormalization.ipynb

In [None]:
%%bash

readonly GH_USER="mbrukman"
readonly GH_REPO="reimplementing-ml-papers"
readonly GH_BRANCH="main"

# Download the AlexNet CIFAR-10 model definition and LocalResponseNormalization
# layer needed to construct it and our library to process the CIFAR-10 dataset.
for path in alexnet/{alexnet_cifar10,local_response_normalization}.py \
            datasets/cifar-10/cifar10_keras.py ; do
  module="$(basename "${path}")"
  if ! [ -f "${module}" ]; then
    curl -s -o "${module}" "https://raw.githubusercontent.com/${GH_USER}/${GH_REPO}/${GH_BRANCH}/${path}"
  fi
done

In [None]:
from tensorflow import keras

# Local imports downloaded above.
from alexnet_cifar10 import AlexNet
from cifar10_keras import CIFAR10

In [None]:
model = AlexNet()
model.summary()

Model: "CIFAR-10-TF-NN-LRN"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 Conv1 (Conv2D)              (None, 32, 32, 64)        4864      
                                                                 
 MaxPool1 (MaxPooling2D)     (None, 15, 15, 64)        0         
                                                                 
 LRN1 (LocalResponseNormali  (None, 15, 15, 64)        0         
 zation)                                                         
                                                                 
 Conv2 (Conv2D)              (None, 15, 15, 64)        102464    
                                                                 
 LRN2 (LocalResponseNormali  (None, 15, 15, 64)        0         
 zation)                                                         
                                                                 
 MaxPool2 (MaxPooling2D)     (None, 7, 7, 64)   

In [None]:
%%capture --no-stderr

# This will download the CIFAR-10 dataset via the Keras library, which writes to
# stdout, so we silence it above to avoid extraneous output.
cifar10_data = CIFAR10()

In [None]:
# Compile the model with optimizer and loss function.
model.compile(optimizer=keras.optimizers.Adam(learning_rate=0.001),
              loss=keras.losses.CategoricalCrossentropy(),
              metrics=[keras.metrics.CategoricalAccuracy()])

In [None]:
# Train the model
model.fit(cifar10_data.x_train_scale_0_1(),
          cifar10_data.y_train_categorical(),
          epochs=20)

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


<keras.callbacks.History at 0x7fb8c04226d0>

In [None]:
# Evaluate the model
model.evaluate(cifar10_data.x_test_scale_0_1(),
               cifar10_data.y_test_categorical())



[1.2246522903442383, 0.7002999782562256]