Source code for fee.visualize.gender_cluster_tsne
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.manifold import TSNE
[docs]class GCT():
"""`GCT` Class"""
def __init__(self, E, random_state=0):
"""Gender Cluster tnse Plot: plot the tSNE visualization for the
neighbourhood of word color coded by computer cluster.
Args:
E (WE class object): Word embeddings object
random_state (int): for reproducibility
"""
self.E = E
self.random_state = random_state
[docs] def cluster(self, vecs):
"""Apply kmeans clustering over vecs
Args:
vecs (np.array): list of word vectors to cluster
"""
labels = KMeans(n_clusters=2,
random_state=self.random_state
).fit_predict(vecs)
return labels
[docs] def visualize(self, vecs, words, labels, title, annotate, figsize,
dpi, colors):
"""Main GCT visualization driver function
Args:
vecs (np.array): list of word vectors to cluster
words (list): list of words (list of string)
labels (list): cluster labels (list of 0s and 1s)
title (str): title of the plot
annotate (bool): annotate each point in scatter plot
figsize (tuple): size of figures in (HxW)
dpi (int): dpi of the figures
colors (list): list of two matplotlib compatible colors
"""
fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
X_embedded = TSNE(n_components=2,
random_state=self.random_state
).fit_transform(vecs)
for i, (x, l, w) in enumerate(zip(X_embedded, labels, words)):
if l:
ax.scatter(x[0], x[1], marker = '.', c = colors[0])
else:
ax.scatter(x[0], x[1], marker = 'x', c = colors[1])
if annotate:
ax.annotate(w, (x[0], x[1]))
if title is not None:
plt.title(title)
plt.show()
return True
[docs] def run(self, word_list, title=None, dpi=300, annotate=False,
figsize=(8, 5), colors=['k', 'r']):
"""Run the GCT visualization
Args:
word_list (list): list of words (list of string)
title (str): title of the plot
dpi (int): dpi of the figures
annotate (bool): annotate each point in scatter plot
figsize (tuple): size of figures in (HxW)
colors (list): list of two matplotlib compatible colors
"""
vecs = [self.E.v(w) for w in word_list]
labels = self.cluster(vecs)
self.visualize(vecs, word_list, labels, title, annotate, figsize,
dpi, colors)