Model types

Models can be supervised, where the data is labeled, and unsupervised where the data is not labeled.

There are three types of ML models:

  • Binary classification - predicts values that can only have two states
  • Multi-class - group into a defined set of classification option (Multinomial Logistic Regression)
  • Regression - predict a numeric value

Selecting the Right Model

There are numerous models built-in to SageMaker and some are even pre-trained. Here are the algorithms supported by SageMaker, by data type by use case. Legend: S = Supervised; U = Unsupervised; T = Text; I = Image/video

Tabular algorithms by Use Case

Classification and Regression

All these can be used for binary/multi-class classification AND regression of a numeric or continue value using a supervised model.

  • K-Nearest Neighbors - simple classification or regression; input: K!
  • Factorization Machine - works well with sparse data; pairs of data (user -> page); click predictions and item recommendations;
  • XGBoost uses decision tree; takes CSV and libsvm; memory bound; GPU training on single instance using gpu_hist
  • Linear learner for labeled data classification. RecordIO/Float32; file or pipe mode

Reduce dimensionality

  • Object2Vec (DL) - enables embeddings, the conversion of high-dimensional objects into low-dimensional space. Example: identify duplicate support tickets or find the correct routing based on similarity of text in the tickets; single instance or single/multiple GPU instances; P3 based inference.
  • Principal Component Analysis (PCA) (U) - it isn’t dropping columns that are useless/weak relationship with label/target, it is the combination of highly correlated data into a fewer number of features; RecordIO/CSV using file or pipe mode

Time-series forecast

  • DeepAR forecasting (S) - time-series predictions based on historical data for a behavior, predict a future behavior across multiple time series streams. Input all data in JSON (can be in Parquet). Start with CPU instances for training; multi-machine or single; can use GPU for larger models or large mini-batch sizes (>512). Prefer DeepAR over Autoregressive Integrated Moving Average (ARIMA) and Error, Trend, Seasonality (ETS) forecasts.

Anomaly Detection

  • Random cut forest (RCF) (U) - Detect abnormal behavior in application: spot when an IoT sensor is sending abnormal readings; input: RecordIO-protobuf or CSV using pipe or file mode; training/inference: no GPU required

Detect intrusion by IP

  • IP Insights (U) - used to protect apps from suspicious users by IP address; inputs CSV including usernames and AccountIDs; 1 to many GPU for training

Clustering or grouping

  • K-Means (U) - Group similar objects/data together: find high-, medium-, and low-spending customers from their transaction histories; input: RecordIO-protobuf or CSV using pipe or file mode; CPU inference/training

Topic modeling

Organize a set of documents into topics (not known in advance): tag a document as belonging to a medical category based on the terms used in the document; num_topics is the definition of how many topics to find for both algos

  • Latent Dirichlet Allocation - unsupervised; input: train channel, optional test channel using RecordIO-protobuf or CSV using pipe or file mode containing counts for each word in vocabulary; single CPU based inference/training (so likely cheaper)

  • Neural Topic Model (NTM) - unsupervised - 4 channel input: train channel of RecordIO-protobuf or CSV using pipe or file mode containing tokenized integers and a dictionary (the aux channel)


  • Collaborative filtering - part of Tensorflow 2.0 uses user, item, rating and similarity to other users to create recommendations. Content filtering is similar but not as effective.

Text analysis algorithms

Classify text

BlazingText (S) - base use case is to assign pre-defined categories to sentences in a corpus; Training data must be labeled; input is one sentence per line with __label__ and label prepended or “augmented manifest text format”. Can also implement word2vec for word embedding with tokenized input; has three modes continuous bag of words, skip-gram, and batch skip-gram (for distributed computation)

Language translation, summarization, speech to text

Sequence to Sequence (seq2seq) - input tokenized text file, convert to protobuf; requires vocabulary file; can optimize on accuracy, BLEU score, perplexity; DL so requires GPU; single instance only but does support multiple GPU

Images classifications algorithms

image and multi-label classification (image overall)

MXNet or TensorFlow - Label/tag an image based on the content of the image: alerts about adult content in an image; MXNet includes full training mode and transfer learning model (initialize with pretrained weights, top layer with random weights, fine tune with new data); TensorFlow allows the top layer to be trained for finer tuning; GPU/instances multiple/single for training; M class inference possible likely needs GPU

object detection

detect specific people or objects (bounding box) - MXNet or TensorFlow - Detect people and objects in an image: police review a large photo gallery for a missing person; MXNet input: RecordIO; GPU/instances multiple/single for training

detailed pixel-level tagging

Semantic Segmentation - computer vision (tag pixels into a segmentation mask); input is JPG and PNG ideally in Pipe mode. Built on MXNet; can be trained incrementally. GPU required on single instance for training; inference on CPU or GPU


Transfer Learning

Transfer Learning saves time using pre-trained models; often called foundational models. Four approaches:

  • Fine tune with more data - using the same data type and input size and a low learning rate; network is initialized with pre-trained weights and the output layer is initialized with random weights
  • Add another layer - freeze weights then add new a layer. (fine tuning and layer adding works too)
  • Retrain from scratch - giant amount of work
  • Use as is


Hyperparameters control how a model is trained - they are the switches, nobs, and dials. Common hyperparameters include:

  • epoch count - the number of times the data is passed through the algorithm
  • batch size - the size of a subset of the training dataset to process at a time
  • mini-batch - the amount of a batch sent to a compute environment in the training fleet
  • learning rate - how big a step the model should take when executing optimization
  • optimizer - which algorithm to use for optimization; Gradient descent, Stochastic Gradient decent, and ADAM are common choices with ADAM the most most common choice.
  • predictor type - Specifies the type of target variable as a binary classification, multiclass classification, or regression.

Generally, the learning rate should be increased as the mini-batch size increases.

If the control of the model is less important, automatic hyperparameter optimization can be used - likely implemented as an AI process. Enabling early stopping can help avoid overfitting using an objective metric evaluated at the end of each epoch using TensorFlow, MXNet, Chainer, PyTorch, and Spark.


Regularization is desensitizing the algorithm to the data to prevent over-fitting by adding bias to the model using a regression penalty. Regularization is used mitigate over-fitting models and when the dataset has more features than samples (the data is wider than it is long).

There are two types of regularization that are applicable to non-DL models:

  • L1 (Lasso) which eliminates useless features to generate sparse output; computationally inefficient
  • L2 (Ridge) which reduces the influence of some of the features. Computationally efficient; assumes all features are important. Also called weight decay

Use L2 regularization when there are more features than samples.

Regularization for DL

Several approaches for regularization in DL:

  • Remove layers, decrease neurons
  • Dropout - drops random number of neurons each epoch (common in CNN) - Dropout layers force the network to spread out its learning throughout the network, and can prevent over-fitting resulting from learning concentrating in one spot.
  • Stop early - detects drop in validation dataset accuracy then stops


Bias is how tight the model fit is; low-bias is over-fit. Variance is how much predictions vary from actual data; high-variance is over-fit

Easiest way to think about this:

  • under-fit = high bias/low variance; too cold
  • generalizes well or balanced = moderate bias and variance; just right
  • over-fit = low bias/ high variance; too hot

Ideally, models balance bias and variance. There is an inverse relationship between bias and variance.

Confusion Matrix

A confusion matrix is a tool for visualizing the performance of a single or multi-class model.

Model » TRUE True positive (TP) False positive (FP)
output » FALSE False negative (FN) True negative (TN)

Accuracy = TP + TN / Total Recall = TP / (TP + FN) Precision = TP / (TP + FP)

Specificity = TN / (TN + FP) False positive rate = 1 - specificity


Does the model make correct inferences?

Accuracy is number of correct predictions / # of predictions. For accuracy, the # of correct predictions is TP + TN. This measure requires balanced data, data where TP and TN both occur with some regularity.


Recall (true positive rate or sensitivity) is TP / (TP + FN); ranges from 0 to 1, where higher is better; measures actual positives measured as positives. In other words, how well does the model find fraud, assuming that it’s ok to find non-fraudulent transactions.


Precision (positive predictive value) is TP / (TP + FP); ranges from 0 to 1, where higher is better; measures actual positives from predicted positives. In other words, how well does the model do at making sure that content is safe for children, assuming we are ok if lots of content that is safe is left out.


F1 is the harmonized average of precision and recall F1 = 2 ((precision * recall) / (precision + recall))

Specificity & False Positive Rate

Less commonly metrics include specificity which is TN / (TN + FP). The false positive rate is 1 - Specificity. This is the false alarm rate so lower is better!


ROC (Receiver Operating Characteristic) is useful for providing data on which model works best for the requirements. ROC curve is a graphical plot that shows the diagnostic ability of a binary classifier system as its discrimination threshold is varied.

AUC (Area under the curve) AUC measures the ability of the model to predict a higher score for positive examples as compared to negative examples. AUC is useful for understanding how effective the algorithm is at classifying data; the higher the better; a value of 1 is a unicorn un-achievable score.



accuracy - given balanced data, does model work? recall - How many of the positives got predicted? use when cost of FN is high precision - How many the positive predictions were correct? use when the cost of FP is high false negative rate

Correlation coefficient - negative = poor correlation; positive = good correlation; zero = no correlation; 1 = perfect positive correlation; (AKA Pearson’s correlation coefficient); Bayesian for dependent variables; naive Bayesian for independent variables


Getting good data about the situation: Use AUC metrics against hyperparameters on a scatter chart to evaluate the best model version.

DL: Recall low/FN important? add weight to FN

DL Precision low/FP important? add weight to FP.

Good training, bad validation? Over-fit

Poor training, poor validation? Add training data, increase passes

Over-fit - use fewer features, decrease n-gram size, increase regularization (L2); in a neural net, a Dropout, early stopping, adding noise can be used to regularize.

under-fit - add domain specific features, Cartesian products, increase n-grams size, decrease regularization (L2); in a neural net, increase amount of data OR increase epochs OR increase depth of network

Neural network, loss function oscillating - learning rate too high.

Neural network, loss function slowly decreasing - learning rate too low.

Model Specific

Regression? Mean Absolute Percentage Error (MAPE) OR Root-mean-square error (RMSE)

Classification threshold analysis? ROC/AUC; use scale_pos_weight to adjust the discrimination threshold.

Training k-NN, low precision, low accuracy? data needs normalization

Time-series analysis - trend can be uptrend, downtrend, sideways; seasonality is regular changes that happen annually

Reduce size of DL model - model pruning removes weights that don’t contribute much to the overall model accuracy. Activation functions do not effect model size. SageMaker Debugger to do this.