What is the relation between nn.Module and nn.Linear?

I can’t understand how a class can inherit from nn.Module and then use nn.Linear?

Inheritance is one of the ideas used in an object-oriented programming.

In this case, nn.Linear has a base class of nn.Module. It means it reuses some parts, that are common among other modules.

Is nn.Module a python module or a class?

It’s a class.

It’s something different from a python module, but shares the name.

1 Like

but why don’t we inherit from nn.Linear rather than nn.Module

Because nn.Linear is already kinda specialized module.

It accepts some inputs, it has a matrix of parameters, and transforms it into outputs.

nn.Linear(5, 15) transforms a vector of 5 inputs into 15 outputs.

Inheriting from it would be only sensible if you wanted to extend it in some way (but I think this module is really an “end one”, where further inheritance makes no sense).
Inheriting from nn.Module makes more sense, because we get a lot of useful “common” functionalities, and we can focus only on making our own forward method.

1 Like