Why do we need the extra dimension in image representation?

why do we need the extra 1 dimension in the mnist dataset when all the data can be expressed without it?

The additional dimension represents an image channel.
In this case, there’s only one channel, so the dimension has length of 1.

I know but I don’t need that extra dimension as shown in the image

It’s just a notion to represent images like that, even if it doesn’t make sense.
Also, convolutional modules require it as well. Doesn’t matter with linear layers though, since you flatten the image anyway.