Source code for fee.visualize.neighbour_plot

import numpy as np
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
from ..utils import get_g, cosine
import matplotlib as mpl
from sklearn.manifold import TSNE

[docs]def color_fader(c1, c2, mix=0): """Fade color between `c1` and `c2` with `mix`. See: https://stackoverflow.com/questions/25668828/how-to-create-colour-gradient-in-python """ c1=np.array(mpl.colors.to_rgb(c1)) c2=np.array(mpl.colors.to_rgb(c2)) return mpl.colors.to_hex((1-mix)*c1 + mix*c2)
[docs]def generate_palette(c1, c2, n): """Generate color palette of length `n` between c1 and c2 """ return [color_fader(c1, c2, m) for m in np.linspace(0, 1, n)]
[docs]class NeighbourPlot(): """`NeighbourPlot` Class""" def __init__(self, E, g=None, random_state=42): """tSNE plot for the neighbourhood of a word color coded by their bias by projection. Args: E (WE class object): Word embeddings object g (np.array): gender direction random_state (int): for reproducibility """ self.E = E if g is None: g = get_g(E) self.g = g self.random_state = random_state
[docs] def get_neighbours(self, word, n=100): """Get `n` neighbours for word `word`. Args: word (str): word to computer neighbours for n (int): number of neighbours to compute """ ns_idx = np.argsort(self.E.vecs.dot(self.E.v(word)))[-n:-1][::-1] return [self.E.words[i] for i in ns_idx]
[docs] def bias_by_projection_sort(self, words): """Sort `words` by bias by projection. Args: words (list): list of words (str) """ bias_dict = {} for w in words: bias_dict[w] = cosine(self.E.v(w), self.g) bias_dict = {k: v for k, v in sorted(bias_dict.items(), key=lambda item: item[1])} return list(bias_dict.keys())
[docs] def visualize(self, words, ranks, title, figsize, dpi, colors, s, annotate=False): """Run the `NeighbourPlot` visualization Args: words (list): list of neighbours title (str): title of the plot figsize (tuple): size of figures in (HxW) dpi (int): dpi of the figures colors (list): list of two matplotlib compatible colors s (int): matplotlib compatible font size for scatter plot annotate (bool): annotate points in scatter plot? """ assert len(words) == len(colors), "Size mismatch in colors and words" vecs = [self.E.v(w) for w in words] fig, ax = plt.subplots(figsize=figsize, dpi=dpi) X_embedded = TSNE(n_components=2, random_state=self.random_state ).fit_transform(vecs) for i, (x, w) in enumerate(zip(X_embedded, words)): if not i: #First neighbour is the word itself ax.scatter(x[0], x[1], marker = 'x', s=s*4, c = 'black') ax.annotate(w, (x[0], x[1]), size=s*2) continue ax.scatter(x[0], x[1], marker = '.', c = colors[i]) if annotate: ax.annotate(f"${w}_{{{ranks[w]}}}$", (x[0], x[1]), size=s) if title is not None: plt.title(title) plt.show() return True
[docs] def run(self, word, title=None, n=100, dpi=300, figsize=(8, 5), colors=['blue', 'red'], fontsize=7, annotate=True): """Run the `NeighbourPlot` visualization Args: word (str): word to compute neighbours of and make this plot title (str): title of the plot n (int): number of neighbours to consider figsize (tuple): size of figures in (HxW) dpi (int): dpi of the figures colors (list): list of two matplotlib compatible colors fontsize (int): matplotlib compatible font size for scatter plot annotate (bool): annotate points in scatter plot? """ neighbours = self.get_neighbours(word, n) ranks = dict(zip(neighbours, [i for i in range(len(neighbours))])) neighbours_sorted = [word] + self.bias_by_projection_sort(neighbours) colors = generate_palette(colors[0], colors[1], n) self.visualize(neighbours_sorted, ranks, title, figsize, dpi, colors, fontsize, annotate)