fusion.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. from qdrant_client.http import models
  2. def reciprocal_rank_fusion(
  3. responses: list[list[models.ScoredPoint]], limit: int = 10
  4. ) -> list[models.ScoredPoint]:
  5. def compute_score(pos: int) -> float:
  6. ranking_constant = (
  7. 2 # the constant mitigates the impact of high rankings by outlier systems
  8. )
  9. return 1 / (ranking_constant + pos)
  10. scores: dict[models.ExtendedPointId, float] = {}
  11. point_pile = {}
  12. for response in responses:
  13. for i, scored_point in enumerate(response):
  14. if scored_point.id in scores:
  15. scores[scored_point.id] += compute_score(i)
  16. else:
  17. point_pile[scored_point.id] = scored_point
  18. scores[scored_point.id] = compute_score(i)
  19. sorted_scores = sorted(scores.items(), key=lambda item: item[1], reverse=True)
  20. sorted_points = []
  21. for point_id, score in sorted_scores[:limit]:
  22. point = point_pile[point_id]
  23. point.score = score
  24. sorted_points.append(point)
  25. return sorted_points
  26. def distribution_based_score_fusion(
  27. responses: list[list[models.ScoredPoint]], limit: int
  28. ) -> list[models.ScoredPoint]:
  29. def normalize(response: list[models.ScoredPoint]) -> list[models.ScoredPoint]:
  30. if len(response) == 1:
  31. response[0].score = 0.5
  32. return response
  33. total = sum([point.score for point in response])
  34. mean = total / len(response)
  35. variance = sum([(point.score - mean) ** 2 for point in response]) / (len(response) - 1)
  36. if variance == 0:
  37. for point in response:
  38. point.score = 0.5
  39. return response
  40. std_dev = variance**0.5
  41. low = mean - 3 * std_dev
  42. high = mean + 3 * std_dev
  43. for point in response:
  44. point.score = (point.score - low) / (high - low)
  45. return response
  46. points_map: dict[models.ExtendedPointId, models.ScoredPoint] = {}
  47. for response in responses:
  48. if not response:
  49. continue
  50. normalized = normalize(response)
  51. for point in normalized:
  52. entry = points_map.get(point.id)
  53. if entry is None:
  54. points_map[point.id] = point
  55. else:
  56. entry.score += point.score
  57. sorted_points = sorted(points_map.values(), key=lambda item: item.score, reverse=True)
  58. return sorted_points[:limit]