Skip to main content

Train a Deep Learning Model

Deep Learning powered methods in Bioimage analysis have changed the way we analysis images, however they are still not widely used in the community. This is arguabley due to the lack of easy to use tools and the inability to adapt a published model to a new paradigm, such as a new microscope or a new cell type.

With Arkitekt we aimed to make it easy to both use already trained deep learning models in your analysis pipeline, and to make it easy to adapt a model to your specific needs. In this tutorial we will show you how to train your own deep learning model using CARE, a popular method for image restoration, and how to use it in your analysis pipeline. We advise you to first read the CARE paper to understand the methodology of care, as well as following both the first steps of the Arkitekt tutorial and the Deep Learning Bridge tutorial.

This tutorial aims to cover the following topics:

  • Our conceptional understanding of Deep Learning Training
  • What is a training Context and how to use it
  • How to train a CARE model

What we are trying to achieve

In this tutorial we would like to train a deep learning model that can be used to restore images that have been acquired with a confocal microscope. We are using the training dataset from the CARE paper, which is a dataset of confocal images of the Tribolium Castaneum embryo, that vary in their signal-to-roise ratio. We here will aim to train a model that will help us in a second step to restore images that have been acquired with the same microscope.

warning

For the purpose of this tutorial we will only train a model on a subset of the data, and we will use the same data for training and validation. This is NOT AT ALL how you should train a model, and we are doing this only for the purpose of this tutorial. Training, Testing, and Validation should always be done on different datasets. This is just a demo.

Training with Arkitekt

With Arkitekt we wanted to make it easy to train deep learning models, but also to give developers the flexibility to adapt their full training process to their needs. This is why we have a very simple conceptional understanding of training, which is that training a deep learning model is simply a Node that returns a Model that in a second step can be used to make predictions (in a Prediction Node). This means that training a model is as simple as running a Node that returns a Model. The developer can then decide how to train the model, and to save it in a way that makes sense for their use case. Importantly with Arkitekt we didn't want to limit the developer to speicifc models like image-to-image, or classifications, but to give them the flexibility to interact with every datatype and model type. But with great power comes great responsibility, this is why we have developed Contexts

Training Context

When training a deep learning model, we need data to train on, which often comes in the form of a dataset, that relates two specific datatypes. For example, in the case of CARE we need a dataset that relates Image and Image. This is where Contexts come in.

Context in Arkitekt are a way to define arbitrary relationships between different Data Types inside the platform. This means that in a Context you can relate a ROI with an Image, an Image with a Table or a Table with a Metric, it all depends on the use case and what type of relationship you want to define. In the case of CARE we want to define a relationship between two Image datatypes, which is what we will do in this tutorial. We just want to note that Contexts are not limited to this use case, and can be used in many different ways. Importantly you can have also have multiple Contexts that relate the same images in different ways (for example a Context that relates Image A and Image B and a Context that relates Image A and Image C or Image B and Table A).

info

One more thing: Relationships inside a Context can be named, so that you can for example have a Context that relates Image A as is High SNR of Image B and a Context that relates Image A as is Low SNR of Image B. Developers can then use these names to define their training process.

Lets get to this

Lets see how we can train a CARE model using Arkitekt. Of course we will need CARE for that, but lets wait for a moment and first see how we can define a Context that relates two Image datatypes. Indeed this relationship has nothing to do with CARE, and we can use it for many different things. But lets start with the Data that we will use for training. We will use the same data as in the CARE pape, which you can download here and here. We will use the same data for training and validation, but as we said before, this is just for the purpose of this tutorial.

Uploading an converting the data

This steps should seem familiar, just a few notes:

  • We are batch uploading the data by selecting both files and multi-dropping them on the webinterface
  • Then we are batch converting the data through the Convert Omero node, which will convert the data to the Image datatype.
  • We now have two Image datatypes in our project, which we can see in the Data tab.

Creating a Context

Now that we have our data, we can create a Context that relates the two Image datatypes. To do this we could either use the Create Context button, or have a Node that returns a Context datatype, and creates this Context for us. However we can of course do this also directly on the Data tab, using the similar pattern of drag-and-drop.

Creating a Context and a Relationship

A few notes:

  • You can always relate items in the Data tab by dragging one item on top of the other. This will open a dialog that will ask you to relate the two items by specifying a Context and a Relationship.
  • You will see a left to right relationship displayed, and you can give an arbitrary name to the relationship. In this case we are using gt for ground truth as our high SNR image is our ground truth for the low SNR image. Just read it out loud in your head: Image A is ground truth of Image B inside the context of Tribolium Denoising.
  • If the relationship or the Context did not exists before it will be created for you, otherwise you will find existing context in the dropdown.
  • Once submited your images are now related. You can see this by cliking on the newly created Context in the Data tab, and you will see the two images related by the relationship you specified.

Training a CARE model

And thats it, you created a training dataset for CARE. You can of course relate more items in the same context, but for now lets leave it at that. Now we can train a CARE model (on the GPU if you have one). For this we will need the KARE app, which is just CARE but enabled as an Arkitekt app (this freaking K is for Arkitekt, get it?). You can install it by clicking on the Install Repo button below.

arkitektio-apps/kare:master


Make sure to appify and deploy the plugin.

This app provides only two nodes Predict CARE and Train CARE, which allow you to train and predict CARE models, with your data. The Train CARE node will take a Context as input, and will return a Model as output. and the Predict CARE node will take a Model and an Image as input and will return a Predicted Image as output. Nodes all the way down. Lets see how we can train a CARE model.

note

Sadly with modern deep learning frameworks there is not yet a reliable way see if another process is using the GPU and some deep learning methods like Stardist and CARE sometimes still linger in the GPU memory even after the process has finished. So if you run into the problem that you can't train a CARE model, because the GPU is already in use, simply stop any other deep learning process that might be running. We are working on a more user friendly solution for this (also for better multi-gpu support).

Reserving and training a CARE Model

A few notes:

  • You know the drill. Search for the Train CARE node and Reserve it to make it available.
  • Navigate back to the Data tab and run it directly from your just created Context.
  • You will be prompted by the Assign Dialog, which will now ask you training parameters. These are specific to CARE, and you can read more about them in the CARE paper. We just leave the default parameters for now, but be aware that you can change them to your needs. The default paramters are not optimal for this dataset, but we are just doing this for the purpose of this tutorial.
  • Once you submit the dialog, the training will start. You can see the progress in the Dashboard tab, and you will see the Train CARE node turning green once the training is done. You can also monitor the progress right there. Looking at the progress bar.

Inspecting the Model

Training CARE with the default parameters can take a LOOOONG time, and depending on your hardware you might have to wait for a while until the progress bar moves. You can of course always check the output of the Deep Learning process in the PluginApps tab, by clicking on the running kare container and inspecting the output of the training process of the container

Inpsecting the deep learning output
info

While we are working on providing a few feedback graphs, that illustrate the training process, we are pretty much settled on the idea, that we don't want to provide a full fledged training interface inside Arkitekt. We think that this is best left to the deep learning frameworks and software projects like wandb that are specifically designed for this purpose. We rather think Arkitekt should be seen as a tool for transfer learning and model adaptation, and not as a tool for developing deep learning models from scratch. But hey we are open to feedback, and you can always just integrate your favorite deep learning framework into Arkitekt and use it as you like.

Using the Model

Your training is going to take time. But at some point it will end, promise! Once it is done you can use the model to make predictions. For this we will now reserve the Predict CARE node, which in this case will take a Model and an Image as input and return an image.

Inpsecting the deep learning output

And thats it, you trained a CARE model and used it to make predictions. You can now use this model in your analysis pipeline, and you can also use it to adapt it to your specific needs. Also of course you can download the model and share it with others, just like we did with the models that we are using in our Showcases.

The developers perspective

If you are a developer and you want to know how to integrate your own deep learning app into Arkitekt, you can read our Developer Tutorial on how to do this here. But also please just find here ALL the code for the Train CARE and Predict CARE nodes, which are just a few lines of code, and should be easy to understand.

from arkitekt import register
from mikro.api.schema import (
from_xarray,
ModelFragment,
create_model,
ModelKind,
RepresentationFragment,
links,
LinkableModels,
ContextFragment,
)
import numpy as np
from csbdeep.io import load_training_data
from csbdeep.models import Config, CARE
import numpy as np
import uuid
import shutil
from csbdeep.data import RawData, create_patches
import numpy as np

from csbdeep.io import load_training_data
from csbdeep.models import Config, CARE
from arkitekt.tqdm import tqdm


@register()
def train_care_model(
context: ContextFragment,
epochs: int = 100,
patches_per_image: int = 1024,
trainsteps_per_epoch: int = 400,
validation_split: float = 0.1,
) -> ModelFragment:
"""Train Care Model

Trains a care model according on a specific context.

Args:
context (ContextFragment): The context
epochs (int, optional): Number of epochs. Defaults to 10.
patches_per_image (int, optional): Number of patches per image. Defaults to 1024.
trainsteps_per_epoch (int, optional): Number of trainsteps per epoch. Defaults to 10.
validation_split (float, optional): Validation split. Defaults to 0.1.


Returns:
ModelFragment: The Model
"""

training_data_id = f"context_data{context.id}"

x = links(
LinkableModels.GRUNNLAG_REPRESENTATION, # only works with images
LinkableModels.GRUNNLAG_REPRESENTATION, # that are related to each other
"gt", # trhough the `gt` relationship
context=context, # inside the given context
) # HERE WE ARE GETTING THE DATA FROM THE CONTEXT,

# we are getting computing data from boths sides of the relationship
X = [t.left.data.sel(t=0, c=0).compute() for t in x]
Y = [t.right.data.sel(t=0, c=0).compute() for t in x]


# This is all CARE specific, and you can read more about it in the CARE paper
raw_data = RawData.from_arrays(X, Y, axes="ZYX")
print(raw_data)

X, Y, XY_axes = create_patches(
raw_data=raw_data,
patch_size=(16, 64, 64),
n_patches_per_image=patches_per_image,
save_file=f"data/{training_data_id}.npz",
)

(X, Y), (X_val, Y_val), axes = load_training_data(
f"data/{training_data_id}.npz",
validation_split=validation_split,
verbose=True,
)
config = Config(axes, train_steps_per_epoch=trainsteps_per_epoch)

model = CARE(config, training_data_id, basedir=".trainedmodels")


# Here we are training the model, we are utilizing a patch tqdm progress bar, that will just update
# the progress bar every time we train on a new patch on the Arkitekt Interface
for i in tqdm(range(epochs)):
model.train(X, Y, validation_data=(X_val, Y_val), epochs=1)

shutil.make_archive(
"active_model", "zip", f".trainedmodels/{training_data_id}"
)

# Here we are creating a model from the trained model, and we are saving it as an Akritekt Model
model = create_model(
"active_model.zip",
kind=ModelKind.TENSORFLOW,
name=f"Care Model of {context.name}",
contexts=[context],
)



shutil.rmtree(f"data")
return model


@register()
def predict(
representation: RepresentationFragment, model: ModelFragment
) -> RepresentationFragment:
"""Predict Care

Use a care model and some images to generate images

Args:
model (ImageToImageModelFragment): The model
representations (List[RepresentationFragment]): The images

Returns:
List[RepresentationFragment]: The predicted images
"""

random_dir = str(uuid.uuid4())
generated = []

# We are downloading the model into a temporary directory
with model.data as f:
# We are unpacking the model into a temporary directory

shutil.unpack_archive(f, f".modelcache/{random_dir}")

image_data = representation.data.sel(c=0, t=0).data.compute()
# We are loading the model from the temporary directory
care_model = CARE(config=None, name=random_dir, basedir=".modelcache")
# We are predicting the image
restored = care_model.predict(
image_data, "ZXY"
)

# We are creating a new representation from the predicted image
generated = from_xarray(
restored,
name=f"Care denoised of {representation.name}",
tags=["denoised"],
origins=[representation],
)

shutil.rmtree(f".modelcache/{random_dir}")


return generated