Train a Deep Learning Model
Deep Learning powered methods in Bioimage analysis have changed the way we analysis images, however they are still not widely used in the community. This is arguabley due to the lack of easy to use tools and the inability to adapt a published model to a new paradigm, such as a new microscope or a new cell type.
With Arkitekt we aimed to make it easy to both use already trained deep learning models in your analysis pipeline, and to make it easy to adapt a model to your specific needs. In this tutorial we will show you how to train your own deep learning model using CARE, a popular method for image restoration, and how to use it in your analysis pipeline. We advise you to first read the CARE paper to understand the methodology of care, as well as following both the first steps of the Arkitekt tutorial and the Deep Learning Bridge tutorial.
This tutorial aims to cover the following topics:
- Our conceptional understanding of Deep Learning Training
- What is a training
Context
and how to use it - How to train a CARE model
What we are trying to achieve
In this tutorial we would like to train a deep learning model that can be used to restore images that have been acquired with a confocal microscope. We are using the training dataset from the CARE paper, which is a dataset of confocal images of the Tribolium Castaneum embryo, that vary in their signal-to-roise ratio. We here will aim to train a model that will help us in a second step to restore images that have been acquired with the same microscope.
For the purpose of this tutorial we will only train a model on a subset of the data, and we will use the same
data for training and validation. This is NOT AT ALL how you should train a model,
and we are doing this only for the purpose of this tutorial. Training, Testing, and Validation should always be done on different datasets. This is just a demo.
Training with Arkitekt
With Arkitekt we wanted to make it easy to train deep learning models, but also to give developers the flexibility to adapt their full training process to their needs. This is why
we have a very simple conceptional understanding of training, which is that training a deep learning model is simply a Node
that returns a Model
that in a second step can be used to
make predictions (in a Prediction Node
). This means that training a model is as simple as running a Node
that returns a Model
. The developer can then decide how to train the model, and to save it in
a way that makes sense for their use case. Importantly with Arkitekt we didn't want to limit the developer to speicifc models like image-to-image, or classifications, but to give them the flexibility to
interact with every datatype and model type. But with great power comes great responsibility, this is why we have developed Contexts
Training Context
When training a deep learning model, we need data to train on, which often comes in the form of a dataset, that relates two specific datatypes. For example, in the case of CARE we need a dataset that relates
Image
and Image
. This is where Contexts
come in.
Context
in Arkitekt are a way to define arbitrary relationships between different Data Types inside the platform. This means that in a Context
you can relate a ROI
with an Image
, an Image
with a Table
or a Table
with a Metric
, it all depends on the use case and what type of relationship you want to define. In the case of CARE we want to define a relationship between two Image
datatypes, which is what we will
do in this tutorial. We just want to note that Contexts
are not limited to this use case, and can be used in many different ways. Importantly you can have also have multiple Contexts
that relate the same images in different
ways (for example a Context
that relates Image A
and Image B
and a Context
that relates Image A
and Image C
or Image B
and Table A
).
One more thing: Relationships inside a Context
can be named, so that you can for example have a Context
that relates Image A
as is High SNR
of Image B
and a Context
that relates Image A
as is Low SNR
of Image B
.
Developers can then use these names to define their training process.
Lets get to this
Lets see how we can train a CARE model using Arkitekt. Of course we will need CARE
for that, but lets wait for a moment and first see how we can define a Context
that relates two Image
datatypes.
Indeed this relationship has nothing to do with CARE, and we can use it for many different things. But lets start with the Data
that we will use for training. We will use the same data as in the CARE pape,
which you can download here and here. We will use the same
data for training and validation, but as we said before, this is just for the purpose of this tutorial.
This steps should seem familiar, just a few notes:
- We are batch uploading the data by selecting both files and
multi-dropping
them on the webinterface - Then we are batch converting the data through the
Convert Omero
node, which will convert the data to theImage
datatype. - We now have two
Image
datatypes in our project, which we can see in theData
tab.
Creating a Context
Now that we have our data, we can create a Context
that relates the two Image
datatypes. To do this we could either use the Create Context
button, or have a Node
that returns a Context
datatype, and
creates this Context
for us. However we can of course do this also directly on the Data
tab, using the similar pattern of drag-and-drop.
A few notes:
- You can always relate items in the
Data
tab by dragging one item on top of the other. This will open a dialog that will ask you to relate the two items by specifying aContext
and aRelationship
. - You will see a left to right relationship displayed, and you can give an arbitrary name to the relationship. In this case we are using
gt
forground truth
as our high SNR image is our ground truth for the low SNR image. Just read it out loud in your head:Image A
isground truth
ofImage B
inside the context ofTribolium Denoising
. - If the relationship or the Context did not exists before it will be created for you, otherwise you will find existing context in the dropdown.
- Once submited your images are now related. You can see this by cliking on the newly created
Context
in theData
tab, and you will see the two images related by the relationship you specified.
Training a CARE model
And thats it, you created a training dataset for CARE. You can of course relate more items in the same context, but for now lets leave it at that. Now we can train a CARE model (on the GPU if you have one).
For this we will need the KARE
app, which is just CARE
but enabled as an Arkitekt app (this freaking K is for Arkitekt, get it?). You can install it by clicking on the Install Repo
button below.
Make sure to appify and deploy the plugin.
This app provides only two nodes Predict CARE
and Train CARE
, which allow you to train and predict CARE models, with your data. The Train CARE node will take a Context
as input, and will return a Model
as output.
and the Predict CARE
node will take a Model
and an Image
as input and will return a Predicted Image
as output. Nodes all the way down. Lets see how we can train a CARE model.
Sadly with modern deep learning frameworks there is not yet a reliable way see if another process is using the GPU and some deep learning methods like Stardist and CARE sometimes still linger in the GPU memory even after the process has finished. So if you run into the problem that you can't train a CARE model, because the GPU is already in use, simply stop any other deep learning process that might be running. We are working on a more user friendly solution for this (also for better multi-gpu support).
A few notes:
- You know the drill. Search for the
Train CARE
node andReserve it
to make it available. - Navigate back to the
Data tab
and run it directly from your just createdContext
. - You will be prompted by the
Assign Dialog
, which will now ask you training parameters. These are specific to CARE, and you can read more about them in the CARE paper. We just leave the default parameters for now, but be aware that you can change them to your needs. The default paramters are not optimal for this dataset, but we are just doing this for the purpose of this tutorial. - Once you submit the dialog, the training will start. You can see the progress in the
Dashboard
tab, and you will see theTrain CARE
node turning green once the training is done. You can also monitor the progress right there. Looking at the progress bar.
Inspecting the Model
Training CARE with the default parameters can take a LOOOONG time, and depending on your hardware you might have to wait for a while until the progress bar moves. You can of course always
check the output of the Deep Learning process in the PluginApps
tab, by clicking on the running kare
container and inspecting the output of the training process of the container
While we are working on providing a few feedback graphs, that illustrate the training process, we are pretty much settled on the idea, that we don't want to provide a full fledged training
interface inside Arkitekt. We think that this is best left to the deep learning frameworks and software projects like wandb that are specifically designed for this purpose.
We rather think Arkitekt should be seen as a tool for transfer learning
and model adaptation
, and not as a tool for developing deep learning models from scratch. But hey we are open to feedback,
and you can always just integrate your favorite deep learning framework into Arkitekt and use it as you like.
Using the Model
Your training is going to take time. But at some point it will end, promise! Once it is done you can use the model to make predictions.
For this we will now reserve the Predict CARE
node, which in this case will take a Model
and an Image
as input and return an image.
And thats it, you trained a CARE model and used it to make predictions. You can now use this model in your analysis pipeline, and you can also use it to adapt it to your specific needs. Also of course you can download the model and share it with others, just like we did with the models that we are using in our Showcases.
The developers perspective
If you are a developer and you want to know how to integrate your own deep learning app into Arkitekt, you can read our Developer Tutorial on how to do this here.
But also please just find here ALL the code for the Train CARE
and Predict CARE
nodes, which are just a few lines of code, and should be easy to understand.
from arkitekt import register
from mikro.api.schema import (
from_xarray,
ModelFragment,
create_model,
ModelKind,
RepresentationFragment,
links,
LinkableModels,
ContextFragment,
)
import numpy as np
from csbdeep.io import load_training_data
from csbdeep.models import Config, CARE
import numpy as np
import uuid
import shutil
from csbdeep.data import RawData, create_patches
import numpy as np
from csbdeep.io import load_training_data
from csbdeep.models import Config, CARE
from arkitekt.tqdm import tqdm
@register()
def train_care_model(
context: ContextFragment,
epochs: int = 100,
patches_per_image: int = 1024,
trainsteps_per_epoch: int = 400,
validation_split: float = 0.1,
) -> ModelFragment:
"""Train Care Model
Trains a care model according on a specific context.
Args:
context (ContextFragment): The context
epochs (int, optional): Number of epochs. Defaults to 10.
patches_per_image (int, optional): Number of patches per image. Defaults to 1024.
trainsteps_per_epoch (int, optional): Number of trainsteps per epoch. Defaults to 10.
validation_split (float, optional): Validation split. Defaults to 0.1.
Returns:
ModelFragment: The Model
"""
training_data_id = f"context_data{context.id}"
x = links(
LinkableModels.GRUNNLAG_REPRESENTATION, # only works with images
LinkableModels.GRUNNLAG_REPRESENTATION, # that are related to each other
"gt", # trhough the `gt` relationship
context=context, # inside the given context
) # HERE WE ARE GETTING THE DATA FROM THE CONTEXT,
# we are getting computing data from boths sides of the relationship
X = [t.left.data.sel(t=0, c=0).compute() for t in x]
Y = [t.right.data.sel(t=0, c=0).compute() for t in x]
# This is all CARE specific, and you can read more about it in the CARE paper
raw_data = RawData.from_arrays(X, Y, axes="ZYX")
print(raw_data)
X, Y, XY_axes = create_patches(
raw_data=raw_data,
patch_size=(16, 64, 64),
n_patches_per_image=patches_per_image,
save_file=f"data/{training_data_id}.npz",
)
(X, Y), (X_val, Y_val), axes = load_training_data(
f"data/{training_data_id}.npz",
validation_split=validation_split,
verbose=True,
)
config = Config(axes, train_steps_per_epoch=trainsteps_per_epoch)
model = CARE(config, training_data_id, basedir=".trainedmodels")
# Here we are training the model, we are utilizing a patch tqdm progress bar, that will just update
# the progress bar every time we train on a new patch on the Arkitekt Interface
for i in tqdm(range(epochs)):
model.train(X, Y, validation_data=(X_val, Y_val), epochs=1)
shutil.make_archive(
"active_model", "zip", f".trainedmodels/{training_data_id}"
)
# Here we are creating a model from the trained model, and we are saving it as an Akritekt Model
model = create_model(
"active_model.zip",
kind=ModelKind.TENSORFLOW,
name=f"Care Model of {context.name}",
contexts=[context],
)
shutil.rmtree(f"data")
return model
@register()
def predict(
representation: RepresentationFragment, model: ModelFragment
) -> RepresentationFragment:
"""Predict Care
Use a care model and some images to generate images
Args:
model (ImageToImageModelFragment): The model
representations (List[RepresentationFragment]): The images
Returns:
List[RepresentationFragment]: The predicted images
"""
random_dir = str(uuid.uuid4())
generated = []
# We are downloading the model into a temporary directory
with model.data as f:
# We are unpacking the model into a temporary directory
shutil.unpack_archive(f, f".modelcache/{random_dir}")
image_data = representation.data.sel(c=0, t=0).data.compute()
# We are loading the model from the temporary directory
care_model = CARE(config=None, name=random_dir, basedir=".modelcache")
# We are predicting the image
restored = care_model.predict(
image_data, "ZXY"
)
# We are creating a new representation from the predicted image
generated = from_xarray(
restored,
name=f"Care denoised of {representation.name}",
tags=["denoised"],
origins=[representation],
)
shutil.rmtree(f".modelcache/{random_dir}")
return generated