 # Multi Class Cross Entropy loss function

I am trying to get use a multi-output cross entropy loss function for the DSTL dataset. I took a look at the Open Solution Mapping Challenge loss functions here:

``````def multiclass_segmentation_loss(output, target):
target = target.squeeze(1).long()
cross_entropy = nn.CrossEntropyLoss()
return cross_entropy(output, target)

def cross_entropy(output, target, squeeze=False):
if squeeze:
target = target.squeeze(1)
return F.nll_loss(output, target)

def multi_output_cross_entropy(outputs, targets):
losses = []
for output, target in zip(outputs, targets):
loss = cross_entropy(output, target, squeeze=True)
losses.append(loss)
return sum(losses) / len(losses)
``````

In my DSTL dataset generator, I generate a mask and add it to each channel, so I have a 10-channel mask.

In the U-Net model, I set the input parameter to 3 (RGB images only) and the output channels=10.

These are the last few lines where I generate the image and mask:

``````# mask generator

raise ValueError('Could not generate concatenated mask!')

# swap color axis because
# numpy image: H x W x C
# torch image: C X H X W

image = resize(image, 256, 256).transpose((2, 0, 1))

image = torch.from_numpy(image).float()

``````

When I use the multiclass_segmentation_loss function, I get the following error:

``````  File "/tool/python/conda/env/gis36/lib/python3.6/site-packages/torch/nn/functional.py", line 1334, in nll_loss
RuntimeError: invalid argument 1: only batches of spatial targets supported (3D tensors) but got targets of dimension: 4 at /opt/conda/conda-bld/pytorch_1524590031827/work/aten/src/THCUNN/generic/SpatialClassNLLCriterion.cu:14
``````

If I use the multi_output_cross_entropy, I get the following error:

``````  File "/tool/python/conda/env/gis36/lib/python3.6/site-packages/torch/nn/functional.py", line 1341, in nll_loss
out_size, target.size()))
ValueError: Expected target size (10, 256), got torch.Size([10, 256, 256])
``````

I would like to do a per pixel cross entropy loss, for all pixels in the image, for all images in a batch.

Would you be able tell me what I should do to write this loss function and make sure that the input and target shapes match?

I think you actually want to use vanila cross_entropy since you have just one output (10 classes though). Multioutput is for exotic situations with a fork-structured output.

So I would just go with cross entropy or weighted sum of cross entropy and soft dice.