BookmarkSubscribeRSS Feed

Training PyTorch deep learning models in Viya & dealing with multi dimensional tabular data -Part II

Started ‎10-10-2024 by
Modified ‎10-10-2024 by
Views 263

This is the second article in the series titled "Training PyTorch deep learning models in Viya & dealing with multi-dimensional tabular data". In the previous post, I discussed how to train a custom PyTorch model in Viya, using a deep learning model based on the Set Transformers architecture. In this article, I will focus on the next aspect of the topic: working with multi-dimensional tabular data.

To recap, the model takes a set of cells as input, with a binary indicator as the target value. Lets assume the number of cells in the set is N. In a typical deep learning framework, this translates to a tensor with dimensions of (batch_size x N x dim), where "dim" refers to the number of channels (or features) for each cell. This is multi-dimensional in the sense that instead of a sample being of the shape  1 x dim, the shape is N x dim.

When training a model in Viya, the data is stored in CAS tables. Typically, each row in a CAS table represents a single sample, but in our case a collection of rows (one each for a cell) represents a sample. To address this inconsistency, we must address two key issues:

 

  1. Identify methods to store multi-dimensional data in CAS tables.
  2. Modify the wrapper class model definition to ensure that the data retrieved from CAS tables, as identified in step 1 above, is reshaped appropriately for the model to process.

There are two potential approaches to tackle the first point. The first approach is to flatten the input dataset, so that the collection of rows representing a single sample is consolidated into a single row. In this format, each row would represent a complete sample, and the corresponding binary indicator would also be included in the same row. This approach is illustrated in the figure below:

 

                                                                                            PankajAttri_0-1728592275659.png

 

The second approach involves stacking the cell sets (or rows) within the CAS table and repeating the target label for each row in a set. Additionally, a new column should be added to the table to facilitate data sorting. This step is necessary because CAS tables are distributed in memory and lack an inherent ordering structure. The new column will ensure the data is properly sorted during the training step, before batch creation. Without sorting, rows from multiple patients might be combined into a single sample, leading to incorrect results. This approach is illustrated in the figure below.

                   
                                                     

 

                                                                         PankajAttri_2-1728593726488.png

 

 

For the second point, the model definition must be modified to accommodate the two approaches outlined for point 1. These changes are required in both the train_one_batch and score_one_batch functions of the wrapper class (as explained in the previous post), though the adjustments are essentially the same for both. Below are the necessary changes in the train_one_batch function, assuming the first approach from point 1 is followed.

 

Changes needed for Approach 1

# trains model on one batch
    torch.jit.export
    def train_one_batch(self,
                        x: List[torch.Tensor], target: List[torch.Tensor]) -> \
            Tuple[Tensor, List[Tensor], Tensor, List[Tensor]]:
        
        ######### START CHANGE FOR OPTION 1 ########################
        #####   Reshape each sample to convert  ####################
        ##### from single row to multiple rows  ####################
   
        data = x[0].reshape(x[0].shape[0],-1,27) # x[0] is batch size
                                                 # 27 is dimensions for each cell
                                                 # -1 forces reshape of each row in the batch to multiple rows, 
                                                 # all together representing set of cells for a patient.

        ######### FINISH CHANGE FOR OPTION 1 #######################
        ......
        ......
        
    

 

Changes needed for Approach 2

If you follow the second approach from point 1, ensure that the batch size is adjusted so that you are passing m x batch_size rows to the train_one_batch and score_one_batch functions. Here, m represents the number of cells in a patient's set, while batch_size reflects the actual batch size intended for training or scoring.

    torch.jit.export
    def train_one_batch(self,
                        x: List[torch.Tensor], target: List[torch.Tensor]) -> \
            Tuple[Tensor, List[Tensor], Tensor, List[Tensor]]:
            
        data = x[0]
        ######### START CHANGE FOR OPTION 2 ########################
        #####   A batch will contain (m x batch_size) rows. ########
        ##### m represents the number_of_cells in a sample #########
        ##### (or set) and batch_size represents number of #########
        ##### samples to be fed to the model.              #########

        data = data.reshape(-1,self.number_of_cells,27) # Reshape the incoming data to (batch_size x number_of_cells per patient x cell dimension)
        target_class = target[0].reshape(-1,self.number_of_cells) # Reshape the incoming data to (batch_size x target_value per cell per patient)
        target_class = torch.unique(target_class,dim=1).squeeze().double() # Take unique target value per set of cells that represents a patient. This returns a tensor of shape (batch_size x 1)
        ......
        ......

Additionally, the shape of the predictions vector returned from the score_one_batch function should align with the incoming batch size, i.e., m x batch_size. This means the prediction value must be duplicated for each cell within a patient's set. Post-processing will be required to either take the unique prediction or compute the average of the prediction values across all cells in the set to derive the final prediction for the entire sample or patient.

Hopefully, these posts have provided you with a clear understanding of how to train and deploy custom PyTorch models using multi-dimensional tabular data in Viya. The techniques outlined in this series should serve as a solid foundation for handling complex data structures in deep learning models.

Version history
Last update:
‎10-10-2024 05:04 PM
Updated by:
Contributors

SAS Innovate 2025: Register Now

Registration is now open for SAS Innovate 2025 , our biggest and most exciting global event of the year! Join us in Orlando, FL, May 6-9.
Sign up by Dec. 31 to get the 2024 rate of just $495.
Register now!

Free course: Data Literacy Essentials

Data Literacy is for all, even absolute beginners. Jump on board with this free e-learning  and boost your career prospects.

Get Started

Article Tags