import json import tempfile from dataclasses import dataclass from pathlib import Path import random import concurrent.futures from ...math.distribution import create_distribution import numpy as np import structlog import whisper from flair.data import Sentence from flair.models import TextClassifier from ...math.average import np_moving_average from ..common import find_moving_average_highlights from ...mediautils.video import filter_moments @dataclass class TextGlob: start: float stop: float text: str sentiment: float class SentimentEditor: def __init__(self, video_path, audio_path, params): self.logger = structlog.get_logger("sentiment") tempdir = tempfile.gettempdir() dest_location = f"{tempdir}/{params['temp_file_name']}-{params['model_size']}-sentiment.json" if not Path(dest_location).is_file(): self.logger.info("loading whisper model", size=params["model_size"]) self.model = whisper.load_model(params["model_size"]) self.logger.info("transcribing audio", path=audio_path) self.result = self.model.transcribe(audio_path) with open(dest_location, "w") as fp: json.dump(self.result, fp) else: self.logger.info("cached transcription found", path=dest_location) with open(dest_location, "r") as f: self.result = json.load(f) self.segments = [] for segment in self.result["segments"]: self.segments.append( TextGlob(segment["start"], segment["end"], segment["text"], 0) ) classifier = TextClassifier.load("en-sentiment") self.sentiments = [] self.logger.info( "calculating sentiment on segments", segments=len(self.segments) ) for segment in self.segments: sentence = Sentence(segment.text) classifier.predict(sentence) sentsum = sentence.labels[0].score if sentence.labels[0].value == "NEGATIVE": sentsum = sentsum * -1 segment.sentiment = sentsum self.sentiments.append(sentsum) self.sentiments = np.array(self.sentiments) def edit(self, large_window, small_window, params): end_time = self.segments[-1].stop window_factor = len(self.sentiments) / end_time long_ma = np_moving_average(self.sentiments, large_window) short_ma = np_moving_average(self.sentiments, small_window) highlights = find_moving_average_highlights( short_ma, long_ma, 1.0 / window_factor ) return highlights, large_window, small_window def full_edit(self, costfunc, desired_time, params): desired = desired_time # Generate center of large window and small window size large_window_center = random.uniform(30, 50) small_window_center = random.uniform(5, 15) # The spread multiplier, or epsilon, slowly decays as we approach the center of the gradient spread_multiplier = random.uniform(0.15, 0.18) # The decay rate, or how quickly our spread multiplier decreases as we approach the center of the gradient spread_decay = random.uniform(0.000001, 0.0001) parallelism = params['parallelism'] # The main loop of the program starts here # we first create distributions # use workers to simultanously create many possible edits # find the best edit of the lot -> this is determined by lowest "cost" # if the best fits within our desitred time range, output, otherwise # reset the distributions using the best as the new center, then repeat # Create distribution of large and small complete = False iterations = 0 while not complete: large_distribution = create_distribution( large_window_center, spread_multiplier, parallelism ) np.random.shuffle(large_distribution) small_distribution = create_distribution( small_window_center, spread_multiplier, parallelism ) np.random.shuffle(small_distribution) # Fire off workers to generate edits moment_results = [] with concurrent.futures.ThreadPoolExecutor() as executor: futures = [] pairs = list(zip(large_distribution, small_distribution)) for pair in pairs: futures.append( executor.submit( self.edit, pair[0] if pair[0] > pair[1] else pair[1], pair[1] if pair[0] > pair[1] else pair[0], params, ) ) failed = None for future in concurrent.futures.as_completed(futures): try: moment_results.append(list(future.result())) except Exception as e: self.logger.exception("error during editing", error=e) failed = e if failed is not None: raise failed costs = [] durations = [] for result in moment_results: total_duration = 0 result[0] = filter_moments(result[0], params['mindur'], params['maxdur']) for moment in result[0]: total_duration = total_duration + moment.get_duration() costs.append(costfunc(desired, total_duration)) durations.append(total_duration) index_min = min(range(len(costs)), key=costs.__getitem__) large_window_center = moment_results[index_min][1] small_window_center = moment_results[index_min][2] self.logger.info( "batch complete", best_large=large_window_center, best_small=small_window_center, duration=durations[index_min], ) if ( durations[index_min] > desired * 0.95 and desired * 1.05 > durations[index_min] ): return moment_results[index_min][0] iterations = iterations + parallelism if iterations > 50000: self.logger.warn( "could not find a viable edit in the target duration, try other params", target=desired, ) return [] spread_multiplier = spread_multiplier - spread_decay if spread_multiplier < 0: self.logger.warn("spread reached 0, resetting") large_window_center = random.uniform(30, 50) small_window_center = random.uniform(5, 15) spread_multiplier = random.uniform(0.15, 0.18) spread_decay = random.uniform(0.0001, 0.001)