Concatenating Layers in Keras
In Keras, concatenating layers allows you to combine the outputs of multiple layers into a single output. This is particularly useful for:
- Merging features from different branches of a network.
- Combining outputs from convolutional layers with fully connected layers.
- Implementing skip connections in residual networks.
Methods for Concatenating Layers
1. Using the `concatenate()` function
The concatenate()
function from Keras’s functional API is the most common way to concatenate layers. It takes a list of input tensors and combines them along a specified axis.
from tensorflow.keras.layers import Input, Dense, Concatenate
from tensorflow.keras.models import Model
# Input layers
input_a = Input(shape=(10,))
input_b = Input(shape=(5,))
# Define layers
dense_a = Dense(8, activation='relu')(input_a)
dense_b = Dense(4, activation='relu')(input_b)
# Concatenate layers
merged = Concatenate(axis=1)([dense_a, dense_b])
# Output layer
output = Dense(1, activation='sigmoid')(merged)
# Define model
model = Model(inputs=[input_a, input_b], outputs=output)
In this example, two input layers input_a
and input_b
are defined with different shapes. Two dense layers, dense_a
and dense_b
, process the inputs respectively. The Concatenate()
function combines the outputs of these layers along axis 1 (features), creating a merged layer. Finally, a dense layer with a sigmoid activation function is used as the output layer.
2. Using the `Lambda` layer
The Lambda
layer can also be used to concatenate layers. It applies a custom function to the input tensors.
from tensorflow.keras.layers import Input, Dense, Lambda
from tensorflow.keras.models import Model
# Input layers
input_a = Input(shape=(10,))
input_b = Input(shape=(5,))
# Define layers
dense_a = Dense(8, activation='relu')(input_a)
dense_b = Dense(4, activation='relu')(input_b)
# Concatenate using Lambda layer
merged = Lambda(lambda x: tf.concat(x, axis=1))([dense_a, dense_b])
# Output layer
output = Dense(1, activation='sigmoid')(merged)
# Define model
model = Model(inputs=[input_a, input_b], outputs=output)
This code defines a lambda function that uses tf.concat
to concatenate the input tensors along axis 1. The Lambda
layer applies this function to the outputs of dense_a
and dense_b
, achieving the concatenation.
Axis for Concatenation
The axis
parameter in the concatenate()
function defines the dimension along which the tensors are concatenated:
Axis | Concatenation |
---|---|
0 | Concatenates along the batch dimension. |
1 | Concatenates along the feature dimension. |
2 | Concatenates along the spatial dimension (for images). |
Choosing the appropriate axis depends on the specific architecture and desired output.