YOLO 실습 코드 : run_transt_custom.py > 자료실

본문 바로가기

YOLO 실습 코드 : run_transt_custom.py

필기자
2025-06-26 17:54 11 0

본문


import sys, os, cv2, numpy as np                # 로그 및 환경설정 패키지
sys.path.append(os.path.abspath("TransT"))      # 환경변수 경로 지정

#TransT 모델 모듈 가져오기
from pytracking.parameter.transt.transt50 import parameters as transt50
from pytracking.tracker.transt import TransT

#영상 처리 + 배열 연산을 위한 패키지
video_path = 'test_obsticle_video.mp4'          # 영상 데이터 경로와 파일이름
start_frame = 29                                # 트래킹을 시작할 프레임 지정

# OpenCV로 영상 로드
cap = cv2.VideoCapture(video_path)
cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
ret, frame = cap.read()
if not ret:
    raise RuntimeError(f"{start_frame}번째 프레임 로딩 실패")

# 마우스 커서로 객체 선택 (바운딩박스 지정)
print("추적할 객체를 마우스로 모두 드래그 후 Enter 누르세요.")
# 마우스로 최초 좌표 지정, 지정 완료 후 스페이스바 or 엔터, 좌표지정 종료 : ESC
bboxes = cv2.selectROIs("Select Objects", frame, fromCenter=False, showCrosshair=True)
cv2.destroyAllWindows()
if len(bboxes) == 0:
    raise RuntimeError("선택된 객체가 없습니다.")

# 라벨링한 객체의 클래스 ID 입력
class_ids = [int(input(f"클래스 ID 입력 (박스 {i}): ")) for i in range(len(bboxes))]

# 트래킹 세팅
trackers = []
for box, cid in zip(bboxes, class_ids):
    tracker = TransT(transt50())
    tracker.initialize(frame, {'init_bbox': np.array(box)})
    tracker.class_id, tracker.bbox_last = cid, box
    trackers.append(tracker)

# 라벨링 된 파일 저장 폴더 생성
os.makedirs("results/images", exist_ok=True)
os.makedirs("results/labels", exist_ok=True)
frame_id = start_frame

# 연산 시작 루프
while cap.isOpened():
    ret, frame = cap.read() # 영상 프레임 읽어오기
    if not ret: break
    frame_copy = frame.copy()
    H, W = frame.shape[:2] # 프레임 높이,넓이 읽기
    labels = []
    for i, trk in enumerate(trackers): # 루프 시작 : 루프횟수 = 지정한 객체 수
        try:
            bbox = trk.track(frame)['target_bbox'] # 모델 추론, 바운딩박스 정보 획득
            trk.bbox_last = bbox
        except:
            bbox = trk.bbox_last
        x, y, w, h = map(int, bbox) # 바운딩 박스 정보를 x, y, w, h로 나누어 각각 저장
        xc, yc, wn, hn = (x + w/2)/W, (y + h/2)/H, w/W, h/H # 좌표값 YOLO 라벨링 방식으로 변환
        labels.append(f"{trk.class_id} {xc:.6f} {yc:.6f} {wn:.6f} {hn:.6f}")

        # 출력 영상에 좌표에 따른 바운딩박스 + 라벨번호 그리기
        cv2.rectangle(frame_copy, (x, y), (x + w, y + h), (0, 255, 0), 2)
        cv2.putText(frame_copy, f"ID {i}", (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255,255,255), 1)

    # 프레임 단위 이미지 + 주석 파일 저장
    cv2.imwrite(f"results/images/frame_{frame_id:04d}.jpg", frame)
    with open(f"results/labels/frame_{frame_id:04d}.txt", "w") as f:
        f.write("\n".join(labels))
    print(f"[{frame_id:04d}] 저장 - 객체 {len(labels)}개")
   
    # 트래킹 된 프레임 출력
    cv2.imshow("Tracking", frame_copy)
    if cv2.waitKey(1) == 27: break # 프로그램 종료 : ESC
    frame_id += 1

cap.release()
cv2.destroyAllWindows()
print("전체 추적 완료")

댓글목록0

등록된 댓글이 없습니다.
게시판 전체검색