diff --git a/src/process/twitter_visualization.py b/src/process/twitter_visualization.py index 5ca6211..ce8dde8 100644 --- a/src/process/twitter_visualization.py +++ b/src/process/twitter_visualization.py @@ -1,6 +1,8 @@ """ TODO: Module Docstring """ +import statistics + from matplotlib import pyplot as plt from tabulate import tabulate @@ -81,8 +83,8 @@ def view_covid_tweets_pop(users: list[ProcessedUser], if len(covid) == 0 or len(tweets) == 0: continue # Get the average popularity for COVID-related tweets - covid_avg = sum(t.popularity for t in covid) / len(covid) - global_avg = sum(t.popularity for t in tweets) / len(tweets) + covid_avg = statistics.mean(t.popularity for t in covid) + global_avg = statistics.mean(t.popularity for t in tweets) # Get the relative popularity user_popularity.append((u.username, covid_avg / global_avg)) @@ -99,10 +101,25 @@ def view_covid_tweets_pop(users: list[ProcessedUser], print(f"20 Users of whose COVID-related posts are the most popular:") print(tabulate([[u[0], f'{u[1]:.2f}'] for u in user_popularity[:20]], ['Username', 'Popularity Ratio'])) + print() + + # Calculate statistics + x_list = [f[1] for f in user_popularity] + s = get_statistics(x_list) + print(f'With outliers, ') + print(f'- mean: {s.mean:.2f}, median: {s.median:.2f}, stddev: {s.stddev:.2f}') + print() # Remove outliers print('As there are many outliers in the popularity ratio, they are removed in graphing.') - x_list = remove_outliers([f[1] for f in user_popularity]) + print() + x_list = remove_outliers(x_list) + + # Calculate statistics without outliers + s = get_statistics(x_list) + print(f'Without outliers, ') + print(f'- mean: {s.mean:.2f}, median: {s.median:.2f}, stddev: {s.stddev:.2f}') + print() # Graph histogram plt.title(f'COVID-related popularity ratios for {sample_name}') diff --git a/src/utils.py b/src/utils.py index 1c50dec..31934d3 100644 --- a/src/utils.py +++ b/src/utils.py @@ -2,10 +2,11 @@ import dataclasses import inspect import json import os +import statistics from dataclasses import dataclass from datetime import datetime, date from pathlib import Path -from typing import Union +from typing import Union, NamedTuple import json5 import numpy as np @@ -124,6 +125,22 @@ def remove_outliers(points: list[float], z_threshold: float = 3.5) -> list[float return [points[v] for v in range(len(x)) if not is_outlier[v]] +class Stats(NamedTuple): + mean: float + median: float + stddev: float + + +def get_statistics(points: list[float]) -> Stats: + """ + Calculate statistics for a set of points + + :param points: Input points + :return: Statistics + """ + return Stats(statistics.mean(points), statistics.median(points), statistics.stdev(points)) + + class EnhancedJSONEncoder(json.JSONEncoder): def default(self, o):