Note
Click here to download the full example code
(beta) Dynamic Quantization on an LSTM Word Language Model¶
Author: James Reed
Edited by: Seth Weidman
Introduction¶
Quantization involves converting the weights and activations of your model from float to int, which can result in smaller model size and faster inference with only a small hit to accuracy.
In this tutorial, we’ll apply the easiest form of quantization - dynamic quantization - to an LSTM-based next word-prediction model, closely following the word language model from the PyTorch examples.
# imports
import os
from io import open
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
1. Define the model¶
Here we define the LSTM model architecture, following the model from the word language model example.
class LSTMModel(nn.Module):
"""Container module with an encoder, a recurrent module, and a decoder."""
def __init__(self, ntoken, ninp, nhid, nlayers, dropout=0.5):
super(LSTMModel, self).__init__()
self.drop = nn.Dropout(dropout)
self.encoder = nn.Embedding(ntoken, ninp)
self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout)
self.decoder = nn.Linear(nhid, ntoken)
self.init_weights()
self.nhid = nhid
self.nlayers = nlayers
def init_weights(self):
initrange = 0.1
self.encoder.weight.data.uniform_(-initrange, initrange)
self.decoder.bias.data.zero_()
self.decoder.weight.data.uniform_(-initrange, initrange)
def forward(self, input, hidden):
emb = self.drop(self.encoder(input))
output, hidden = self.rnn(emb, hidden)
output = self.drop(output)
decoded = self.decoder(output)
return decoded, hidden
def init_hidden(self, bsz):
weight = next(self.parameters())
return (weight.new_zeros(self.nlayers, bsz, self.nhid),
weight.new_zeros(self.nlayers, bsz, self.nhid))
2. Load in the text data¶
Next, we load the Wikitext-2 dataset into a Corpus, again following the preprocessing from the word language model example.
class Dictionary(object):
def __init__(self):
self.word2idx = {}
self.idx2word = []
def add_word(self, word):
if word not in self.word2idx:
self.idx2word.append(word)
self.word2idx[word] = len(self.idx2word) - 1
return self.word2idx[word]
def __len__(self):
return len(self.idx2word)
class Corpus(object):
def __init__(self, path):
self.dictionary = Dictionary()
self.train = self.tokenize(os.path.join(path, 'train.txt'))
self.valid = self.tokenize(os.path.join(path, 'valid.txt'))
self.test = self.tokenize(os.path.join(path, 'test.txt'))
def tokenize(self, path):
"""Tokenizes a text file."""
assert os.path.exists(path)
# Add words to the dictionary
with open(path, 'r', encoding="utf8") as f:
for line in f:
words = line.split() + ['<eos>']
for word in words:
self.dictionary.add_word(word)
# Tokenize file content
with open(path, 'r', encoding="utf8") as f:
idss = []
for line in f:
words = line.split() + ['<eos>']
ids = []
for word in words:
ids.append(self.dictionary.word2idx[word])
idss.append(torch.tensor(ids).type(torch.int64))
ids = torch.cat(idss)
return ids
model_data_filepath = 'data/'
corpus = Corpus(model_data_filepath + 'wikitext-2')
3. Load the pre-trained model¶
This is a tutorial on dynamic quantization, a quantization technique that is applied after a model has been trained. Therefore, we’ll simply load some pre-trained weights into this model architecture; these weights were obtained by training for five epochs using the default settings in the word language model example.
ntokens = len(corpus.dictionary)
model = LSTMModel(
ntoken = ntokens,
ninp = 512,
nhid = 256,
nlayers = 5,
)
model.load_state_dict(
torch.load(
model_data_filepath + 'word_language_model_quantize.pth',
map_location=torch.device('cpu')
)
)
model.eval()
print(model)
Out:
LSTMModel(
(drop): Dropout(p=0.5, inplace=False)
(encoder): Embedding(33278, 512)
(rnn): LSTM(512, 256, num_layers=5, dropout=0.5)
(decoder): Linear(in_features=256, out_features=33278, bias=True)
)
Now let’s generate some text to ensure that the pre-trained model is working properly - similarly to before, we follow here
input_ = torch.randint(ntokens, (1, 1), dtype=torch.long)
hidden = model.init_hidden(1)
temperature = 1.0
num_words = 1000
with open(model_data_filepath + 'out.txt', 'w') as outf:
with torch.no_grad(): # no tracking history
for i in range(num_words):
output, hidden = model(input_, hidden)
word_weights = output.squeeze().div(temperature).exp().cpu()
word_idx = torch.multinomial(word_weights, 1)[0]
input_.fill_(word_idx)
word = corpus.dictionary.idx2word[word_idx]
outf.write(str(word.encode('utf-8')) + ('\n' if i % 20 == 19 else ' '))
if i % 100 == 0:
print('| Generated {}/{} words'.format(i, 1000))
with open(model_data_filepath + 'out.txt', 'r') as outf:
all_output = outf.read()
print(all_output)
Out:
| Generated 0/1000 words
| Generated 100/1000 words
| Generated 200/1000 words
| Generated 300/1000 words
| Generated 400/1000 words
| Generated 500/1000 words
| Generated 600/1000 words
| Generated 700/1000 words
| Generated 800/1000 words
| Generated 900/1000 words
b'@-@' b'loss' b'questions' b'through' b'Desmond' b'Francisco' b'<unk>' b'James' b'Rubinstein' b'for' b'runway' b'(' b'deciding' b')' b'a' b'half' b'goal' b'of' b'city' b'Ke'
b'1235' b'Category' b'16' b'.' b'<eos>' b'1874' b':' b'20' b'in' b'May' b'2008' b'.' b'In' b'May' b'1945' b',' b'the' b'Hygiene' b'Hockey' b'Race'
b'gave' b'12' b'@,@' b'000' b',' b'in' b'13' b'%' b'a' b'week' b'.' b'The' b'same' b'term' b'of' b'a' b'character' b'may' b'be' b'kept'
b'in' b'A' b'and' b'over' b'the' b'posterior' b'race' b'multi' b'@-@' b'Muslim' b'year' b'due' b'to' b'link' b'on' b'an' b'four' b'@-@' b'1860s' b'conversion'
b'up' b'to' b'over' b'39' b'metres' b'(' b'27' b'@.@' b'5' b'\xc2\xb0' b')' b'before' b'a' b'real' b'improvement' b'they' b'were' b'purchased' b'in' b'Signs'
b'.' b'Hindenburg' b'continue' b'to' b'influence' b'large' b'Kramer' b'Odessa' b'and' b'Burnley' b'(' b'Steven' b'Falling' b')' b'.' b'After' b'the' b'other' b'until' b'1950'
b'features' b'his' b'gravity' b'with' b'eye' b'under' b'O' b"'Malley" b'discredited' b'norms' b'in' b'<unk>' b',' b'the' b'well' b'@-@' b'operative' b'and' b'first' b'as'
b'"' b'PC' b'"' b'following' b'nothing' b'known' b'as' b'"' b'the' b'Harvard' b'mountain' b'"' b'(' b'Westchester' b')' b'.' b'The' b'Peach' b'mantelli' b'Football'
b'Society' b'believed' b'their' b'fathers' b'\xe2\x80\x94' b'but' b'race' b'sequences' b'was' b'used' b'for' b'success' b'directly' b'written' b'by' b'van' b'B.' b'Shore' b'.' b'The'
b'film' b'was' b'performed' b'on' b'July' b'27' b'on' b'15' b'October' b'2008' b',' b'with' b'Deane' b'and' b'Nic' b'sharing' b'the' b'conclusion' b'to' b'sign'
b'renovation' b'Union' b'as' b'Brilliant' b'Cider' b'them' b'.' b'Hornung' b'dedicated' b'for' b'a' b'tour' b'when' b'he' b'performed' b'its' b'children' b'overall' b'in' b'a'
b'starters' b'in' b'the' b'human' b'again' b'while' b'he' b'left' b'his' b'interests' b'to' b'solve' b'March' b'2' b'Hunters' b'.' b'The' b'recording' b'remained' b'down'
b'in' b"'" b'Love' b'region' b"'" b'(' b'now' b'known' b'to' b'do' b')' b'.' b'<eos>' b'The' b'title' b'@-@' b'established' b'version' b'of' b'quartet'
b'video' b',' b'providers' b',' b'and' b'Ito' b'Mitsuda' b'featured' b'Ceres' b'.' b'In' b'recent' b'ways' b'\xe2\x80\x94' b'Tower' b'listed' b'a' b'chelicerae' b'to' b'6th'
b'Stoneman' b'Arts' b'since' b'2002' b'that' b'these' b'most' b'of' b'his' b'short' b'instruments' b'have' b'portrayed' b'since' b'2007' b',' b'and' b'was' b'for' b'its'
b'production' b'at' b'<unk>' b'Lion' b'capping' b'.' b'<unk>' b'story' b'Edwards' b'advances' b'in' b'sequence' b'1923' b'in' b'Brussels' b'.' b'Their' b'final' b'films' b'was'
b'determined' b'as' b'the' b'2012' b'commander' b'of' b'the' b'Bellamy' b'.' b'Election' b'in' b'France' b'felt' b'Gwen' b'of' b'Hornung' b'was' b'heard' b'for' b'its'
b'character' b',' b'which' b'by' b'the' b'nationalist' b'<unk>' b'Anne' b'of' b'Caesarea' b'in' b'2001' b'.' b'<eos>' b'<eos>' b'=' b'=' b'Development' b'=' b'='
b'<eos>' b'<eos>' b'Many' b'of' b'its' b'songs' b',' b'and' b'organised' b'355' b',' b'depending' b'on' b'Fickett' b',' b'is' b'split' b'Waterways' b'by' b'post'
b'thousand' b'acoustic' b'rooms' b',' b'after' b'providing' b'special' b'CGT' b'audiences' b',' b'while' b'some' b'or' b'four' b'stars' b'are' b'investigate' b'.' b'The' b'long'
b'level' b'of' b'a' b'sharp' b'<unk>' b'appears' b'on' b'what' b'is' b'him' b'to' b'be' b'an' b'member' b'of' b'his' b'kakapo' b',' b'a' b'idea'
b'from' b'the' b'nucleoplasm' b'.' b'After' b'this' b',' b'they' b'are' b'Winnebago' b'to' b'find' b'according' b'to' b'a' b'male' b'bus' b',' b'with' b'records'
b'of' b'a' b'few' b'different' b'snake' b'or' b'a' b'sculpture' b'(' b'c' b')' b',' b'which' b'introduced' b'off' b'in' b'one' b'million' b'centuries' b'off'
b'the' b'size' b'of' b'biblical' b'prays' b'structures' b'in' b'order' b'to' b'bolster' b'these' b'understory' b'<unk>' b'.' b'As' b'a' b'result' b',' b'the' b'band'
b'gives' b'its' b'first' b'season' b'of' b'the' b'twentieth' b'version' b'of' b'colour' b',' b'more' b'Paki' b'.' b'It' b'also' b'allows' b'1' b'mechanical' b'buildings'
b'\xe2\x80\x94' b'including' b'CRIA' b',' b'<unk>' b',' b'promises' b',' b'and' b'specific' b'number' b'have' b'been' b'affected' b'by' b'Pasupathy' b'.' b'The' b'most' b'recent'
b'scrolls' b'in' b'Hemisphere' b'inspired' b'or' b'<unk>' b'stars' b'associates' b'pot' b'here' b'with' b'<unk>' b'at' b'the' b'other' b'layer' b'.' b'It' b'is' b'present'
b'to' b'fire' b'medieval' b'by' b'those' b'barges' b'is' b'as' b'other' b'.' b'Their' b'first' b'walking' b'losses' b';' b'does' b'not' b'be' b'distinguished' b','
b'by' b'every' b'tentative' b'patch' b'towards' b'a' b'head' b'trademark' b'art' b'for' b'a' b'standing' b'species' b'.' b'Nests' b'then' b'objected' b'to' b'a' b'body'
b'of' b'year' b',' b'or' b'has' b'probably' b'also' b'first' b'all' b'the' b'eggs' b'of' b'modern' b'forever' b',' b'breaks' b'in' b'the' b'Eucalyptus' b'Palace'
b'and' b'the' b'Onondaga' b'species' b'(' b'dark' b'on' b'St' b'cherry' b',' b'which' b'is' b'in' b'case' b'as' b'a' b'figurines' b')' b'.' b'There'
b'are' b'no' b'black' b'found' b'to' b'handle' b'which' b'demon' b'to' b'a' b'nest' b';' b'they' b'are' b'great' b'a' b'ion' b'effect' b'.' b'A'
b'large' b'margin' b'of' b'Common' b'collapses' b'and' b'older' b'impact' b'are' b'Presbyterians' b'.' b'The' b'eastern' b'and' b'sixth' b'quantity' b'is' b'so' b'Ireland' b'and'
b'loyal' b'to' b'a' b'lack' b'of' b'large' b',' b'in' b'they' b'prefer' b'its' b'substance' b'.' b'The' b'ability' b'to' b'be' b'important' b'for' b'anything'
b'numbers' b'.' b'<unk>' b'nearly' b'anterior' b'when' b'they' b'occurs' b'around' b'a' b'or' b'@-@' b'toed' b',' b'so' b'by' b'his' b'recovery' b'time' b','
b'grasses' b'used' b'research' b'brown' b'Mexican' b'students' b'play' b'several' b'other' b'types' b'open' b'land' b',' b'and' b'Selma' b'activity' b'is' b'body' b'in' b'standing'
b',' b'and' b'around' b'11' b'million' b'regions' b'are' b'red' b',' b'even' b'as' b'they' b'are' b'tissue' b'.' b'Consequently' b',' b'his' b'body' b','
b'especially' b'as' b'her' b'truly' b'fate' b'between' b'the' b'kakapo' b'or' b'fusion' b',' b'is' b'prone' b'to' b'farmers' b'.' b'<unk>' b'may' b'be' b'regulated'
b'by' b'common' b'incurred' b',' b'deep' b'beds' b',' b'small' b'common' b'<unk>' b',' b'eucalypts' b',' b'Ward' b',' b'golf' b',' b'and' b'owl' b'gather'
b'throughout' b'the' b'second' b'columns' b'of' b'Inari' b'.' b'Common' b'maintenance' b'distributed' b'from' b'a' b'wealthy' b'note' b'of' b'scandal' b'and' b'shrubs' b'of' b'this'
b'species' b'by' b'fish' b'to' b'their' b'upgrades' b'.' b'It' b'was' b'from' b'some' b'ceremonies' b'mountain' b'acid' b',' b'it' b'has' b'only' b'been' b'even'
b'taught' b'to' b'be' b'very' b'males' b'due' b'to' b'Antrim' b'"' b'God' b'"' b',' b'which' b'<unk>' b'proved' b'to' b'breed' b'around' b'by' b'fuller'
b'during' b'the' b'bird' b"'s" b'moral' b'Brussels' b'.' b'They' b'typically' b'are' b'accepted' b'by' b'a' b'sizable' b'planet' b'@-@' b'shaped' b'craftsman' b'when' b'their'
b'crescent' b'is' b'<unk>' b'below' b'.' b'That' b'does' b'so' b'at' b'any' b'of' b'these' b'plays' b',' b'a' b'planet' b'of' b'contemporary' b',' b'more'
b'of' b'Iguanodon' b'.' b'This' b'images' b'allows' b'certainty' b'that' b'also' b'feed' b'on' b'their' b'way' b'in' b'the' b'history' b'of' b'large' b'areas' b':'
b'the' b'under' b'regeneration' b'<unk>' b'are' b'increasingly' b'more' b'rapid' b',' b'almost' b'more' b'of' b'the' b'other' b'hydrogen' b',' b'much' b'imposing' b'kills' b','
b'immigration' b'in' b'<unk>' b',' b'emphasis' b'together' b'and' b'they' b'endanger' b'off' b'.' b'<unk>' b'may' b'be' b'heard' b'only' b'exposed' b',' b'with' b'a'
b'cabin' b'on' b'5' b'an' b'24' b'@.@' b'2' b'miles' b'(' b'31' b'@.@' b'5' b'to' b'8' b'@.@' b'8' b'in' b')' b'of' b'its'
b'origin' b'.' b'Such' b'juveniles' b'are' b'so' b'clear' b',' b'measuring' b'only' b'in' b'small' b'1898' b'.' b'<eos>' b'The' b'average' b'(' b'high' b'here'
b'had' b'happened' b',' b'crushed' b'only' b'@-@' b'spored' b')' b'and' b'have' b'a' b'variety' b'of' b'yellow' b'species' b'per' b'breeding' b'feathers' b'(' b'e.g.'
It’s no GPT-2, but it looks like the model has started to learn the structure of language!
We’re almost ready to demonstrate dynamic quantization. We just need to define a few more helper functions:
bptt = 25
criterion = nn.CrossEntropyLoss()
eval_batch_size = 1
# create test data set
def batchify(data, bsz):
# Work out how cleanly we can divide the dataset into bsz parts.
nbatch = data.size(0) // bsz
# Trim off any extra elements that wouldn't cleanly fit (remainders).
data = data.narrow(0, 0, nbatch * bsz)
# Evenly divide the data across the bsz batches.
return data.view(bsz, -1).t().contiguous()
test_data = batchify(corpus.test, eval_batch_size)
# Evaluation functions
def get_batch(source, i):
seq_len = min(bptt, len(source) - 1 - i)
data = source[i:i+seq_len]
target = source[i+1:i+1+seq_len].reshape(-1)
return data, target
def repackage_hidden(h):
"""Wraps hidden states in new Tensors, to detach them from their history."""
if isinstance(h, torch.Tensor):
return h.detach()
else:
return tuple(repackage_hidden(v) for v in h)
def evaluate(model_, data_source):
# Turn on evaluation mode which disables dropout.
model_.eval()
total_loss = 0.
hidden = model_.init_hidden(eval_batch_size)
with torch.no_grad():
for i in range(0, data_source.size(0) - 1, bptt):
data, targets = get_batch(data_source, i)
output, hidden = model_(data, hidden)
hidden = repackage_hidden(hidden)
output_flat = output.view(-1, ntokens)
total_loss += len(data) * criterion(output_flat, targets).item()
return total_loss / (len(data_source) - 1)
4. Test dynamic quantization¶
Finally, we can call torch.quantization.quantize_dynamic
on the model!
Specifically,
- We specify that we want the
nn.LSTM
andnn.Linear
modules in our model to be quantized - We specify that we want weights to be converted to
int8
values
import torch.quantization
quantized_model = torch.quantization.quantize_dynamic(
model, {nn.LSTM, nn.Linear}, dtype=torch.qint8
)
print(quantized_model)
Out:
LSTMModel(
(drop): Dropout(p=0.5, inplace=False)
(encoder): Embedding(33278, 512)
(rnn): DynamicQuantizedLSTM(512, 256, num_layers=5, dropout=0.5)
(decoder): DynamicQuantizedLinear(in_features=256, out_features=33278, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
)
The model looks the same; how has this benefited us? First, we see a significant reduction in model size:
def print_size_of_model(model):
torch.save(model.state_dict(), "temp.p")
print('Size (MB):', os.path.getsize("temp.p")/1e6)
os.remove('temp.p')
print_size_of_model(model)
print_size_of_model(quantized_model)
Out:
Size (MB): 113.944608
Size (MB): 79.739098
Second, we see faster inference time, with no difference in evaluation loss:
Note: we set the number of threads to one for single threaded comparison, since quantized models run single threaded.
torch.set_num_threads(1)
def time_model_evaluation(model, test_data):
s = time.time()
loss = evaluate(model, test_data)
elapsed = time.time() - s
print('''loss: {0:.3f}\nelapsed time (seconds): {1:.1f}'''.format(loss, elapsed))
time_model_evaluation(model, test_data)
time_model_evaluation(quantized_model, test_data)
Out:
loss: 5.167
elapsed time (seconds): 197.3
loss: 5.168
elapsed time (seconds): 103.4
Running this locally on a MacBook Pro, without quantization, inference takes about 200 seconds, and with quantization it takes just about 100 seconds.