more cleanup

This commit is contained in:
Tanishq Dubey 2023-02-16 19:34:13 -05:00
parent 12bba40f0a
commit b49b05c4d3
3 changed files with 34 additions and 20 deletions

22
main.py
View File

@ -12,7 +12,7 @@ import time
import numpy as np import numpy as np
from src.mediautils.audio import extract_audio_from_video from src.mediautils.audio import extract_audio_from_video
from src.mediautils.video import render_moments from src.mediautils.video import render_moments, filter_moments
from src.editors.amplitude.editor import AmplitudeEditor from src.editors.amplitude.editor import AmplitudeEditor
from src.editors.sentiment.editor import SentimentEditor from src.editors.sentiment.editor import SentimentEditor
from src.math.cost import quadratic_loss from src.math.cost import quadratic_loss
@ -58,10 +58,12 @@ def main(args):
log.info("using cached audio file", cache_path=audio_path) log.info("using cached audio file", cache_path=audio_path)
else: else:
log.info("extracted audio", cache_path=audio_path) log.info("extracted audio", cache_path=audio_path)
params = vars(args)
params["temp_file_name"] = temp_file_name
# Initalize Editor # Initalize Editor
log.info("initializing editor", editor=args.editor) log.info("initializing editor", editor=args.editor)
editor = EDITORS[args.editor](str(in_vid_path.resolve()), audio_path, vars(args)) editor = EDITORS[args.editor](str(in_vid_path.resolve()), audio_path, params)
log.info("initialized editor", editor=args.editor) log.info("initialized editor", editor=args.editor)
costfunc = ERROR_FUNCS[args.cost] costfunc = ERROR_FUNCS[args.cost]
desired = args.duration desired = args.duration
@ -74,7 +76,7 @@ def main(args):
spread_multiplier = random.uniform(0.15, 0.18) spread_multiplier = random.uniform(0.15, 0.18)
# The decay rate, or how quickly our spread multiplier decreases as we approach the center of the gradient # The decay rate, or how quickly our spread multiplier decreases as we approach the center of the gradient
spread_decay = random.uniform(0.0001, 0.001) spread_decay = random.uniform(0.000001, 0.0001)
parallelism = args.parallelism parallelism = args.parallelism
@ -89,7 +91,6 @@ def main(args):
complete = False complete = False
iterations = 0 iterations = 0
while not complete: while not complete:
log.info("creating distributions", large_start=large_window_center, small_start=small_window_center, spread=spread_multiplier, decay=spread_decay)
large_distribution = create_distribution(large_window_center, spread_multiplier, parallelism) large_distribution = create_distribution(large_window_center, spread_multiplier, parallelism)
np.random.shuffle(large_distribution) np.random.shuffle(large_distribution)
small_distribution = create_distribution(small_window_center, spread_multiplier, parallelism) small_distribution = create_distribution(small_window_center, spread_multiplier, parallelism)
@ -104,14 +105,14 @@ def main(args):
futures.append( futures.append(
executor.submit( executor.submit(
editor.edit, editor.edit,
pair[0], pair[0] if pair[0] > pair[1] else pair[1],
pair[1], pair[1] if pair[0] > pair[1] else pair[0],
vars(args) vars(args)
) )
) )
for future in concurrent.futures.as_completed(futures): for future in concurrent.futures.as_completed(futures):
try: try:
moment_results.append(future.result()) moment_results.append(list(future.result()))
except Exception: except Exception:
log.exception("error during editing") log.exception("error during editing")
sys.exit(-2) sys.exit(-2)
@ -120,6 +121,7 @@ def main(args):
durations = [] durations = []
for result in moment_results: for result in moment_results:
total_duration = 0 total_duration = 0
result[0] = filter_moments(result[0], args.mindur, args.maxdur)
for moment in result[0]: for moment in result[0]:
total_duration = total_duration + moment.get_duration() total_duration = total_duration + moment.get_duration()
costs.append(costfunc(desired, total_duration)) costs.append(costfunc(desired, total_duration))
@ -147,12 +149,6 @@ def main(args):
small_window_center = random.uniform(5, 15) small_window_center = random.uniform(5, 15)
spread_multiplier = random.uniform(0.15, 0.18) spread_multiplier = random.uniform(0.15, 0.18)
spread_decay = random.uniform(0.0001, 0.001) spread_decay = random.uniform(0.0001, 0.001)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,4 +1,7 @@
import whisper import whisper
import json
from pathlib import Path
import tempfile
import numpy as np import numpy as np
import structlog import structlog
@ -19,10 +22,22 @@ class TextGlob:
class SentimentEditor: class SentimentEditor:
def __init__(self, video_path, audio_path, params): def __init__(self, video_path, audio_path, params):
self.logger = structlog.get_logger("sentiment") self.logger = structlog.get_logger("sentiment")
self.logger.info("loading whisper model", size=params["model_size"]) tempdir = tempfile.gettempdir()
self.model = whisper.load_model(params["model_size"]) dest_location = f"{tempdir}/{params['temp_file_name']}-{params['model_size']}-sentiment.json"
self.logger.info("transcribing audio", path=audio_path) if not Path(dest_location).is_file():
self.result = self.model.transcribe(audio_path) self.logger.info("loading whisper model", size=params["model_size"])
self.model = whisper.load_model(params["model_size"])
self.logger.info("transcribing audio", path=audio_path)
self.result = self.model.transcribe(audio_path)
with open(dest_location, 'w') as fp:
json.dump(self.result, fp)
else:
self.logger.info("cached transcription found", path=dest_location)
with open(dest_location, 'r') as f:
self.result = json.load(f)
self.segments = [] self.segments = []
for segment in self.result['segments']: for segment in self.result['segments']:
self.segments.append(TextGlob(segment['start'], segment['end'], segment['text'], 0)) self.segments.append(TextGlob(segment['start'], segment['end'], segment['text'], 0))
@ -42,7 +57,7 @@ class SentimentEditor:
def edit(self, large_window, small_window, params): def edit(self, large_window, small_window, params):
end_time = self.segments[-1].stop end_time = self.segments[-1].stop
window_factor = len(self.sentiments) / end_time window_factor = len(self.sentiments) / end_time
long_ma = np_moving_average(self.squared_subsample, large_window * window_factor) long_ma = np_moving_average(self.sentiments, large_window)
short_ma = np_moving_average(self.squared_subsample, small_window * window_factor) short_ma = np_moving_average(self.sentiments, small_window)
highlights = find_moving_average_highlights(short_ma, long_ma, 1.0 / window_factor) highlights = find_moving_average_highlights(short_ma, long_ma, 1.0 / window_factor)
return highlights, large_window, small_window return highlights, large_window, small_window

View File

@ -9,5 +9,8 @@ def get_subclips(source_video_path, moments):
def render_moments(moments, input_video_path, output_path): def render_moments(moments, input_video_path, output_path):
clips, vid = get_subclips(input_video_path, moments) clips, vid = get_subclips(input_video_path, moments)
to_render = mp.concatenate_videoclips(clips, logger=None) to_render = mp.concatenate_videoclips(clips)
to_render.write_videofile(output_path, logger=None) to_render.write_videofile(output_path, logger=None)
def filter_moments(moments, min_length, max_length):
return [m for m in moments if m.get_duration() > min_length and m.get_duration() < max_length]