What is Transfer Learning
Four Steps:
- Load data
- Build model
- Train and Test
- Transfer Learning
Load Data
- Images and labels
- X = [1.png, 2.png, 3.png]
- Y = [4, 9, 1]
data = tf.data.Dataset.from_tensor_slices((x,y))
data.shuffle().map(func).batch()
Preprocessing
- Read and resize
- 224*224 for ResNet
- Data Augmentation
- Rotate/Flip
- Crop
- Normalize
- Mean, std
def preprocess(x, y):
# x: 图片的路径,y:图片的数字编码
x = tf.io.read_file(x)
x = tf.image.decode_jpeg(x, channels=3) # RGBA
x = tf.image.resize(x, [244, 244])
# x = tf.image.random_flip_left_right(x)
x = tf.image.random_flip_up_down(x)
x = tf.image.random_crop(x, [224, 224, 3])
# x: [0,255]=> -1~1
x = tf.cast(x, dtype=tf.float32) / 255.
x = normalize(x)
y = tf.convert_to_tensor(y)
y = tf.one_hot(y, depth=5)
return x, y
Build Model
- Inherit from Model
- Define forward graph
- Add optimizer
Train and Test
- Train, validation, test
- Early stopping
详见Code
Reference
Note: Cover Picture