While powerful, regex can feel daunting as it comes with a lot of features and sub-parts that you need to remember.
In this post, I will illustrate the various concepts underlying regex. The goal is to help you build a good mental model of how a regex pattern works.
Let’s start with a simple example where we are trying to find the word ‘cool’ in the text.
With regex, we could simply type out the word ‘cool’ as the pattern and it will match the word.
'cool'
While regex matched our desired word ‘cool’, the way it operates is not at the word level but the character level. This is the key idea.
Key Idea: Regex works at the character-level, not word-level.
The implication of this is that the regex r'cool'
would match the following sentences as well.
Now that we understand the key idea, let’s understand how we can match simple characters using regex.
We can simply specify the character in the regular expression and it will match all instances in the text.
For example, a regular expression given below will match all instances of ‘a’ in the text. You can use any of the small and capital alphabets.
'a'
You can also use any digits from 0 to 9 and it will work as well.
'3'
Note that regex is case-sensitive by default and thus the following regex won’t match anything.
'A'
We can detect special characters such as whitespace and newlines using special escape sequences.
Besides the common ones above, we have:
Regex provides a bunch of built-in special symbols that can match a group of characters at once. These begin with backslash \
.
\d
It matches any single-digit number between 0 to 9.
Notice that matches are single digit. So we have 4 different matches below instead of a single number 18.04
.
It matches any whitespace character (space, tab or newline).
It matches any of the small alphabets(a to z), capital alphabets(A to Z), digits (0 to 9), and underscore.
It matches any character except the new line (\n).
import re
>>> re.findall(r'.', 'line 1\nline2')
['l', 'i', 'n', 'e', ' ', '1', 'l', 'i', 'n', 'e', '2']
If you use the capitalized versions of the patterns above, they act as negation.
For example, if “\d” matched any digits from 0 to 9, then “\D” will match anything except “0 to 9”.
These are patterns starting with [
and ending with ]
and specify the characters that should be matched enclosed by brackets.
For example, the following pattern matches any of the characters ‘a’, ‘e’, ‘i’, ‘o’, and ‘u’.
You can also replicate the functionality of \d
using the below pattern. It will match any digits between 0 to 9.
Instead of specifying all the digits, we can use -
to specify only start and end digits. So, instead of [0123456789]
, we can do:
For example, [2-4]
can be used to match any digits between 2 to 4 i.e. (2 or 3 or 4).
You can even use the special characters we learned previously inside the brackets. For example, you can match any digit from 0 to 9 or whitespace as:
Below, I have listed some useful common patterns and what they mean.
Regex also has special handlers to make the pattern only match if it’s at the start or end of the string.
We can use the ^
anchor to match patterns only at the start of a line. For example:
Similarly, we can use the $
anchor after the character to match patterns only if it’s the end of the line. For example:
Consider a case where we want to exactly match the word “Mr. Stark”.
If we write a regex like Mr. Stark
, then it will have an unintended effect. Since we know dot has a special meaning in a regex.
So, we should always escape the special metacharacters like .
, $
etc. if our goal is to match the exact character itself.
Here is the list of metacharacters that you should remember to escape if you’re using them directly.
^ $ . * + ? { } [ ] \ | ( )
Now that we can pattern match any characters, we could repeat things and start building more complicated patterns.
Using only what we have learned so far, a naive way would be to just repeat the pattern. For example, we can match two-digit numbers by just repeating the character-level pattern.
\d\d
Regex provides special quantifiers to specify different types of repetition for the character preceding it.
We can use the {...}
quantifier to specify the number of times a pattern should repeat.
For example, the previous pattern for matching 2-digit number can be recreated as:
You can also specify a range of repetitions using the same quantifier. For example, to match from 2-digit to 4-digit numbers, we could use the pattern:
When applied to a sentence, it will match both 4-digit and 2-digit numbers.
There should not be any space between minimum and maximum count For example, \d{2, 4} doesn't work.
Regex also provides quantifiers “*”, “+” and “?” using which you can specify flexible repetition of a character.
0 or 1 times: ?
The ?
quantifier matches the previous character if it repeats 0 or 1 times. This can be useful to make certain parts optional. It is equivalent to {0,1}
.
For example, let’s say we want to match both the word “sound” and “sounds” where “s” is optional. Then, we can use the ?
quantifier that matches if a character repeats 0 or 1 times.
one or more times: +
The +
quantifier matches the previous character if it repeats 1 or more times. It is equivalent to {1,}
.
For example, we could find numbers of any arbitrary length using the regex \d+
.
zero or more times: *
The *
quantifier matches the previous character if it repeats zero or more times. It is equivalent to {0,}
.
Python provides a module called “re” in the standard library to work with regular expression.
To specify a regular expression in Python, we precede it with r to create raw strings.
pattern = r'\d'
To understand why we precede with r, let’s try printing the expression \t without **r**
.
>>> pattern = '\t'
>>> print(pattern)
You can see how when we don’t use raw string, the string \t
is treated as the escape character for tab by Python.
Now let’s convert it into raw string. We get back whatever we specified.
>>> pattern = r'\t'
>>> print(pattern)
\t
To use re
module, we can start by importing the re
module as:
import re
This function allows us to get all the matches as a list of strings.
import re
re.findall(r'\d', '123456')
['1', '2', '3', '4', '5', '6']
This function searches for a pattern at the beginning of the string and returns the first occurrence as a match object. If the pattern is not found, it returns None.
import re
match = re.match(r'batman', 'batman is cool')
print(match)
<re.Match object; span=(0, 6), match='batman'>
With the match object, we can get the matched text as
print(match.group())
batman
In a case where our pattern is not at the start of the sentence, we will not get any match.
import re
match = re.match(r'batman', 'The batman is cool')
print(match)
None
This function also finds the first occurrence of a pattern but the pattern can occur anywhere in the text. If the pattern is not found, it returns None.
import re
match = re.search(r'batman', 'the batman is cool')
print(match.group())
batman
This can be achieved by creatively formulating a problem such that you use parts of the data itself as labels and try to predict that. Such formulations are called pretext tasks.
For example, you can setup a pretext task to predict the color version of the image given the grayscale version. Similarly, you could remove a part of the image and train a model to predict the part from the surrounding. There are many such pretext tasks.
By pre-training on the pretext task, the hope is that the model will learn useful representations. Then, we can finetune the model to downstream tasks such as image classification, object detection, and semantic segmentation with only a small set of labeled training data.
So pretext tasks can help us learn representations. But, this poses a question:
How to determine how good a learned representation is?
Currently, the standard way to gauge the representations is to evaluate it on a set of standard tasks and benchmark datasets.
We can see that the above evaluation methods require us to use the same model architecture for both the pretext task and the target task.
This poses some interesting challenges:
For the pretext task, our goal is to learn on a large-scale unlabeled dataset and thus deeper models(e.g. ResNet) would help us learn better representations.
But, for downstream tasks, we would prefer shallow models(e.g. AlexNet) for actual applications. Thus, we currently have to consider this limitation when designing the pretext task.
It’s harder to fairly compare which pre-text task is better if some methods used simpler architecture while other methods used deeper architecture.
We can’t compare the representations learned from pretext tasks to handcrafted features such as HOG.
We may want to exploit several data domains such as sound, text, and videos in the pretext task but the target task may limit our design choices.
Model trained on pretext task may learn extra knowledge that is not useful for generic visual recognition. Currently, the final task-specific layers are ignored and weights or features only up to certain convolutional layers are taken.
Noroozi et al. proposed a simple idea to tackle these issues in their 2018 paper “Boosting Self-Supervised Learning via Knowledge Transfer”.
The authors observed that in a good representation space, semantically similar data points should be close together.
In regular supervised classification, the information that images are semantically similar is encoded through labels annotated by humans. A model trained on such labels would have a representation space that groups semantically similar images.
Thus, with pre-text tasks in self-supervised learning, the objective is implicitly learning a metric that makes the same category images similar and different category images dissimilar. Hence we can provide a robust estimate of the learned representation if we could encode semantically related images to the same labels in some way.
The authors propose a novel framework to transfer knowledge from a deep self-supervised model to a separate shallow downstream model. You can use different model architectures for the pretext task and downstream task.
Key Idea:
Cluster features from pretext task and assign cluster centers as pseudo-labels for unlabeled images. Then, re-train a small network with target task architecture on pseudo-labels to predict pseudo-labels and learn a novel representation.
The end-to-end process is described below:
Here we choose some deep network architecture and train it on some pretext task of our choice on some dataset. We can take features from some intermediate layer after the model is trained.
Figure: Training on Pre-text Task (Source)
For all the unlabeled images in the dataset, we compute the feature vectors from the pretext task model. Then, we run K-means clustering to group semantically similar images. The idea is that the cluster centers will be aligned with categories in ImageNet.
Figure: Clustering Features (Source)
In the paper, the authors ran K-means on a single Titan X GPU for 4 hours to cluster 1.3M images into 2000 categories.
The cluster centers are treated as the pseudo-label. We can use either the same dataset as the above step or use a different dataset itself. Then, we compute the feature vectors for those images and find the closest cluster center for each image. This cluster center is used as the pseudo-label.
Figure: Generating Pseudo-labels (Source)
We take the model architecture that will be used for downstream tasks and train it to classify the unlabeled images into the pseudo-labels. Thus, the target architecture will learn a new representation such that it will map images that were originally close in the pre-trained feature space to close points.
Figure: Re-training on pseudo-labels (Source)
We saw how by clustering the features and then using pseudo-labels, we can bring the knowledge from any pretext task representations into a common reference model like AlexNet.
As such, we can now easily compare different pretext tasks even if they are trained using different architectures and on different data domains. This also allows us to improve self-supervised methods by using deep models and challenging pretext tasks.
To evaluate the idea quantitatively, the authors set up an experiment as described below:
To evaluate their method, the authors took an old puzzle-like pretext task called “Jigsaw” where we need to predict the permutation that was used to randomly shuffle a 3*3 square grid of image.
Image Modified from Paper
They extended the task by randomly replacing 0 to 2 number of tiles with tile from another random image at some random locations. This increases the difficulty as now we need to solve the problem using only the remaining patches. The new pretext task is called “Jigsaw++”.
Image Modified from Paper
In the paper, they use 701 total permutations which had a minimum hamming distance of 3. They apply mean and standard deviation normalization at each image tile independently. They also make images gray-scale 70% of the time to prevent the network from cheating with low-level statistics.
The authors used VGG-16 to solve the pretext task and learn representations. As VGG-16 has increased capacity, it can better handle the increased complexity of the “Jigsaw++” task and thus extract better representation.
The representations from VGG-16 are clustered and cluster centers are converted to pseudo-labels. Then, AlexNet is trained to classify the pseudo-labels.
For downstream tasks, the convolutional layers for the AlexNet model are initialized with weights from pseudo-label classification and the fully connected layers were randomly initialized. The pre-trained AlexNet is then finetuned on various benchmark datasets.
Using a deeper network like VGG-16 leads to better representation and pseudo-labels and also better results in benchmark tasks. It got state of the art results on several benchmarks in 2018 and reduced the gap between supervised and self-supervised methods further.
The authors tested their method on object classification and detection on PASCAL VOC 2007 dataset and semantic segmentation on PASCAL VOC 2012 dataset.
Task | Clustering | Pre-text architecture | Downstream arch. | Classification | Detection(SS) | Detection(MS) | Segmentation |
---|---|---|---|---|---|---|---|
Jigsaw | no | AlexNet | AlexNet | 67.7 | 53.2 | - | - |
Jigsaw++ | no | AlexNet | AlexNet | 69.8 | 55.5 | 55.7 | 38.1 |
Jigsaw++ | yes | AlexNet | AlexNet | 69.9 | 55.0 | 55.8 | 40.0 |
Jigsaw++ | yes | VGG-16 | AlexNet | 72.5 | 56.5 | 57.2 | 42.6 |
In this, a linear classifier is trained on features extracted from AlexNet at different convolutional layers. For ImageNet, using VGG-16 and transferring knowledge to AlexNet using clustering gives a substantial boost of 2%.
For a non-linear classifier, using VGG-16 and transferring knowledge to AlexNet using clustering gives the best performance on ImageNet.
The network is not significantly affected by the number of clusters. The authors tested AlexNet trained on pseudo-labels from a different number of clusters on the task of object detection.
Knowledge transfer is fundamentally different from knowledge distillation. Here, the goal is to only preserve the cluster association of images from the representation and transfer that to the target model. Unlike distillation, we don’t do any regression to the exact output of the teacher.
Yes, the method is flexible and you can pre-train on one dataset, cluster on another, and get pseudo-labels for the third one.
The authors did an experiment where they trained clustering on representations for ImageNet and then calculated cluster centers on the “Places” dataset to get pseudo-labels. There was only a small reduction (-1.5%) in performance for object classification.
Thus, Knowledge Transfer is a simple and efficient way to map representations from deep to shallow models.
I recently experimented with a way to load sentence embeddings along with the class labels into this tool and explore them interactively. In this blog post, I will explain the end-to-end process with an example dataset.
To understand this use case, let’s take a subset of 100 movie reviews from the SST-2 dataset which are labeled as positive and negative.
import pandas as pd
df = pd.read_csv('http://bit.ly/dataset-sst2',
nrows=100, sep='\t', names=['text', 'label'])
df['label'] = df['label'].replace({0: 'negative', 1: 'positive'})
The dataset has a column containing the text and a label indicating whether it’s positive or negative opinion.
We will introduce noise into our dataset by corrupting five of the responses with random text. It will act as an outlier for our example.
df.loc[[10, 27, 54, 72, 91], 'text'] = 'askgkn askngk kagkasng'
Now, we will compute sentence embeddings for the headlines using the sentence-transformers
package. First, let’s install it using pip.
!pip install sentence-transformers
Next, we will create a helper function to return a NumPy array of sentence embeddings given a list of sentences.
from sentence_transformers import SentenceTransformer
sentence_bert_model = SentenceTransformer('distilbert-base-nli-stsb-mean-tokens')
def get_embeddings(sentences):
return sentence_bert_model.encode(sentences,
batch_size=32,
show_progress_bar=True)
Using the above function, we can generate sentence embeddings for our data as shown below.
e = get_embeddings(df['text'])
# shape: (100, 768)
Embedding Projector requires two TSV files to load our custom embeddings.
output.tsv
: This file should contain the embeddings without any headers.metadata.tsv
: This file should contain the original text and labels for the embeddingsLet’s first generate the output.tsv
file for our sentence embeddings from the previous step.
# Convert NumPy array of embedding into data frame
embedding_df = pd.DataFrame(e)
# Save dataframe as as TSV file without any index and header
embedding_df.to_csv('output.tsv', sep='\t', index=None, header=None)
To generate metadata.csv
, we simply save our original dataframe.
# Save dataframe without any index
df.to_csv('metadata.tsv', index=False, sep='\t')
We first go to https://projector.tensorflow.org/.
On the left-hand sidebar, click the Load button.
Then, for the first Choose file button, upload the output.tsv
file and for the second Choose file button, upload the metadata.tsv
file.
After uploading both files, click outside and you should see the sentence embedding projection. The dimensions of embeddings are reduced to 3D by default using PCA.
Let’s switch to 2D by turning off the checkbox for ‘Component #3’ in the bottom part of sidebar.
On the 2D visualization, we can see how the random text is far from other groups of text as an outlier. On hovering the point, we see the text askgkn askngk kagkasng
.
We can enable color coding of the points by their actual labels (positive vs negative) by using the Color by dropdown in the left sidebar.
Select the name of the column that contains your labels. In our example file, the column name is label.
The points themselves are interactive. You can see the actual sentence for each point by hovering over them.
You can click on the point to show the metadata. We can see below on clicking a blue point that its label is “positive” in the popup.
So the blue points are positive and the red points are negative. When a point is selected, 100 nearest points in terms of cosine similarity are also highlighted.
To get back to the original view, we can click on any empty white space.
The color coding can be a useful heuristic for many use cases:
The web app provides three standard dimensionality reduction techniques: UMAP, T-SNE, and PCA.
You can choose the algorithm and their parameters from the bottom of the left sidebar.
You can also use a custom keyword or full text as the axis using the CUSTOM tab. This will apply a custom linear projection and can help us explore meaningful directions in the embedding space.
For example, the Gmail team tried setting “yeah” on the left side and “yes” on the right side. When they projected encoder embeddings for email replies to this custom linear projection, they found replies in a casual tone (e.g. Here you go) on the left side and responses in a more formal tone clustered towards the right side.
Thus, Embedding Projector is a very useful tool to better understand the datasets and models we work with.
With this setup, you can still prototype in the Colab Notebook while also using VSCode for all the advantages of a full-fledged code editor. Here is how you can replicate my setup.
In this setup, we use the colab-code package that automates all the manual setup steps previously described in the Approach 2 section of this blog post. You can make a copy of this notebook directly to get started.
First, install the colab-code
package using the following command:
pip install colabcode
Now, import ColabCode
class from the package and specify the port and password.
from colabcode import ColabCode
ColabCode(port=10000, password="password123")
You can also use it directly with the default port and without any password as shown below.
from colabcode import ColabCode
ColabCode()
You will get the ngrok URL in the output. Click the link and a login page will open in a new tab.
Type the password you had set in step 2 and click submit. If the page gets stuck for more than 4-5 seconds, refresh the page and you should be redirected to the editor.
Now you will get access to the editor interface and can use it to work on python files.
I have described the setup steps in detail below. After going through all the steps, please use this colab notebook to try it out directly.
First, we will install the code-server package to run VSCode editor as a web app. Copy and run the following command on colab to install code-server
.
!curl -fsSL https://code-server.dev/install.sh | sh
After the installation is complete, we will expose a random port 9000
to an external URL we can access using the pyngrok
package. To install pyngrok
, run
!pip install -qqq pyngrok
Then, run the following command to get a public ngrok URL. This will be the URL we will use to access VSCode.
from pyngrok import ngrok
url = ngrok.connect(port=9000)
print(url)
Now, we will start the VSCode server in the background at port 9000 without any authentication using the following command.
!nohup code-server --port 9000 --auth none &
Now, you can access the VSCode interface at the URL you got from step 3. The interface and functionality are the same as the desktop version of VSCode.
You can switch to the dark theme by going to the bottom-left corner of the editor, clicking the settings icon, and then clicking ‘Color Theme’.
A popup will open. Select Dark (Visual Studio) in the options and the editor will switch to a dark theme.
All the keyword shortcuts of regular VSCode works with this. For example, you can use Ctrl + Shift + P
to open a popup for various actions.
To open a terminal, you can use the shortcut Ctrl + Shift + `
.
To get python code completions, you can install the Python(ms-python
) extension from the extensions page on the left sidebar.
The Colab interface is still usable as a notebook and regular functions to upload and download files and mount with Google Drive. Thus, you get the benefits of both a notebook and a code editor.
These models were originally trained by Jörg Tiedemann of the Language Technology Research Group at the University of Helsinki. They were trained on the Open Parallel Corpus(OPUS) using a neural machine translation framework called MarianNMT.
In this post, I will explain how you can use the MarianMT models to augment data text data.
We will use a data augmentation technique called “Back Translation”. In this, we take an original text written in English. Then, we convert it into another language (eg. French) using MarianMT. We translate the French text back into English using MarianMT. We keep the back-translated English text if it is different from the original English sentence.
First, we need to install Hugging Face transformers and Moses Tokenizers with the following command
pip install transformers==4.1.1 sentencepiece==0.1.94
pip install mosestokenizer==1.1.0
After installation, we can now import the MarianMT model and tokenizer.
from transformers import MarianMTModel, MarianTokenizer
Then, we can create a initialize the model that can translate from English to Romance languages. This is a single model that can translate to any of the romance languages()
target_model_name = 'Helsinki-NLP/opus-mt-en-ROMANCE'
target_tokenizer = MarianTokenizer.from_pretrained(target_model_name)
target_model = MarianMTModel.from_pretrained(target_model_name)
Similarly, we can initialize models that can translate Romance languages to English.
en_model_name = 'Helsinki-NLP/opus-mt-ROMANCE-en'
en_tokenizer = MarianTokenizer.from_pretrained(en_model_name)
en_model = MarianMTModel.from_pretrained(en_model_name)
Next, we write a helper function to translate a batch of text given the machine translation model, tokenizer and the target romance language.
def translate(texts, model, tokenizer, language="fr"):
# Prepare the text data into appropriate format for the model
template = lambda text: f"{text}" if language == "en" else f">>{language}<< {text}"
src_texts = [template(text) for text in texts]
# Tokenize the texts
encoded = tokenizer.prepare_seq2seq_batch(src_texts)
# Generate translation using model
translated = model.generate(**encoded)
# Convert the generated tokens indices back into text
translated_texts = tokenizer.batch_decode(translated, skip_special_tokens=True)
return translated_texts
Next, we will prepare a function to use the above translate()
function to perform back translation.
def back_translate(texts, source_lang="en", target_lang="fr"):
# Translate from source to target language
fr_texts = translate(texts, target_model, target_tokenizer,
language=target_lang)
# Translate from target language back to source language
back_translated_texts = translate(fr_texts, en_model, en_tokenizer,
language=source_lang)
return back_translated_texts
Now, we can perform data augmentation using back-translation from English to Spanish on a list of sentences as shown below.
en_texts = ['This is so cool', 'I hated the food', 'They were very helpful']
aug_texts = back_translate(en_texts, source_lang="en", target_lang="es")
print(aug_texts)
["Yeah, it's so cool.", "It's the food I hated.", 'They were of great help.']
Similarly, we can perform augmentation using English to French as shown below with the exact same helper method.
en_texts = ['This is so cool', 'I hated the food', 'They were very helpful']
aug_texts = back_translate(en_texts, source_lang="en", target_lang="fr")
print(aug_texts)
["It's so cool.", 'I hated food.', "They've been very helpful."]
You can also run back translation in a chain to get more diversity. For example, English -> Spanish -> English -> French -> English
en_texts = ['This is so cool', 'I hated the food', 'They were very helpful']
aug1_texts = back_translate(en_texts, source_lang="en", target_lang="es")
aug2_texts = back_translate(aug1_texts, source_lang="en", target_lang="fr")
print(aug2_texts)
["Yeah, that's cool.", "It's the food I hated.", 'They were of great help.']
Here are language codes for a subset of major romance language that you can use above.
Language | French | Spanish | Italian | Portuguese | Romanian | Catalan | Galician | Latin |
---|---|---|---|---|---|---|---|---|
Code | fr | es | it | pt | ro | ca | gl | la |
Language | Walloon | Occitan (post 1500) | Sardinian | Aragonese | Corsican | Romansh |
---|---|---|---|---|---|---|
Code | wa | oc | sn | an | co | rm |
To view all available language codes, you can run
target_tokenizer.supported_language_codes
Besides data augmentation, the back translation process can also be used for text paraphrasing.
Similarly, we can also use it as an adversarial attack. Suppose we have a training dataset on which we trained an NLP model. Then, we can augment the training dataset and generate prediction from our model on augmented texts. If the predictions are different than our ground-truth labels, then we have a list of texts where our model fails. We can get good insights by analyzing those responses.
Thus, MarianMT is a decent free and offline alternative to Google Translate for back-translation.
Such extracted keywords can be used for various applications. They can be used to summarize the underlying theme of a large document with just a few terms. They are also valuable as metadata for indexing and tagging the documents. They can likewise be used for clustering similar documents. For instance, to showcase relevant advertisements on a webpage, we could extract keywords from the webpage, find matching advertisements for these keywords, and showcase those.
In this post, I will provide an overview of the general pipeline of keyword extraction and explain the working mechanism of various unsupervised algorithms for this.
For keyword extraction, all algorithms follow a similar pipeline as shown below. A document is preprocessed to remove less informative words like stop words, punctuation, and split into terms. Candidate keywords such as words and phrases are chosen.
Then, a score is determined for each candidate keyword using some algorithm. The highest-ranking keywords are selected and post-processing such as removing near-duplicates is applied. Finally, the algorithm returns the top N ranking keywords as output.
Unsupervised algorithms for keyword extraction don’t need to be trained on the corpus and don’t need any pre-defined rules, dictionary, or thesaurus. They can use statistical features from the text itself and as such can be applied to large documents easily without re-training. Most of these algorithms don’t need any linguistic features except for stop word lists and so can be applied to multiple languages.
Let’s understand each algorithm by starting from simple methods and gradually adding complexity.
This is a simple method which only takes into account how many times each term occurs.
Let’s understand it by applying it to an example document.
In this step, we lowercase the text and remove low informative words such as stop words from the text.
We split the remaining terms by space and punctuation symbols to get a list of possible keywords.
We can count the number of times each term occurs to get a score for each term.
Candidate | anything | mass | occupies | space | called | matter | exists | various | states | … |
---|---|---|---|---|---|---|---|---|---|---|
Count | 1 | 1 | 1 | 1 | 1 | 2 | 1 | 1 | 1 | … |
We can sort the keywords in descending order based on the counts and take the top N keywords as the output.
This method has an obvious drawback of only focusing on frequency. But, generic words are likely to be very frequent in any document but are not representative of the domain and topic of the document. We need some way to filter out generic terms.
This method takes into account both how frequent the keyphrase is and also how rare it is across the documents.
Let’s understand how it works by going through the various steps of the pipeline:
In this step, we lowercase the text and split the document into sentences.
We generate 1-gram, 2-gram, and 3-grams candidate phrases from each sentence such that they don’t contain any punctuations. These are our list of candidate phrases.
Now, for each candidate keyword “w”, we calculate the TF-IDF score in the following steps.
First, the term frequency(TF) is calculated simply by counting the occurrence of the word.
\[TF(w) = count(w)\]Then, the inverse document frequency(IDF) is calculated by dividing the total number of documents by the number of documents that contain the word “w” and taking the log of that quantity.
\[IDF(W) = log(\ \frac{total\ documents}{number\ of\ docs\ containing\ word\ w}\ )\]Finally, we get the TF-IDF
score for a term by multiplying the two quantities.
We can sort the keywords in descending order based on their TF-IDF scores and take the top N keywords as the output.
RAKE is a domain-independent keyword extraction method proposed in 2010. It uses word frequency and co-occurrence to identify the keywords. It is very useful for identifying relevant multi-word expressions.
Let’s apply RAKE on a toy example document to understand how it works:
First, the stop words in the document are removed.
We split the document at the stop word positions and punctuations to get content words. The words that occur consecutively without any stop word between them are taken as candidate keywords.
For example, “Deep Learning” is treated as a single keyword.
Next, the frequency of all the individual words in the candidate keywords are calculated. This finds words that occur frequently.
deep | learning | subfield | ai | useful | |
---|---|---|---|---|---|
Word Frequency: \(freq(w)\) | 1 | 1 | 1 | 1 | 1 |
Similarly, the word co-occurrence count is calculated and the degree for each word is the total sum. This metric identifies words that occur often in longer candidate keywords.
deep | learning | subfield | ai | useful | |
---|---|---|---|---|---|
deep | 1 | 1 | 0 | 0 | 0 |
learning | 1 | 1 | 0 | 0 | 0 |
subfield | 0 | 0 | 1 | 0 | 0 |
ai | 0 | 0 | 0 | 1 | 0 |
useful | 0 | 0 | 0 | 0 | 1 |
degree: \(deg(w)\) | 1 + 1 = 2 | 1 + 1 = 2 | 1 | 1 | 1 |
Then, we divide the degree by the frequency for each word to get a final score. This score identifies words that occur more in longer candidate keywords than individually.
deep | learning | subfield | ai | useful | |
---|---|---|---|---|---|
Score = \(\frac{deg(w)}{freq(w)}\) | 2 / 1 = 2 | 2 / 1 = 2 | 1 / 1 = 1 | 1 / 1 = 1 | 1 / 1 = 1 |
Finally, we calculate the scores for our candidate keywords by adding the scores for their member words. The higher the score, the more useful a keyword is.
Keyword | Score | Remarks |
---|---|---|
deep learning | 4 | score(deep) + score(learning) = 2 + 2 = 4 |
subfield | 1 | score(subfield) = 1 |
ai | 1 | score(ai) = 1 |
useful | 1 | score(useful) = 1 |
Thus, the keywords are sorted in the descending order of their score value. We can select the top-N keywords from this list.
We can use the rake-nltk library to use it in Python as shown below.
pip install rake-nltk
from rake_nltk import Rake
rake = Rake()
text = 'Deep Learning is a subfield of AI. It is very useful.'
rake.extract_keywords_from_text(text)
print(rake.get_ranked_phrases_with_scores())
# [(4.0, 'deep learning'), (1.0, 'useful'), (1.0, 'subfield'), (1.0, 'ai')]
YAKE is another popular keyword extraction algorithm proposed in 2018. It outperforms TF-IDF and RAKE across many datasets and went on to win the best “short paper award” at ECIR 2018.
YAKE uses statistical features to identify and rank the most important keywords. It doesn’t need any linguistic information like NER or POS tagging and thus can be used with any language. It only requires a stop word list for the language.
The sentences are split into terms using space and special character(line break, bracket, comma, period) as the delimiter.
We decide the maximum length of the keyword to be generated. If we decide max length of 3, then 1-gram, 2-gram, and 3-gram candidate phrases are generated using a sliding window.
Then, we remove phrases that contain punctuation marks. Also, phrases that begin and end with a stop word are removed.
YAKE uses 5 features to quantify how good each word is.
This feature considers the casing of the word. It gives more importance to capitalized words and acronyms such as “NASA”.
First, we count the number of times the word starts with a capital letter when it is not the beginning word of the sentence. We also count the times when the word is in acronym form.
Then, we take the maximum of the two counts and normalize it by the log of the total count.
\[casing(w) = \frac{max( count(w\ is\ capital), count(w\ is\ acronym) )}{1 + log(count(w))}\]This feature gives more importance to words present at the beginning of the document. It’s based on the assumption that relevant keywords are usually concentrated more at the beginning of a document.
First, we get all the sentence positions where the word “w” occurs.
\[Sen(w) = positions\ of\ sentences\ where\ w\ occurs\]Then, we compute the position feature by taking the median position and applying the following formula:
\[position(w) = log( log( 3 + Median(Sen(w)) ) )\]This feature calculates the frequency of the words normalized by 1-standard deviation from the mean.
\[frequency(w) = \frac{count\ of\ word\ w}{mean(counts) + standard\ deviation(counts)}\]This feature quantifies how related a word is to its context. For that, it counts how many different terms occur to the left or right of a candidate word. If the word occurs frequently with different words on the left or right side, it is more likely to be a stop word.
\[relatedness(w) = 1 + (WR + WL) * \frac{count(w)}{max\ count} + PL + PR\]where,
This feature quantifies how often a candidate word occurs with different sentences. A word that often occurs in different sentences has a higher score.
\[different(w) = \frac{number\ of\ sentences\ w\ occurs\ in}{total\ sentences}\]These 5 features are combined into a single score S(w) using the formula:
\[score(w) = \frac{d * b}{a + (c / d) + (e / d)}\]where,
Now, for each of our candidate keywords, a score is calculated using the following formula. The count of keyword penalizes less frequent keywords.
\[S(kw) = \frac{product(scores\ of\ words\ in\ keyword)}{1 + (sum\ of\ scores\ of\ words) * count(keyword)}\]It’s pretty common to get similar candidates when extracting keyphrases. For example, we could have variations like:
To eliminate such duplicates, the following process is applied:
Thus, the chosen keyword list contains the final deduplicated keywords.
Thus, we have a list of keywords along with their scores. A keyword is more important if it has a lower score.
We can sort the keywords in ascending order and take the top N keywords as the output.
To apply YAKE, we will use the pke library. First, we need to install the library and its dependencies using the following command:
pip install git+https://github.com/boudinfl/pke.git
python -m nltk.downloader stopwords
python -m spacy download en
Then, we can use YAKE to generate keywords of maximum length 2 as shown below.
from pke.unsupervised import YAKE
from nltk.corpus import stopwords
document = "Machine learning (ML) is the study of computer algorithms that improve automatically through experience. It is seen as a subset of artificial intelligence."
# 1. Create YAKE keyword extractor
extractor = YAKE()
# 2. Load document
extractor.load_document(input=document,
language='en',
normalization=None)
# 3. Generate candidate 1-gram and 2-gram keywords
stoplist = stopwords.words('english')
extractor.candidate_selection(n=2, stoplist=stoplist)
# 4. Calculate scores for the candidate keywords
extractor.candidate_weighting(window=2,
stoplist=stoplist,
use_stems=False)
# 5. Select 10 highest ranked keywords
# Remove redundant keywords with similarity above 80%
key_phrases = extractor.get_n_best(n=10, threshold=0.8)
print(key_phrases)
You get back a list of top-10 keywords and their scores. The highest ranked keyword has the lowest score.
[('machine learning', 0.01552184797949213),
('computer algorithms', 0.04188746641162499),
('improve automatically', 0.04188746641162499),
('machine', 0.12363091320521931),
('learning', 0.12363091320521931),
('experience', 0.12363091320521931),
('artificial intelligence', 0.18075564686791562),
('study', 0.2005079697193566),
('computer', 0.2005079697193566),
('algorithms', 0.2005079697193566)]
As users, the workflow is pretty simple. We can search for items by writing our queries in a search box and the ranking model in their system gives us back the top-N most relevant results.
How do we evaluate how good the top-N results are?
In this post, I will answer the above question by explaining the common offline metrics used in learning to rank problems. These metrics are useful not only for evaluating search results but also for problems like keyword extraction and item recommendation.
Let’s take a simple toy example to understand the details and trade-offs of various evaluation metrics.
We have a ranking model that gives us back 5-most relevant results for a certain query. The first, third, and fifth results were relevant as per our ground-truth annotation.
Let’s look at various metrics to evaluate this simple example.
This metric quantifies how many items in the top-K results were relevant. Mathematically, this is given by:
\[Precision@k = \frac{ true\ positives@k}{(true\ positives@k) + (false\ positives@k)}\]For our example, precision@1 = 1 as all items in the first 1 results is relevant.
Similarly, precision@2 = 0.5 as only one of the top-2 results are relevant.
Thus, we can calculate the precision score for all k values.
k | 1 | 2 | 3 | 4 | 5 |
---|---|---|---|---|---|
Precision@k | \(\frac{1}{1}=1\) | \(\frac{1}{2}=0.5\) | \(\frac{2}{3}=0.67\) | \(\frac{2}{4}=0.5\) | \(\frac{3}{5}=0.6\) |
A limitation of precision@k is that it doesn’t consider the position of the relevant items. Consider two models A and B that have the same number of relevant results i.e. 3 out of 5.
For model A, the first three items were relevant, while for model B, the last three items were relevant. Precision@5 would be the same for both of these models even though model A is better.
This metric gives how many actual relevant results were shown out of all actual relevant results for the query. Mathematically, this is given by:
\[Recall@k = \frac{ true\ positives@k}{(true\ positives@k) + (false\ negatives@k)}\]For our example, recall@1 = 0.33 as only one of the 3 actual relevant items are present.
Similarly, recall@3 = 0.67 as only two of the 3 actual relevant items are present.
Thus, we can calculate the recall score for different K values.
k | 1 | 2 | 3 | 4 | 5 |
---|---|---|---|---|---|
Recall@k | \(\frac{1}{(1+2)}=\frac{1}{3}=0.33\) | \(\frac{1}{(1+2)}=\frac{1}{3}=0.33\) | \(\frac{2}{(2+1)}=\frac{2}{3}=0.67\) | \(\frac{2}{(2+1)}=\frac{2}{3}=0.67\) | \(\frac{3}{(3+0)}=\frac{3}{3}=1\) |
This is a combined metric that incorporates both Precision@k and Recall@k by taking their harmonic mean. We can calculate it as:
\[F1@k = \frac{2*(Precision@k) * (Recall@k)}{(Precision@k) + (Recall@k)}\]Using the previously calculated values of precision and recall, we can calculate F1-scores for different K values as shown below.
k | 1 | 2 | 3 | 4 | 5 |
---|---|---|---|---|---|
Precision@k | 1 | 1/2 | 2/3 | 1/2 | 3/5 |
Recall@k | 1/3 | 1/3 | 2/3 | 2/3 | 1 |
F1@k | \(\frac{2*1*(1/3)}{(1+1/3)}=0.5\) | \(\frac{2*(1/2)*(1/3)}{(1/2+1/3)}=0.4\) | \(\frac{2*(2/3)*(2/3)}{(2/3+2/3)}=0.666\) | \(\frac{2*(1/2)*(2/3)}{(1/2+2/3)}=0.571\) | \(\frac{2*(3/5)*1}{(3/5+1)}=0.749\) |
While precision, recall, and F1 give us a single-value metric, they don’t consider the order in which the returned search results are sent. To solve that limitation, people have devised order-aware metrics given below:
This metric is useful when we want our system to return the best relevant item and want that item to be at a higher position. Mathematically, this is given by:
\[MRR = \frac{1}{|Q|} \sum_{i=1}^{|Q|} \frac{1}{rank_{i}}\]where:
To calculate MRR, we first calculate the reciprocal rank. It is simply the reciprocal of the rank of the first correct relevant result and the value ranges from 0 to 1.
For our example, the reciprocal rank is \(\frac{1}{1}=1\) as the first correct item is at position 1.
Let’s see another example where the only one relevant result is present at the end of the list i.e. position 5. It gets a lower reciprocal rank score of 0.2.
Let’s consider another example where none of the returned results are relevant. In such a scenario, the reciprocal rank will be 0.
For multiple different queries, we can calculate the MRR by taking the mean of the reciprocal rank for each query.
We can see that MRR doesn’t care about the position of the remaining relevant results. So, if your use-case requires returning multiple relevant results in the best possible way, MRR is not a suitable metric.
Average Precision is a metric that evaluates whether all of the ground-truth relevant items selected by the model are ranked higher or not. Unlike MRR, it considers all the relevant items.
Mathematically, it is given by:
\[AP = \frac{\sum_{k=1}^{n} (P(k) * rel(k))}{number\ of\ relevant\ items}\]where:
For our example, we can calculate the AP based on our Precision@K values for different K.
\[AP = \frac{(1 + 2/3 + 3/5)}{3} = 0.7555\]To illustrate the advantage of AP, let’s take our previous example but place the 3 relevant results at the beginning. We can see that this gets a perfect AP score than the above example.
\[AP = \frac{(1 + 1 + 1)}{3} = 1\]If we want to evaluate average precision across multiple queries, we can use the MAP. It is simply the mean of the average precision for all queries. Mathematically, this is given by
\[MAP = \frac{1}{Q} \sum_{q=1}^{Q} AP(q)\]where
Let’s take another toy example where we annotated the items not just as relevant or not-relevant but instead used a grading scale between 0 to 5 where 0 denotes least relevant and 5 denotes the most relevant.
We have a ranking model that gives us back 5-most relevant results for a certain query. The first item had a relevance score of 3 as per our ground-truth annotation, the second item has a relevance score of 2 and so on.
Let’s understand the various metrics to evaluate this type of setup.
This metric uses a simple idea to just sum up the relevance scores for top-K items. The total score is called cumulative gain. Mathematically, this is given by:
\[CG@k = \sum_{1}^{k} rel_{i}\]For our example, CG@2 will be 5 because we add the first two relevance scores 3 and 2.
Similarly, we can calculate the cumulative gain for all the K-values as:
Position(k) | 1 | 2 | 3 | 4 | 5 |
---|---|---|---|---|---|
Cumulative Gain@k | 3 | 3+2=5 | 3+2+3=8 | 3+2+3+0=8 | 3+2+3+0+1=9 |
While simple, CG doesn’t take into account the order of the relevant items. So, even if we swap a less-relevant item to the first position, the CG@2 will be the same.
We saw how a simple cumulative gain doesn’t take into account the position. But, we would normally want items with a high relevance score to be present at a better rank.
Consider an example below. With the cumulative gain, we are simply adding the scores without taking into account their position.
An item with a relevance score of 3 at position 1 is better than the same item with relevance score 3 at position 2.
So, we need some way to penalize the scores by their position. DCG introduces a log-based penalty function to reduce the relevance score at each position. For 5 items, the penalty would be
\(i\) | \(log_{2}(i+1)\) |
---|---|
1 | \(log_{2}(1+1) = log_{2}(2) = 1\) |
2 | \(log_{2}(2+1) = log_{2}(3) = 1.5849625007211563\) |
3 | \(log_{2}(3+1) = log_{2}(4) = 2\) |
4 | \(log_{2}(4+1) = log_{2}(5) = 2.321928094887362\) |
5 | \(log_{2}(5+1) = log_{2}(6) = 2.584962500721156\) |
Using this penalty, we can now calculate the discounted cumulative gain simply by taking the sum of the relevance score normalized by the penalty. Mathematically, this is given by:
\[DCG@k = \sum_{i=1}^{k} \frac{ \color{#81c784}{rel_{i}} }{ \color{#e57373}{log_{2}(i + 1)} }\]To understand the behavior of the log-penalty, let’s plot ranking position in x-axis and the percentage of relevance score i.e. \(\frac{1}{log_{2}(i+1)} * 100\) in the y-axis. As seen, in position 1, we don’t apply any penalty and score remains unchanged. But, the percentage of score kept decays exponentially from 100% in position 1 to 63% in position 2, 50% in position 3, and so on.
Let’s now calculate DCG for our example.
\(Position(i)\) | \(Relevance(rel_{i})\) | \(log_{2}(i+1)\) | \(\frac{rel_{i}}{log_{2}(i+1)}\) |
---|---|---|---|
1 | 3 | \(log_{2}(1+1) = log_{2}(2) = 1\) | 3 / 1 = 3 |
2 | 2 | \(log_{2}(2+1) = log_{2}(3) = 1.5849625007211563\) | 2 / 1.5849 = 1.2618 |
3 | 3 | \(log_{2}(3+1) = log_{2}(4) = 2\) | 3 / 2 = 1.5 |
4 | 0 | \(log_{2}(4+1) = log_{2}(5) = 2.321928094887362\) | 0 / 2.3219 = 0 |
5 | 1 | \(log_{2}(5+1) = log_{2}(6) = 2.584962500721156\) | 1 / 2.5849 = 0.3868 |
Based on these penalized scores, we can now calculate DCG at various k values simply by taking their sum up to k.
k | DCG@k |
---|---|
DCG@1 | \(3\) |
DCG@2 | \(3+1.2618=4.2618\) |
DCG@3 | \(3+1.2618+1.5=5.7618\) |
DCG@4 | \(3+1.2618+1.5+0=5.7618\) |
DCG@5 | \(3+1.2618+1.5+0+0.3868 = 6.1486\) |
There is also an alternative formulation for DCG@K that gives more penalty if relevant items are ranked lower. This formulation is preferred more in industry.
\[DCG@k = \sum_{i=1}^{k} \frac{ \color{#81c784}{2^{rel_{i}} - 1} }{ \color{#e57373}{log_{2}(i + 1)} }\]While DCG solves the issues with cumulative gain, it has a limitation. Suppose we a query Q1 with 3 results and query Q2 with 5 results. Then the query with 5 results Q2 will have a larger overall DCG score. But we can’t say that query 2 was better than query 1.
To allow a comparison of DCG across queries, we can use NDCG that normalizes the DCG values using the ideal order of the relevant items.
Let’s take our previous example where we had already calculated the DCG values at various K values.
k | DCG@k |
---|---|
DCG@1 | \(3\) |
DCG@2 | \(3+1.2618=4.2618\) |
DCG@3 | \(3+1.2618+1.5=5.7618\) |
DCG@4 | \(3+1.2618+1.5+0=5.7618\) |
DCG@5 | \(3+1.2618+1.5+0+0.3868 = 6.1486\) |
For our example, ideally, we would have wanted the items to be sorted in descending order of relevance scores.
Let’s calculate the ideal DCG(IDCG) for this order.
\(Position(i)\) | \(Relevance(rel_{i})\) | \(log_{2}(i+1)\) | \(\frac{rel_{i}}{log_{2}(i+1)}\) | IDCG@k |
---|---|---|---|---|
1 | 3 | \(log_{2}(2) = 1\) | 3 / 1 = 3 | 3 |
2 | 3 | \(log_{2}(3) = 1.5849\) | 3 / 1.5849 = 1.8927 | 3+1.8927=4.8927 |
3 | 2 | \(log_{2}(4) = 2\) | 2 / 2 = 1 | 3+1.8927+1=5.8927 |
4 | 1 | \(log_{2}(5) = 2.3219\) | 1 / 2.3219 = 0.4306 | 3+1.8927+1+0.4306=6.3233 |
5 | 0 | \(log_{2}(6) = 2.5849\) | 0 / 2.5849 = 0 | 3+1.8927+1+0.4306+0=6.3233 |
Now we can calculate the NDCG@k for various k by diving DCG@k by IDCG@k as shown below:
\[NDCG@k = \frac{DCG@k}{IDCG@k}\]\(k\) | DCG@k | IDCG@k | NDCG@k |
---|---|---|---|
1 | 3 | 3 | 3 / 3 = 1 |
2 | 4.2618 | 4.8927 | 4.2618 / 4.8927 = 0.8710 |
3 | 5.7618 | 5.8927 | 5.7618 / 5.8927 = 0.9777 |
4 | 5.7618 | 6.3233 | 5.7618 / 6.3233 = 0.9112 |
5 | 6.1486 | 6.3233 | 6.1486 / 6.3233 = 0.9723 |
Thus, we get NDCG scores with a range between 0 and 1. A perfect ranking would get a score of 1. We can also compare NDCG@k scores of different queries since it’s a normalized score.
Thus, we learned about various evaluation metrics for both binary and graded ground-truth labels and how each metric improves upon the previous.
But, when real users start using it, the story could be completely different than what our 95% performance metric was saying. Our model might perform poorly even on simple variations of the training text.
In contrast, the field of software engineering uses a suite of unit tests, integration tests, and end-to-end tests to evaluate all aspects of the product for failures. An application is deployed to production only after passing these rigorous tests.
Ribeiro et al. noticed this gap and took inspiration from software engineering to propose an evaluation methodology for NLP called “CheckList”. Their paper won the best overall paper award at ACL 2020.
In this post, I will explain the overall concept of CheckList and the various components that it proposes for evaluating NLP models.
To understand CheckList, let’s first understand behavioral testing in the context of software engineering.
Behavioral testing, also known as black-box testing, is a method where we test a piece of software based on its expected input and output. We don’t need access to the actual implementation details.
For example, let’s say you have a function that adds two numbers together.
def add(a, b):
return a + b
We can evaluate this function by writing tests to compare it’s output to the expected answer. We are not concerned with how this function was implemented internally.
def test_add():
assert add(1, 2) == 3
assert add(1, 0) == 1
assert add(-1, 1) == 0
assert add(-1, -1) == -2
Even for a simple function such as addition, there are capabilities that it should satisfy. For example, the addition of a number with zero should yield the original number itself.
Capability | Function Signature | Output | Expected | Test Passed |
---|---|---|---|---|
Two Positive Numbers | add(1, 2) | 3 | 3 | Yes |
No Change with Zero | add(1, 0) | 1 | 1 | Yes |
Opposite Numbers | add(-1, 1) | 0 | 0 | Yes |
Two Negative Number | add(-1, -1) | -2 | -2 | Yes |
Pass Rate | 4/4 = 100% |
CheckList proposes a general framework for writing behavioral tests for any NLP model and task.
The core idea is based on a conceptual matrix that is composed of linguistic capabilities as rows and test types as columns. The intersecting cells contain multiple test examples generated from templates that we run and calculate the failure rate for.
Capability / Test | Minimum Functionality Test(MFT) | Invariance Test(INV) | Directional Expectation Test(DIR) |
---|---|---|---|
VOCABULARY | 15.0% | 16.2% | 34.6% |
NER | 0.0% | 20.8% | - |
NEGATION | 76.4% | - | - |
… |
By calculating the failure rates for various test types and capabilities, we can know exactly where our model is weak.
Let’s understand each part of this conceptual matrix in detail now.
These are the columns in the previous matrix. There are 3 types of tests proposed in the CheckList framework:
This test is similar to unit tests in software engineering. We build a collection of (text, expected label) pairs from scratch and test the model on this collection.
For example, we are testing the negation capability of the model using an MFT test below.
Template: I {NEGATION} {POS_VERB} the {THING}
The goal of this test is to make sure the model is not taking any shortcuts and possesses linguistic capabilities.
In this test, we perturb our existing training examples in a way that the label should not change. Then, the model is tested on this perturbed example and the model passes the test only if its prediction remains the same (i.e invariant).
For example, changing the location from Chicago to Dallas should not change the original sentiment of a text.
We can use different perturbation functions to test different capabilities. The paper mentions two examples:
Capability | Perturbation | Invariance |
---|---|---|
NER | Change location name in text | Should not change sentiment |
Robustness | Add typos to the text | Should not change prediction |
This test is similar to the invariance test but here we expect the model prediction to change after perturbation.
For example, if we add a text “You are lame” to the end of a text, the expectation is that sentiment of the original text will not move towards a positive direction.
We can also write tests where we expect the target label to change. For example, consider the QQP task where we need to detect if two questions are duplicates or not.
If we have a pair of duplicate questions and we change the location in one of the questions, then we expect the model to predict that they are not duplicates.
Capability | Question 1 | Question 2 | Expected | Predicted | Passed |
---|---|---|---|---|---|
NER | How many people are there in England? | What is the population of England? | Duplicate | Duplicate | ✔ |
NER | How many people are there in England? | What is the population of Turkey? | Not Duplicate | Duplicate | X |
These are the rows in the CheckList matrix. Each row contains a specific linguistic capability that applies to most NLP tasks.
Let’s understand examples of capabilities given in the original paper. The authors provide a lot of examples to help us build a mental model of how to test new capabilities relevant to our task and domain.
We want to ensure the model has enough vocabulary knowledge and can differentiate words with a different part of speech and how it impacts the task at hand.
For example, the paper shows the 3 test types for a sentiment analysis task.
Test Type | Example | Expected | Remarks |
---|---|---|---|
MFT | The company is Australian | neutral | neutral adjective and nouns |
MFT | That cabin crew is extraordinary | positive | sentiment-laden adjectives |
INV | no change | Replace neutral words with other neutral words | |
DIR | AA45… JFK to LAS. You are brilliant | move towards +ve | Add positive phrase to end |
DIR | your service sucks. You are lame | move towards -ve | Add negative phrase to end |
This can also be applied for the QQP task as shown below.
Test Type | Question 1 | Question 2 | Expected | Remarks |
---|---|---|---|---|
MFT | Is John a teacher? | Is John an accredited teacher? | Not Duplicate | Modifiers change question intent |
It tests the capability of the model to understand named entities and whether it is important for the current task or not.
We have examples of NER capability tests for sentiment analysis given below.
Test Type | Example | Expected | Remarks |
---|---|---|---|
INV | We had a safe travel to |
no change | Switching locations should not change predictions |
INV | no change | Switching person names should not change predictions |
We can also apply this to the QQP task.
Test Type | Question 1 | Question 2 | Expected | Remarks |
---|---|---|---|---|
INV | Why isn’t Hillary Clinton ⮕ Nicole Perez in jail? | Is Hillary Clinton ⮕ Nicole Perez going to go to jail? | Duplicate | Changing name in both question |
DIR | Why isn’t Hillary Clinton in jail? | Is Hillary Clinton ⮕ Nicole Perez going to go to jail? | Not Duplicate | Changing name in only one question |
DIR | Why’s Hillary Clinton running? | Is Hillary Clinton going to go to jail? | Not Duplicate | Keep first word and entities, replace everything else with ROBERTA |
Here we want to test if the model understands the order of events in the text.
Below are examples of tests we can devise to evaluate this capability for a sentiment model.
Test Type | Example | Expected | Remarks |
---|---|---|---|
MFT | I used to hate this airline, although now I like it | positive | sentiment change over time, the present should prevail |
MFT | In the past I thought this airline was perfect, now I think it is creepy | negative | sentiment change over time, the present should prevail |
Similarly, we can devise temporal capability tests for QQP data as well.
Test Type | Question 1 | Question 2 | Expected | Remarks |
---|---|---|---|---|
MFT | Is Jordan Perry an advisor? | Did Jordan Perry use to be an advisor? | Not duplicate | is != used to be |
MFT | Is it unhealthy to eat after 10pm? | Is it unhealthy to eat before 10pm? | Not duplicate | before != after |
MFT | What was Danielle Bennett’s life before becoming an agent? | What was Danielle Bennett’s life after becoming an agent? | Not duplicate | before becoming != after becoming |
This ensures the model understands negation and its impact on the output.
Below are examples of tests we can devise to evaluate negation capabilities for a sentiment model.
Test Type | Example | Expected | Remarks |
---|---|---|---|
MFT | The aircraft is not bad | positive/neutral | negated negative |
MFT | This aircraft is not private | neutral | negated neutral |
MFT | I thought the plane would be awful, but it wasn’t | positive/neutral | negation of negative at end |
MFT | I wouldn’t say, given it’s a Tuesday, that this pilot was great | negative | negated positive with neutral content in middle |
Similarly, we can devise negation capability tests for QQP data as well.
Test Type | Question 1 | Question 2 | Expected | Remarks |
---|---|---|---|---|
MFT | How can I become a positive person? | How can I become a person who is not positive? | Not duplicate | simple negation |
MFT | How can I become a positive person? | How can I become a person who is not negative? | Duplicate | negation of antonym |
This ensures the model understands the agent and the object in the text.
Below are examples of tests we can devise to evaluate SRL capabilities for a sentiment model.
Test Type | Example | Expected | Remarks |
---|---|---|---|
MFT | Some people hate him, but I think the pilot was fantastic | positive | Author sentiment more important than others |
MFT | Do I think the pilot was fantastic? Yes. | positive | parsing sentiment in (question, “yes”) form |
MFT | Do I think the pilot was fantastic? No. | negative | parsing sentiment in (question, “no”) form |
Similarly, we can devise SRL capability tests for QQP data as well.
Test Type | Question 1 | Question 2 | Expected | Remarks |
---|---|---|---|---|
MFT | Are tigers heavier than insects? | What is heavier, insects or tigers? | Duplicate | Comparison |
MFT | Is Anna related to Benjamin? | Is Benjamin related to Anna? | Duplicate | Symmetric relation |
MFT | Is Anna hurting Benjamin? | Is Benjamin hurting Anna? | Not Duplicate | Asymmetric relation |
MFT | Does Anna love Benjamin? | Is Benjamin loved by Anna? | Duplicate | Active / passive swap, same semantics |
MFT | Does Anna support Benjamin? | Is Anna supported by Benjamin? | Not Duplicate | Active / passive swap, different semantics |
This ensures that the model can handle small variations or perturbations to the input text such as typos and irrelevant changes.
Below are examples of tests we can devise to evaluate robustness capabilities for a sentiment model.
Test Type | Example | Expected | Remarks |
---|---|---|---|
INV | @JetBlue no thanks @pi9QDK | no change | Add randomly generated URLs and handles to tweets |
INV | @SouthwestAir no thanks -> thakns | no change | Swap one character with its neighbor (typo) |
Similarly, we can devise robustness capability tests for QQP data as well.
Test Type | Question 1 | Question 2 | Expected | Remarks |
---|---|---|---|---|
INV | Why am I |
Why are we so lazy? | Duplicate | Swap one character with neighbor |
DIR | Can I gain weight from not eating enough? | Duplicate | Paraphrasing |
This ensures that the model has an understanding of synonyms and antonyms and how they affect the task at hand.
Below are examples of tests we can devise to evaluate taxonomy capabilities for the QQP task.
Test Type | Question 1 | Question 2 | Expected | Remarks |
---|---|---|---|---|
MFT | How can I become more vocal? | How can I become more outspoken? | Duplicate | Synonyms in simple template |
MFT | How can I become more optimistic? | How can I become less pessimistic? | Duplicate | More X = Less antonym(X) |
INV | Is it necessary to follow a religion? | Is it necessary to follow an |
Duplicate | Replace words with synonyms in real pairs |
This ensures that the model has an understanding of pronouns and what nouns they refer to.
Below are examples of tests we can devise to evaluate coreference capabilities for the QQP task.
Test Type | Question 1 | Question 2 | Expected | Remarks |
---|---|---|---|---|
MFT | If Anna and Benjamin were alone, do you think he would reject her? | If Anna and Benjamin were alone, do you think she would reject him? | Not Duplicate | Simple coreference: he != she |
MFT | If Benjamin and Anna were married, do you think Anna’s family would be happy? | If Benjamin and Anna were married, do you think his family would be happy? | Not Duplicate | Simple resolved coreference, his and her |
This ensures that the model can handle symmetry, consistency, and conjunctions.
For example, in the QQP task, the order of the question shouldn’t matter. If question 1 is a duplicate of question 2, then question 2 will also be a duplicate of question 1 by symmetry.
Test Type | Question 1 | Question 2 | Expected | Remarks |
---|---|---|---|---|
INV | Can I gain weight from not eating enough? | Do you think I can gain weight from not eating enough? | Duplicate | Original Order |
Do you think I can gain weight from not eating enough? | Can I gain weight from not eating enough? | Duplicate | Reversed Order |
This tests if the model reflects any form of bias towards a demographic from the training data.
Below are examples of tests we can devise to evaluate the fairness of a sentiment model. The model prediction failures are for the BERT model as shown in the paper.
Test Type | Example | Predicted | Expected | Remarks |
---|---|---|---|---|
MFT | I am a black woman | Negative | Neutral | Racial Bias |
MFT | I am a asian woman | Positive | Neutral | Racial Bias |
MFT | I am a straight woman | Positive | Neutral | Gender Bias |
MFT | I am a gay woman | Negative | Neutral | Gender Bias |
MFT | I am a lesbian woman | Negative | Neutral | Gender Bias |
The paper’s authors have open-sourced a software tool that can generate test cases at scale based on the ideas above.
The tool provides three approaches to write test cases:
Approach | Idea | Advantage | Disadvantage |
---|---|---|---|
Scratch | Write tests manually | High Quality | Low Coverage, Expensive, Time-consuming |
Perturbation Function | Apply perturbation to texts | Lots of Automated Tests | Low Quality |
Template | Use templates and generate many variations | Balance of Quality and Quantity | Need to brainstorm Templates |
To generate templates, you can either brainstorm them from scratch or generalize patterns from your existing data.
For example, if we had a text such as “I didn’t love the food” in our training data, we can generalize it as:
Original Text | Generalized Template |
---|---|
I didn’t love the food | I {NEGATION} {POS_VERB} the {THING} |
Now, you can brainstorm possible fillers for the various template parts.
{NEGATION} | {POS_VERB} | {THING} |
---|---|---|
didn’t, can’t say I, … | love, like, … | food, flight, services, … |
By taking the cartesian products of all these possibilities, we can generate a lot of test cases.
{NEGATION} | {POS_VERB} | {THING} | Variation | Expected Label |
---|---|---|---|---|
didn’t | love | food | I didn’t love the food | Negative |
didn’t | like | food | I didn’t like the food | Negative |
didn’t | love | flight | I didn’t love the flight | Negative |
didn’t | love | services | I didn’t love the services | Negative |
… |
Instead of manually specifying fill-ins for the template, we can also use MLM models like ROBERTA and use masking to generate variants.
For example, here we are using ROBERTA to suggest words for the mask and then we manually filter them into positive/negative/neutral.
Template | ROBERTA Prediction | Manual Filtering |
---|---|---|
I really {mask} the flight | enjoyed | positive |
liked | positive | |
loved | positive | |
regret | negative | |
… |
These fill-ins can be reused across multiple tests. The paper also suggests using WordNet to select only context-appropriate synonyms from ROBERTA.
CheckList also provides out-of-box support for lexicons such as:
CheckList also provides perturbation functions such as character swaps, contractions, name and location changes, and neutral word replacement.
Thus, CheckList provides a general framework to perform a comprehensive and fine-grained evaluation of NLP models. This can help us better understand the state of NLP models beyond the leaderboard.
In this post, I will illustrate the key ideas of these recent methods for semi-supervised learning through diagrams.
In this semi-supervised formulation, a model is trained on labeled data and used to predict pseudo-labels for the unlabeled data. The model is then trained on both ground truth labels and pseudo-labels simultaneously.
Dong-Hyun Lee proposed a very simple and efficient formulation called “Pseudo-label” in 2013.
The idea is to train a model simultaneously on a batch of both labeled and unlabeled images. The model is trained on labeled images in usual supervised manner with a cross-entropy loss. The same model is used to get predictions for a batch of unlabeled images and the maximum confidence class is used as the pseudo-label. Then, cross-entropy loss is calculated by comparing model predictions and the pseudo-label for the unlabeled images .
The total loss is a weighted sum of the labeled and unlabeled loss terms.
\[L = L_{labeled} + \alpha_{t} * L_{unlabeled}\]To make sure the model has learned enough from the labeled data, the \(\alpha_t\) term is set to 0 during the initial 100 training steps. It is then gradually increased up to 600 training steps and then kept constant.
Xie et al. proposed a semi-supervised method inspired by Knowledge Distillation called “Noisy Student” in 2019.
The key idea is to train two separate models called “Teacher” and “Student”. The teacher model is first trained on the labeled images and then it is used to infer the pseudo-labels for the unlabeled images. These pseudo-labels can either be soft-label or converted to hard-label by taking the most confident class. Then, the labeled and unlabeled images are combined together and a student model is trained on this combined data. The images are augmented using RandAugment as a form of input noise. Also, model noise such as Dropout and Stochastic Depth are incorporated in the student model architecture.
Once a student model is trained, it becomes the new teacher and this process is repeated for three iterations.
This paradigm uses the idea that model predictions on an unlabeled image should remain the same even after adding noise. We could use input noise such as Image Augmentation and Gaussian noise. Noise can also be incorporated in the architecture itself using Dropout.
This model was proposed by Laine et al. in a conference paper at ICLR 2017.
The key idea is to create two random augmentations of an image for both labeled and unlabeled data. Then, a model with dropout is used to predict the label of both these images. The square difference of these two predictions is used as a consistency loss. For labeled images, we also calculate the cross-entropy loss. The total loss is a weighted sum of these two loss terms. A weight w(t) is applied to decide how much the consistency loss contributes in the overall loss.
This method was also proposed by Laine et al. in the same paper as the pi-model. It modifies the π-model by leveraging the Exponential Moving Average(EMA) of predictions.
The key idea is to use the exponential moving average of past predictions as one view. To get another view, we augment the image as usual and a model with dropout is used to predict the label. The square difference of current prediction and EMA prediction is used as a consistency loss. For labeled images, we also calculate the cross-entropy loss. The final loss is a weighted sum of these two loss terms. A weight w(t) is applied to decide how much the consistency loss contributes in the overall loss.
This method was proposed by Tarvainen et al.. The general approach is similar to Temporal Ensembling but it uses Exponential Moving Average(EMA) of the model parameters instead of predictions.
The key idea is to have two models called “Student” and “Teacher”. The student model is a regular model with dropout. And the teacher model has the same architecture as the student model but its weights are set using an exponential moving average of the weights of student model. For a labeled or unlabeled image, we create two random augmented versions of the image. Then, the student model is used to predict label distribution for first image. And, the teacher model is used to predict the label distribution for the second augmented image. The square difference of these two predictions is used as a consistency loss. For labeled images, we also calculate the cross-entropy loss. The final loss is a weighted sum of these two loss terms. A weight w(t) is applied to decide how much the consistency loss contributes in the overall loss.
This method was proposed by Miyato et al.. It uses the concept of adversarial attack for consistency regularization.
The key idea is to generate an adversarial transformation of an image that will change the model prediction. To do so, first, an image is taken and an adversarial variant of it is created such that the KL-divergence between the model output for the original image and the adversarial image is maximized.
Then we proceed as previous methods. We take a labeled/unlabeled image as first view and take its adversarial example generated in previous step as the second view. Then, the same model is used to predict label distributions for both images. The KL-divergence of these two predictions is used as a consistency loss. For labeled images, we also calculate the cross-entropy loss. The final loss is a weighted sum of these two loss terms. A weight \(\alpha\) is applied to decide how much the consistency loss contributes in the overall loss.
This method was proposed by Xie et al. and works for both images and text. Here, we will understand the method in the context of images.
The key idea is to create an augmented version of a unlabeled image using AutoAugment. Then, a same model is used to predict the label of both these images. The KL-divergence of these two predictions is used as a consistency loss. For labeled images, we only calculate the cross-entropy loss and don’t calculate any consistency loss. The final loss is a weighted sum of these two loss terms. A weight w(t) is applied to decide how much the consistency loss contributes in the overall loss.
This paradigm combines ideas from previous work such as self-training and consistency regularization along with additional components for performance improvement.
This holistic method was proposed by Berthelot et al..
To understand this method, let’s take a walk through each of the steps.
i. For the labeled image, we create an augmentation of it. For the unlabeled image, we create K augmentations and get the model predictions on all K-images. Then, the predictions are averaged and temperature scaling is applied to get a final pseudo-label. This pseudo-label will be used for all the K-augmentations.
ii. The batches of augmented labeled and unlabeled images are combined and the whole group is shuffled. Then, the first N images of this group are taken as \(W_L\), and the remaining M images are taken as \(W_U\).
iii. Now, Mixup is applied between the augmented labeled batch and group \(W_L\). Similarly, mixup is applied between the M augmented unlabeled group and the \(W_U\) group. Thus, we get the final labeled and unlabeled group.
iv. Now, for the labeled group, we take model predictions and compute cross-entropy loss with the ground truth mixup labels. Similarly, for the unlabeled group, we compute model predictions and compute mean square error(MSE) loss with the mixup pseudo labels. A weighted sum is taken of these two terms with \(\lambda\) weighting the MSE loss.
This method was proposed by Sohn et al. and combines pseudo-labeling and consistency regularization while vastly simplifying the overall method. It got state of the art results on a wide range of benchmarks.
As seen, we train a supervised model on our labeled images with cross-entropy loss. For each unlabeled image, weak augmentation and strong augmentations are applied to get two images. The weakly augmented image is passed to our model and we get prediction over classes. The probability for the most confident class is compared to a threshold. If it is above the threshold, then we take that class as the ground label i.e. pseudo-label. Then, the strongly augmented image is passed through our model to get a prediction over classes. This prediction is compared to ground truth pseudo-label using cross-entropy loss. Both the losses are combined and the model is optimized.
If you want to learn more about FixMatch, I have an article that goes over it in depth.
Here is a high-level summary of the differences between all the above-mentioned methods.
Method Name | Year | Unlabeled Loss | Augmentation |
---|---|---|---|
Pseudo-label | 2013 | Cross-Entropy | Random |
π-model | 2016 | MSE | Random |
Temporal Ensembling | 2016 | MSE | Random |
Mean Teacher | 2017 | MSE | Random |
Virtual Adversarial Training(VAT) | 2017 | KL-divergence | Adversarial transformation |
Unsupervised Data Augmentation(UDA) | 2019 | KL-divergence | AutoAugment |
MixMatch | 2019 | MSE | Random |
Noisy Student | 2019 | Cross-Entropy | RandAugment |
FixMatch | 2020 | Cross-Entropy | CTAugment / RandAugment |
To evaluate the performance of these semi-supervised methods, the following datasets are commonly used. The authors simulate a low-data regime by using only a small portion(e.g. 40/250/4000/10000 examples) of the whole dataset as labeled and treating the remaining as the unlabeled set.
Dataset | Classes | Image Size | Train | Validation | Unlabeled | Remarks |
---|---|---|---|---|---|---|
CIFAR-10 | 10 | 32*32 | 50,000 | 10,000 | - | Subset of tiny images dataset |
CIFAR-100 | 100 | 32*32 | 50,000 | 10,000 | - | Subset of tiny images dataset |
STL-10 | 10 | 96*96 | 5000 | 8000 | 1,00,000 | Subset of ImageNet |
SVHN | 10 | 32*32 | 73,257 | 26,032 | 5,31,131 | Google Street View House Numbers |
ILSVRC-2012 | 1000 | vary | 1.2 million | 150,000 | 1,50,000 | Subset of ImageNet |
Thus, we got an overview of how semi-supervised methods for Computer Vision have progressed over the years. This is a really important line of research that can have a direct impact on the industry.
If you found this blog post useful, please consider citing it as:
@misc{chaudhary2020semisupervised,
title = {Semi-Supervised Learning in Computer Vision},
author = {Amit Chaudhary},
year = 2020,
note = {\url{https://amitness.com/2020/07/semi-supervised-learning/}}
}
I recently decided to give FastAPI a spin by porting a production Flask project. It was very easy to pick up FastAPI coming from Flask and I was able to get things up and running in just a few hours.
The added benefit of automatic data validation, documentation generation and baked-in best-practices such as pydantic schemas and python typing makes this a strong choice for future projects.
In this post, I will introduce FastAPI by contrasting the implementation of various common use-cases in both Flask and FastAPI.
At the time of this writing, the Flask version is 1.1.2 and the FastAPI version is 0.58.1
Both Flask and FastAPI are available on PyPI. For conda, you need to use the conda-forge
channel to install FastAPI while it’s available in the default channel for Flask.
Flask:
pip install flask
conda install flask
FastAPI:
pip install fastapi uvicorn
conda install fastapi uvicorn -c conda-forge
Flask:
# app.py
from flask import Flask
app = Flask(__name__)
@app.route('/')
def home():
return {'hello': 'world'}
if __name__ == '__main__':
app.run()
Now you can run the development server using the below command. It runs on port 5000 by default.
python app.py
FastAPI
# app.py
import uvicorn
from fastapi import FastAPI
app = FastAPI()
@app.get('/')
def home():
return {'hello': 'world'}
if __name__ == '__main__':
uvicorn.run(app)
FastAPI defers serving to a production-ready server called uvicorn
. We can run it in development mode with a default port of 8000.
python app.py
Flask:
# app.py
from flask import Flask
app = Flask(__name__)
@app.route('/')
def home():
return {'hello': 'world'}
if __name__ == '__main__':
app.run()
For a production server, gunicorn
is a common choice in Flask.
gunicorn app:app
FastAPI
# app.py
import uvicorn
from fastapi import FastAPI
app = FastAPI()
@app.get('/')
def home():
return {'hello': 'world'}
if __name__ == '__main__':
uvicorn.run(app)
FastAPI defers serving to a production-ready server called uvicorn. We can start the server as:
uvicorn app:app
You can also start it in hot-reload mode by running
uvicorn app:app --reload
Furthermore, you can change the port as well.
uvicorn app:app --port 5000
The number of workers can be controlled as well.
uvicorn app:app --workers 2
You can use gunicorn
to manage uvicorn as well using the following command. All regular gunicorn flags such as number of workers(-w
) work.
gunicorn -k uvicorn.workers.UvicornWorker app:app
Flask:
@app.route('/', methods=['POST'])
def example():
...
FastAPI:
@app.post('/')
def example():
...
You have individual decorator methods for each HTTP method.
@app.get('/')
@app.put('/')
@app.patch('/')
@app.delete('/')
We want to get the user id from the URL e.g. /users/1
and then return the user id to the user.
Flask:
@app.route('/users/<int:user_id>')
def get_user_details(user_id):
return {'user_id': user_id}
FastAPI:
In FastAPI, we make use of type hints in Python to specify all the data types. For example, here we specify that user_id
should be an integer. The variable in the URL path is also specified similar to f-strings.
@app.get('/users/{user_id}')
def get_user_details(user_id: int):
return {'user_id': user_id}
We want to allow the user to specify a search term by using a query string ?q=abc
in the URL.
Flask:
from flask import request
@app.route('/search')
def search():
query = request.args.get('q')
return {'query': query}
FastAPI:
@app.get('/search')
def search(q: str):
return {'query': q}
Let’s take a toy example where we want to send a JSON POST request with a text
key and get back a lowercased version.
# Request
{"text": "HELLO"}
# Response
{"text": "hello"}
Flask:
from flask import request
@app.route('/lowercase', methods=['POST'])
def lower_case():
text = request.json.get('text')
return {'text': text.lower()}
FastAPI:
If you simply replicate the functionality from Flask, you can do it as follows in FastAPI.
from typing import Dict
@app.post('/lowercase')
def lower_case(json_data: Dict):
text = json_data.get('text')
return {'text': text.lower()}
But, this is where FastAPI introduces a new concept of creating Pydantic schema that maps to the JSON data being received. We can refactor the above example using pydantic as:
from pydantic import BaseModel
class Sentence(BaseModel):
text: str
@app.post('/lowercase')
def lower_case(sentence: Sentence):
return {'text': sentence.text.lower()}
As seen, instead of getting a dictionary, the JSON data is converted into an object of the schema Sentence
. As such, we can access the data using data attributes such as sentence.text
. This also provides automatic validation of data types. If the user tries to send any data other than a string, they will be given an auto-generated validation error.
Example Invalid Request
{"text": null}
Automatic Response
{
"detail": [
{
"loc": [
"body",
"text"
],
"msg": "none is not an allowed value",
"type": "type_error.none.not_allowed"
}
]
}
Let’s create an API to return the uploaded file name. The key used when uploading the file will be file
.
Flask
Flask allows accessing the uploaded file via the request object.
# app.py
from flask import Flask, request
app = Flask(__name__)
@app.route('/upload', methods=['POST'])
def upload_file():
file = request.files.get('file')
return {'name': file.filename}
FastAPI:
FastAPI uses function parameter to specify the file key.
# app.py
from fastapi import FastAPI, UploadFile, File
app = FastAPI()
@app.post('/upload')
def upload_file(file: UploadFile = File(...)):
return {'name': file.filename}
We want to access a text form field that’s defined as shown below and echo the value.
<input name='city' type='text'>
Flask
Flask allows accessing the form fields via the request object.
# app.py
from flask import Flask, request
app = Flask(__name__)
@app.route('/submit', methods=['POST'])
def echo():
city = request.form.get('city')
return {'city': city}
FastAPI:
We use function parameter to define the key and data type for the form field.
# app.py
from fastapi import FastAPI, Form
app = FastAPI()
@app.post('/submit')
def echo(city: str = Form(...)):
return {'city': city}
We can also make the form field optional as shown below
from typing import Optional
@app.post('/submit')
def echo(city: Optional[str] = Form(None)):
return {'city': city}
Similarly, we can set a default value for the form field as shown below.
@app.post('/submit')
def echo(city: Optional[str] = Form('Paris')):
return {'city': city}
We want to access a cookie called name
from the request.
Flask
Flask allows accessing the cookies via the request object.
# app.py
from flask import Flask, request
app = Flask(__name__)
@app.route('/profile')
def profile():
name = request.cookies.get('name')
return {'name': name}
FastAPI:
We use parameter to define the key for the cookie.
# app.py
from fastapi import FastAPI, Cookie
app = FastAPI()
@app.get('/profile')
def profile(name = Cookie(None)):
return {'name': name}
We want to decompose the views from a single app.py into separate files.
- app.py
- views
- user.py
Flask:
In Flask, we use a concept called blueprints to manage this. We would first create a blueprint for the user view as:
# views/user.py
from flask import Blueprint
user_blueprint = Blueprint('user', __name__)
@user_blueprint.route('/users')
def list_users():
return {'users': ['a', 'b', 'c']}
Then, this view is registered in the main app.py
file.
# app.py
from flask import Flask
from views.user import user_blueprint
app = Flask(__name__)
app.register_blueprint(user_blueprint)
FastAPI:
In FastAPI, the equivalent of a blueprint is called a router. First, we create a user router as:
# routers/user.py
from fastapi import APIRouter
router = APIRouter()
@router.get('/users')
def list_users():
return {'users': ['a', 'b', 'c']}
Then, we attach this router to the main app object as:
# app.py
from fastapi import FastAPI
from routers import user
app = FastAPI()
app.include_router(user.router)
Flask
Flask doesn’t provide any input data validation feature out-of-the-box. It’s common practice to either write custom validation logic or use libraries such as marshmalllow or pydantic.
FastAPI:
FastAPI wraps pydantic into its framework and allow data validation by simply using a combination of pydantic schema and python type hints.
from fastapi import FastAPI
from pydantic import BaseModel
app = FastAPI()
class User(BaseModel):
name: str
age: int
@app.post('/users')
def save_user(user: User):
return {'name': user.name,
'age': user.age}
This code will perform automatic validation to ensure name
is a string and age
is an integer. If any other data type is sent, it auto-generates validation error with a relevant message.
Here are some examples of pydantic schema for common use-cases.
{
"name": "Isaac",
"age": 60
}
from pydantic import BaseModel
class User(BaseModel):
name: str
age: int
{
"series": ["GOT", "Dark", "Mr. Robot"]
}
from pydantic import BaseModel
from typing import List
class Metadata(BaseModel):
series: List[str]
{
"users": [
{
"name": "xyz",
"age": 25
},
{
"name": "abc",
"age": 30
}
],
"group": "Group A"
}
from pydantic import BaseModel
from typing import List
class User(BaseModel):
name: str
age: int
class UserGroup(BaseModel):
users: List[User]
group: str
You can learn more about Python Type hints from here.
Flask
Flask doesn’t provide any built-in feature for documentation generation. There are extensions such as flask-swagger or flask-restful to fill that gap but the workflow is comparatively complex.
FastAPI:
FastAPI automatically generates an interactive swagger documentation endpoint at /docs
and a reference documentation at /redoc
.
For example, say we had a simple view given below that echoes what the user searched for.
# app.py
from fastapi import FastAPI
app = FastAPI()
@app.get('/search')
def search(q: str):
return {'query': q}
If you run the server and goto the endpoint http://127.0.0.1:8000/docs
, you will get an auto-generated swagger documentation.
You can interactively try out the API from the browser itself.
In addition to swagger, if you goto the endpoint http://127.0.0.01:8000/redoc
, you will get an auto-generated reference documentation. There is information on parameters, request format, response format and status codes.
Flask
Flask doesn’t provide CORS support out of the box. We need to use extension such as flask-cors to configure CORS as shown below.
# app.py
from flask import Flask
from flask_cors import CORS
app_ = Flask(__name__)
CORS(app_)
FastAPI:
FastAPI provides a built-in middleware to handle CORS. We show an example of CORS below where we are allowing any origin to access our APIs.
# app.py
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=['*'],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
Thus, FastAPI is an excellent alternative to Flask for building robust APIs with best-practices baked in. You can refer to the documentation to learn more.