def classify(query, processed_schema):
query_embed = get_embeddings([query])
for k, v in processed_schema.items():
items.append(torch.dist(item, query_embed).item())
# use the harmonic mean bc if there is a strong match it should bias the entire group
# if using a normal mean, even if there is a dead match, the average will be only somewhat affected
# when using the harmonic mean, if one example is a perfect match then the entire class will rank as a perfect match
hmean = harmonic_mean(torch.tensor(items))
logits = hmean[None, ...]
logits = torch.cat((logits, hmean[None, ...]), dim=0)
# doing the inverse bc you want dist to be low (close together in that high dimensional space)
logits = logits.max() - logits
# softmax to normalize as a probability distribution
probs = torch.nn.functional.softmax(logits, dim=0)
for i, c in enumerate(classes):
classification[c] = probs[i].contiguous()