more cleanup
This commit is contained in:
parent
12bba40f0a
commit
b49b05c4d3
22
main.py
22
main.py
@ -12,7 +12,7 @@ import time
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from src.mediautils.audio import extract_audio_from_video
|
from src.mediautils.audio import extract_audio_from_video
|
||||||
from src.mediautils.video import render_moments
|
from src.mediautils.video import render_moments, filter_moments
|
||||||
from src.editors.amplitude.editor import AmplitudeEditor
|
from src.editors.amplitude.editor import AmplitudeEditor
|
||||||
from src.editors.sentiment.editor import SentimentEditor
|
from src.editors.sentiment.editor import SentimentEditor
|
||||||
from src.math.cost import quadratic_loss
|
from src.math.cost import quadratic_loss
|
||||||
@ -58,10 +58,12 @@ def main(args):
|
|||||||
log.info("using cached audio file", cache_path=audio_path)
|
log.info("using cached audio file", cache_path=audio_path)
|
||||||
else:
|
else:
|
||||||
log.info("extracted audio", cache_path=audio_path)
|
log.info("extracted audio", cache_path=audio_path)
|
||||||
|
params = vars(args)
|
||||||
|
params["temp_file_name"] = temp_file_name
|
||||||
|
|
||||||
# Initalize Editor
|
# Initalize Editor
|
||||||
log.info("initializing editor", editor=args.editor)
|
log.info("initializing editor", editor=args.editor)
|
||||||
editor = EDITORS[args.editor](str(in_vid_path.resolve()), audio_path, vars(args))
|
editor = EDITORS[args.editor](str(in_vid_path.resolve()), audio_path, params)
|
||||||
log.info("initialized editor", editor=args.editor)
|
log.info("initialized editor", editor=args.editor)
|
||||||
costfunc = ERROR_FUNCS[args.cost]
|
costfunc = ERROR_FUNCS[args.cost]
|
||||||
desired = args.duration
|
desired = args.duration
|
||||||
@ -74,7 +76,7 @@ def main(args):
|
|||||||
spread_multiplier = random.uniform(0.15, 0.18)
|
spread_multiplier = random.uniform(0.15, 0.18)
|
||||||
|
|
||||||
# The decay rate, or how quickly our spread multiplier decreases as we approach the center of the gradient
|
# The decay rate, or how quickly our spread multiplier decreases as we approach the center of the gradient
|
||||||
spread_decay = random.uniform(0.0001, 0.001)
|
spread_decay = random.uniform(0.000001, 0.0001)
|
||||||
|
|
||||||
parallelism = args.parallelism
|
parallelism = args.parallelism
|
||||||
|
|
||||||
@ -89,7 +91,6 @@ def main(args):
|
|||||||
complete = False
|
complete = False
|
||||||
iterations = 0
|
iterations = 0
|
||||||
while not complete:
|
while not complete:
|
||||||
log.info("creating distributions", large_start=large_window_center, small_start=small_window_center, spread=spread_multiplier, decay=spread_decay)
|
|
||||||
large_distribution = create_distribution(large_window_center, spread_multiplier, parallelism)
|
large_distribution = create_distribution(large_window_center, spread_multiplier, parallelism)
|
||||||
np.random.shuffle(large_distribution)
|
np.random.shuffle(large_distribution)
|
||||||
small_distribution = create_distribution(small_window_center, spread_multiplier, parallelism)
|
small_distribution = create_distribution(small_window_center, spread_multiplier, parallelism)
|
||||||
@ -104,14 +105,14 @@ def main(args):
|
|||||||
futures.append(
|
futures.append(
|
||||||
executor.submit(
|
executor.submit(
|
||||||
editor.edit,
|
editor.edit,
|
||||||
pair[0],
|
pair[0] if pair[0] > pair[1] else pair[1],
|
||||||
pair[1],
|
pair[1] if pair[0] > pair[1] else pair[0],
|
||||||
vars(args)
|
vars(args)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
for future in concurrent.futures.as_completed(futures):
|
for future in concurrent.futures.as_completed(futures):
|
||||||
try:
|
try:
|
||||||
moment_results.append(future.result())
|
moment_results.append(list(future.result()))
|
||||||
except Exception:
|
except Exception:
|
||||||
log.exception("error during editing")
|
log.exception("error during editing")
|
||||||
sys.exit(-2)
|
sys.exit(-2)
|
||||||
@ -120,6 +121,7 @@ def main(args):
|
|||||||
durations = []
|
durations = []
|
||||||
for result in moment_results:
|
for result in moment_results:
|
||||||
total_duration = 0
|
total_duration = 0
|
||||||
|
result[0] = filter_moments(result[0], args.mindur, args.maxdur)
|
||||||
for moment in result[0]:
|
for moment in result[0]:
|
||||||
total_duration = total_duration + moment.get_duration()
|
total_duration = total_duration + moment.get_duration()
|
||||||
costs.append(costfunc(desired, total_duration))
|
costs.append(costfunc(desired, total_duration))
|
||||||
@ -149,12 +151,6 @@ def main(args):
|
|||||||
spread_decay = random.uniform(0.0001, 0.001)
|
spread_decay = random.uniform(0.0001, 0.001)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
prog="ALE", description="ALE: Automatic Linear Editor.",
|
prog="ALE", description="ALE: Automatic Linear Editor.",
|
||||||
|
@ -1,4 +1,7 @@
|
|||||||
import whisper
|
import whisper
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
import tempfile
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import structlog
|
import structlog
|
||||||
|
|
||||||
@ -19,10 +22,22 @@ class TextGlob:
|
|||||||
class SentimentEditor:
|
class SentimentEditor:
|
||||||
def __init__(self, video_path, audio_path, params):
|
def __init__(self, video_path, audio_path, params):
|
||||||
self.logger = structlog.get_logger("sentiment")
|
self.logger = structlog.get_logger("sentiment")
|
||||||
|
tempdir = tempfile.gettempdir()
|
||||||
|
dest_location = f"{tempdir}/{params['temp_file_name']}-{params['model_size']}-sentiment.json"
|
||||||
|
if not Path(dest_location).is_file():
|
||||||
self.logger.info("loading whisper model", size=params["model_size"])
|
self.logger.info("loading whisper model", size=params["model_size"])
|
||||||
self.model = whisper.load_model(params["model_size"])
|
self.model = whisper.load_model(params["model_size"])
|
||||||
self.logger.info("transcribing audio", path=audio_path)
|
self.logger.info("transcribing audio", path=audio_path)
|
||||||
self.result = self.model.transcribe(audio_path)
|
self.result = self.model.transcribe(audio_path)
|
||||||
|
|
||||||
|
with open(dest_location, 'w') as fp:
|
||||||
|
json.dump(self.result, fp)
|
||||||
|
else:
|
||||||
|
self.logger.info("cached transcription found", path=dest_location)
|
||||||
|
|
||||||
|
with open(dest_location, 'r') as f:
|
||||||
|
self.result = json.load(f)
|
||||||
|
|
||||||
self.segments = []
|
self.segments = []
|
||||||
for segment in self.result['segments']:
|
for segment in self.result['segments']:
|
||||||
self.segments.append(TextGlob(segment['start'], segment['end'], segment['text'], 0))
|
self.segments.append(TextGlob(segment['start'], segment['end'], segment['text'], 0))
|
||||||
@ -42,7 +57,7 @@ class SentimentEditor:
|
|||||||
def edit(self, large_window, small_window, params):
|
def edit(self, large_window, small_window, params):
|
||||||
end_time = self.segments[-1].stop
|
end_time = self.segments[-1].stop
|
||||||
window_factor = len(self.sentiments) / end_time
|
window_factor = len(self.sentiments) / end_time
|
||||||
long_ma = np_moving_average(self.squared_subsample, large_window * window_factor)
|
long_ma = np_moving_average(self.sentiments, large_window)
|
||||||
short_ma = np_moving_average(self.squared_subsample, small_window * window_factor)
|
short_ma = np_moving_average(self.sentiments, small_window)
|
||||||
highlights = find_moving_average_highlights(short_ma, long_ma, 1.0 / window_factor)
|
highlights = find_moving_average_highlights(short_ma, long_ma, 1.0 / window_factor)
|
||||||
return highlights, large_window, small_window
|
return highlights, large_window, small_window
|
||||||
|
@ -9,5 +9,8 @@ def get_subclips(source_video_path, moments):
|
|||||||
|
|
||||||
def render_moments(moments, input_video_path, output_path):
|
def render_moments(moments, input_video_path, output_path):
|
||||||
clips, vid = get_subclips(input_video_path, moments)
|
clips, vid = get_subclips(input_video_path, moments)
|
||||||
to_render = mp.concatenate_videoclips(clips, logger=None)
|
to_render = mp.concatenate_videoclips(clips)
|
||||||
to_render.write_videofile(output_path, logger=None)
|
to_render.write_videofile(output_path, logger=None)
|
||||||
|
|
||||||
|
def filter_moments(moments, min_length, max_length):
|
||||||
|
return [m for m in moments if m.get_duration() > min_length and m.get_duration() < max_length]
|
||||||
|
Loading…
Reference in New Issue
Block a user