smurf.start_optimization

smurf.start_optimization(spots_X, celltype_X, cells_X_plus, nonzero_indices_dic, device, num_epochs=1000, learning_rate=0.1, print_each=100, epsilon=0.001, random_seed=42, print_memory=False)

Starts the optimization process to estimate cell-type proportions in each spot using PyTorch.

This function performs optimization using a custom neural network layer implemented in PyTorch. It aims to learn the weights (proportions) of cells in each spot that best reconstruct the observed gene expression data. The optimization minimizes the difference between the predicted and true cell type expression profiles using cosine similarity.

Parameters:
  • spots_X (dict) – A dictionary where each key is a group identifier, and each value is a NumPy array of spot expression matrices for that group.

  • celltype_X (dict) – A dictionary where each key is a group identifier, and each value is a NumPy array of cell-type-specific weight matrices for that group.

  • cells_X_plus (dict) – A dictionary where each key is a group identifier, and each value is a NumPy array of cell expression matrices for that group.

  • nonzero_indices_dic (dict) – A dictionary where each key is a group identifier, and each value is a list of non-zero indices indicating cell presence in spots for that group.

  • device (str) – The device on which to perform the computation (e.g., ‘cpu’ or ‘cuda’).

  • num_epochs (int, optional) – The number of training epochs for the optimization. Defaults to 1000.

  • learning_rate (float, optional) – The learning rate for the optimizer. Defaults to 0.1.

  • print_each (int, optional) – Frequency of printing the training loss. Prints every print_each epochs. Defaults to 100.

  • epsilon (float, optional) – Threshold for early stopping based on minimal loss improvement. Defaults to 1e-3.

  • random_seed (int, optional) – Random seed for reproducibility. Defaults to 42.

  • print_memory (bool, optional) – Whether to print GPU memory usage during training. Requires pynvml. Defaults to False.

Returns:

A dictionary spot_cell_dic where each key is a group identifier, and each value is a list of learned weights (cell proportions) for that group.

Return type:

dict

Dependencies:
  • This function requires the following packages:
    • torch

    • torch.nn

    • torch.optim

    • torch.nn.functional (for cosine_similarity)

    • pynvml (optional, for GPU memory tracking if print_memory is True)