179 lines
7.0 KiB
Python
179 lines
7.0 KiB
Python
import concurrent.futures
|
|
import json
|
|
import random
|
|
import tempfile
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
import time
|
|
|
|
import numpy as np
|
|
import structlog
|
|
import whisper
|
|
from flair.data import Sentence
|
|
from flair.models import TextClassifier
|
|
|
|
from ...math.average import np_moving_average
|
|
from ...math.distribution import create_distribution
|
|
from ...mediautils.video import filter_moments
|
|
from ..common import find_moving_average_highlights
|
|
|
|
|
|
@dataclass
|
|
class TextGlob:
|
|
start: float
|
|
stop: float
|
|
text: str
|
|
sentiment: float
|
|
|
|
|
|
class SentimentEditor:
|
|
def __init__(self, video_path, audio_path, params):
|
|
self.logger = structlog.get_logger("sentiment")
|
|
tempdir = tempfile.gettempdir()
|
|
dest_location = f"{tempdir}/{params['temp_file_name']}-{params['model_size']}-sentiment.json"
|
|
if not Path(dest_location).is_file():
|
|
self.logger.info("loading whisper model", size=params["model_size"])
|
|
self.model = whisper.load_model(params["model_size"])
|
|
self.logger.info("transcribing audio", path=audio_path)
|
|
self.result = self.model.transcribe(audio_path)
|
|
|
|
with open(dest_location, "w") as fp:
|
|
json.dump(self.result, fp)
|
|
else:
|
|
self.logger.info("cached transcription found", path=dest_location)
|
|
|
|
with open(dest_location, "r") as f:
|
|
self.result = json.load(f)
|
|
|
|
self.segments = []
|
|
for segment in self.result["segments"]:
|
|
self.segments.append(
|
|
TextGlob(segment["start"], segment["end"], segment["text"], 0)
|
|
)
|
|
classifier = TextClassifier.load("en-sentiment")
|
|
self.sentiments = []
|
|
self.logger.info(
|
|
"calculating sentiment on segments", segments=len(self.segments)
|
|
)
|
|
for segment in self.segments:
|
|
sentence = Sentence(segment.text)
|
|
classifier.predict(sentence)
|
|
sentsum = sentence.labels[0].score
|
|
if sentence.labels[0].value == "NEGATIVE":
|
|
sentsum = sentsum * -1
|
|
segment.sentiment = sentsum
|
|
self.sentiments.append(sentsum)
|
|
self.sentiments = np.array(self.sentiments)
|
|
|
|
def edit(self, large_window, small_window, params):
|
|
end_time = self.segments[-1].stop
|
|
window_factor = len(self.sentiments) / end_time
|
|
long_ma = np_moving_average(self.sentiments, large_window)
|
|
short_ma = np_moving_average(self.sentiments, small_window)
|
|
highlights = find_moving_average_highlights(
|
|
short_ma, long_ma, 1.0 / window_factor
|
|
)
|
|
return highlights, large_window, small_window
|
|
|
|
def full_edit(self, costfunc, desired_time, params):
|
|
desired = desired_time
|
|
|
|
# Generate center of large window and small window size
|
|
large_window_center = random.uniform(30, 50)
|
|
small_window_center = random.uniform(5, 15)
|
|
|
|
# The spread multiplier, or epsilon, slowly decays as we approach the center of the gradient
|
|
spread_multiplier = random.uniform(0.15, 0.18)
|
|
|
|
# The decay rate, or how quickly our spread multiplier decreases as we approach the center of the gradient
|
|
spread_decay = random.uniform(0.000001, 0.0001)
|
|
|
|
parallelism = params["parallelism"]
|
|
|
|
# The main loop of the program starts here
|
|
# we first create distributions
|
|
# use workers to simultanously create many possible edits
|
|
# find the best edit of the lot -> this is determined by lowest "cost"
|
|
# if the best fits within our desitred time range, output, otherwise
|
|
# reset the distributions using the best as the new center, then repeat
|
|
# Create distribution of large and small
|
|
|
|
complete = False
|
|
iterations = 0
|
|
while not complete:
|
|
large_distribution = create_distribution(
|
|
large_window_center, spread_multiplier, parallelism
|
|
)
|
|
np.random.shuffle(large_distribution)
|
|
small_distribution = create_distribution(
|
|
small_window_center, spread_multiplier, parallelism
|
|
)
|
|
np.random.shuffle(small_distribution)
|
|
|
|
# Fire off workers to generate edits
|
|
moment_results = []
|
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
futures = []
|
|
pairs = list(zip(large_distribution, small_distribution))
|
|
for pair in pairs:
|
|
futures.append(
|
|
executor.submit(
|
|
self.edit,
|
|
pair[0] if pair[0] > pair[1] else pair[1],
|
|
pair[1] if pair[0] > pair[1] else pair[0],
|
|
params,
|
|
)
|
|
)
|
|
failed = None
|
|
for future in concurrent.futures.as_completed(futures):
|
|
try:
|
|
moment_results.append(list(future.result()))
|
|
except Exception as e:
|
|
self.logger.exception("error during editing", error=e)
|
|
failed = e
|
|
if failed is not None:
|
|
raise failed
|
|
costs = []
|
|
durations = []
|
|
for result in moment_results:
|
|
total_duration = 0
|
|
result[0] = filter_moments(
|
|
result[0], params["mindur"], params["maxdur"]
|
|
)
|
|
for moment in result[0]:
|
|
total_duration = total_duration + moment.get_duration()
|
|
costs.append(costfunc(desired, total_duration))
|
|
durations.append(total_duration)
|
|
index_min = min(range(len(costs)), key=costs.__getitem__)
|
|
large_window_center = moment_results[index_min][1]
|
|
small_window_center = moment_results[index_min][2]
|
|
self.logger.info(
|
|
"batch complete",
|
|
best_large=large_window_center,
|
|
best_small=small_window_center,
|
|
duration=durations[index_min],
|
|
sp_multiplier=spread_multiplier,
|
|
)
|
|
if (
|
|
durations[index_min] > desired * 0.95
|
|
and desired * 1.05 > durations[index_min]
|
|
):
|
|
return moment_results[index_min][0]
|
|
|
|
iterations = iterations + parallelism
|
|
if iterations > 50000:
|
|
self.logger.warn(
|
|
"could not find a viable edit in the target duration, try other params",
|
|
target=desired,
|
|
)
|
|
return []
|
|
spread_multiplier = spread_multiplier - spread_decay
|
|
if spread_multiplier < 0:
|
|
self.logger.warn("spread reached 0, resetting")
|
|
time.sleep(3)
|
|
large_window_center = random.uniform(30, 50)
|
|
small_window_center = random.uniform(5, 15)
|
|
spread_multiplier = random.uniform(0.15, 0.18)
|
|
spread_decay = random.uniform(0.0001, 0.001)
|
|
iterations = int(iterations / 2)
|