Train an ML model on a dataset#
In the previous tutorial, we loaded an entire dataset into memory to perform a simple analysis.
Here, we’ll iterate over the artifacts within the dataset to train an ML model.
import lamindb as ln
import anndata as ad
import numpy as np
💡 lamindb instance: testuser1/test-scrna
ln.track()
💡 notebook imports: anndata==0.9.2 lamindb==0.64.2 numpy==1.26.2 torch==2.1.2
💡 saved: Transform(uid='Qr1kIHvK506rz8', name='Train an ML model on a dataset', short_name='scrna5', version='0', type=notebook, updated_at=2023-12-22 11:26:25 UTC, created_by_id=1)
💡 saved: Run(uid='D8BivzEtGyQEArC7ajEZ', run_at=2023-12-22 11:26:25 UTC, transform_id=5, created_by_id=1)
Preprocessing#
Let us get our dataset:
dataset_v2 = ln.Dataset.filter(name="My versioned scRNA-seq dataset", version="2").one()
dataset_v2
Dataset(uid='yIwUH2JUgdJ8sandeDLl', name='My versioned scRNA-seq dataset', version='2', hash='BOAf0T5UbN_iOe3fQDyq', visibility=1, updated_at=2023-12-22 11:26:01 UTC, transform_id=2, run_id=2, initial_version_id=1, created_by_id=1)
PyTorch DataLoader#
If you need to train your model on a list of artifacts, you can use mapped()
with the PyTorch DataLoader
.
It only loads batches into memory and thus allows to work with very large datasets.
from torch.utils.data import DataLoader, WeightedRandomSampler
Let us create a MappedDataset
object from Dataset
.
Under-the-hood, it performs a virtual inner join of the variables of the underlying AnnData
objects.
ds_mapped = dataset_v2.mapped(label_keys=["cell_type"])
The intersected variable names can be accessed:
len(ds_mapped.var_joint)
749
This is compatible with pytorch DataLoader
because it implements __getitem__
over a list of backed AnnData
objects.
ds_mapped[5]
Show code cell output
[array([0. , 0. , 0. , 1.51316404, 0. ,
0. , 0. , 0. , 1.01904154, 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 3.43718052, 0. , 0. ,
0. , 1.51316404, 0. , 0. , 1.51316404,
1.01904154, 0. , 0. , 1.01904154, 1.01904154,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 1.01904154, 0. ,
0. , 0. , 0. , 0. , 2.08965826,
0. , 1.51316404, 1.84239149, 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
1.01904154, 0. , 0. , 0. , 0. ,
1.01904154, 0. , 0. , 0. , 0. ,
1.84239149, 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
1.01904154, 0. , 0. , 1.01904154, 0. ,
1.84239149, 0. , 1.01904154, 0. , 0. ,
0. , 0. , 1.01904154, 0. , 0. ,
0. , 1.51316404, 0. , 0. , 2.28774452,
2.45300555, 0. , 0. , 1.01904154, 0. ,
0. , 0. , 0. , 1.01904154, 0. ,
0. , 0. , 1.01904154, 1.01904154, 1.01904154,
1.01904154, 1.84239149, 0. , 0. , 2.08965826,
1.01904154, 1.84239149, 0. , 0. , 0. ,
0. , 0. , 1.51316404, 0. , 1.51316404,
1.01904154, 0. , 0. , 1.01904154, 0. ,
0. , 0. , 1.01904154, 0. , 1.01904154,
0. , 0. , 1.01904154, 0. , 1.01904154,
1.51316404, 0. , 1.01904154, 0. , 0. ,
1.01904154, 1.01904154, 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 2.28774452, 0. , 1.51316404, 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 1.01904154, 0. ,
1.51316404, 1.01904154, 0. , 0. , 1.01904154,
1.84239149, 2.08965826, 1.01904154, 0. , 1.51316404,
1.01904154, 1.01904154, 1.01904154, 0. , 1.01904154,
0. , 0. , 0. , 0. , 0. ,
1.01904154, 1.51316404, 0. , 0. , 1.84239149,
0. , 1.01904154, 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
5.14149761, 0. , 2.45300555, 1.01904154, 1.51316404,
0. , 0. , 1.01904154, 0. , 2.45300555,
1.01904154, 0. , 0. , 1.01904154, 0. ,
1.51316404, 0. , 1.84239149, 0. , 0. ,
1.51316404, 2.92881131, 0. , 0. , 0. ,
2.28774452, 1.01904154, 3.85087061, 1.51316404, 3.24989128,
2.45300555, 2.28774452, 0. , 1.84239149, 0. ,
2.71894431, 2.92881131, 0. , 1.01904154, 0. ,
1.01904154, 0. , 1.01904154, 1.01904154, 0. ,
2.28774452, 1.01904154, 0. , 1.84239149, 1.01904154,
1.01904154, 1.01904154, 0. , 0. , 0. ,
1.51316404, 1.01904154, 1.84239149, 0. , 0. ,
1.01904154, 1.84239149, 1.84239149, 0. , 0. ,
1.51316404, 0. , 1.01904154, 0. , 2.08965826,
1.01904154, 0. , 0. , 1.01904154, 0. ,
0. , 1.01904154, 1.01904154, 2.59478951, 1.01904154,
0. , 0. , 0. , 0. , 0. ,
0. , 1.51316404, 0. , 1.01904154, 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 2.08965826, 0. , 0. ,
0. , 0. , 1.01904154, 0. , 1.01904154,
0. , 0. , 0. , 1.01904154, 0. ,
0. , 0. , 0. , 0. , 1.01904154,
0. , 1.51316404, 0. , 0. , 0. ,
0. , 0. , 0. , 1.51316404, 0. ,
1.01904154, 1.51316404, 0. , 1.01904154, 0. ,
1.51316404, 0. , 1.01904154, 0. , 1.01904154,
0. , 0. , 1.01904154, 0. , 0. ,
0. , 3.43718052, 0. , 0. , 0. ,
0. , 2.08965826, 0. , 0. , 1.01904154,
0. , 1.01904154, 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 1.51316404,
1.01904154, 3.17876172, 0. , 1.01904154, 1.01904154,
1.01904154, 0. , 0. , 1.01904154, 0. ,
1.01904154, 0. , 0. , 0. , 0. ,
1.51316404, 1.51316404, 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 1.51316404, 1.51316404, 1.01904154,
0. , 0. , 0. , 2.45300555, 0. ,
1.51316404, 0. , 0. , 0. , 0. ,
1.01904154, 1.01904154, 0. , 1.01904154, 0. ,
1.51316404, 0. , 0. , 1.01904154, 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 1.01904154, 0. , 1.01904154, 0. ,
1.51316404, 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
1.51316404, 0. , 0. , 0. , 1.01904154,
1.01904154, 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 2.08965826, 0. ,
2.45300555, 1.01904154, 0. , 0. , 0. ,
1.51316404, 0. , 1.51316404, 1.01904154, 0. ,
0. , 0. , 0. , 1.01904154, 1.01904154,
0. , 0. , 0. , 1.01904154, 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 2.28774452,
0. , 2.59478951, 1.01904154, 0. , 0. ,
0. , 0. , 1.01904154, 2.92881131, 1.01904154,
2.45300555, 0. , 0. , 1.01904154, 2.08965826,
0. , 0. , 0. , 0. , 0. ,
1.51316404, 0. , 1.01904154, 0. , 1.01904154,
0. , 0. , 1.01904154, 0. , 1.01904154,
0. , 0. , 0. , 0. , 0. ,
0. , 1.01904154, 1.51316404, 0. , 0. ,
0. , 0. , 1.01904154, 0. , 0. ,
0. , 0. , 1.01904154, 1.51316404, 1.01904154,
0. , 0. , 0. , 2.08965826, 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 1.84239149,
0. , 0. , 0. , 0. , 0. ,
0. , 1.51316404, 0. , 0. , 0. ,
1.84239149, 0. , 0. , 1.84239149, 0. ,
0. , 0. , 0. , 0. , 1.01904154,
0. , 0. , 0. , 0. , 0. ,
0. , 1.51316404, 0. , 0. , 1.51316404,
0. , 2.45300555, 0. , 1.51316404, 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 1.51316404, 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 1.01904154,
1.01904154, 1.51316404, 0. , 0. , 0. ,
1.01904154, 1.01904154, 1.01904154, 0. , 0. ,
0. , 0. , 0. , 1.01904154, 1.01904154,
0. , 0. , 0. , 1.51316404, 0. ,
0. , 0. , 0. , 0. , 1.01904154,
1.84239149, 2.08965826, 0. , 0. , 0. ,
1.01904154, 0. , 2.92881131, 0. , 1.01904154,
0. , 1.01904154, 1.01904154, 0. , 1.84239149,
0. , 0. , 0. , 0. , 0. ,
1.01904154, 0. , 0. , 2.08965826, 0. ,
0. , 0. , 0. , 0. , 3.31629562,
0. , 2.45300555, 0. , 1.01904154, 0. ,
1.01904154, 0. , 0. , 0. , 0. ,
0. , 0. , 1.01904154, 0. , 0. ,
0. , 0. , 1.01904154, 0. , 0. ,
0. , 1.01904154, 0. , 0. , 1.01904154,
1.51316404, 1.84239149, 0. , 2.45300555, 1.01904154,
1.01904154, 0. , 1.01904154, 1.01904154, 0. ,
1.51316404, 0. , 1.01904154, 0. , 0. ,
0. , 1.01904154, 1.01904154, 0. , 0. ,
1.01904154, 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 1.01904154,
0. , 2.08965826, 0. , 1.01904154, 1.01904154,
0. , 0. , 0. , 1.51316404, 0. ,
0. , 0. , 1.51316404, 1.01904154, 1.01904154,
0. , 0. , 0. , 0. , 1.01904154,
1.01904154, 1.01904154, 1.51316404, 0. , 1.01904154,
0. , 0. , 0. , 4.55217552, 1.01904154,
0. , 1.01904154, 0. , 4.14264011]),
31]
The labels
are encoded into integers.
ds_mapped.encoders
Show code cell output
[{'megakaryocyte': 0,
'lymphocyte': 1,
'CD16-negative, CD56-bright natural killer cell, human': 2,
'effector memory CD4-positive, alpha-beta T cell, terminally differentiated': 3,
'dendritic cell, human': 4,
'plasmacytoid dendritic cell': 5,
'B cell, CD19-positive': 6,
'animal cell': 7,
'gamma-delta T cell': 8,
'mast cell': 9,
'plasma cell': 10,
'macrophage': 11,
'alveolar macrophage': 12,
'plasmablast': 13,
'mucosal invariant T cell': 14,
'naive thymus-derived CD8-positive, alpha-beta T cell': 15,
'group 3 innate lymphoid cell': 16,
'CD8-positive, alpha-beta memory T cell, CD45RO-positive': 17,
'classical monocyte': 18,
'CD4-positive, alpha-beta T cell': 19,
'T follicular helper cell': 20,
'non-classical monocyte': 21,
'naive thymus-derived CD4-positive, alpha-beta T cell': 22,
'alpha-beta T cell': 23,
'CD16-positive, CD56-dim natural killer cell, human': 24,
'progenitor cell': 25,
'CD8-positive, CD25-positive, alpha-beta regulatory T cell': 26,
'regulatory T cell': 27,
'effector memory CD8-positive, alpha-beta T cell, terminally differentiated': 28,
'CD38-positive naive B cell': 29,
'dendritic cell': 30,
'memory B cell': 31,
'effector memory CD4-positive, alpha-beta T cell': 32,
'CD4-positive helper T cell': 33,
'CD14-positive, CD16-negative classical monocyte': 34,
'germinal center B cell': 35,
'cytotoxic T cell': 36,
'CD8-positive, alpha-beta memory T cell': 37,
'conventional dendritic cell': 38,
'naive B cell': 39}]
Let us use a weighted sampler:
# label_key for weight doesn't have to be in labels on init
sampler = WeightedRandomSampler(
weights=ds_mapped.get_label_weights("cell_type"), num_samples=len(ds_mapped)
)
dl = DataLoader(ds_mapped, batch_size=128, sampler=sampler)
We can now iterate through the data loader:
for batch in dl:
pass
Close the connections in MappedDataset
:
ds_mapped.close()
In practice, use a context manager
with ds_train.mapped(label_keys=["cell_type"]) as ds_mapped:
sampler = WeightedRandomSampler(
weights=ds_mapped.get_label_weights("cell_type"), num_samples=len(ds_mapped)
)
dl = DataLoader(ds_mapped, batch_size=128, sampler=sampler)
for batch in dl:
pass
Show code cell content
# clean up test instance
!lamin delete --force test-scrna
!rm -r ./test-scrna
💡 deleting instance testuser1/test-scrna
✅ deleted instance settings file: /home/runner/.lamin/instance--testuser1--test-scrna.env
✅ instance cache deleted
✅ deleted '.lndb' sqlite file
❗ consider manually deleting your stored data: /home/runner/work/lamin-usecases/lamin-usecases/docs/test-scrna