ale/src/editors/sentiment/editor.py
2024-02-01 21:49:08 -05:00

179 lines
7.0 KiB
Python

import concurrent.futures
import json
import random
import tempfile
from dataclasses import dataclass
from pathlib import Path
import time
import numpy as np
import structlog
import whisper
from flair.data import Sentence
from flair.models import TextClassifier
from ...math.average import np_moving_average
from ...math.distribution import create_distribution
from ...mediautils.video import filter_moments
from ..common import find_moving_average_highlights
@dataclass
class TextGlob:
start: float
stop: float
text: str
sentiment: float
class SentimentEditor:
def __init__(self, video_path, audio_path, params):
self.logger = structlog.get_logger("sentiment")
tempdir = tempfile.gettempdir()
dest_location = f"{tempdir}/{params['temp_file_name']}-{params['model_size']}-sentiment.json"
if not Path(dest_location).is_file():
self.logger.info("loading whisper model", size=params["model_size"])
self.model = whisper.load_model(params["model_size"])
self.logger.info("transcribing audio", path=audio_path)
self.result = self.model.transcribe(audio_path)
with open(dest_location, "w") as fp:
json.dump(self.result, fp)
else:
self.logger.info("cached transcription found", path=dest_location)
with open(dest_location, "r") as f:
self.result = json.load(f)
self.segments = []
for segment in self.result["segments"]:
self.segments.append(
TextGlob(segment["start"], segment["end"], segment["text"], 0)
)
classifier = TextClassifier.load("en-sentiment")
self.sentiments = []
self.logger.info(
"calculating sentiment on segments", segments=len(self.segments)
)
for segment in self.segments:
sentence = Sentence(segment.text)
classifier.predict(sentence)
sentsum = sentence.labels[0].score
if sentence.labels[0].value == "NEGATIVE":
sentsum = sentsum * -1
segment.sentiment = sentsum
self.sentiments.append(sentsum)
self.sentiments = np.array(self.sentiments)
def edit(self, large_window, small_window, params):
end_time = self.segments[-1].stop
window_factor = len(self.sentiments) / end_time
long_ma = np_moving_average(self.sentiments, large_window)
short_ma = np_moving_average(self.sentiments, small_window)
highlights = find_moving_average_highlights(
short_ma, long_ma, 1.0 / window_factor
)
return highlights, large_window, small_window
def full_edit(self, costfunc, desired_time, params):
desired = desired_time
# Generate center of large window and small window size
large_window_center = random.uniform(30, 50)
small_window_center = random.uniform(5, 15)
# The spread multiplier, or epsilon, slowly decays as we approach the center of the gradient
spread_multiplier = random.uniform(0.15, 0.18)
# The decay rate, or how quickly our spread multiplier decreases as we approach the center of the gradient
spread_decay = random.uniform(0.000001, 0.0001)
parallelism = params["parallelism"]
# The main loop of the program starts here
# we first create distributions
# use workers to simultanously create many possible edits
# find the best edit of the lot -> this is determined by lowest "cost"
# if the best fits within our desitred time range, output, otherwise
# reset the distributions using the best as the new center, then repeat
# Create distribution of large and small
complete = False
iterations = 0
while not complete:
large_distribution = create_distribution(
large_window_center, spread_multiplier, parallelism
)
np.random.shuffle(large_distribution)
small_distribution = create_distribution(
small_window_center, spread_multiplier, parallelism
)
np.random.shuffle(small_distribution)
# Fire off workers to generate edits
moment_results = []
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = []
pairs = list(zip(large_distribution, small_distribution))
for pair in pairs:
futures.append(
executor.submit(
self.edit,
pair[0] if pair[0] > pair[1] else pair[1],
pair[1] if pair[0] > pair[1] else pair[0],
params,
)
)
failed = None
for future in concurrent.futures.as_completed(futures):
try:
moment_results.append(list(future.result()))
except Exception as e:
self.logger.exception("error during editing", error=e)
failed = e
if failed is not None:
raise failed
costs = []
durations = []
for result in moment_results:
total_duration = 0
result[0] = filter_moments(
result[0], params["mindur"], params["maxdur"]
)
for moment in result[0]:
total_duration = total_duration + moment.get_duration()
costs.append(costfunc(desired, total_duration))
durations.append(total_duration)
index_min = min(range(len(costs)), key=costs.__getitem__)
large_window_center = moment_results[index_min][1]
small_window_center = moment_results[index_min][2]
self.logger.info(
"batch complete",
best_large=large_window_center,
best_small=small_window_center,
duration=durations[index_min],
sp_multiplier=spread_multiplier,
)
if (
durations[index_min] > desired * 0.95
and desired * 1.05 > durations[index_min]
):
return moment_results[index_min][0]
iterations = iterations + parallelism
if iterations > 50000:
self.logger.warn(
"could not find a viable edit in the target duration, try other params",
target=desired,
)
return []
spread_multiplier = spread_multiplier - spread_decay
if spread_multiplier < 0:
self.logger.warn("spread reached 0, resetting")
time.sleep(3)
large_window_center = random.uniform(30, 50)
small_window_center = random.uniform(5, 15)
spread_multiplier = random.uniform(0.15, 0.18)
spread_decay = random.uniform(0.0001, 0.001)
iterations = int(iterations / 2)