import concurrent.futures import json import random import tempfile from dataclasses import dataclass from pathlib import Path import time import numpy as np import structlog import whisper from flair.data import Sentence from flair.models import TextClassifier from ...math.average import np_moving_average from ...math.distribution import create_distribution from ...mediautils.video import filter_moments from ..common import find_moving_average_highlights @dataclass class TextGlob: start: float stop: float text: str sentiment: float class SentimentEditor: def __init__(self, video_path, audio_path, params): self.logger = structlog.get_logger("sentiment") tempdir = tempfile.gettempdir() dest_location = f"{tempdir}/{params['temp_file_name']}-{params['model_size']}-sentiment.json" if not Path(dest_location).is_file(): self.logger.info("loading whisper model", size=params["model_size"]) self.model = whisper.load_model(params["model_size"]) self.logger.info("transcribing audio", path=audio_path) self.result = self.model.transcribe(audio_path) with open(dest_location, "w") as fp: json.dump(self.result, fp) else: self.logger.info("cached transcription found", path=dest_location) with open(dest_location, "r") as f: self.result = json.load(f) self.segments = [] for segment in self.result["segments"]: self.segments.append( TextGlob(segment["start"], segment["end"], segment["text"], 0) ) classifier = TextClassifier.load("en-sentiment") self.sentiments = [] self.logger.info( "calculating sentiment on segments", segments=len(self.segments) ) for segment in self.segments: sentence = Sentence(segment.text) classifier.predict(sentence) sentsum = sentence.labels[0].score if sentence.labels[0].value == "NEGATIVE": sentsum = sentsum * -1 segment.sentiment = sentsum self.sentiments.append(sentsum) self.sentiments = np.array(self.sentiments) def edit(self, large_window, small_window, params): end_time = self.segments[-1].stop window_factor = len(self.sentiments) / end_time long_ma = np_moving_average(self.sentiments, large_window) short_ma = np_moving_average(self.sentiments, small_window) highlights = find_moving_average_highlights( short_ma, long_ma, 1.0 / window_factor ) return highlights, large_window, small_window def full_edit(self, costfunc, desired_time, params): desired = desired_time # Generate center of large window and small window size large_window_center = random.uniform(30, 50) small_window_center = random.uniform(5, 15) # The spread multiplier, or epsilon, slowly decays as we approach the center of the gradient spread_multiplier = random.uniform(0.15, 0.18) # The decay rate, or how quickly our spread multiplier decreases as we approach the center of the gradient spread_decay = random.uniform(0.000001, 0.0001) parallelism = params["parallelism"] # The main loop of the program starts here # we first create distributions # use workers to simultanously create many possible edits # find the best edit of the lot -> this is determined by lowest "cost" # if the best fits within our desitred time range, output, otherwise # reset the distributions using the best as the new center, then repeat # Create distribution of large and small complete = False iterations = 0 while not complete: large_distribution = create_distribution( large_window_center, spread_multiplier, parallelism ) np.random.shuffle(large_distribution) small_distribution = create_distribution( small_window_center, spread_multiplier, parallelism ) np.random.shuffle(small_distribution) # Fire off workers to generate edits moment_results = [] with concurrent.futures.ThreadPoolExecutor() as executor: futures = [] pairs = list(zip(large_distribution, small_distribution)) for pair in pairs: futures.append( executor.submit( self.edit, pair[0] if pair[0] > pair[1] else pair[1], pair[1] if pair[0] > pair[1] else pair[0], params, ) ) failed = None for future in concurrent.futures.as_completed(futures): try: moment_results.append(list(future.result())) except Exception as e: self.logger.exception("error during editing", error=e) failed = e if failed is not None: raise failed costs = [] durations = [] for result in moment_results: total_duration = 0 result[0] = filter_moments( result[0], params["mindur"], params["maxdur"] ) for moment in result[0]: total_duration = total_duration + moment.get_duration() costs.append(costfunc(desired, total_duration)) durations.append(total_duration) index_min = min(range(len(costs)), key=costs.__getitem__) large_window_center = moment_results[index_min][1] small_window_center = moment_results[index_min][2] self.logger.info( "batch complete", best_large=large_window_center, best_small=small_window_center, duration=durations[index_min], sp_multiplier=spread_multiplier, ) if ( durations[index_min] > desired * 0.95 and desired * 1.05 > durations[index_min] ): return moment_results[index_min][0] iterations = iterations + parallelism if iterations > 50000: self.logger.warn( "could not find a viable edit in the target duration, try other params", target=desired, ) return [] spread_multiplier = spread_multiplier - spread_decay if spread_multiplier < 0: self.logger.warn("spread reached 0, resetting") time.sleep(3) large_window_center = random.uniform(30, 50) small_window_center = random.uniform(5, 15) spread_multiplier = random.uniform(0.15, 0.18) spread_decay = random.uniform(0.0001, 0.001) iterations = int(iterations / 2)