from scipy.fft import fft, ifft, fftfreq
import numpy as np
import sys
import matplotlib.pyplot as plt

n  = 5000
t = np.linspace(0, 1, n-1)
ff = np.cos(2*np.pi*t)
ff2 = np.cos(4*np.pi*t)
ff3 = ff + ff2

nn = np.arange(n-1)
per = 1
w   = 2.0 * np.pi / per
Delta_t = t[1] - t[0]

ff_F = fft(ff)
amp_F = abs(ff_F)/n
frq_F = 2.0 * np.pi * (nn) / (n*Delta_t)
frq_F= frq_F / w

ff_F2 = fft(ff2)
amp_F2 = abs(ff_F2)/n
frq_F2 = 2.0 * np.pi * (nn) / (n*Delta_t)
frq_F2 = frq_F2 / w

ff_F3 = fft(ff3)

ff_F3[1]  = ff_F3[1]  - ff_F[1]
ff_F3[-1] = ff_F3[-1] - ff_F[-1]
amp_F3 = abs(ff_F3)/n
frq_F3 = 2.0 * np.pi * (nn) / (n*Delta_t)
frq_F3 = frq_F3 / w




fig = plt.figure(figsize=(10, 5))
ax = fig.add_subplot(111)
ax.plot(frq_F, amp_F, label='Amplitude Spectrum of f')
ax.plot(frq_F2, amp_F2, label='Amplitude Spectrum of f')
ax.plot(frq_F3, amp_F3, label='Amplitude Spectrum of f')
plt.xlabel('Frequency (normalized)')
plt.ylabel('Amplitude')
plt.legend()
plt.yscale('log')
plt.grid()
# plt.xlim([0, 50])  # Limit the frequency range for better visualization
# plt.show()

ff1_t = ifft(ff_F)
ff2_t = ifft(ff_F2)
ff3_t = ifft(ff_F3)

fig = plt.figure(figsize=(10, 5))
ax = fig.add_subplot(111)
ax.plot(t, ff1_t, label='f')
ax.plot(t, ff2_t, label='f2')
ax.plot(t, ff3_t, label='f3')
plt.xlabel('Time')
plt.ylabel('f')
plt.legend()
plt.grid()
plt.show()