Shared classifier for 3 neural networks

I want to create 3 different VGGs, with a shared classifier. Basically each of these architectures has only the convolutions, and then I combine all the nets, with a classifier.
For a better explanation, let’s see this image:

I have no idea on how to do this in Pytorch. Do you have any example that can I study?

Create 3 model classes:

  • VGGA
  • VGGB
  • VGGC

Then create a 4th model which uses VGGA, VGGB and VGGC as modules, and it concatenates their result.

Can you try taking a shot at writing the code for it? I can suggest improvements.


Thanks. Sure, I will write some code this evening, or at most, tomorrow.

class VGGBlock(nn.Module):
    def __init__(self, in_channels, out_channels,batch_norm=False):


        conv2_params = {'kernel_size': (3, 3),
                        'stride'     : (1, 1),
                        'padding'   : 1

        noop = lambda x : x

        self._batch_norm = batch_norm

        self.conv1 = nn.Conv2d(in_channels=in_channels,out_channels=out_channels , **conv2_params)
        self.bn1 = nn.BatchNorm2d(out_channels) if batch_norm else noop

        self.conv2 = nn.Conv2d(in_channels=out_channels,out_channels=out_channels, **conv2_params)
        self.bn2 = nn.BatchNorm2d(out_channels) if batch_norm else noop

        self.max_pooling = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))

    def batch_norm(self):
        return self._batch_norm

    def forward(self,x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = F.relu(x)

        x = self.max_pooling(x)

        return x
class VGG16(nn.Module):

  def __init__(self, input_size, num_classes=1,batch_norm=False):
    super(VGG16, self).__init__()

    self.in_channels,self.in_width,self.in_height = input_size

    self.block_1 = VGGBlock(self.in_channels,64,batch_norm=batch_norm)
    self.block_2 = VGGBlock(64, 128,batch_norm=batch_norm)
    self.block_3 = VGGBlock(128, 256,batch_norm=batch_norm)
    self.block_4 = VGGBlock(256,512,batch_norm=batch_norm)

  def input_size(self):
      return self.in_channels,self.in_width,self.in_height

  def forward(self, x):

    x = self.block_1(x)
    x = self.block_2(x)
    x = self.block_3(x)
    x = self.block_4(x)
    # x = self.avgpool(x)
    x = torch.flatten(x,1)

    return x
class VGG16Classifier(nn.Module):

  def __init__(self, vgg_a,vgg_b,vgg_star, num_classes=1,classifier = None,batch_norm=False):
    super(VGG16Classifier, self).__init__()

    assert vgg_a.input_size == vgg_b.input_size == vgg_star.input_size

    self._vgg_a = vgg_a
    self._vgg_b = vgg_b
    self._vgg_star = vgg_star
    self.classifier = classifier

    if (self.classifier is None):
        self.classifier = nn.Sequential(
          nn.Linear(2048, 2048),
          nn.Linear(2048, 512),
          nn.Linear(512, num_classes)

  def forward(self, x1,x2,x3):
      xc =,x2,x3),0)
      xc = self.classifier(xc)

      return xc
model1 = VGG16((1,32,32),batch_norm=True)
model2 = VGG16((1,32,32),batch_norm=True)
model_star = VGG16((1,32,32),batch_norm=True)
model_combo = VGG16Classifier(model1,model2,model_star)